Source code for astronomer.providers.amazon.aws.triggers.batch

import asyncio
from typing import Any, AsyncIterator, Dict, Optional, Tuple

from airflow.triggers.base import BaseTrigger, TriggerEvent

from astronomer.providers.amazon.aws.hooks.batch_client import BatchClientHookAsync


[docs]class BatchOperatorTrigger(BaseTrigger): """ Checks for the state of a previously submitted job to AWS Batch. BatchOperatorTrigger is fired as deferred class with params to poll the job state in Triggerer :param job_id: the job ID, usually unknown (None) until the submit_job operation gets the jobId defined by AWS Batch :param job_name: the name for the job that will run on AWS Batch (templated) :param job_definition: the job definition name on AWS Batch :param job_queue: the queue name on AWS Batch :param container_overrides: the `containerOverrides` parameter for boto3 (templated) :param array_properties: the `arrayProperties` parameter for boto3 :param parameters: the `parameters` for boto3 (templated) :param waiters: a :class:`.BatchWaiters` object (see note below); if None, polling is used with max_retries and status_retries. :param tags: collection of tags to apply to the AWS Batch job submission if None, no tags are submitted :param max_retries: exponential back-off retries, 4200 = 48 hours; polling is only used when waiters is None :param status_retries: number of HTTP retries to get job status, 10; polling is only used when waiters is None :param aws_conn_id: connection id of AWS credentials / region name. If None, credential boto3 strategy will be used. :param region_name: AWS region name to use . Override the region_name in connection (if provided) """ def __init__( self, job_id: Optional[str], job_name: str, job_definition: str, job_queue: str, container_overrides: Dict[str, str], array_properties: Dict[str, str], parameters: Dict[str, str], waiters: Any, tags: Dict[str, str], max_retries: int, status_retries: int, region_name: Optional[str], aws_conn_id: Optional[str] = "aws_default", ): super().__init__() self.job_id = job_id self.job_name = job_name self.job_definition = job_definition self.job_queue = job_queue self.container_overrides = container_overrides or {} self.array_properties = array_properties or {} self.parameters = parameters or {} self.waiters = waiters self.tags = tags or {} self.max_retries = max_retries self.status_retries = status_retries self.aws_conn_id = aws_conn_id self.region_name = region_name
[docs] def serialize(self) -> Tuple[str, Dict[str, Any]]: """Serializes BatchOperatorTrigger arguments and classpath.""" return ( "astronomer.providers.amazon.aws.triggers.batch.BatchOperatorTrigger", { "job_id": self.job_id, "job_name": self.job_name, "job_definition": self.job_definition, "job_queue": self.job_queue, "container_overrides": self.container_overrides, "array_properties": self.array_properties, "parameters": self.parameters, "waiters": self.waiters, "tags": self.tags, "max_retries": self.max_retries, "status_retries": self.status_retries, "aws_conn_id": self.aws_conn_id, "region_name": self.region_name, }, )
[docs] async def run(self) -> AsyncIterator["TriggerEvent"]: """ Make async connection using aiobotocore library to AWS Batch, periodically poll for the job status on the Triggerer The status that indicates job completion are: 'SUCCEEDED'|'FAILED'. So the status options that this will poll for are the transitions from: 'SUBMITTED'>'PENDING'>'RUNNABLE'>'STARTING'>'RUNNING'>'SUCCEEDED'|'FAILED' """ hook = BatchClientHookAsync(job_id=self.job_id, waiters=self.waiters, aws_conn_id=self.aws_conn_id) try: response = await hook.monitor_job() if response: yield TriggerEvent(response) else: error_message = f"{self.job_id} failed" yield TriggerEvent({"status": "error", "message": error_message}) except Exception as e: yield TriggerEvent({"status": "error", "message": str(e)})
[docs]class BatchSensorTrigger(BaseTrigger): """ Checks for the status of a submitted job_id to AWS Batch until it reaches a failure or a success state. BatchSensorTrigger is fired as deferred class with params to poll the job state in Triggerer :param job_id: the job ID, to poll for job completion or not :param aws_conn_id: connection id of AWS credentials / region name. If None, credential boto3 strategy will be used :param region_name: AWS region name to use Override the region_name in connection (if provided) :param poke_interval: polling period in seconds to check for the status of the job """ def __init__( self, job_id: str, region_name: Optional[str], aws_conn_id: Optional[str] = "aws_default", poke_interval: float = 5, ): super().__init__() self.job_id = job_id self.aws_conn_id = aws_conn_id self.region_name = region_name self.poke_interval = poke_interval
[docs] def serialize(self) -> Tuple[str, Dict[str, Any]]: """Serializes BatchSensorTrigger arguments and classpath.""" return ( "astronomer.providers.amazon.aws.triggers.batch.BatchSensorTrigger", { "job_id": self.job_id, "aws_conn_id": self.aws_conn_id, "region_name": self.region_name, "poke_interval": self.poke_interval, }, )
[docs] async def run(self) -> AsyncIterator["TriggerEvent"]: """ Make async connection using aiobotocore library to AWS Batch, periodically poll for the Batch job status The status that indicates job completion are: 'SUCCEEDED'|'FAILED'. """ hook = BatchClientHookAsync(job_id=self.job_id, aws_conn_id=self.aws_conn_id) try: while True: response = await hook.get_job_description(self.job_id) state = response["status"] if state == BatchClientHookAsync.SUCCESS_STATE: success_message = f"{self.job_id} was completed successfully" yield TriggerEvent({"status": "success", "message": success_message}) if state == BatchClientHookAsync.FAILURE_STATE: error_message = f"{self.job_id} failed" yield TriggerEvent({"status": "error", "message": error_message}) await asyncio.sleep(self.poke_interval) except Exception as e: yield TriggerEvent({"status": "error", "message": str(e)})