Source code for astronomer.providers.microsoft.azure.hooks.data_factory
from typing import Any, Optional, Union
from airflow.exceptions import AirflowException
from airflow.providers.microsoft.azure.hooks.data_factory import AzureDataFactoryHook
from asgiref.sync import sync_to_async
from azure.identity.aio import ClientSecretCredential, DefaultAzureCredential
from azure.mgmt.datafactory.aio import DataFactoryManagementClient
from azure.mgmt.datafactory.models import PipelineRun
Credentials = Union[ClientSecretCredential, DefaultAzureCredential]
[docs]class AzureDataFactoryHookAsync(AzureDataFactoryHook):
"""
An Async Hook connects to Azure DataFactory to perform pipeline operations
:param azure_data_factory_conn_id: The :ref:`Azure Data Factory connection id<howto/connection:adf>`.
"""
def __init__(self, azure_data_factory_conn_id: str):
self._async_conn: DataFactoryManagementClient = None
self.conn_id = azure_data_factory_conn_id
super().__init__(azure_data_factory_conn_id=azure_data_factory_conn_id)
[docs] async def get_async_conn(self) -> DataFactoryManagementClient:
"""Get async connection and connect to azure data factory"""
if self._conn is not None:
return self._conn
conn = await sync_to_async(self.get_connection)(self.conn_id)
tenant = conn.extra_dejson.get("extra__azure_data_factory__tenantId")
try:
subscription_id = conn.extra_dejson["extra__azure_data_factory__subscriptionId"]
except KeyError:
raise ValueError("A Subscription ID is required to connect to Azure Data Factory.")
credential: Credentials
if conn.login is not None and conn.password is not None:
if not tenant:
raise ValueError("A Tenant ID is required when authenticating with Client ID and Secret.")
credential = ClientSecretCredential(
client_id=conn.login, client_secret=conn.password, tenant_id=tenant
)
else:
credential = DefaultAzureCredential()
return DataFactoryManagementClient(
credential=credential,
subscription_id=subscription_id,
)
[docs] async def get_pipeline_run(
self,
run_id: str,
resource_group_name: Optional[str] = None,
factory_name: Optional[str] = None,
**config: Any,
) -> PipelineRun:
"""
Connects to Azure Data Factory asynchronously to get the pipeline run details by run id
:param run_id: The pipeline run identifier.
:param resource_group_name: The resource group name.
:param factory_name: The factory name.
"""
async with await self.get_async_conn() as client:
try:
pipeline_run = await client.pipeline_runs.get(resource_group_name, factory_name, run_id)
return pipeline_run
except Exception as e:
raise AirflowException(e)
[docs] async def get_adf_pipeline_run_status(
self, run_id: str, resource_group_name: Optional[str] = None, factory_name: Optional[str] = None
) -> str:
"""
Connects to Azure Data Factory asynchronously and gets the pipeline status by run_id
:param run_id: The pipeline run identifier.
:param resource_group_name: The resource group name.
:param factory_name: The factory name.
"""
try:
pipeline_run = await self.get_pipeline_run(
run_id=run_id,
factory_name=factory_name,
resource_group_name=resource_group_name,
)
status: str = pipeline_run.status
return status
except Exception as e:
raise AirflowException(e)