Source code for astronomer.providers.dbt.cloud.hooks.dbt
import warnings
from functools import wraps
from inspect import signature
from typing import Any, Dict, List, Optional, Tuple, TypeVar, cast
import aiohttp
from aiohttp import ClientResponseError
from airflow import AirflowException
from airflow.hooks.base import BaseHook
from airflow.models import Connection
from asgiref.sync import sync_to_async
from astronomer.providers.package import get_provider_info
T = TypeVar("T", bound=Any)
[docs]
def provide_account_id(func: T) -> T:
"""
Decorator which provides a fallback value for ``account_id``. If the ``account_id`` is None or not passed
to the decorated function, the value will be taken from the configured dbt Cloud Airflow Connection.
"""
function_signature = signature(func)
@wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
bound_args = function_signature.bind(*args, **kwargs)
if bound_args.arguments.get("account_id") is None:
self = args[0]
if self.dbt_cloud_conn_id:
connection = await sync_to_async(self.get_connection)(self.dbt_cloud_conn_id)
default_account_id = connection.login
if not default_account_id:
raise AirflowException("Could not determine the dbt Cloud account.")
bound_args.arguments["account_id"] = int(default_account_id)
return await func(*bound_args.args, **bound_args.kwargs)
return cast(T, wrapper)
[docs]
class DbtCloudHookAsync(BaseHook):
"""
This class is deprecated and will be removed in 2.0.0.
Use :class: `~airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook` instead.
"""
conn_name_attr = "dbt_cloud_conn_id"
default_conn_name = "dbt_cloud_default"
conn_type = "dbt_cloud"
hook_name = "dbt Cloud"
def __init__(self, dbt_cloud_conn_id: str):
warnings.warn(
(
"This class is deprecated. "
"Use `airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook` instead."
),
DeprecationWarning,
stacklevel=2,
)
self.dbt_cloud_conn_id = dbt_cloud_conn_id
[docs]
@staticmethod
def get_request_url_params(
tenant: str, endpoint: str, include_related: Optional[List[str]] = None
) -> Tuple[str, Dict[str, Any]]:
"""
Form URL from base url and endpoint url
:param tenant: The tenant domain name which is need to be replaced in base url.
:param endpoint: Endpoint url to be requested.
:param include_related: Optional. List of related fields to pull with the run.
Valid values are "trigger", "job", "repository", and "environment".
"""
data: Dict[str, Any] = {}
base_url = f"https://{tenant}/api/v2/accounts/"
if include_related:
data = {"include_related": include_related}
url = base_url + (endpoint or "")
return url, data
[docs]
@provide_account_id
async def get_job_details(
self, run_id: int, account_id: Optional[int] = None, include_related: Optional[List[str]] = None
) -> Any:
"""
Uses Http async call to retrieve metadata for a specific run of a dbt Cloud job.
:param run_id: The ID of a dbt Cloud job run.
:param account_id: Optional. The ID of a dbt Cloud account.
:param include_related: Optional. List of related fields to pull with the run.
Valid values are "trigger", "job", "repository", and "environment".
"""
endpoint = f"{account_id}/runs/{run_id}/"
headers, tenant = await self.get_headers_tenants_from_connection()
url, params = self.get_request_url_params(tenant, endpoint, include_related)
async with aiohttp.ClientSession(headers=headers) as session:
async with session.get(url, params=params) as response:
try:
response.raise_for_status()
return await response.json()
except ClientResponseError as e:
raise AirflowException(str(e.status) + ":" + e.message)
[docs]
async def get_job_status(
self, run_id: int, account_id: Optional[int] = None, include_related: Optional[List[str]] = None
) -> int:
"""
Retrieves the status for a specific run of a dbt Cloud job.
:param run_id: The ID of a dbt Cloud job run.
:param account_id: Optional. The ID of a dbt Cloud account.
:param include_related: Optional. List of related fields to pull with the run.
Valid values are "trigger", "job", "repository", and "environment".
"""
try:
self.log.info("Getting the status of job run %s.", str(run_id))
response = await self.get_job_details(
run_id, account_id=account_id, include_related=include_related
)
job_run_status: int = response["data"]["status"]
return job_run_status
except Exception as e:
raise e