Source code for astronomer.providers.amazon.aws.triggers.emr

import asyncio
from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Tuple

from airflow.triggers.base import BaseTrigger, TriggerEvent

from astronomer.providers.amazon.aws.hooks.emr import (
    EmrContainerHookAsync,
    EmrJobFlowHookAsync,
    EmrStepSensorHookAsync,
)


[docs]class EmrContainerSensorTrigger(BaseTrigger): """ The EmrContainerSensorTrigger is triggered when EMR container is created, it polls for the AWS EMR EKS Virtual Cluster Job status. It is fired as deferred class with params to run the task in trigger worker :param virtual_cluster_id: Reference Emr cluster id :param job_id: job_id to check the state :param max_retries: maximum retry for poll for the status :param aws_conn_id: Reference to AWS connection id :param poll_interval: polling period in seconds to check for the status """ def __init__( self, virtual_cluster_id: str, job_id: str, max_retries: Optional[int] = None, aws_conn_id: str = "aws_default", poll_interval: int = 10, ): super().__init__() self.virtual_cluster_id = virtual_cluster_id self.job_id = job_id self.max_retries = max_retries self.aws_conn_id = aws_conn_id self.poll_interval = poll_interval
[docs] def serialize(self) -> Tuple[str, Dict[str, Any]]: """Serializes EmrContainerSensorTrigger arguments and classpath.""" return ( "astronomer.providers.amazon.aws.triggers.emr.EmrContainerSensorTrigger", { "virtual_cluster_id": self.virtual_cluster_id, "job_id": self.job_id, "max_retries": self.max_retries, "poll_interval": self.poll_interval, "aws_conn_id": self.aws_conn_id, }, )
[docs] async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override] """Make async connection to EMR container, polls for the job state""" hook = EmrContainerHookAsync(aws_conn_id=self.aws_conn_id, virtual_cluster_id=self.virtual_cluster_id) try: try_number: int = 1 while True: query_status = await hook.check_job_status(job_id=self.job_id) if query_status is None or query_status in ("PENDING", "SUBMITTED", "RUNNING"): await asyncio.sleep(self.poll_interval) elif query_status in ("FAILED", "CANCELLED", "CANCEL_PENDING"): msg = f"EMR Containers sensor failed {query_status}" yield TriggerEvent({"status": "error", "message": msg}) else: msg = "EMR Containers sensors completed" yield TriggerEvent({"status": "success", "message": msg}) if self.max_retries and try_number >= self.max_retries: yield TriggerEvent( { "status": "error", "message": "Timeout: Maximum retry limit exceed", "job_id": self.job_id, } ) try_number += 1 except Exception as e: yield TriggerEvent({"status": "error", "message": str(e)})
[docs]class EmrContainerOperatorTrigger(BaseTrigger): """ The EmrContainerSensorTrigger is triggered when EMR container is created, it polls for the AWS EMR EKS Virtual Cluster Job status. It is fired as deferred class with params to run the task in trigger worker :param virtual_cluster_id: The EMR on EKS virtual cluster ID. :param name: The name of the job run. :param execution_role_arn: The IAM role ARN associated with the job run. :param release_label: The Amazon EMR release version to use for the job run. :param job_driver: Job configuration details, e.g. the Spark job parameters. :param configuration_overrides: The configuration overrides for the job run, specifically either application configuration or monitoring configuration. :param client_request_token: The client idempotency token of the job run request. :param aws_conn_id: Reference to AWS connection id. :param poll_interval: polling period in seconds to check for the status. :param max_retries: maximum retry for poll for the status. """ INTERMEDIATE_STATES: List[str] = ["PENDING", "SUBMITTED", "RUNNING"] FAILURE_STATES: List[str] = ["FAILED", "CANCELLED", "CANCEL_PENDING"] SUCCESS_STATES: List[str] = ["COMPLETED"] TERMINAL_STATES: List[str] = ["COMPLETED", "FAILED", "CANCELLED", "CANCEL_PENDING"] def __init__( self, virtual_cluster_id: str, name: str, job_id: str, aws_conn_id: str = "aws_default", poll_interval: int = 30, max_tries: Optional[int] = None, ): super().__init__() self.virtual_cluster_id = virtual_cluster_id self.name = name self.job_id = job_id self.aws_conn_id = aws_conn_id self.poll_interval = poll_interval self.max_tries = max_tries
[docs] def serialize(self) -> Tuple[str, Dict[str, Any]]: """Serializes EmrContainerOperatorTrigger arguments and classpath.""" return ( "astronomer.providers.amazon.aws.triggers.emr.EmrContainerOperatorTrigger", { "virtual_cluster_id": self.virtual_cluster_id, "name": self.name, "job_id": self.job_id, "aws_conn_id": self.aws_conn_id, "max_tries": self.max_tries, "poll_interval": self.poll_interval, }, )
[docs] async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override] """Run until EMR container reaches the desire state""" hook = EmrContainerHookAsync(aws_conn_id=self.aws_conn_id, virtual_cluster_id=self.virtual_cluster_id) try: try_number: int = 1 while True: query_state = await hook.check_job_status(self.job_id) if query_state is None: self.log.info("Try %s: Invalid query state. Retrying again", try_number) await asyncio.sleep(self.poll_interval) elif query_state in self.FAILURE_STATES: self.log.info( "Try %s: Query execution completed. Final state is %s", try_number, query_state ) error_message = await hook.get_job_failure_reason(self.job_id) message = ( f"EMR Containers job failed. Final state is {query_state}. " f"query_execution_id is {self.job_id}. Error: {error_message}" ) yield TriggerEvent( { "status": "error", "message": message, "job_id": self.job_id, } ) elif query_state in self.SUCCESS_STATES: self.log.info( "Try %s: Query execution completed. Final state is %s", try_number, query_state ) yield TriggerEvent( { "status": "success", "message": f"EMR Containers Operator success {query_state}", "job_id": self.job_id, } ) else: self.log.info( "Try %s: Query is still in non-terminal state - %s", try_number, query_state ) await asyncio.sleep(self.poll_interval) if self.max_tries and try_number >= self.max_tries: yield TriggerEvent( { "status": "error", "message": "Timeout: Maximum retry limit exceed", "job_id": self.job_id, } ) try_number += 1 except Exception as e: yield TriggerEvent({"status": "error", "message": str(e)})
[docs]class EmrStepSensorTrigger(BaseTrigger): """ A trigger that fires once AWS EMR cluster step reaches either target or failed state :param job_flow_id: job_flow_id which contains the step check the state of :param step_id: step to check the state of :param aws_conn_id: aws connection to use, defaults to 'aws_default' :param poke_interval: Time in seconds to wait between two consecutive call to check emr cluster step state :param target_states: the target states, sensor waits until step reaches any of these states :param failed_states: the failure states, sensor fails when step reaches any of these states """ def __init__( self, job_flow_id: str, step_id: str, aws_conn_id: str, poke_interval: float, target_states: Optional[Iterable[str]] = None, failed_states: Optional[Iterable[str]] = None, ): super().__init__() self.job_flow_id = job_flow_id self.step_id = step_id self.aws_conn_id = aws_conn_id self.poke_interval = poke_interval self.target_states = target_states or ["COMPLETED"] self.failed_states = failed_states or ["CANCELLED", "FAILED", "INTERRUPTED"]
[docs] def serialize(self) -> Tuple[str, Dict[str, Any]]: """Serializes EmrStepSensorTrigger arguments and classpath.""" return ( "astronomer.providers.amazon.aws.triggers.emr.EmrStepSensorTrigger", { "job_flow_id": self.job_flow_id, "step_id": self.step_id, "aws_conn_id": self.aws_conn_id, "poke_interval": self.poke_interval, "target_states": self.target_states, "failed_states": self.failed_states, }, )
[docs] async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override] """Run until AWS EMR cluster step reach target or failed state""" hook = EmrStepSensorHookAsync( aws_conn_id=self.aws_conn_id, job_flow_id=self.job_flow_id, step_id=self.step_id ) try: while True: response = await hook.emr_describe_step() state = hook.state_from_response(response) if state in self.target_states: yield TriggerEvent({"status": "success", "message": f"Job flow currently {state}"}) elif state in self.failed_states: yield TriggerEvent( {"status": "error", "message": hook.failure_message_from_response(response)} ) self.log.info("EMR step state is %s. Sleeping for %s seconds.", state, self.poke_interval) await asyncio.sleep(self.poke_interval) except Exception as e: yield TriggerEvent({"status": "error", "message": str(e)})
[docs]class EmrJobFlowSensorTrigger(BaseTrigger): """ EmrJobFlowSensorTrigger is fired as deferred class with params to run the task in trigger worker, when EMR JobFlow is created :param job_flow_id: job_flow_id to check the state of :param target_states: the target states, sensor waits until job flow reaches any of these states :param failed_states: the failure states, sensor fails when job flow reaches any of these states :param poll_interval: polling period in seconds to check for the status """ def __init__( self, job_flow_id: str, aws_conn_id: str, poll_interval: float, target_states: Optional[Iterable[str]] = None, failed_states: Optional[Iterable[str]] = None, ): super().__init__() self.job_flow_id = job_flow_id self.aws_conn_id = aws_conn_id self.poll_interval = poll_interval self.target_states = target_states or ["TERMINATED"] self.failed_states = failed_states or ["TERMINATED_WITH_ERRORS"]
[docs] def serialize(self) -> Tuple[str, Dict[str, Any]]: """Serializes EmrJobFlowSensorTrigger arguments and classpath.""" return ( "astronomer.providers.amazon.aws.triggers.emr.EmrJobFlowSensorTrigger", { "job_flow_id": self.job_flow_id, "aws_conn_id": self.aws_conn_id, "target_states": self.target_states, "failed_states": self.failed_states, "poll_interval": self.poll_interval, }, )
[docs] async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override] """Make async connection to EMR container, polls for the target job state""" hook = EmrJobFlowHookAsync(aws_conn_id=self.aws_conn_id) try: while True: cluster_details = await hook.get_cluster_details(self.job_flow_id) cluster_state = hook.state_from_response(cluster_details) if cluster_state in self.target_states: yield TriggerEvent( {"status": "success", "message": f"Job flow currently {cluster_state}"} ) elif cluster_state in self.failed_states: final_message = "EMR job failed" failure_message = hook.failure_message_from_response(cluster_details) if failure_message: final_message += " " + failure_message yield TriggerEvent({"status": "error", "message": final_message}) await asyncio.sleep(self.poll_interval) except Exception as e: yield TriggerEvent({"status": "error", "message": str(e)})