from __future__ import annotations
import os
from typing import Any
from urllib.parse import quote
import requests
from aiohttp import ClientSession
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
[docs]
class AstroHook(BaseHook):
"""
Custom Apache Airflow Hook for interacting with Astro Cloud API.
:param astro_cloud_conn_id: The connection ID to retrieve Astro Cloud credentials.
"""
conn_name_attr = "astro_cloud_conn_id"
default_conn_name = "astro_cloud_default"
conn_type = "Astro Cloud"
hook_name = "Astro Cloud"
def __init__(self, astro_cloud_conn_id: str = "astro_cloud_conn_id"):
super().__init__()
self.astro_cloud_conn_id = astro_cloud_conn_id
[docs]
@classmethod
def get_ui_field_behaviour(cls) -> dict[str, Any]:
"""
Returns UI field behavior customization for the Astro Cloud connection.
This method defines hidden fields, relabeling, and placeholders for UI display.
"""
return {
"hidden_fields": ["login", "port", "schema", "extra"],
"relabeling": {
"password": "Astro Cloud API Token",
},
"placeholders": {
"host": "https://clmkpsyfc010391acjie00t1l.astronomer.run/d5lc9c9x",
"password": "Astro API JWT Token",
},
}
[docs]
def get_conn(self) -> tuple[str, str]:
"""Retrieves the Astro Cloud connection details."""
conn = BaseHook.get_connection(self.astro_cloud_conn_id)
base_url = conn.host or os.environ.get("AIRFLOW__WEBSERVER__BASE_URL")
if base_url is None:
raise AirflowException(f"Airflow host is missing in connection {self.astro_cloud_conn_id}")
token = conn.password
if token is None:
raise AirflowException(f"Astro API token is missing in connection {self.astro_cloud_conn_id}")
return base_url, token
@property
def _headers(self) -> dict[str, str]:
"""Generates and returns headers for Astro Cloud API requests."""
_, token = self.get_conn()
headers = {"accept": "application/json", "Authorization": f"Bearer {token}"}
return headers
[docs]
def get_dag_runs(self, external_dag_id: str) -> list[dict[str, str]]:
"""
Retrieves information about running or queued DAG runs.
:param external_dag_id: External ID of the DAG.
"""
base_url, _ = self.get_conn()
path = f"/api/v1/dags/{external_dag_id}/dagRuns"
params: dict[str, int | str | list[str]] = {
"limit": 1,
"state": ["running", "queued"],
"order_by": "-execution_date",
}
url = f"{base_url}{path}"
response = requests.get(url, headers=self._headers, params=params)
response.raise_for_status()
data: dict[str, list[dict[str, str]]] = response.json()
return data["dag_runs"]
[docs]
def get_dag_run(self, external_dag_id: str, dag_run_id: str) -> dict[str, Any] | None:
"""
Retrieves information about a specific DAG run.
:param external_dag_id: External ID of the DAG.
:param dag_run_id: ID of the DAG run.
"""
base_url, _ = self.get_conn()
dag_run_id = quote(dag_run_id)
path = f"/api/v1/dags/{external_dag_id}/dagRuns/{dag_run_id}"
url = f"{base_url}{path}"
response = requests.get(url, headers=self._headers)
response.raise_for_status()
dr: dict[str, Any] = response.json()
return dr
[docs]
async def get_a_dag_run(self, external_dag_id: str, dag_run_id: str) -> dict[str, Any] | None:
"""
Retrieves information about a specific DAG run.
:param external_dag_id: External ID of the DAG.
:param dag_run_id: ID of the DAG run.
"""
base_url, _ = self.get_conn()
dag_run_id = quote(dag_run_id)
path = f"/api/v1/dags/{external_dag_id}/dagRuns/{dag_run_id}"
url = f"{base_url}{path}"
async with ClientSession(headers=self._headers) as session:
async with session.get(url) as response:
response.raise_for_status()
dr: dict[str, Any] = await response.json()
return dr
[docs]
def get_task_instance(
self, external_dag_id: str, dag_run_id: str, external_task_id: str
) -> dict[str, Any] | None:
"""
Retrieves information about a specific task instance within a DAG run.
:param external_dag_id: External ID of the DAG.
:param dag_run_id: ID of the DAG run.
:param external_task_id: External ID of the task.
"""
base_url, _ = self.get_conn()
dag_run_id = quote(dag_run_id)
path = f"/api/v1/dags/{external_dag_id}/dagRuns/{dag_run_id}/taskInstances/{external_task_id}"
url = f"{base_url}{path}"
response = requests.get(url, headers=self._headers)
response.raise_for_status()
ti: dict[str, Any] = response.json()
return ti
[docs]
async def get_a_task_instance(
self, external_dag_id: str, dag_run_id: str, external_task_id: str
) -> dict[str, Any] | None:
"""
Retrieves information about a specific task instance within a DAG run.
:param external_dag_id: External ID of the DAG.
:param dag_run_id: ID of the DAG run.
:param external_task_id: External ID of the task.
"""
base_url, _ = self.get_conn()
dag_run_id = quote(dag_run_id)
path = f"/api/v1/dags/{external_dag_id}/dagRuns/{dag_run_id}/taskInstances/{external_task_id}"
url = f"{base_url}{path}"
async with ClientSession(headers=self._headers) as session:
async with session.get(url) as response:
response.raise_for_status()
ti: dict[str, Any] = await response.json()
return ti