import asyncio
import time
from typing import Any, AsyncIterator, Dict, List, Optional, Tuple
from airflow.providers.amazon.aws.hooks.sagemaker import LogState
from airflow.triggers.base import BaseTrigger, TriggerEvent
from astronomer.providers.amazon.aws.hooks.sagemaker import SageMakerHookAsync
[docs]class SagemakerProcessingTrigger(BaseTrigger):
"""
SagemakerProcessingTrigger is fired as deferred class with params to run the task in triggerer.
:param job_name: name of the job to check status
:param poll_interval: polling period in seconds to check for the status
:param aws_conn_id: AWS connection ID for sagemaker
:param end_time: the end time in seconds. Any
SageMaker jobs that run longer than this will fail.
"""
NON_TERMINAL_STATES = ("InProgress", "Stopping")
TERMINAL_STATE = ("Failed",)
def __init__(
self,
job_name: str,
poll_interval: float,
end_time: Optional[float],
aws_conn_id: str = "aws_default",
):
super().__init__()
self.job_name = job_name
self.poll_interval = poll_interval
self.aws_conn_id = aws_conn_id
self.end_time = end_time
[docs] def serialize(self) -> Tuple[str, Dict[str, Any]]:
"""Serializes SagemakerProcessingTrigger arguments and classpath."""
return (
"astronomer.providers.amazon.aws.triggers.sagemaker.SagemakerProcessingTrigger",
{
"job_name": self.job_name,
"poll_interval": self.poll_interval,
"aws_conn_id": self.aws_conn_id,
"end_time": self.end_time,
},
)
[docs] async def run(self) -> AsyncIterator["TriggerEvent"]:
"""
Makes async connection to sagemaker async hook and gets job status for a job submitted by the operator.
Trigger returns a failure event if any error and success in state return the success event.
"""
hook = self._get_async_hook()
while True:
try:
# check if time limit is set and timeout has happened or not
if self.end_time and time.time() > self.end_time:
yield TriggerEvent({"status": "error", "message": "Timeout"})
response = await hook.describe_processing_job_async(self.job_name)
status = response["ProcessingJobStatus"]
if status in self.TERMINAL_STATE:
error_message = f"SageMaker job failed because {response['FailureReason']}"
yield TriggerEvent({"status": "error", "message": error_message})
elif status in self.NON_TERMINAL_STATES:
self.log.info("Job still running current status is %s", status)
await asyncio.sleep(self.poll_interval)
else:
yield TriggerEvent({"status": "success", "message": response})
except Exception as e:
yield TriggerEvent({"status": "error", "message": str(e)})
def _get_async_hook(self) -> SageMakerHookAsync:
return SageMakerHookAsync(aws_conn_id=self.aws_conn_id)
[docs]class SagemakerTrigger(BaseTrigger):
"""
SagemakerTrigger is common trigger for both transform and training sagemaker job and it is
fired as deferred class with params to run the task in triggerer.
:param job_name: name of the job to check status
:param job_type: Type of the sagemaker job whether it is Transform or Training
:param response_key: The key which needs to be look in the response.
:param poke_interval: polling period in seconds to check for the status
:param end_time: Time in seconds to wait for a job run to reach a terminal status.
:param aws_conn_id: AWS connection ID for sagemaker
"""
NON_TERMINAL_STATES = ("InProgress", "Stopping", "Stopped")
TERMINAL_STATE = ("Failed",)
def __init__(
self,
job_name: str,
job_type: str,
response_key: str,
poke_interval: float,
end_time: Optional[float] = None,
aws_conn_id: str = "aws_default",
):
super().__init__()
self.job_name = job_name
self.job_type = job_type
self.response_key = response_key
self.poke_interval = poke_interval
self.end_time = end_time
self.aws_conn_id = aws_conn_id
[docs] def serialize(self) -> Tuple[str, Dict[str, Any]]:
"""Serializes SagemakerTrigger arguments and classpath."""
return (
"astronomer.providers.amazon.aws.triggers.sagemaker.SagemakerTrigger",
{
"poke_interval": self.poke_interval,
"aws_conn_id": self.aws_conn_id,
"end_time": self.end_time,
"job_name": self.job_name,
"job_type": self.job_type,
"response_key": self.response_key,
},
)
[docs] async def run(self) -> AsyncIterator["TriggerEvent"]:
"""
Makes async connection to sagemaker async hook and gets job status for a job submitted by the operator.
Trigger returns a failure event if any error and success in state return the success event.
"""
hook = self._get_async_hook()
while True:
try:
if self.end_time and time.time() > self.end_time:
yield TriggerEvent({"status": "error", "message": "Timeout"})
response = await self.get_job_status(hook, self.job_name, self.job_type)
status = response[self.response_key]
if status in self.NON_TERMINAL_STATES:
await asyncio.sleep(self.poke_interval)
elif status in self.TERMINAL_STATE:
error_message = f"SageMaker job failed because {response['FailureReason']}"
yield TriggerEvent({"status": "error", "message": error_message})
else:
yield TriggerEvent({"status": "success", "message": response})
except Exception as e:
yield TriggerEvent({"status": "error", "message": str(e)})
def _get_async_hook(self) -> SageMakerHookAsync:
return SageMakerHookAsync(aws_conn_id=self.aws_conn_id)
[docs] @staticmethod
async def get_job_status(hook: SageMakerHookAsync, job_name: str, job_type: str) -> Dict[str, Any]:
"""
Based on the job type the SageMakerHookAsync connect to sagemaker related function
and get the response of the job and return it
"""
if job_type == "Transform":
response = await hook.describe_transform_job_async(job_name)
elif job_type == "Training":
response = await hook.describe_training_job_async(job_name)
return response
[docs]class SagemakerTrainingWithLogTrigger(BaseTrigger):
"""
SagemakerTrainingWithLogTrigger is fired as deferred class with params to run the task in triggerer.
:param job_name: name of the job to check status
:param instance_count: count of the instance created for running the training job
:param status: The status of the training job created.
:param poke_interval: polling period in seconds to check for the status
:param end_time: Time in seconds to wait for a job run to reach a terminal status.
:param aws_conn_id: AWS connection ID for sagemaker
"""
NON_TERMINAL_STATES = ("InProgress", "Stopping", "Stopped")
TERMINAL_STATE = ("Failed",)
def __init__(
self,
job_name: str,
instance_count: int,
status: str,
poke_interval: float,
end_time: Optional[float] = None,
aws_conn_id: str = "aws_default",
):
super().__init__()
self.job_name = job_name
self.instance_count = instance_count
self.status = status
self.poke_interval = poke_interval
self.end_time = end_time
self.aws_conn_id = aws_conn_id
[docs] def serialize(self) -> Tuple[str, Dict[str, Any]]:
"""Serializes SagemakerTrainingWithLogTrigger arguments and classpath."""
return (
"astronomer.providers.amazon.aws.triggers.sagemaker.SagemakerTrainingWithLogTrigger",
{
"poke_interval": self.poke_interval,
"aws_conn_id": self.aws_conn_id,
"end_time": self.end_time,
"job_name": self.job_name,
"status": self.status,
"instance_count": self.instance_count,
},
)
[docs] async def run(self) -> AsyncIterator["TriggerEvent"]:
"""
Makes async connection to sagemaker async hook and gets job status for a job submitted by the operator.
Trigger returns a failure event if any error and success in state return the success event.
"""
hook = self._get_async_hook()
last_description = await hook.describe_training_job_async(self.job_name)
stream_names: List[str] = [] # The list of log streams
positions: Dict[str, Any] = {} # The current position in each stream, map of stream name -> position
job_already_completed = self.status not in self.NON_TERMINAL_STATES
state = LogState.TAILING if not job_already_completed else LogState.COMPLETE
last_describe_job_call = time.time()
while True:
try:
if self.end_time and time.time() > self.end_time:
yield TriggerEvent(
{
"status": "error",
"message": f"SageMaker job took more than {self.end_time} seconds",
}
)
state, last_description, last_describe_job_call = await hook.describe_training_job_with_log(
self.job_name,
positions,
stream_names,
self.instance_count,
state,
last_description,
last_describe_job_call,
)
status = last_description["TrainingJobStatus"]
if status in self.NON_TERMINAL_STATES:
await asyncio.sleep(self.poke_interval)
elif status in self.TERMINAL_STATE:
reason = last_description.get("FailureReason", "(No reason provided)")
error_message = f"SageMaker job failed because {reason}"
yield TriggerEvent({"status": "error", "message": error_message})
else:
billable_time = (
last_description["TrainingEndTime"] - last_description["TrainingStartTime"]
) * self.instance_count
self.log.info("Billable seconds: %d", int(billable_time.total_seconds()) + 1)
yield TriggerEvent({"status": "success", "message": last_description})
except Exception as e:
yield TriggerEvent({"status": "error", "message": str(e)})
def _get_async_hook(self) -> SageMakerHookAsync:
return SageMakerHookAsync(aws_conn_id=self.aws_conn_id)