Source code for astronomer.providers.amazon.aws.hooks.base_aws

"""This module contains async AWS Base Hook for deferrable operators and sensors."""

from __future__ import annotations

from typing import Any

import boto3.session
from aiobotocore.client import AioBaseClient
from aiobotocore.session import AioSession, get_session
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper
from asgiref.sync import sync_to_async


[docs] class AwsBaseHookAsync(AwsBaseHook): """ Interacts with AWS using aiobotocore asynchronously. .. note:: AwsBaseHookAsync uses aiobotocore to create asynchronous hooks. Hence, AwsBaseHookAsync only supports the authentication mechanism that aiobotocore supports. Currently, AwsBaseHookAsync supports only AWS STS client method ``assume_role`` provided in the Airflow connection extra args via aiobotocore. :param aws_conn_id: The Airflow connection used for AWS credentials. If this is None or empty then the default boto3 behaviour is used. If running Airflow in a distributed manner and aws_conn_id is None or empty, then default boto3 configuration would be used (and must be maintained on each worker node). :param verify: Whether or not to verify SSL certificates. :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. :param client_type: boto3.client client_type. Eg 's3', 'emr' etc :param resource_type: boto3.resource resource_type. Eg 'dynamodb' etc :param config: Configuration for botocore client. .. seealso:: `AWS API <https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html>`_ """
[docs] async def get_client_async(self) -> AioBaseClient: """Create an Async Client object to communicate with AWS services.""" # Fetch the Airflow connection object connection_object = await sync_to_async(self.get_connection)(self.aws_conn_id) # type: ignore[arg-type] conn_config = AwsConnectionWrapper( conn=connection_object, region_name=self.region_name, botocore_config=self.config, verify=self.verify, ) async_session = get_session() if conn_config.role_arn: async_session = self._update_session_with_assume_role(async_session, conn_config) return async_session.create_client( service_name=self.client_type, verify=self.verify, endpoint_url=self.conn_config.endpoint_url, config=self.config, ) session_token = conn_config.aws_session_token aws_secret = conn_config.aws_secret_access_key aws_access = conn_config.aws_access_key_id return async_session.create_client( service_name=self.client_type, region_name=conn_config.region_name, aws_secret_access_key=aws_secret, aws_access_key_id=aws_access, aws_session_token=session_token, verify=self.verify, config=self.config, endpoint_url=conn_config.endpoint_url, )
@staticmethod def _create_basic_session(session_kwargs: dict[str, Any]) -> boto3.session.Session: """Create a basic boto3 session.""" return boto3.session.Session(**session_kwargs) @staticmethod def _assume_role(sts_client: boto3.client, conn_config: AwsConnectionWrapper) -> Any: """Assume the role using the STS client.""" kw = { "RoleSessionName": "RoleSession", "RoleArn": conn_config.role_arn, **conn_config.assume_role_kwargs, } return sts_client.assume_role(**kw) def _refresh_credentials(self) -> dict[str, str]: """Refresh the credentials using the STS client.""" conn_config = AwsConnectionWrapper( conn=self.get_connection(self.aws_conn_id), # type: ignore[arg-type] region_name=self.region_name, botocore_config=self.config, verify=self.verify, ) sts_client = self._create_basic_session(conn_config.session_kwargs).client("sts") sts_response = self._assume_role(sts_client=sts_client, conn_config=conn_config) sts_response_http_status = sts_response["ResponseMetadata"]["HTTPStatusCode"] if sts_response_http_status != 200: raise RuntimeError(f"sts_response_http_status={sts_response_http_status}") creds = sts_response["Credentials"] expiry_time = creds["Expiration"].isoformat() credentials: dict[str, str] = { "access_key": creds["AccessKeyId"], "secret_key": creds["SecretAccessKey"], "token": creds["SessionToken"], "expiry_time": expiry_time, } return credentials def _update_session_with_assume_role( self, async_session: AioSession, conn_config: AwsConnectionWrapper ) -> AioSession: """Update the session with the assume role credentials.""" # Refreshable credentials do have initial credentials params = { "metadata": self._refresh_credentials(), "refresh_using": self._refresh_credentials, "method": "sts-assume-role", } from aiobotocore.credentials import AioRefreshableCredentials credentials = AioRefreshableCredentials.create_from_metadata(**params) async_session._credentials = credentials async_session.set_config_variable("region", conn_config.region_name) return async_session
[docs] @staticmethod async def get_role_credentials( async_session: AioSession, conn_config: AwsConnectionWrapper ) -> dict[str, str] | None: """Get the role_arn, method credentials from connection details and get the role credentials detail.""" async with async_session.create_client( "sts", aws_access_key_id=conn_config.aws_access_key_id, aws_secret_access_key=conn_config.aws_secret_access_key, ) as client: return_response = None if conn_config.assume_role_method == "assume_role" or conn_config.assume_role_method is None: response: dict[str, dict[str, str]] = await client.assume_role( RoleArn=conn_config.role_arn, RoleSessionName="RoleSession", **conn_config.assume_role_kwargs, ) return_response = response["Credentials"] return return_response