Source code for astronomer.providers.databricks.hooks.databricks
from __future__ import annotations
import asyncio
import base64
import warnings
from typing import Any, Dict, cast
import aiohttp
from aiohttp import ClientConnectorError, ClientResponseError
from airflow import __version__
from airflow.exceptions import AirflowException
from airflow.providers.databricks.hooks.databricks import (
GET_RUN_ENDPOINT,
OUTPUT_RUNS_JOB_ENDPOINT,
DatabricksHook,
RunState,
)
from asgiref.sync import sync_to_async
USER_AGENT_HEADER = {"user-agent": f"airflow-{__version__}"}
[docs]
class DatabricksHookAsync(DatabricksHook):
"""
This class is deprecated and will be removed in 2.0.0.
Use :class: `~airflow.providers.databricks.hooks.databricks.DatabricksHook` instead.
"""
def __init__(self, *args: Any, **kwargs: Any):
warnings.warn(
"This class is deprecated and will be removed in 2.0.0. "
"Use `airflow.providers.databricks.hooks.databricks.DatabricksHook` instead "
)
super().__init__(*args, **kwargs)
[docs]
async def get_run_state_async(self, run_id: str) -> RunState:
"""
Retrieves run state of the run using an asynchronous api call.
:param run_id: id of the run
:return: state of the run
"""
response = await self.get_run_response(run_id)
state = response["state"]
life_cycle_state = state["life_cycle_state"]
# result_state may not be in the state if not terminal
result_state = state.get("result_state", None)
state_message = state["state_message"]
self.log.info("Getting run state. ")
return RunState(life_cycle_state, result_state, state_message)
[docs]
async def get_run_response(self, run_id: str) -> dict[str, Any]:
"""
Makes Async API call to get the run state info.
:param run_id: id of the run
"""
json = {"run_id": run_id}
response = await self._do_api_call_async(GET_RUN_ENDPOINT, json)
return response
[docs]
async def get_run_output_response(self, task_run_id: str) -> dict[str, Any]:
"""
Retrieves run output of the run.
:param task_run_id: id of the run
"""
json = {"run_id": task_run_id}
run_output = await self._do_api_call_async(OUTPUT_RUNS_JOB_ENDPOINT, json)
return run_output
async def _do_api_call_async(
self, endpoint_info: tuple[str, str], json: dict[str, Any]
) -> dict[str, Any]:
"""
Utility function to perform an asynchronous API call with retries
:param endpoint_info: Tuple of method and endpoint
:type endpoint_info: tuple[string, string]
:param json: Parameters for this API call.
:type json: dict
:return: If the api call returns a OK status code,
this function returns the response in JSON. Otherwise,
we throw an AirflowException.
:rtype: dict
"""
method, endpoint = endpoint_info
headers = USER_AGENT_HEADER
attempt_num = 1
if not self.databricks_conn:
self.databricks_conn = await sync_to_async(self.get_connection)(self.databricks_conn_id)
if "token" in self.databricks_conn.extra_dejson:
self.log.info("Using token auth. ")
auth = self.databricks_conn.extra_dejson["token"]
# aiohttp assumes basic auth for its 'auth' parameter, so we need to
# set this manually in the header for both bearer token and basic auth.
headers["Authorization"] = f"Bearer {auth}"
if "host" in self.databricks_conn.extra_dejson:
host = self._parse_host(self.databricks_conn.extra_dejson["host"])
else:
host = self.databricks_conn.host
else:
self.log.info("Using basic auth. ")
auth_str = f"{self.databricks_conn.login}:{self.databricks_conn.password}"
encoded_bytes = auth_str.encode("utf-8")
auth = base64.b64encode(encoded_bytes).decode("utf-8")
headers["Authorization"] = f"Basic {auth}"
host = self.databricks_conn.host
self.log.info("Auth: %s; Host: %s", auth, host)
url = f"https://{self._parse_host(host)}/{endpoint}"
async with aiohttp.ClientSession() as session:
if method == "GET":
request_func = session.get
elif method == "POST":
request_func = session.post
elif method == "PATCH":
request_func = session.patch
else:
raise AirflowException("Unexpected HTTP Method: " + method)
while True:
try:
response = await request_func(
url,
json=json if method in ("POST", "PATCH") else None,
params=json if method == "GET" else None,
headers=headers,
timeout=self.timeout_seconds, # type: ignore[arg-type]
)
response.raise_for_status()
return cast(Dict[str, Any], await response.json())
except (ClientConnectorError, ClientResponseError) as e:
if not self._retryable_error_async(e):
# In this case, the user probably made a mistake.
# Don't retry rather raise exception
raise AirflowException(str(e))
self._log_request_error(attempt_num, str(e))
if attempt_num == self.retry_limit:
raise AirflowException(
("API requests to Databricks failed {} times. " + "Giving up.").format(
self.retry_limit
)
)
attempt_num += 1
await asyncio.sleep(self.retry_delay)
@staticmethod
def _retryable_error_async(exception: ClientConnectorError | ClientResponseError) -> bool:
"""
Determines whether or not an exception that was thrown might be successful
on a subsequent attempt.
Base Databricks operator considers the following to be retryable:
- requests_exceptions.ConnectionError
- requests_exceptions.Timeout
- anything with a status code >= 500
- status code == 403
Most retryable errors are covered by status code >= 500.
:return: if the status is retryable
:rtype: bool
"""
if isinstance(exception, ClientResponseError):
status_code = exception.status
# according to user feedback, 403 sometimes works after retry
return status_code >= 500 or status_code == 403
return True