import asyncio
import datetime
import typing
from typing import Any, AsyncIterator, Dict, List, Tuple
from airflow import AirflowException
from airflow.models import DagRun, TaskInstance
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.utils.session import provide_session
from asgiref.sync import sync_to_async
from sqlalchemy import func
from sqlalchemy.orm import Session
from astronomer.providers.http.triggers.http import HttpTrigger
[docs]class TaskStateTrigger(BaseTrigger):
"""
Waits asynchronously for a task in a different DAG to complete for a
specific logical date.
:param dag_id: The dag_id that contains the task you want to wait for
:param task_id: The task_id that contains the task you want to
wait for. If ``None`` (default value) the sensor waits for the DAG
:param states: allowed states, default is ``['success']``
:param execution_dates:
:param poll_interval: The time interval in seconds to check the state.
The default value is 5 sec.
"""
def __init__(
self,
dag_id: str,
task_id: str,
states: List[str],
execution_dates: List[datetime.datetime],
poll_interval: float = 5.0,
):
super().__init__()
self.dag_id = dag_id
self.task_id = task_id
self.states = states
self.execution_dates = execution_dates
self.poll_interval = poll_interval
[docs] def serialize(self) -> Tuple[str, Dict[str, Any]]:
"""Serializes TaskStateTrigger arguments and classpath."""
return (
"astronomer.providers.core.triggers.external_task.TaskStateTrigger",
{
"dag_id": self.dag_id,
"task_id": self.task_id,
"states": self.states,
"execution_dates": self.execution_dates,
"poll_interval": self.poll_interval,
},
)
[docs] async def run(self) -> typing.AsyncIterator["TriggerEvent"]:
"""
Checks periodically in the database to see if the task exists, and has
hit one of the states yet, or not.
"""
while True:
num_tasks = await self.count_tasks()
if num_tasks == len(self.execution_dates):
yield TriggerEvent(True)
await asyncio.sleep(self.poll_interval)
@sync_to_async
@provide_session
def count_tasks(self, session: Session) -> typing.Optional[int]:
"""Count how many task instances in the database match our criteria."""
count = (
session.query(func.count()) # .count() is inefficient
.filter(
TaskInstance.dag_id == self.dag_id,
TaskInstance.task_id == self.task_id,
TaskInstance.state.in_(self.states),
TaskInstance.execution_date.in_(self.execution_dates),
)
.scalar()
)
return typing.cast(int, count)
[docs]class DagStateTrigger(BaseTrigger):
"""
Waits asynchronously for a task in a different DAG to complete for a
specific logical date.
:param dag_id: The dag_id that contains the task you want to wait for
:param task_id: The task_id that contains the task you want to
wait for. If ``None`` (default value) the sensor waits for the DAG
:param states: allowed states, default is ``['success']``
:param execution_dates: The logical date at which DAG run.
:param poll_interval: The time interval in seconds to check the state.
The default value is 5.0 sec.
"""
def __init__(
self,
dag_id: str,
states: List[str],
execution_dates: List[datetime.datetime],
poll_interval: float = 5.0,
):
super().__init__()
self.dag_id = dag_id
self.states = states
self.execution_dates = execution_dates
self.poll_interval = poll_interval
[docs] def serialize(self) -> Tuple[str, Dict[str, Any]]:
"""Serializes DagStateTrigger arguments and classpath."""
return (
"astronomer.providers.core.triggers.external_task.DagStateTrigger",
{
"dag_id": self.dag_id,
"states": self.states,
"execution_dates": self.execution_dates,
"poll_interval": self.poll_interval,
},
)
[docs] async def run(self) -> typing.AsyncIterator["TriggerEvent"]:
"""
Checks periodically in the database to see if the dag run exists, and has
hit one of the states yet, or not.
"""
while True:
num_dags = await self.count_dags()
if num_dags == len(self.execution_dates):
yield TriggerEvent(True)
await asyncio.sleep(self.poll_interval)
@sync_to_async
@provide_session
def count_dags(self, session: Session) -> typing.Optional[int]:
"""Count how many dag runs in the database match our criteria."""
count = (
session.query(func.count()) # .count() is inefficient
.filter(
DagRun.dag_id == self.dag_id,
DagRun.state.in_(self.states),
DagRun.execution_date.in_(self.execution_dates),
)
.scalar()
)
return typing.cast(int, count)
[docs]class ExternalDeploymentTaskTrigger(HttpTrigger):
"""ExternalDeploymentTaskTrigger Inherits from HttpTrigger and make Async http call to get the deployment state"""
[docs] def serialize(self) -> Tuple[str, Dict[str, Any]]:
"""Serializes ExternalDeploymentTaskTrigger arguments and classpath."""
return (
"astronomer.providers.core.triggers.external_task.ExternalDeploymentTaskTrigger",
{
"endpoint": self.endpoint,
"data": self.data,
"headers": self.headers,
"extra_options": self.extra_options,
"http_conn_id": self.http_conn_id,
"poke_interval": self.poke_interval,
},
)
[docs] async def run(self) -> AsyncIterator["TriggerEvent"]:
"""
Makes a series of asynchronous http calls via an http hook poll for state of the job
run until it reaches a failure state or success state. It yields a Trigger if response state is successful.
"""
from airflow.utils.state import State
hook = self._get_async_hook()
while True:
try:
response = await hook.run(
endpoint=self.endpoint,
data=self.data,
headers=self.headers,
extra_options=self.extra_options,
)
resp_json = await response.json()
if resp_json["state"] in State.finished:
yield TriggerEvent(resp_json)
else:
await asyncio.sleep(self.poke_interval)
except AirflowException as exc:
if str(exc).startswith("404"):
await asyncio.sleep(self.poke_interval)
yield TriggerEvent({"state": "error", "message": str(exc)})