Source code for astronomer.providers.amazon.aws.triggers.redshift_sql

from typing import Any, AsyncIterator, Dict, List, Tuple

from airflow.triggers.base import BaseTrigger, TriggerEvent

from astronomer.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHookAsync


[docs]class RedshiftSQLTrigger(BaseTrigger): """ RedshiftSQLTrigger is fired as deferred class with params to run the task in trigger worker :param task_id: Reference to task id of the Dag :param polling_period_seconds: polling period in seconds to check for the status :param aws_conn_id: Reference to AWS connection id for redshift :param query_ids: list of Query ids to run and poll for the status """ def __init__( self, task_id: str, polling_period_seconds: float, aws_conn_id: str, query_ids: List[str], ): super().__init__() self.task_id = task_id self.polling_period_seconds = polling_period_seconds self.aws_conn_id = aws_conn_id self.query_ids = query_ids
[docs] def serialize(self) -> Tuple[str, Dict[str, Any]]: """Serializes RedshiftSQLTrigger arguments and classpath.""" return ( "astronomer.providers.amazon.aws.triggers.redshift_sql.RedshiftSQLTrigger", { "task_id": self.task_id, "polling_period_seconds": self.polling_period_seconds, "aws_conn_id": self.aws_conn_id, "query_ids": self.query_ids, }, )
[docs] async def run(self) -> AsyncIterator["TriggerEvent"]: """Make async connection and execute query using the Amazon Redshift Data API.""" hook = RedshiftSQLHookAsync(aws_conn_id=self.aws_conn_id) try: response = await hook.get_query_status(self.query_ids) if response: yield TriggerEvent(response) else: error_message = f"{self.task_id} failed" yield TriggerEvent({"status": "error", "message": error_message}) except Exception as e: yield TriggerEvent({"status": "error", "message": str(e)})