Source code for astronomer.providers.databricks.triggers.databricks

import asyncio
from typing import Any, AsyncIterator, Dict, Optional, Tuple

from airflow.triggers.base import BaseTrigger, TriggerEvent

from astronomer.providers.databricks.hooks.databricks import DatabricksHookAsync


[docs]class DatabricksTrigger(BaseTrigger): """ Wait asynchronously for databricks job to reach the terminal state. :param conn_id: The databricks connection id. The default value is ``databricks_default``. :param task_id: The task id. :param run_id: The databricks job run id. :param retry_limit: Amount of times retry if the Databricks backend is unreachable. Its value must be greater than or equal to 1. :param retry_delay: Number of seconds to wait between retries (it might be a floating point number). :param polling_period_seconds: Controls the rate which we poll for the result of this run. By default, the operator will poll every 30 seconds. :param job_id: The databricks job id. :param run_page_url: The databricks run page url. """ def __init__( self, conn_id: str, task_id: str, run_id: str, retry_limit: int, retry_delay: int, polling_period_seconds: int, job_id: Optional[int] = None, run_page_url: Optional[str] = None, ): super().__init__() self.conn_id = conn_id self.task_id = task_id self.run_id = run_id self.job_id = job_id self.run_page_url = run_page_url self.retry_limit = retry_limit self.retry_delay = retry_delay self.polling_period_seconds = polling_period_seconds
[docs] def serialize(self) -> Tuple[str, Dict[str, Any]]: """Serializes DatabricksTrigger arguments and classpath.""" return ( "astronomer.providers.databricks.triggers.databricks.DatabricksTrigger", { "conn_id": self.conn_id, "task_id": self.task_id, "run_id": self.run_id, "job_id": self.job_id, "run_page_url": self.run_page_url, "retry_limit": self.retry_limit, "retry_delay": self.retry_delay, "polling_period_seconds": self.polling_period_seconds, }, )
[docs] async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override] """ Makes a series of asynchronous http calls via a Databrick hook. It yields a Trigger if response is a 200 and run_state is successful, will retry the call up to the retry limit if the error is 'retryable', otherwise it throws an exception. """ hook = self._get_async_hook() while True: run_state = await hook.get_run_state_async(self.run_id) if run_state.is_terminal: if run_state.is_successful: yield TriggerEvent( { "status": "success", "job_id": self.job_id, "run_id": self.run_id, "run_page_url": self.run_page_url, } ) else: error_message = f"{self.task_id} failed with terminal state: {run_state}" yield TriggerEvent({"status": "error", "message": str(error_message)}) else: self.log.info("%s in run state: %s", self.task_id, run_state) self.log.info("Sleeping for %s seconds.", self.polling_period_seconds) await asyncio.sleep(self.polling_period_seconds)
def _get_async_hook(self) -> DatabricksHookAsync: return DatabricksHookAsync( self.conn_id, retry_limit=self.retry_limit, retry_delay=self.retry_delay, )