Source code for astronomer.providers.snowflake.triggers.snowflake_trigger

import asyncio
from datetime import timedelta
from typing import Any, AsyncIterator, Dict, List, Optional, Tuple

from airflow import AirflowException
from airflow.triggers.base import BaseTrigger, TriggerEvent
from asgiref.sync import sync_to_async

from astronomer.providers.snowflake.hooks.snowflake import (
    SnowflakeHookAsync,
    fetch_all_snowflake_handler,
)
from astronomer.providers.snowflake.hooks.snowflake_sql_api import (
    SnowflakeSqlApiHookAsync,
)


[docs]def get_db_hook(snowflake_conn_id: str) -> SnowflakeHookAsync: """ Create and return SnowflakeHookAsync. :return: a SnowflakeHookAsync instance. """ return SnowflakeHookAsync(snowflake_conn_id=snowflake_conn_id)
[docs]class SnowflakeTrigger(BaseTrigger): """ Snowflake Trigger inherits from the BaseTrigger,it is fired as deferred class with params to run the task in trigger worker and fetch the status for the query ids passed :param task_id: Reference to task id of the Dag :param poll_interval: polling period in seconds to check for the status :param query_ids: List of Query ids to run and poll for the status :param snowflake_conn_id: Reference to Snowflake connection id """ def __init__( self, task_id: str, poll_interval: float, query_ids: List[str], snowflake_conn_id: str, ): super().__init__() self.task_id = task_id self.poll_interval = poll_interval self.query_ids = query_ids self.snowflake_conn_id = snowflake_conn_id
[docs] def serialize(self) -> Tuple[str, Dict[str, Any]]: """Serializes SnowflakeTrigger arguments and classpath.""" return ( "astronomer.providers.snowflake.triggers.snowflake_trigger.SnowflakeTrigger", { "task_id": self.task_id, "poll_interval": self.poll_interval, "query_ids": self.query_ids, "snowflake_conn_id": self.snowflake_conn_id, }, )
[docs] async def run(self) -> AsyncIterator["TriggerEvent"]: """ Makes a series of connections to snowflake to get the status of the query by async get_query_status function """ hook = get_db_hook(self.snowflake_conn_id) try: run_state = await hook.get_query_status(self.query_ids, self.poll_interval) if run_state: yield TriggerEvent(run_state) else: error_message = f"{self.task_id} failed with terminal state: {run_state}" yield TriggerEvent({"status": "error", "message": error_message}) except Exception as e: yield TriggerEvent({"status": "error", "message": str(e)})
[docs]class SnowflakeSqlApiTrigger(BaseTrigger): """ SnowflakeSqlApi Trigger inherits from the BaseTrigger,it is fired as deferred class with params to run the task in trigger worker and fetch the status for the query ids passed :param task_id: Reference to task id of the Dag :param poll_interval: polling period in seconds to check for the status :param query_ids: List of Query ids to run and poll for the status :param snowflake_conn_id: Reference to Snowflake connection id """ def __init__( self, poll_interval: float, query_ids: List[str], snowflake_conn_id: str, token_life_time: timedelta, token_renewal_delta: timedelta, ): super().__init__() self.poll_interval = poll_interval self.query_ids = query_ids self.snowflake_conn_id = snowflake_conn_id self.token_life_time = token_life_time self.token_renewal_delta = token_renewal_delta
[docs] def serialize(self) -> Tuple[str, Dict[str, Any]]: """Serializes SnowflakeSqlApiTrigger arguments and classpath.""" return ( "astronomer.providers.snowflake.triggers.snowflake_trigger.SnowflakeSqlApiTrigger", { "poll_interval": self.poll_interval, "query_ids": self.query_ids, "snowflake_conn_id": self.snowflake_conn_id, "token_life_time": self.token_life_time, "token_renewal_delta": self.token_renewal_delta, }, )
[docs] async def run(self) -> AsyncIterator["TriggerEvent"]: """ Makes a GET API request to snowflake with query_id to get the status of the query by get_sql_api_query_status async function """ hook = SnowflakeSqlApiHookAsync( self.snowflake_conn_id, self.token_life_time, self.token_renewal_delta, ) try: statement_query_ids: List[str] = [] for query_id in self.query_ids: while await self.is_still_running(query_id): await asyncio.sleep(self.poll_interval) statement_status = await hook.get_sql_api_query_status(query_id) if statement_status["status"] == "error": yield TriggerEvent(statement_status) if statement_status["status"] == "success": statement_query_ids.extend(statement_status["statement_handles"]) yield TriggerEvent( { "status": "success", "statement_query_ids": statement_query_ids, } ) except Exception as e: yield TriggerEvent({"status": "error", "message": str(e)})
[docs] async def is_still_running(self, query_id: str) -> bool: """ Async function to check whether the query statement submitted via SQL API is still running state and returns True if it is still running else return False """ hook = SnowflakeSqlApiHookAsync( self.snowflake_conn_id, self.token_life_time, self.token_renewal_delta, ) statement_status = await hook.get_sql_api_query_status(query_id) if statement_status["status"] in ["running"]: return True return False
[docs]class SnowflakeSensorTrigger(BaseTrigger): """ This trigger validates the result of a query (asynchronously). An Airflow Trigger asynchronously polls for a certain condition to be true (which yields a ``TriggerEvent``), after which a synchronous piece of code can be used to complete the logic (set by ``method_name`` on AsyncOperator/Sensor.defer()). Docs: https://airflow.apache.org/docs/apache-airflow/stable/concepts/deferring.html#triggering-deferral """ def __init__( self, sql: str, dag_id: str, task_id: str, run_id: str, snowflake_conn_id: str, parameters: Optional[str] = None, success: Optional[str] = None, failure: Optional[str] = None, fail_on_empty: bool = False, poke_interval: float = 60, ): super().__init__() self._sql = sql self._parameters = parameters self._success = success self._failure = failure self._fail_on_empty = fail_on_empty self._dag_id = dag_id self._task_id = task_id self._run_id = run_id self._conn_id = snowflake_conn_id self._poke_interval = poke_interval
[docs] def serialize(self) -> Tuple[str, Dict[str, Any]]: """Serializes SqlTrigger arguments and classpath..""" return ( "astronomer.providers.snowflake.triggers.snowflake_trigger.SnowflakeSensorTrigger", { "sql": self._sql, "parameters": self._parameters, "poke_interval": self._poke_interval, "success": self._success, "failure": self._failure, "fail_on_empty": self._fail_on_empty, "dag_id": self._dag_id, "task_id": self._task_id, "run_id": self._run_id, "snowflake_conn_id": self._conn_id, }, )
[docs] def validate_result(self, result: List[Tuple[Any]]) -> Any: """Validates query result and verifies if it returns a row""" if not result: if self._fail_on_empty: raise AirflowException("No rows returned, raising as per fail_on_empty flag") else: return False first_cell = result[0][0] if self._failure is not None: if callable(self._failure): if self._failure(first_cell): raise AirflowException(f"Failure criteria met. self.failure({first_cell}) returned True") else: raise AirflowException(f"self.failure is present, but not callable -> {self._failure}") if self._success is not None: if callable(self._success): return self._success(first_cell) else: raise AirflowException(f"self.success is present, but not callable -> {self._success}") return bool(first_cell)
[docs] async def run(self) -> AsyncIterator["TriggerEvent"]: """ Make an asynchronous connection to Snowflake and defer until query returns a result """ try: hook = get_db_hook(self._conn_id) while True: query_ids = await sync_to_async(hook.run)(self._sql, parameters=self._parameters) run_state = await hook.get_query_status(query_ids, 5) if run_state: result = await sync_to_async(hook.check_query_output)( query_ids=query_ids, handler=fetch_all_snowflake_handler, ) self.log.info( "Raw query result = %s <DAG id = %s, task id = %s, run id = %s>", result, self._dag_id, self._task_id, self._run_id, ) if await sync_to_async(self.validate_result)(result): yield TriggerEvent( { "status": "success", "message": "Found expected markers.", } ) else: self.log.info( ( "No success yet. Checking again in %s seconds. " "<DAG id = %s, task id = %s, run id = %s>" ), self._poke_interval, self._dag_id, self._task_id, self._run_id, ) await asyncio.sleep(self._poke_interval) else: error_message = f"{self._task_id} failed with terminal state: {run_state}" yield TriggerEvent({"status": "error", "message": error_message}) except Exception as e: yield TriggerEvent({"status": "error", "message": str(e)})