Source code for astronomer.providers.amazon.aws.triggers.redshift_cluster

import asyncio
from typing import Any, AsyncIterator, Dict, Tuple

from airflow.triggers.base import BaseTrigger, TriggerEvent

from astronomer.providers.amazon.aws.hooks.redshift_cluster import RedshiftHookAsync


[docs]class RedshiftClusterTrigger(BaseTrigger): """ RedshiftClusterTrigger is fired as deferred class with params to run the task in trigger worker :param task_id: Reference to task id of the Dag :param polling_period_seconds: polling period in seconds to check for the status :param aws_conn_id: Reference to AWS connection id for redshift :param cluster_identifier: unique identifier of a cluster :param operation_type: Reference to the type of operation need to be performed eg: pause_cluster, resume_cluster """ def __init__( self, task_id: str, polling_period_seconds: float, aws_conn_id: str, cluster_identifier: str, operation_type: str, ): super().__init__() self.task_id = task_id self.polling_period_seconds = polling_period_seconds self.aws_conn_id = aws_conn_id self.cluster_identifier = cluster_identifier self.operation_type = operation_type
[docs] def serialize(self) -> Tuple[str, Dict[str, Any]]: """Serializes RedshiftClusterTrigger arguments and classpath.""" return ( "astronomer.providers.amazon.aws.triggers.redshift_cluster.RedshiftClusterTrigger", { "task_id": self.task_id, "polling_period_seconds": self.polling_period_seconds, "aws_conn_id": self.aws_conn_id, "cluster_identifier": self.cluster_identifier, "operation_type": self.operation_type, }, )
[docs] async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override] """ Make async connection to redshift, based on the operation type call the RedshiftHookAsync functions if operation_type is 'resume_cluster' it will call the resume_cluster function in RedshiftHookAsync if operation_type is 'pause_cluster it will call the pause_cluster function in RedshiftHookAsync """ hook = RedshiftHookAsync(aws_conn_id=self.aws_conn_id) try: if self.operation_type == "resume_cluster": response = await hook.resume_cluster(cluster_identifier=self.cluster_identifier) if response: yield TriggerEvent(response) else: error_message = f"{self.task_id} failed" yield TriggerEvent({"status": "error", "message": error_message}) else: response = await hook.pause_cluster(cluster_identifier=self.cluster_identifier) if response: yield TriggerEvent(response) else: error_message = f"{self.task_id} failed" yield TriggerEvent({"status": "error", "message": error_message}) except Exception as e: yield TriggerEvent({"status": "error", "message": str(e)})
[docs]class RedshiftClusterSensorTrigger(BaseTrigger): """ RedshiftClusterSensorTrigger is fired as deferred class with params to run the task in trigger worker :param task_id: Reference to task id of the Dag :param aws_conn_id: Reference to AWS connection id for redshift :param cluster_identifier: unique identifier of a cluster :param target_status: Reference to the status which needs to be checked :param polling_period_seconds: polling period in seconds to check for the status """ def __init__( self, task_id: str, aws_conn_id: str, cluster_identifier: str, target_status: str, polling_period_seconds: float, ): super().__init__() self.task_id = task_id self.aws_conn_id = aws_conn_id self.cluster_identifier = cluster_identifier self.target_status = target_status self.polling_period_seconds = polling_period_seconds
[docs] def serialize(self) -> Tuple[str, Dict[str, Any]]: """Serializes RedshiftClusterSensorTrigger arguments and classpath.""" return ( "astronomer.providers.amazon.aws.triggers.redshift_cluster.RedshiftClusterSensorTrigger", { "task_id": self.task_id, "aws_conn_id": self.aws_conn_id, "cluster_identifier": self.cluster_identifier, "target_status": self.target_status, "polling_period_seconds": self.polling_period_seconds, }, )
[docs] async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override] """Simple async function run until the cluster status match the target status.""" try: hook = RedshiftHookAsync(aws_conn_id=self.aws_conn_id) while True: res = await hook.cluster_status(self.cluster_identifier) if (res["status"] == "success" and res["cluster_state"] == self.target_status) or res[ "status" ] == "error": yield TriggerEvent(res) await asyncio.sleep(self.polling_period_seconds) except Exception as e: yield TriggerEvent({"status": "error", "message": str(e)})