Source code for astronomer.providers.core.triggers.astro

from __future__ import annotations

import asyncio
from typing import Any, AsyncIterator

from airflow.triggers.base import BaseTrigger, TriggerEvent

from astronomer.providers.core.hooks.astro import AstroHook


[docs] class AstroDeploymentTrigger(BaseTrigger): """ Custom Apache Airflow trigger for monitoring the completion status of an external deployment using Astro Cloud. :param external_dag_id: External ID of the DAG being monitored. :param dag_run_id: ID of the DAG run being monitored. :param external_task_id: External ID of the task being monitored. If None, monitors the entire DAG. :param astro_cloud_conn_id: The connection ID to retrieve Astro Cloud credentials. Defaults to "astro_cloud_default". :param poke_interval: Time in seconds to wait between consecutive checks for completion status. :param kwargs: Additional keyword arguments passed to the BaseTrigger constructor. """ def __init__( self, external_dag_id: str, dag_run_id: str, external_task_id: str | None = None, astro_cloud_conn_id: str = "astro_cloud_default", poke_interval: float = 5.0, **kwargs: Any, ): super().__init__(**kwargs) self.external_dag_id = external_dag_id self.dag_run_id = dag_run_id self.external_task_id = external_task_id self.astro_cloud_conn_id = astro_cloud_conn_id self.poke_interval = poke_interval
[docs] def serialize(self) -> tuple[str, dict[str, Any]]: """Serialize the trigger for storage in the database.""" return ( "astronomer.providers.core.triggers.astro.AstroDeploymentTrigger", { "external_dag_id": self.external_dag_id, "external_task_id": self.external_task_id, "dag_run_id": self.dag_run_id, "astro_cloud_conn_id": self.astro_cloud_conn_id, "poke_interval": self.poke_interval, }, )
[docs] async def run(self) -> AsyncIterator[TriggerEvent]: """ Asynchronously runs the trigger and yields completion status events. Checks the status of the external deployment using Astro Cloud at regular intervals. Yields TriggerEvent with the status "done" if successful, "failed" if failed. """ hook = AstroHook(self.astro_cloud_conn_id) while True: if self.external_task_id is not None: task_instance = await hook.get_a_task_instance( self.external_dag_id, self.dag_run_id, self.external_task_id ) state = task_instance.get("state") if task_instance else None if state in ("success", "skipped"): yield TriggerEvent({"status": "done"}) elif state in ("failed", "upstream_failed"): yield TriggerEvent({"status": "failed"}) else: dag_run = await hook.get_a_dag_run(self.external_dag_id, self.dag_run_id) state = dag_run.get("state") if dag_run else None if state == "success": yield TriggerEvent({"status": "done"}) elif state == "failed": yield TriggerEvent({"status": "failed"}) self.log.info("Job status is %s sleeping for %s seconds.", state, self.poke_interval) await asyncio.sleep(self.poke_interval)