Source code for astronomer.providers.core.sensors.external_task

import datetime
from typing import TYPE_CHECKING, Any, Dict, List, Optional

from airflow.exceptions import AirflowException
from airflow.sensors.external_task import ExternalTaskSensor
from airflow.utils.context import Context
from airflow.utils.session import provide_session

from astronomer.providers.core.triggers.external_task import (
    DagStateTrigger,
    TaskStateTrigger,
)

if TYPE_CHECKING:
    from sqlalchemy.orm.session import Session


[docs]class ExternalTaskSensorAsync(ExternalTaskSensor): # noqa: D101
[docs] def execute(self, context: Context) -> None: """Correctly identify which trigger to execute, and defer execution as expected.""" execution_dates = self.get_execution_dates(context) # Work out if we are a DAG sensor or a Task sensor # Defer to our trigger if not self.external_task_id: # Tempting to explicitly check for None, but this captures falsy values self.defer( timeout=self.execution_timeout, trigger=DagStateTrigger( dag_id=self.external_dag_id, # The trigger does not do pass/fail, only "a state was reached", # so we pass it all states that might make us pass or fail, and # then work out which result we have in execute_complete. states=self.allowed_states + self.failed_states, execution_dates=execution_dates, ), method_name="execute_complete", ) else: self.defer( timeout=self.execution_timeout, trigger=TaskStateTrigger( dag_id=self.external_dag_id, task_id=self.external_task_id, states=self.allowed_states + self.failed_states, execution_dates=execution_dates, ), method_name="execute_complete", )
[docs] @provide_session def execute_complete( self, context: Context, session: "Session", event: Optional[Dict[str, Any]] = None ) -> None: """Verifies that there is a success status for each task via execution date.""" execution_dates = self.get_execution_dates(context) count_allowed = self.get_count(execution_dates, session, self.allowed_states) if count_allowed != len(execution_dates): if self.external_task_id: raise AirflowException( f"The external task {self.external_task_id} in DAG {self.external_dag_id} failed." ) else: raise AirflowException(f"The external DAG {self.external_dag_id} failed.") return None
[docs] def get_execution_dates(self, context: Context) -> List[datetime.datetime]: """Helper function to set execution dates depending on which context and/or internal fields are populated.""" if self.execution_delta: execution_date = context["execution_date"] - self.execution_delta elif self.execution_date_fn: execution_date = self._handle_execution_date_fn(context=context) else: execution_date = context["execution_date"] execution_dates = execution_date if isinstance(execution_date, list) else [execution_date] return execution_dates