Source code for astronomer.providers.microsoft.azure.hooks.data_factory
import inspect
from functools import wraps
from typing import Any, Optional, TypeVar, Union, cast
from airflow.exceptions import AirflowException
from airflow.providers.microsoft.azure.hooks.data_factory import AzureDataFactoryHook
from asgiref.sync import sync_to_async
from azure.identity.aio import ClientSecretCredential, DefaultAzureCredential
from azure.mgmt.datafactory.aio import DataFactoryManagementClient
from azure.mgmt.datafactory.models import PipelineRun
Credentials = Union[ClientSecretCredential, DefaultAzureCredential]
T = TypeVar("T", bound=Any)
[docs]def provide_targeted_factory_async(func: T) -> T:
"""
Provide the targeted factory to the async decorated function in case it isn't specified.
If ``resource_group_name`` or ``factory_name`` is not provided it defaults to the value specified in
the connection extras.
"""
signature = inspect.signature(func)
@wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
bound_args = signature.bind(*args, **kwargs)
async def bind_argument(arg: Any, default_key: str) -> None:
# Check if arg was not included in the function signature or, if it is, the value is not provided.
if arg not in bound_args.arguments or bound_args.arguments[arg] is None:
self = args[0]
conn = await sync_to_async(self.get_connection)(self.conn_id)
default_value = conn.extra_dejson.get(default_key)
if not default_value:
raise AirflowException("Could not determine the targeted data factory.")
bound_args.arguments[arg] = conn.extra_dejson[default_key]
await bind_argument("resource_group_name", "extra__azure_data_factory__resource_group_name")
await bind_argument("factory_name", "extra__azure_data_factory__factory_name")
return await func(*bound_args.args, **bound_args.kwargs)
return cast(T, wrapper)
[docs]class AzureDataFactoryHookAsync(AzureDataFactoryHook):
"""
An Async Hook connects to Azure DataFactory to perform pipeline operations
:param azure_data_factory_conn_id: The :ref:`Azure Data Factory connection id<howto/connection:adf>`.
"""
def __init__(self, azure_data_factory_conn_id: str):
self._async_conn: DataFactoryManagementClient = None
self.conn_id = azure_data_factory_conn_id
super().__init__(azure_data_factory_conn_id=azure_data_factory_conn_id)
[docs] async def get_async_conn(self) -> DataFactoryManagementClient:
"""Get async connection and connect to azure data factory"""
if self._conn is not None:
return self._conn
conn = await sync_to_async(self.get_connection)(self.conn_id)
tenant = conn.extra_dejson.get("extra__azure_data_factory__tenantId")
try:
subscription_id = conn.extra_dejson["extra__azure_data_factory__subscriptionId"]
except KeyError:
raise ValueError("A Subscription ID is required to connect to Azure Data Factory.")
credential: Credentials
if conn.login is not None and conn.password is not None:
if not tenant:
raise ValueError("A Tenant ID is required when authenticating with Client ID and Secret.")
credential = ClientSecretCredential(
client_id=conn.login, client_secret=conn.password, tenant_id=tenant
)
else:
credential = DefaultAzureCredential()
return DataFactoryManagementClient(
credential=credential,
subscription_id=subscription_id,
)
[docs] @provide_targeted_factory_async
async def get_pipeline_run(
self,
run_id: str,
resource_group_name: Optional[str] = None,
factory_name: Optional[str] = None,
**config: Any,
) -> PipelineRun:
"""
Connects to Azure Data Factory asynchronously to get the pipeline run details by run id
:param run_id: The pipeline run identifier.
:param resource_group_name: The resource group name.
:param factory_name: The factory name.
"""
async with await self.get_async_conn() as client:
try:
pipeline_run = await client.pipeline_runs.get(resource_group_name, factory_name, run_id)
return pipeline_run
except Exception as e:
raise AirflowException(e)
[docs] async def get_adf_pipeline_run_status(
self, run_id: str, resource_group_name: Optional[str] = None, factory_name: Optional[str] = None
) -> str:
"""
Connects to Azure Data Factory asynchronously and gets the pipeline status by run_id
:param run_id: The pipeline run identifier.
:param resource_group_name: The resource group name.
:param factory_name: The factory name.
"""
try:
pipeline_run = await self.get_pipeline_run(
run_id=run_id,
factory_name=factory_name,
resource_group_name=resource_group_name,
)
status: str = pipeline_run.status
return status
except Exception as e:
raise AirflowException(e)