Source code for astronomer.providers.databricks.hooks.databricks

import asyncio
import base64
from typing import Any, Dict, Tuple, cast

import aiohttp
from aiohttp import ClientResponseError
from airflow import __version__
from airflow.exceptions import AirflowException
from airflow.providers.databricks.hooks.databricks import (
    GET_RUN_ENDPOINT,
    OUTPUT_RUNS_JOB_ENDPOINT,
    DatabricksHook,
    RunState,
)
from asgiref.sync import sync_to_async

USER_AGENT_HEADER = {"user-agent": f"airflow-{__version__}"}


[docs]class DatabricksHookAsync(DatabricksHook): """ Interact with Databricks. :param databricks_conn_id: Reference to the Databricks connection. :type databricks_conn_id: str :param timeout_seconds: The amount of time in seconds the requests library will wait before timing-out. :type timeout_seconds: int :param retry_limit: The number of times to retry the connection in case of service outages. :type retry_limit: int :param retry_delay: The number of seconds to wait between retries (it might be a floating point number). :type retry_delay: float """
[docs] async def get_run_state_async(self, run_id: str) -> RunState: """ Retrieves run state of the run using an asynchronous api call. :param run_id: id of the run :return: state of the run """ response = await self.get_run_response(run_id) state = response["state"] life_cycle_state = state["life_cycle_state"] # result_state may not be in the state if not terminal result_state = state.get("result_state", None) state_message = state["state_message"] self.log.info("Getting run state. ") return RunState(life_cycle_state, result_state, state_message)
[docs] async def get_run_response(self, run_id: str) -> Dict[str, Any]: """ Makes Async API call to get the run state info. :param run_id: id of the run """ json = {"run_id": run_id} response = await self._do_api_call_async(GET_RUN_ENDPOINT, json) return response
[docs] async def get_run_output_response(self, task_run_id: str) -> Dict[str, Any]: """ Retrieves run output of the run. :param task_run_id: id of the run """ json = {"run_id": task_run_id} run_output = await self._do_api_call_async(OUTPUT_RUNS_JOB_ENDPOINT, json) return run_output
async def _do_api_call_async( self, endpoint_info: Tuple[str, str], json: Dict[str, Any] ) -> Dict[str, Any]: """ Utility function to perform an asynchronous API call with retries :param endpoint_info: Tuple of method and endpoint :type endpoint_info: tuple[string, string] :param json: Parameters for this API call. :type json: dict :return: If the api call returns a OK status code, this function returns the response in JSON. Otherwise, we throw an AirflowException. :rtype: dict """ method, endpoint = endpoint_info headers = USER_AGENT_HEADER attempt_num = 1 if not self.databricks_conn: self.databricks_conn = await sync_to_async(self.get_connection)(self.databricks_conn_id) if "token" in self.databricks_conn.extra_dejson: self.log.info("Using token auth. ") auth = self.databricks_conn.extra_dejson["token"] # aiohttp assumes basic auth for its 'auth' parameter, so we need to # set this manually in the header for both bearer token and basic auth. headers["Authorization"] = f"Bearer {auth}" if "host" in self.databricks_conn.extra_dejson: host = self._parse_host(self.databricks_conn.extra_dejson["host"]) else: host = self.databricks_conn.host else: self.log.info("Using basic auth. ") auth_str = f"{self.databricks_conn.login}:{self.databricks_conn.password}" encoded_bytes = auth_str.encode("utf-8") auth = base64.b64encode(encoded_bytes).decode("utf-8") headers["Authorization"] = f"Basic {auth}" host = self.databricks_conn.host self.log.info("Auth: %s; Host: %s", auth, host) url = f"https://{self._parse_host(host)}/{endpoint}" async with aiohttp.ClientSession() as session: if method == "GET": request_func = session.get elif method == "POST": request_func = session.post elif method == "PATCH": request_func = session.patch else: raise AirflowException("Unexpected HTTP Method: " + method) while True: try: response = await request_func( url, json=json if method in ("POST", "PATCH") else None, params=json if method == "GET" else None, headers=headers, timeout=self.timeout_seconds, ) response.raise_for_status() return cast(Dict[str, Any], await response.json()) except ClientResponseError as e: if not self._retryable_error_async(e): # In this case, the user probably made a mistake. # Don't retry rather raise exception raise AirflowException(str(e)) self._log_request_error(attempt_num, str(e)) if attempt_num == self.retry_limit: raise AirflowException( ("API requests to Databricks failed {} times. " + "Giving up.").format( self.retry_limit ) ) attempt_num += 1 await asyncio.sleep(self.retry_delay) def _retryable_error_async(self, exception: ClientResponseError) -> bool: """ Determines whether or not an exception that was thrown might be successful on a subsequent attempt. Base Databricks operator considers the following to be retryable: - requests_exceptions.ConnectionError - requests_exceptions.Timeout - anything with a status code >= 500 Most retryable errors are covered by status code >= 500. :return: if the status is retryable :rtype: bool """ return exception.status >= 500