Source code for astronomer.providers.amazon.aws.triggers.batch

from __future__ import annotations

import warnings
from typing import Any, AsyncIterator

from airflow.triggers.base import BaseTrigger, TriggerEvent

from astronomer.providers.amazon.aws.hooks.batch_client import BatchClientHookAsync


[docs] class BatchOperatorTrigger(BaseTrigger): """ Checks for the state of a previously submitted job to AWS Batch. BatchOperatorTrigger is fired as deferred class with params to poll the job state in Triggerer This class is deprecated and will be removed in 2.0.0. Use :class: `~airflow.providers.amazon.aws.triggers.batch.BatchOperatorTrigger` instead :param job_id: the job ID, usually unknown (None) until the submit_job operation gets the jobId defined by AWS Batch :param waiters: a :class:`.BatchWaiters` object (see note below); if None, polling is used with max_retries and status_retries. :param max_retries: exponential back-off retries, 4200 = 48 hours; polling is only used when waiters is None :param aws_conn_id: connection id of AWS credentials / region name. If None, credential boto3 strategy will be used. :param region_name: AWS region name to use . Override the region_name in connection (if provided) """ def __init__( self, job_id: str | None, waiters: Any, max_retries: int, region_name: str | None, aws_conn_id: str | None = "aws_default", ): warnings.warn( ( "This module is deprecated and will be removed in 2.0.0." "Please use `airflow.providers.amazon.aws.triggers.batch.BatchOperatorTrigger`" ), DeprecationWarning, stacklevel=2, ) super().__init__() self.job_id = job_id self.waiters = waiters self.max_retries = max_retries self.aws_conn_id = aws_conn_id self.region_name = region_name
[docs] def serialize(self) -> tuple[str, dict[str, Any]]: """Serializes BatchOperatorTrigger arguments and classpath.""" return ( "astronomer.providers.amazon.aws.triggers.batch.BatchOperatorTrigger", { "job_id": self.job_id, "waiters": self.waiters, "max_retries": self.max_retries, "aws_conn_id": self.aws_conn_id, "region_name": self.region_name, }, )
[docs] async def run(self) -> AsyncIterator[TriggerEvent]: """ Make async connection using aiobotocore library to AWS Batch, periodically poll for the job status on the Triggerer The status that indicates job completion are: 'SUCCEEDED'|'FAILED'. So the status options that this will poll for are the transitions from: 'SUBMITTED'>'PENDING'>'RUNNABLE'>'STARTING'>'RUNNING'>'SUCCEEDED'|'FAILED' """ hook = BatchClientHookAsync(job_id=self.job_id, waiters=self.waiters, aws_conn_id=self.aws_conn_id) try: response = await hook.monitor_job() if response: yield TriggerEvent(response) else: error_message = f"{self.job_id} failed" yield TriggerEvent({"status": "error", "message": error_message}) except Exception as e: yield TriggerEvent({"status": "error", "message": str(e)})