Source code for astronomer.providers.databricks.triggers.databricks
import asyncio
import warnings
from typing import Any, AsyncIterator, Dict, Optional, Tuple
from airflow.providers.databricks.hooks.databricks import RunState
from airflow.triggers.base import BaseTrigger, TriggerEvent
from astronomer.providers.databricks.hooks.databricks import DatabricksHookAsync
[docs]
class DatabricksTrigger(BaseTrigger):
"""
This class is deprecated and will be removed in 2.0.0.
Use :class: `~airflow.providers.databricks.triggers.databricks.DatabricksExecutionTrigger` instead.
"""
def __init__(
self,
conn_id: str,
task_id: str,
run_id: str,
retry_limit: int,
retry_delay: int,
polling_period_seconds: int,
job_id: Optional[int] = None,
run_page_url: Optional[str] = None,
):
warnings.warn(
"This class is deprecated and will be removed in 2.0.0."
"Use `airflow.providers.databricks.triggers.databricks.DatabricksExecutionTrigger` instead."
)
super().__init__()
self.conn_id = conn_id
self.task_id = task_id
self.run_id = run_id
self.job_id = job_id
self.run_page_url = run_page_url
self.retry_limit = retry_limit
self.retry_delay = retry_delay
self.polling_period_seconds = polling_period_seconds
[docs]
def serialize(self) -> Tuple[str, Dict[str, Any]]:
"""Serializes DatabricksTrigger arguments and classpath."""
return (
"astronomer.providers.databricks.triggers.databricks.DatabricksTrigger",
{
"conn_id": self.conn_id,
"task_id": self.task_id,
"run_id": self.run_id,
"job_id": self.job_id,
"run_page_url": self.run_page_url,
"retry_limit": self.retry_limit,
"retry_delay": self.retry_delay,
"polling_period_seconds": self.polling_period_seconds,
},
)
[docs]
async def run(self) -> AsyncIterator["TriggerEvent"]:
"""
Makes a series of asynchronous http calls via a Databrick hook. It yields a Trigger if
response is a 200 and run_state is successful, will retry the call up to the retry limit
if the error is 'retryable', otherwise it throws an exception.
"""
hook = self._get_async_hook()
while True:
try:
run_info = await hook.get_run_response(self.run_id)
run_state = RunState(**run_info["state"])
if not run_state.is_terminal:
self.log.info("%s in run state: %s", self.task_id, run_state)
self.log.info("Sleeping for %s seconds.", self.polling_period_seconds)
await asyncio.sleep(self.polling_period_seconds)
elif run_state.is_terminal and run_state.is_successful:
yield TriggerEvent(
{
"status": "success",
"job_id": self.job_id,
"run_id": self.run_id,
"run_page_url": self.run_page_url,
}
)
elif run_state.result_state == "FAILED":
notebook_error = run_state.state_message
tasks = run_info["tasks"] if "tasks" in run_info else []
for task in tasks:
if task.get("state", {}).get("result_state", "") == "FAILED":
task_run_id = task["run_id"]
run_output = await hook.get_run_output_response(task_run_id)
notebook_error = run_output["error"] if "error" in run_output else notebook_error
error_message = (
f"{self.task_id} failed with terminal state: {run_state} "
f"and with the error {notebook_error}"
)
yield TriggerEvent({"status": "error", "message": error_message})
else:
error_message = (
f"{self.task_id} failed with terminal state: {run_state} "
f"and with the error {run_state.state_message}"
)
yield TriggerEvent({"status": "error", "message": error_message})
except Exception as e:
yield TriggerEvent({"status": "error", "message": str(e)})
def _get_async_hook(self) -> DatabricksHookAsync:
return DatabricksHookAsync(
self.conn_id,
retry_limit=self.retry_limit,
retry_delay=self.retry_delay,
)