Source code for astronomer.providers.microsoft.azure.hooks.data_factory

"""This module contains the Azure Data Factory hook's asynchronous implementation."""
from __future__ import annotations

import inspect
from functools import wraps
from typing import Any, TypeVar, Union, cast

from airflow.exceptions import AirflowException
from airflow.providers.microsoft.azure.hooks.data_factory import AzureDataFactoryHook
from asgiref.sync import sync_to_async
from azure.identity.aio import ClientSecretCredential, DefaultAzureCredential
from azure.mgmt.datafactory.aio import DataFactoryManagementClient
from azure.mgmt.datafactory.models import PipelineRun

Credentials = Union[ClientSecretCredential, DefaultAzureCredential]

T = TypeVar("T", bound=Any)


[docs]def get_field(extras: dict[str, Any], field_name: str, strict: bool = False) -> Any: """Get field from extra, first checking short name, then for backward compatibility we check for prefixed name.""" backward_compatibility_prefix = "extra__azure_data_factory__" if field_name.startswith("extra__"): raise ValueError( f"Got prefixed name {field_name}; please remove the '{backward_compatibility_prefix}' prefix " "when using this method." ) if field_name in extras: return extras[field_name] or None prefixed_name = f"{backward_compatibility_prefix}{field_name}" if prefixed_name in extras: return extras[prefixed_name] or None if strict: raise KeyError(f"Field {field_name} not found in extras")
[docs]def provide_targeted_factory_async(func: T) -> T: """ Provide the targeted factory to the async decorated function in case it isn't specified. If ``resource_group_name`` or ``factory_name`` is not provided it defaults to the value specified in the connection extras. """ signature = inspect.signature(func) @wraps(func) async def wrapper(*args: Any, **kwargs: Any) -> Any: bound_args = signature.bind(*args, **kwargs) async def bind_argument(arg: Any, default_key: str) -> None: # Check if arg was not included in the function signature or, if it is, the value is not provided. if arg not in bound_args.arguments or bound_args.arguments[arg] is None: self = args[0] conn = await sync_to_async(self.get_connection)(self.conn_id) default_value = conn.extra_dejson.get(default_key) if not default_value: raise AirflowException("Could not determine the targeted data factory.") bound_args.arguments[arg] = conn.extra_dejson[default_key] await bind_argument("resource_group_name", "extra__azure_data_factory__resource_group_name") await bind_argument("factory_name", "extra__azure_data_factory__factory_name") return await func(*bound_args.args, **bound_args.kwargs) return cast(T, wrapper)
[docs]class AzureDataFactoryHookAsync(AzureDataFactoryHook): """ An Async Hook connects to Azure DataFactory to perform pipeline operations. :param azure_data_factory_conn_id: The :ref:`Azure Data Factory connection id<howto/connection:adf>`. """ def __init__(self, azure_data_factory_conn_id: str): """Initialize the hook instance.""" self._async_conn: DataFactoryManagementClient = None self.conn_id = azure_data_factory_conn_id super().__init__(azure_data_factory_conn_id=azure_data_factory_conn_id)
[docs] async def get_async_conn(self) -> DataFactoryManagementClient: """Get async connection and connect to azure data factory.""" if self._conn is not None: return self._conn conn = await sync_to_async(self.get_connection)(self.conn_id) extras = conn.extra_dejson tenant = get_field(extras, "tenantId") try: subscription_id = get_field(extras, "subscriptionId", strict=True) except KeyError: raise ValueError("A Subscription ID is required to connect to Azure Data Factory.") credential: Credentials if conn.login is not None and conn.password is not None: if not tenant: raise ValueError("A Tenant ID is required when authenticating with Client ID and Secret.") credential = ClientSecretCredential( client_id=conn.login, client_secret=conn.password, tenant_id=tenant ) else: credential = DefaultAzureCredential() return DataFactoryManagementClient( credential=credential, subscription_id=subscription_id, )
[docs] @provide_targeted_factory_async async def get_pipeline_run( self, run_id: str, resource_group_name: str | None = None, factory_name: str | None = None, **config: Any, ) -> PipelineRun: """ Connect to Azure Data Factory asynchronously to get the pipeline run details by run id. :param run_id: The pipeline run identifier. :param resource_group_name: The resource group name. :param factory_name: The factory name. """ async with await self.get_async_conn() as client: try: pipeline_run = await client.pipeline_runs.get(resource_group_name, factory_name, run_id) return pipeline_run except Exception as e: raise AirflowException(e)
[docs] async def get_adf_pipeline_run_status( self, run_id: str, resource_group_name: str | None = None, factory_name: str | None = None ) -> str: """ Connect to Azure Data Factory asynchronously and get the pipeline status by run_id. :param run_id: The pipeline run identifier. :param resource_group_name: The resource group name. :param factory_name: The factory name. """ try: pipeline_run = await self.get_pipeline_run( run_id=run_id, factory_name=factory_name, resource_group_name=resource_group_name, ) status: str = pipeline_run.status return status except Exception as e: raise AirflowException(e)