Source code for astronomer.providers.amazon.aws.triggers.sagemaker

from __future__ import annotations

import asyncio
import time
import warnings
from typing import Any, AsyncIterator

from airflow.providers.amazon.aws.hooks.sagemaker import LogState
from airflow.triggers.base import BaseTrigger, TriggerEvent

from astronomer.providers.amazon.aws.hooks.sagemaker import SageMakerHookAsync


[docs] class SagemakerProcessingTrigger(BaseTrigger): """ SagemakerProcessingTrigger is fired as deferred class with params to run the task in triggerer. This class is deprecated and will be removed in 2.0.0. Use :class: `~airflow.providers.amazon.aws.triggers.sagemaker.SageMakerTrigger` instead :param job_name: name of the job to check status :param poll_interval: polling period in seconds to check for the status :param aws_conn_id: AWS connection ID for sagemaker :param end_time: the end time in seconds. Any SageMaker jobs that run longer than this will fail. """ NON_TERMINAL_STATES = ("InProgress", "Stopping") TERMINAL_STATE = ("Failed",) def __init__( self, job_name: str, poll_interval: float, end_time: float | None, aws_conn_id: str = "aws_default", ): warnings.warn( ( "This module is deprecated and will be removed in 2.0.0." "Please use `airflow.providers.amazon.aws.hooks.sagemaker.SageMakerTrigger`" ), DeprecationWarning, stacklevel=2, ) super().__init__() self.job_name = job_name self.poll_interval = poll_interval self.aws_conn_id = aws_conn_id self.end_time = end_time
[docs] def serialize(self) -> tuple[str, dict[str, Any]]: """Serializes SagemakerProcessingTrigger arguments and classpath.""" return ( "astronomer.providers.amazon.aws.triggers.sagemaker.SagemakerProcessingTrigger", { "job_name": self.job_name, "poll_interval": self.poll_interval, "aws_conn_id": self.aws_conn_id, "end_time": self.end_time, }, )
[docs] async def run(self) -> AsyncIterator[TriggerEvent]: """ Makes async connection to sagemaker async hook and gets job status for a job submitted by the operator. Trigger returns a failure event if any error and success in state return the success event. """ hook = self._get_async_hook() while True: try: # check if time limit is set and timeout has happened or not if self.end_time and time.time() > self.end_time: yield TriggerEvent({"status": "error", "message": "Timeout"}) response = await hook.describe_processing_job_async(self.job_name) status = response["ProcessingJobStatus"] if status in self.TERMINAL_STATE: error_message = f"SageMaker job failed because {response['FailureReason']}" yield TriggerEvent({"status": "error", "message": error_message}) elif status in self.NON_TERMINAL_STATES: self.log.info("Job still running current status is %s", status) await asyncio.sleep(self.poll_interval) else: yield TriggerEvent({"status": "success", "message": response}) except Exception as e: yield TriggerEvent({"status": "error", "message": str(e)})
def _get_async_hook(self) -> SageMakerHookAsync: return SageMakerHookAsync(aws_conn_id=self.aws_conn_id)
[docs] class SagemakerTrigger(BaseTrigger): """ SagemakerTrigger is common trigger for both transform and training sagemaker job and it is fired as deferred class with params to run the task in triggerer. This class is deprecated and will be removed in 2.0.0. Use :class: `~airflow.providers.amazon.aws.triggers.sagemaker.SageMakerTrigger` instead :param job_name: name of the job to check status :param job_type: Type of the sagemaker job whether it is Transform or Training :param response_key: The key which needs to be look in the response. :param poke_interval: polling period in seconds to check for the status :param end_time: Time in seconds to wait for a job run to reach a terminal status. :param aws_conn_id: AWS connection ID for sagemaker """ NON_TERMINAL_STATES = ("InProgress", "Stopping", "Stopped") TERMINAL_STATE = ("Failed",) def __init__( self, job_name: str, job_type: str, response_key: str, poke_interval: float, end_time: float | None = None, aws_conn_id: str = "aws_default", ): warnings.warn( ( "This module is deprecated and will be removed in 2.0.0." "Please use `airflow.providers.amazon.aws.hooks.sagemaker.SageMakerTrigger`" ), DeprecationWarning, stacklevel=2, ) super().__init__() self.job_name = job_name self.job_type = job_type self.response_key = response_key self.poke_interval = poke_interval self.end_time = end_time self.aws_conn_id = aws_conn_id
[docs] def serialize(self) -> tuple[str, dict[str, Any]]: """Serializes SagemakerTrigger arguments and classpath.""" return ( "astronomer.providers.amazon.aws.triggers.sagemaker.SagemakerTrigger", { "poke_interval": self.poke_interval, "aws_conn_id": self.aws_conn_id, "end_time": self.end_time, "job_name": self.job_name, "job_type": self.job_type, "response_key": self.response_key, }, )
[docs] async def run(self) -> AsyncIterator[TriggerEvent]: """ Makes async connection to sagemaker async hook and gets job status for a job submitted by the operator. Trigger returns a failure event if any error and success in state return the success event. """ hook = self._get_async_hook() while True: try: if self.end_time and time.time() > self.end_time: yield TriggerEvent({"status": "error", "message": "Timeout"}) response = await self.get_job_status(hook, self.job_name, self.job_type) status = response[self.response_key] if status in self.NON_TERMINAL_STATES: await asyncio.sleep(self.poke_interval) elif status in self.TERMINAL_STATE: error_message = f"SageMaker job failed because {response['FailureReason']}" yield TriggerEvent({"status": "error", "message": error_message}) else: yield TriggerEvent({"status": "success", "message": response}) except Exception as e: yield TriggerEvent({"status": "error", "message": str(e)})
def _get_async_hook(self) -> SageMakerHookAsync: return SageMakerHookAsync(aws_conn_id=self.aws_conn_id)
[docs] @staticmethod async def get_job_status(hook: SageMakerHookAsync, job_name: str, job_type: str) -> dict[str, Any]: """ Based on the job type the SageMakerHookAsync connect to sagemaker related function and get the response of the job and return it """ if job_type == "Transform": response = await hook.describe_transform_job_async(job_name) elif job_type == "Training": response = await hook.describe_training_job_async(job_name) return response
[docs] class SagemakerTrainingWithLogTrigger(BaseTrigger): """ SagemakerTrainingWithLogTrigger is fired as deferred class with params to run the task in triggerer. This class is deprecated and will be removed in 2.0.0. Use :class: `~airflow.providers.amazon.aws.triggers.sagemaker.SageMakerTrainingPrintLogTrigger` instead :param job_name: name of the job to check status :param instance_count: count of the instance created for running the training job :param status: The status of the training job created. :param poke_interval: polling period in seconds to check for the status :param end_time: Time in seconds to wait for a job run to reach a terminal status. :param aws_conn_id: AWS connection ID for sagemaker """ NON_TERMINAL_STATES = ("InProgress", "Stopping", "Stopped") TERMINAL_STATE = ("Failed",) def __init__( self, job_name: str, instance_count: int, status: str, poke_interval: float, end_time: float | None = None, aws_conn_id: str = "aws_default", ): warnings.warn( ( "This module is deprecated and will be removed in 2.0.0." "Please use `airflow.providers.amazon.aws.hooks.sagemaker.SageMakerTrainingPrintLogTrigger`" ), DeprecationWarning, stacklevel=2, ) super().__init__() self.job_name = job_name self.instance_count = instance_count self.status = status self.poke_interval = poke_interval self.end_time = end_time self.aws_conn_id = aws_conn_id
[docs] def serialize(self) -> tuple[str, dict[str, Any]]: """Serializes SagemakerTrainingWithLogTrigger arguments and classpath.""" return ( "astronomer.providers.amazon.aws.triggers.sagemaker.SagemakerTrainingWithLogTrigger", { "poke_interval": self.poke_interval, "aws_conn_id": self.aws_conn_id, "end_time": self.end_time, "job_name": self.job_name, "status": self.status, "instance_count": self.instance_count, }, )
[docs] async def run(self) -> AsyncIterator[TriggerEvent]: """ Makes async connection to sagemaker async hook and gets job status for a job submitted by the operator. Trigger returns a failure event if any error and success in state return the success event. """ hook = self._get_async_hook() last_description = await hook.describe_training_job_async(self.job_name) stream_names: list[str] = [] # The list of log streams positions: dict[str, Any] = {} # The current position in each stream, map of stream name -> position job_already_completed = self.status not in self.NON_TERMINAL_STATES state = LogState.TAILING if not job_already_completed else LogState.COMPLETE last_describe_job_call = time.time() while True: try: if self.end_time and time.time() > self.end_time: yield TriggerEvent( { "status": "error", "message": f"SageMaker job took more than {self.end_time} seconds", } ) state, last_description, last_describe_job_call = await hook.describe_training_job_with_log( self.job_name, positions, stream_names, self.instance_count, state, last_description, last_describe_job_call, ) status = last_description["TrainingJobStatus"] if status in self.NON_TERMINAL_STATES: await asyncio.sleep(self.poke_interval) elif status in self.TERMINAL_STATE: reason = last_description.get("FailureReason", "(No reason provided)") error_message = f"SageMaker job failed because {reason}" yield TriggerEvent({"status": "error", "message": error_message}) else: billable_time = ( last_description["TrainingEndTime"] - last_description["TrainingStartTime"] ) * self.instance_count self.log.info("Billable seconds: %d", int(billable_time.total_seconds()) + 1) yield TriggerEvent({"status": "success", "message": last_description}) except Exception as e: yield TriggerEvent({"status": "error", "message": str(e)})
def _get_async_hook(self) -> SageMakerHookAsync: return SageMakerHookAsync(aws_conn_id=self.aws_conn_id)