import asyncio
import datetime
import typing
import warnings
from typing import Any, AsyncIterator, Dict, List, Tuple
from airflow import AirflowException
from airflow.models import DagRun, TaskInstance
from airflow.providers.http.hooks.http import HttpHook
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.utils.session import provide_session
from asgiref.sync import sync_to_async
from sqlalchemy import func
from sqlalchemy.orm import Session
from astronomer.providers.http.triggers.http import HttpTrigger
[docs]
class TaskStateTrigger(BaseTrigger):
"""
Waits asynchronously for a task in a different DAG to complete for a
specific logical date.
:param dag_id: The dag_id that contains the task you want to wait for
:param task_id: The task_id that contains the task you want to
wait for. If ``None`` (default value) the sensor waits for the DAG
:param states: allowed states, default is ``['success']``
:param execution_dates:
:param poll_interval: The time interval in seconds to check the state.
The default value is 5 sec.
"""
def __init__(
self,
dag_id: str,
task_id: str,
states: List[str],
execution_dates: List[datetime.datetime],
poll_interval: float = 5.0,
):
super().__init__()
self.dag_id = dag_id
self.task_id = task_id
self.states = states
self.execution_dates = execution_dates
self.poll_interval = poll_interval
[docs]
def serialize(self) -> Tuple[str, Dict[str, Any]]:
"""Serializes TaskStateTrigger arguments and classpath."""
return (
"astronomer.providers.core.triggers.external_task.TaskStateTrigger",
{
"dag_id": self.dag_id,
"task_id": self.task_id,
"states": self.states,
"execution_dates": self.execution_dates,
"poll_interval": self.poll_interval,
},
)
[docs]
async def run(self) -> typing.AsyncIterator["TriggerEvent"]:
"""
Checks periodically in the database to see if the task exists, and has
hit one of the states yet, or not.
"""
while True:
# mypy confuses typing here
num_tasks = await self.count_tasks() # type: ignore[call-arg]
if num_tasks == len(self.execution_dates):
yield TriggerEvent(True)
await asyncio.sleep(self.poll_interval)
[docs]
@sync_to_async
@provide_session
def count_tasks(self, session: Session) -> typing.Optional[int]:
"""Count how many task instances in the database match our criteria."""
count = (
session.query(func.count()) # .count() is inefficient
.filter(
TaskInstance.dag_id == self.dag_id,
TaskInstance.task_id == self.task_id,
TaskInstance.state.in_(self.states),
TaskInstance.execution_date.in_(self.execution_dates),
)
.scalar()
)
return typing.cast(int, count)
[docs]
class DagStateTrigger(BaseTrigger):
"""
Waits asynchronously for a different DAG to complete for a
specific logical date.
:param dag_id: The dag_id that contains the task you want to wait for
:param states: allowed states, default is ``['success']``
:param execution_dates: The logical date at which DAG run.
:param poll_interval: The time interval in seconds to check the state.
The default value is 5.0 sec.
"""
def __init__(
self,
dag_id: str,
states: List[str],
execution_dates: List[datetime.datetime],
poll_interval: float = 5.0,
):
warnings.warn(
(
"This module is deprecated and will be removed in airflow>=2.9.0"
"Please use `airflow.triggers.external_task.WorkflowTrigger` "
"and set deferrable to True instead."
),
DeprecationWarning,
stacklevel=2,
)
super().__init__()
self.dag_id = dag_id
self.states = states
self.execution_dates = execution_dates
self.poll_interval = poll_interval
[docs]
def serialize(self) -> Tuple[str, Dict[str, Any]]:
"""Serializes DagStateTrigger arguments and classpath."""
return (
"astronomer.providers.core.triggers.external_task.DagStateTrigger",
{
"dag_id": self.dag_id,
"states": self.states,
"execution_dates": self.execution_dates,
"poll_interval": self.poll_interval,
},
)
[docs]
async def run(self) -> typing.AsyncIterator["TriggerEvent"]:
"""
Checks periodically in the database to see if the dag run exists, and has
hit one of the states yet, or not.
"""
while True:
# mypy confuses typing here
num_dags = await self.count_dags() # type: ignore[call-arg]
if num_dags == len(self.execution_dates):
yield TriggerEvent(True)
await asyncio.sleep(self.poll_interval)
[docs]
@sync_to_async
@provide_session
def count_dags(self, session: Session) -> typing.Optional[int]:
"""Count how many dag runs in the database match our criteria."""
count = (
session.query(func.count()) # .count() is inefficient
.filter(
DagRun.dag_id == self.dag_id,
DagRun.state.in_(self.states),
DagRun.execution_date.in_(self.execution_dates),
)
.scalar()
)
return typing.cast(int, count)
[docs]
class ExternalDeploymentTaskTrigger(HttpTrigger):
"""ExternalDeploymentTaskTrigger Inherits from HttpTrigger and make Async http call to get the deployment state"""
[docs]
def serialize(self) -> Tuple[str, Dict[str, Any]]:
"""Serializes ExternalDeploymentTaskTrigger arguments and classpath."""
return (
"astronomer.providers.core.triggers.external_task.ExternalDeploymentTaskTrigger",
{
"endpoint": self.endpoint,
"data": self.data,
"headers": self.headers,
"extra_options": self.extra_options,
"http_conn_id": self.http_conn_id,
"poke_interval": self.poke_interval,
},
)
[docs]
async def run(self) -> AsyncIterator["TriggerEvent"]:
"""
Makes a series of http calls via an http hook poll for state of the job
run until it reaches a failure state or success state. It yields a Trigger if response state is successful.
"""
from airflow.utils.state import State
hook = HttpHook(method="GET", http_conn_id=self.http_conn_id)
while True:
try:
response = hook.run(
endpoint=self.endpoint,
data=self.data,
headers=self.headers,
extra_options=self.extra_options,
)
resp_json = response.json()
if resp_json["state"] in State.finished:
yield TriggerEvent(resp_json)
return
self.log.info(
"The current status is %s. Sleeping for %s seconds",
resp_json.get("state"),
self.poke_interval,
)
await asyncio.sleep(self.poke_interval)
except AirflowException as exc:
self.log.info("An error occur while calling API %s", str(exc))
if str(exc).startswith("404"):
await asyncio.sleep(self.poke_interval)
yield TriggerEvent({"state": "error", "message": str(exc)})
return