Source code for astronomer.providers.amazon.aws.operators.emr

from typing import Any, Dict

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook
from airflow.providers.amazon.aws.operators.emr import EmrContainerOperator

from astronomer.providers.amazon.aws.triggers.emr import EmrContainerOperatorTrigger
from astronomer.providers.utils.typing_compat import Context


[docs]class EmrContainerOperatorAsync(EmrContainerOperator): """ An async operator that submits jobs to EMR on EKS virtual clusters. :param name: The name of the job run. :param virtual_cluster_id: The EMR on EKS virtual cluster ID :param execution_role_arn: The IAM role ARN associated with the job run. :param release_label: The Amazon EMR release version to use for the job run. :param job_driver: Job configuration details, e.g. the Spark job parameters. :param configuration_overrides: The configuration overrides for the job run, specifically either application configuration or monitoring configuration. :param client_request_token: The client idempotency token of the job run request. Use this if you want to specify a unique ID to prevent two jobs from getting started. If no token is provided, a UUIDv4 token will be generated for you. :param aws_conn_id: The Airflow connection used for AWS credentials. :param poll_interval: Time (in seconds) to wait between two consecutive calls to check query status on EMR :param max_tries: Deprecated - use max_polling_attempts instead. :param max_polling_attempts: Maximum number of times to wait for the job run to finish. Defaults to None, which will poll until the job is *not* in a pending, submitted, or running state. :param tags: The tags assigned to job runs. Defaults to None """
[docs] def execute(self, context: Context) -> None: """Deferred and give control to trigger""" hook = EmrContainerHook(aws_conn_id=self.aws_conn_id, virtual_cluster_id=self.virtual_cluster_id) job_id = hook.submit_job( name=self.name, execution_role_arn=self.execution_role_arn, release_label=self.release_label, job_driver=self.job_driver, configuration_overrides=self.configuration_overrides, client_request_token=self.client_request_token, ) try: # for apache-airflow-providers-amazon<6.0.0 polling_attempts = self.max_tries # type: ignore[attr-defined] except AttributeError: # pragma: no cover # for apache-airflow-providers-amazon>=6.0.0 # max_tries is deprecated so instead of max_tries using self.max_polling_attempts polling_attempts = self.max_polling_attempts self.defer( timeout=self.execution_timeout, trigger=EmrContainerOperatorTrigger( virtual_cluster_id=self.virtual_cluster_id, job_id=job_id, aws_conn_id=self.aws_conn_id, poll_interval=self.poll_interval, max_tries=polling_attempts, ), method_name="execute_complete", )
[docs] def execute_complete(self, context: Context, event: Dict[str, Any]) -> str: """ Callback for when the trigger fires - returns immediately. Relies on trigger to throw an exception, otherwise it assumes execution was successful. """ if "status" in event and event["status"] == "error": raise AirflowException(event["message"]) self.log.info(event["message"]) job_id: str = event["job_id"] return job_id