Source code for astronomer.providers.core.sensors.astro

from __future__ import annotations

import datetime
from typing import Any, cast

from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.sensors.base import BaseSensorOperator, PokeReturnValue

from astronomer.providers.core.hooks.astro import AstroHook
from astronomer.providers.core.triggers.astro import AstroDeploymentTrigger
from astronomer.providers.utils.typing_compat import Context


[docs] class ExternalDeploymentSensor(BaseSensorOperator): """ Custom Apache Airflow sensor for monitoring external deployments using Astro Cloud. :param external_dag_id: External ID of the DAG being monitored. :param astro_cloud_conn_id: The connection ID to retrieve Astro Cloud credentials. Defaults to "astro_cloud_default". :param external_task_id: External ID of the task being monitored. If None, monitors the entire DAG. :param kwargs: Additional keyword arguments passed to the BaseSensorOperator constructor. """ def __init__( self, external_dag_id: str, astro_cloud_conn_id: str = "astro_cloud_default", external_task_id: str | None = None, **kwargs: Any, ): super().__init__(**kwargs) self.astro_cloud_conn_id = astro_cloud_conn_id self.external_task_id = external_task_id self.external_dag_id = external_dag_id self._dag_run_id: str = ""
[docs] def poke(self, context: Context) -> bool | PokeReturnValue: """ Check the status of a DAG/task in another deployment. Queries Airflow's REST API for the status of the specified DAG or task instance. Returns True if successful, False otherwise. :param context: The task execution context. """ hook = AstroHook(self.astro_cloud_conn_id) dag_runs: list[dict[str, Any]] = hook.get_dag_runs(self.external_dag_id) if not dag_runs: self.log.info("No DAG runs found for DAG %s", self.external_dag_id) return True self._dag_run_id = cast(str, dag_runs[0]["dag_run_id"]) if self.external_task_id is not None: task_instance = hook.get_task_instance( self.external_dag_id, self._dag_run_id, self.external_task_id ) task_state = task_instance.get("state") if task_instance else None if task_state == "success": return True else: state = dag_runs[0].get("state") if state == "success": return True return False
[docs] def execute(self, context: Context) -> Any: """ Executes the sensor. If the external deployment is not successful, it defers the execution using an AstroDeploymentTrigger. :param context: The task execution context. """ if not self.poke(context): self.defer( timeout=datetime.timedelta(seconds=self.timeout), trigger=AstroDeploymentTrigger( astro_cloud_conn_id=self.astro_cloud_conn_id, external_task_id=self.external_task_id, external_dag_id=self.external_dag_id, poke_interval=self.poke_interval, dag_run_id=self._dag_run_id, ), method_name="execute_complete", )
[docs] def execute_complete(self, context: Context, event: dict[str, str]) -> None: """ Handles the completion event from the deferred execution. Raises AirflowSkipException if the upstream job failed and `soft_fail` is True. Otherwise, raises AirflowException. :param context: The task execution context. :param event: The event dictionary received from the deferred execution. """ if event.get("status") == "failed": if self.soft_fail: raise AirflowSkipException("Upstream job failed. Skipping the task.") raise AirflowException("Upstream job failed.")