from __future__ import annotations
import time
import warnings
from typing import Any, AsyncGenerator
from airflow.providers.amazon.aws.hooks.sagemaker import (
LogState,
Position,
argmin,
secondary_training_status_changed,
secondary_training_status_message,
)
from asgiref.sync import sync_to_async
from astronomer.providers.amazon.aws.hooks.aws_logs import AwsLogsHookAsync
from astronomer.providers.amazon.aws.hooks.base_aws import AwsBaseHookAsync
from astronomer.providers.amazon.aws.hooks.s3 import S3HookAsync
[docs]
class SageMakerHookAsync(AwsBaseHookAsync):
"""
Interact with Amazon SageMaker async using aiobotocore python library.
Additional arguments (such as ``aws_conn_id``) may be specified and
are passed down to the underlying AwsBaseHookAsync.
This class is deprecated and will be removed in 2.0.0.
Use :class: `~airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook` instead
"""
NON_TERMINAL_STATES = ("InProgress", "Stopping", "Stopped")
def __init__(self, *args: Any, **kwargs: Any):
warnings.warn(
(
"This module is deprecated and will be removed in 2.0.0."
"Please use `airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook`"
),
DeprecationWarning,
stacklevel=2,
)
kwargs["client_type"] = "sagemaker"
super().__init__(*args, **kwargs)
self.s3_hook = S3HookAsync(aws_conn_id=self.aws_conn_id)
self.logs_hook_async = AwsLogsHookAsync(aws_conn_id=self.aws_conn_id)
[docs]
async def describe_processing_job_async(self, job_name: str) -> dict[str, Any]:
"""
Return the processing job info associated with the name
:param job_name: the name of the processing job
"""
async with await self.get_client_async() as client:
response: dict[str, Any] = await client.describe_processing_job(ProcessingJobName=job_name)
return response
[docs]
async def describe_training_job_async(self, job_name: str) -> dict[str, Any]:
"""
Return the training job info associated with the name
:param job_name: the name of the training job
"""
async with await self.get_client_async() as client:
response: dict[str, Any] = await client.describe_training_job(TrainingJobName=job_name)
return response
[docs]
async def describe_training_job_with_log(
self,
job_name: str,
positions: dict[str, Any],
stream_names: list[str],
instance_count: int,
state: int,
last_description: dict[str, Any],
last_describe_job_call: float,
) -> tuple[int, dict[str, Any], float]:
"""
Return the training job info associated with job_name and print CloudWatch logs
:param job_name: name of the job to check status
:param positions: A list of pairs of (timestamp, skip) which represents the last record
read from each stream.
:param stream_names: A list of the log stream names. The position of the stream in this list is
the stream number.
:param instance_count: Count of the instance created for the job initially
:param state: log state
:param last_description: Latest description of the training job
:param last_describe_job_call: previous job called time
"""
log_group = "/aws/sagemaker/TrainingJobs"
if len(stream_names) < instance_count:
streams = await self.logs_hook_async.describe_log_streams_async(
log_group=log_group,
stream_prefix=job_name + "/",
order_by="LogStreamName",
count=instance_count,
)
stream_names = [s["logStreamName"] for s in streams["logStreams"]] if streams else []
positions.update([(s, Position(timestamp=0, skip=0)) for s in stream_names if s not in positions])
if len(stream_names) > 0:
async for idx, event in self.get_multi_stream(log_group, stream_names, positions):
self.log.info(event["message"])
ts, count = positions[stream_names[idx]]
if event["timestamp"] == ts:
positions[stream_names[idx]] = Position(timestamp=ts, skip=count + 1) # pragma: no cover
else:
positions[stream_names[idx]] = Position(timestamp=event["timestamp"], skip=1)
if state == LogState.COMPLETE:
return state, last_description, last_describe_job_call
if state == LogState.JOB_COMPLETE:
state = LogState.COMPLETE
elif time.time() - last_describe_job_call >= 30:
description = await self.describe_training_job_async(job_name)
last_describe_job_call = time.time()
if await sync_to_async(secondary_training_status_changed)(description, last_description):
self.log.info(
await sync_to_async(secondary_training_status_message)(description, last_description)
) # pragma: no cover
last_description = description # pragma: no cover
status = description["TrainingJobStatus"]
if status not in self.NON_TERMINAL_STATES:
state = LogState.JOB_COMPLETE
return state, last_description, last_describe_job_call
[docs]
async def get_multi_stream(
self, log_group: str, streams: list[str], positions: dict[str, Any]
) -> AsyncGenerator[Any, tuple[int, Any | None]]:
"""
Iterate over the available events coming from a set of log streams in a single log group
interleaving the events from each stream so they're yielded in timestamp order.
:param log_group: The name of the log group.
:param streams: A list of the log stream names. The position of the stream in this list is
the stream number.
:param positions: A list of pairs of (timestamp, skip) which represents the last record
read from each stream.
"""
positions = positions or {s: Position(timestamp=0, skip=0) for s in streams}
events: list[Any | None] = []
event_iters = [
self.logs_hook_async.get_log_events_async(log_group, s, positions[s].timestamp, positions[s].skip)
for s in streams
]
for event_stream in event_iters:
if not event_stream:
events.append(None) # pragma: no cover
continue # pragma: no cover
try:
events.append(await event_stream.__anext__())
except StopAsyncIteration: # pragma: no cover
events.append(None) # pragma: no cover
while any(events):
i = argmin(events, lambda x: x["timestamp"] if x else 9999999999) or 0
yield i, events[i]
try:
events[i] = await event_iters[i].__anext__()
except StopAsyncIteration:
events[i] = None