Source code for astronomer.providers.amazon.aws.triggers.emr
from __future__ import annotations
import asyncio
import warnings
from typing import Any, AsyncIterator, Iterable
from airflow.triggers.base import BaseTrigger, TriggerEvent
from astronomer.providers.amazon.aws.hooks.emr import (
EmrJobFlowHookAsync,
)
[docs]
class EmrContainerBaseTrigger(BaseTrigger):
"""
Poll for the status of EMR container until reaches terminal state
:param virtual_cluster_id: Reference Emr cluster id
:param job_id: job_id to check the state
:param max_tries: maximum try attempts for polling the status
:param aws_conn_id: Reference to AWS connection id
:param poll_interval: polling period in seconds to check for the status
"""
def __init__(
self,
virtual_cluster_id: str,
job_id: str,
aws_conn_id: str = "aws_default",
poll_interval: int = 10,
max_tries: int | None = None,
**kwargs: Any,
):
self.virtual_cluster_id = virtual_cluster_id
self.job_id = job_id
self.aws_conn_id = aws_conn_id
self.poll_interval = poll_interval
self.max_tries = max_tries
super().__init__(**kwargs)
[docs]
class EmrJobFlowSensorTrigger(BaseTrigger):
"""
EmrJobFlowSensorTrigger is fired as deferred class with params to run the task in trigger worker, when
EMR JobFlow is created
:param job_flow_id: job_flow_id to check the state of
:param target_states: the target states, sensor waits until
job flow reaches any of these states
:param failed_states: the failure states, sensor fails when
job flow reaches any of these states
:param poll_interval: polling period in seconds to check for the status
"""
def __init__(
self,
job_flow_id: str,
aws_conn_id: str,
poll_interval: float,
target_states: Iterable[str] | None = None,
failed_states: Iterable[str] | None = None,
):
warnings.warn(
(
"This module is deprecated and will be removed in 2.0.0."
"Please use :class: `~airflow.providers.amazon.aws.triggers.emr.EmrTerminateJobFlowTrigger."
),
DeprecationWarning,
stacklevel=2,
)
super().__init__()
self.job_flow_id = job_flow_id
self.aws_conn_id = aws_conn_id
self.poll_interval = poll_interval
self.target_states = target_states or ["TERMINATED"]
self.failed_states = failed_states or ["TERMINATED_WITH_ERRORS"]
[docs]
def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serializes EmrJobFlowSensorTrigger arguments and classpath."""
return (
"astronomer.providers.amazon.aws.triggers.emr.EmrJobFlowSensorTrigger",
{
"job_flow_id": self.job_flow_id,
"aws_conn_id": self.aws_conn_id,
"target_states": self.target_states,
"failed_states": self.failed_states,
"poll_interval": self.poll_interval,
},
)
[docs]
async def run(self) -> AsyncIterator[TriggerEvent]:
"""Make async connection to EMR container, polls for the target job state"""
hook = EmrJobFlowHookAsync(aws_conn_id=self.aws_conn_id)
try:
while True:
cluster_details = await hook.get_cluster_details(self.job_flow_id)
cluster_state = hook.state_from_response(cluster_details)
if cluster_state in self.target_states:
yield TriggerEvent(
{"status": "success", "message": f"Job flow currently {cluster_state}"}
)
elif cluster_state in self.failed_states:
final_message = "EMR job failed"
failure_message = hook.failure_message_from_response(cluster_details)
if failure_message:
final_message += " " + failure_message
yield TriggerEvent({"status": "error", "message": final_message})
await asyncio.sleep(self.poll_interval)
except Exception as e:
yield TriggerEvent({"status": "error", "message": str(e)})