Source code for astronomer.providers.databricks.triggers.databricks

import asyncio
import warnings
from typing import Any, AsyncIterator, Dict, Optional, Tuple

from airflow.providers.databricks.hooks.databricks import RunState
from airflow.triggers.base import BaseTrigger, TriggerEvent

from astronomer.providers.databricks.hooks.databricks import DatabricksHookAsync


[docs] class DatabricksTrigger(BaseTrigger): """ This class is deprecated and will be removed in 2.0.0. Use :class: `~airflow.providers.databricks.triggers.databricks.DatabricksExecutionTrigger` instead. """ def __init__( self, conn_id: str, task_id: str, run_id: str, retry_limit: int, retry_delay: int, polling_period_seconds: int, job_id: Optional[int] = None, run_page_url: Optional[str] = None, ): warnings.warn( "This class is deprecated and will be removed in 2.0.0." "Use `airflow.providers.databricks.triggers.databricks.DatabricksExecutionTrigger` instead." ) super().__init__() self.conn_id = conn_id self.task_id = task_id self.run_id = run_id self.job_id = job_id self.run_page_url = run_page_url self.retry_limit = retry_limit self.retry_delay = retry_delay self.polling_period_seconds = polling_period_seconds
[docs] def serialize(self) -> Tuple[str, Dict[str, Any]]: """Serializes DatabricksTrigger arguments and classpath.""" return ( "astronomer.providers.databricks.triggers.databricks.DatabricksTrigger", { "conn_id": self.conn_id, "task_id": self.task_id, "run_id": self.run_id, "job_id": self.job_id, "run_page_url": self.run_page_url, "retry_limit": self.retry_limit, "retry_delay": self.retry_delay, "polling_period_seconds": self.polling_period_seconds, }, )
[docs] async def run(self) -> AsyncIterator["TriggerEvent"]: """ Makes a series of asynchronous http calls via a Databrick hook. It yields a Trigger if response is a 200 and run_state is successful, will retry the call up to the retry limit if the error is 'retryable', otherwise it throws an exception. """ hook = self._get_async_hook() while True: try: run_info = await hook.get_run_response(self.run_id) run_state = RunState(**run_info["state"]) if not run_state.is_terminal: self.log.info("%s in run state: %s", self.task_id, run_state) self.log.info("Sleeping for %s seconds.", self.polling_period_seconds) await asyncio.sleep(self.polling_period_seconds) elif run_state.is_terminal and run_state.is_successful: yield TriggerEvent( { "status": "success", "job_id": self.job_id, "run_id": self.run_id, "run_page_url": self.run_page_url, } ) elif run_state.result_state == "FAILED": notebook_error = run_state.state_message tasks = run_info["tasks"] if "tasks" in run_info else [] for task in tasks: if task.get("state", {}).get("result_state", "") == "FAILED": task_run_id = task["run_id"] run_output = await hook.get_run_output_response(task_run_id) notebook_error = run_output["error"] if "error" in run_output else notebook_error error_message = ( f"{self.task_id} failed with terminal state: {run_state} " f"and with the error {notebook_error}" ) yield TriggerEvent({"status": "error", "message": error_message}) else: error_message = ( f"{self.task_id} failed with terminal state: {run_state} " f"and with the error {run_state.state_message}" ) yield TriggerEvent({"status": "error", "message": error_message}) except Exception as e: yield TriggerEvent({"status": "error", "message": str(e)})
def _get_async_hook(self) -> DatabricksHookAsync: return DatabricksHookAsync( self.conn_id, retry_limit=self.retry_limit, retry_delay=self.retry_delay, )