Source code for astronomer.providers.core.triggers.external_task

import asyncio
import datetime
import typing
import warnings
from typing import Any, AsyncIterator, Dict, List, Tuple

from airflow import AirflowException
from airflow.models import DagRun, TaskInstance
from airflow.providers.http.hooks.http import HttpHook
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.utils.session import provide_session
from asgiref.sync import sync_to_async
from sqlalchemy import func
from sqlalchemy.orm import Session

from astronomer.providers.http.triggers.http import HttpTrigger


[docs] class TaskStateTrigger(BaseTrigger): """ Waits asynchronously for a task in a different DAG to complete for a specific logical date. :param dag_id: The dag_id that contains the task you want to wait for :param task_id: The task_id that contains the task you want to wait for. If ``None`` (default value) the sensor waits for the DAG :param states: allowed states, default is ``['success']`` :param execution_dates: :param poll_interval: The time interval in seconds to check the state. The default value is 5 sec. """ def __init__( self, dag_id: str, task_id: str, states: List[str], execution_dates: List[datetime.datetime], poll_interval: float = 5.0, ): super().__init__() self.dag_id = dag_id self.task_id = task_id self.states = states self.execution_dates = execution_dates self.poll_interval = poll_interval
[docs] def serialize(self) -> Tuple[str, Dict[str, Any]]: """Serializes TaskStateTrigger arguments and classpath.""" return ( "astronomer.providers.core.triggers.external_task.TaskStateTrigger", { "dag_id": self.dag_id, "task_id": self.task_id, "states": self.states, "execution_dates": self.execution_dates, "poll_interval": self.poll_interval, }, )
[docs] async def run(self) -> typing.AsyncIterator["TriggerEvent"]: """ Checks periodically in the database to see if the task exists, and has hit one of the states yet, or not. """ while True: # mypy confuses typing here num_tasks = await self.count_tasks() # type: ignore[call-arg] if num_tasks == len(self.execution_dates): yield TriggerEvent(True) await asyncio.sleep(self.poll_interval)
[docs] @sync_to_async @provide_session def count_tasks(self, session: Session) -> typing.Optional[int]: """Count how many task instances in the database match our criteria.""" count = ( session.query(func.count()) # .count() is inefficient .filter( TaskInstance.dag_id == self.dag_id, TaskInstance.task_id == self.task_id, TaskInstance.state.in_(self.states), TaskInstance.execution_date.in_(self.execution_dates), ) .scalar() ) return typing.cast(int, count)
[docs] class DagStateTrigger(BaseTrigger): """ Waits asynchronously for a different DAG to complete for a specific logical date. :param dag_id: The dag_id that contains the task you want to wait for :param states: allowed states, default is ``['success']`` :param execution_dates: The logical date at which DAG run. :param poll_interval: The time interval in seconds to check the state. The default value is 5.0 sec. """ def __init__( self, dag_id: str, states: List[str], execution_dates: List[datetime.datetime], poll_interval: float = 5.0, ): warnings.warn( ( "This module is deprecated and will be removed in airflow>=2.9.0" "Please use `airflow.triggers.external_task.WorkflowTrigger` " "and set deferrable to True instead." ), DeprecationWarning, stacklevel=2, ) super().__init__() self.dag_id = dag_id self.states = states self.execution_dates = execution_dates self.poll_interval = poll_interval
[docs] def serialize(self) -> Tuple[str, Dict[str, Any]]: """Serializes DagStateTrigger arguments and classpath.""" return ( "astronomer.providers.core.triggers.external_task.DagStateTrigger", { "dag_id": self.dag_id, "states": self.states, "execution_dates": self.execution_dates, "poll_interval": self.poll_interval, }, )
[docs] async def run(self) -> typing.AsyncIterator["TriggerEvent"]: """ Checks periodically in the database to see if the dag run exists, and has hit one of the states yet, or not. """ while True: # mypy confuses typing here num_dags = await self.count_dags() # type: ignore[call-arg] if num_dags == len(self.execution_dates): yield TriggerEvent(True) await asyncio.sleep(self.poll_interval)
[docs] @sync_to_async @provide_session def count_dags(self, session: Session) -> typing.Optional[int]: """Count how many dag runs in the database match our criteria.""" count = ( session.query(func.count()) # .count() is inefficient .filter( DagRun.dag_id == self.dag_id, DagRun.state.in_(self.states), DagRun.execution_date.in_(self.execution_dates), ) .scalar() ) return typing.cast(int, count)
[docs] class ExternalDeploymentTaskTrigger(HttpTrigger): """ExternalDeploymentTaskTrigger Inherits from HttpTrigger and make Async http call to get the deployment state"""
[docs] def serialize(self) -> Tuple[str, Dict[str, Any]]: """Serializes ExternalDeploymentTaskTrigger arguments and classpath.""" return ( "astronomer.providers.core.triggers.external_task.ExternalDeploymentTaskTrigger", { "endpoint": self.endpoint, "data": self.data, "headers": self.headers, "extra_options": self.extra_options, "http_conn_id": self.http_conn_id, "poke_interval": self.poke_interval, }, )
[docs] async def run(self) -> AsyncIterator["TriggerEvent"]: """ Makes a series of http calls via an http hook poll for state of the job run until it reaches a failure state or success state. It yields a Trigger if response state is successful. """ from airflow.utils.state import State hook = HttpHook(method="GET", http_conn_id=self.http_conn_id) while True: try: response = hook.run( endpoint=self.endpoint, data=self.data, headers=self.headers, extra_options=self.extra_options, ) resp_json = response.json() if resp_json["state"] in State.finished: yield TriggerEvent(resp_json) return self.log.info( "The current status is %s. Sleeping for %s seconds", resp_json.get("state"), self.poke_interval, ) await asyncio.sleep(self.poke_interval) except AirflowException as exc: self.log.info("An error occur while calling API %s", str(exc)) if str(exc).startswith("404"): await asyncio.sleep(self.poke_interval) yield TriggerEvent({"state": "error", "message": str(exc)}) return