import asyncio
import time
from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Union
from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.dataproc import DataprocHook
from airflow.triggers.base import BaseTrigger, TriggerEvent
from google.api_core.exceptions import NotFound
from google.cloud.dataproc_v1 import Cluster
from google.cloud.dataproc_v1.types import JobStatus, clusters
from astronomer.providers.google.cloud.hooks.dataproc import DataprocHookAsync
[docs]class DataprocCreateClusterTrigger(BaseTrigger):
"""
Asynchronously check the status of a cluster
:param project_id: The ID of the Google Cloud project the cluster belongs to
:param region: The Cloud Dataproc region in which to handle the request
:param cluster_name: The name of the cluster
:param end_time: Time in second left to check the cluster status
:param metadata: Additional metadata that is provided to the method
:param gcp_conn_id: The connection ID to use when fetching connection info.
:param polling_interval: Time in seconds to sleep between checks of cluster status
"""
def __init__(
self,
*,
project_id: Optional[str] = None,
region: Optional[str] = None,
cluster_name: str,
end_time: float,
metadata: Sequence[Tuple[str, str]] = (),
delete_on_error: bool = True,
cluster_config: Optional[Union[Dict[str, Any], clusters.Cluster]] = None,
labels: Optional[Dict[str, str]] = None,
gcp_conn_id: str = "google_cloud_default",
polling_interval: float = 5.0,
**kwargs: Any,
):
super().__init__(**kwargs)
self.project_id = project_id
self.region = region
self.cluster_name = cluster_name
self.end_time = end_time
self.metadata = metadata
self.delete_on_error = delete_on_error
self.cluster_config = cluster_config
self.labels = labels
self.gcp_conn_id = gcp_conn_id
self.polling_interval = polling_interval
[docs] def serialize(self) -> Tuple[str, Dict[str, Any]]:
"""Serializes DataprocCreateClusterTrigger arguments and classpath."""
return (
"astronomer.providers.google.cloud.triggers.dataproc.DataprocCreateClusterTrigger",
{
"project_id": self.project_id,
"region": self.region,
"cluster_name": self.cluster_name,
"end_time": self.end_time,
"metadata": self.metadata,
"delete_on_error": self.delete_on_error,
"cluster_config": self.cluster_config,
"labels": self.labels,
"gcp_conn_id": self.gcp_conn_id,
"polling_interval": self.polling_interval,
},
)
[docs] async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override]
"""Check the status of cluster until reach the terminal state"""
while self.end_time > time.time():
try:
cluster = await self._get_cluster()
if cluster.status.state == cluster.status.State.RUNNING:
yield TriggerEvent(
{
"status": "success",
"data": Cluster.to_dict(cluster),
"cluster_name": self.cluster_name,
}
)
elif cluster.status.state == cluster.status.State.DELETING:
await self._wait_for_deleting()
self._create_cluster()
await self._handle_error(cluster)
self.log.info(
"Cluster status is %s. Sleeping for %s seconds.",
cluster.status.state,
self.polling_interval,
)
await asyncio.sleep(self.polling_interval)
except Exception as e:
yield TriggerEvent({"status": "error", "message": str(e)})
yield TriggerEvent({"status": "error", "message": "Timeout"})
async def _handle_error(self, cluster: clusters.Cluster) -> None:
if cluster.status.state != cluster.status.State.ERROR:
return
self.log.info("Cluster is in ERROR state")
gcs_uri = self._diagnose_cluster()
self.log.info("Diagnostic information for cluster %s available at: %s", self.cluster_name, gcs_uri)
if self.delete_on_error:
self._delete_cluster()
await self._wait_for_deleting()
raise AirflowException(
"Cluster was created but was in ERROR state. \n"
" Diagnostic information for cluster %s available at: %s",
self.cluster_name,
gcs_uri,
)
raise AirflowException(
"Cluster was created but is in ERROR state. \n "
"Diagnostic information for cluster %s available at: %s",
self.cluster_name,
gcs_uri,
)
def _delete_cluster(self) -> None:
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id)
hook.delete_cluster(
project_id=self.project_id,
region=self.region,
cluster_name=self.cluster_name,
metadata=self.metadata,
)
async def _wait_for_deleting(self) -> None:
while self.end_time > time.time():
try:
cluster = await self._get_cluster()
if cluster.status.State.DELETING:
self.log.info(
"Cluster status is %s. Sleeping for %s seconds.",
cluster.status.state,
self.polling_interval,
)
await asyncio.sleep(self.polling_interval)
except NotFound:
return
except Exception as e:
raise e
def _create_cluster(self) -> Any:
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id)
return hook.create_cluster(
region=self.region,
project_id=self.project_id,
cluster_name=self.cluster_name,
cluster_config=self.cluster_config,
labels=self.labels,
metadata=self.metadata,
)
async def _get_cluster(self) -> clusters.Cluster:
hook = DataprocHookAsync(gcp_conn_id=self.gcp_conn_id)
return await hook.get_cluster(
region=self.region, # type: ignore[arg-type]
cluster_name=self.cluster_name,
project_id=self.project_id, # type: ignore[arg-type]
metadata=self.metadata,
)
def _diagnose_cluster(self) -> Any:
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id)
return hook.diagnose_cluster(
project_id=self.project_id,
region=self.region,
cluster_name=self.cluster_name,
metadata=self.metadata,
)
[docs]class DataprocDeleteClusterTrigger(BaseTrigger):
"""
Asynchronously check the status of a cluster
:param cluster_name: The name of the cluster
:param end_time: Time in second left to check the cluster status
:param project_id: The ID of the Google Cloud project the cluster belongs to
:param region: The Cloud Dataproc region in which to handle the request
:param metadata: Additional metadata that is provided to the method
:param gcp_conn_id: The connection ID to use when fetching connection info.
:param polling_interval: Time in seconds to sleep between checks of cluster status
"""
def __init__(
self,
cluster_name: str,
end_time: float,
project_id: Optional[str] = None,
region: Optional[str] = None,
metadata: Sequence[Tuple[str, str]] = (),
gcp_conn_id: str = "google_cloud_default",
polling_interval: float = 5.0,
**kwargs: Any,
):
super().__init__(**kwargs)
self.cluster_name = cluster_name
self.end_time = end_time
self.project_id = project_id
self.region = region
self.metadata = metadata
self.gcp_conn_id = gcp_conn_id
self.polling_interval = polling_interval
[docs] def serialize(self) -> Tuple[str, Dict[str, Any]]:
"""Serializes DataprocDeleteClusterTrigger arguments and classpath."""
return (
"astronomer.providers.google.cloud.triggers.dataproc.DataprocDeleteClusterTrigger",
{
"cluster_name": self.cluster_name,
"end_time": self.end_time,
"project_id": self.project_id,
"region": self.region,
"metadata": self.metadata,
"gcp_conn_id": self.gcp_conn_id,
"polling_interval": self.polling_interval,
},
)
[docs] async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override]
"""Wait until cluster is deleted completely"""
hook = DataprocHookAsync(gcp_conn_id=self.gcp_conn_id)
while self.end_time > time.time():
try:
cluster = await hook.get_cluster(
region=self.region, # type: ignore[arg-type]
cluster_name=self.cluster_name,
project_id=self.project_id, # type: ignore[arg-type]
metadata=self.metadata,
)
self.log.info(
"Cluster status is %s. Sleeping for %s seconds.",
cluster.status.state,
self.polling_interval,
)
await asyncio.sleep(self.polling_interval)
except NotFound:
yield TriggerEvent({"status": "success", "message": ""})
except Exception as e:
yield TriggerEvent({"status": "error", "message": str(e)})
yield TriggerEvent({"status": "error", "message": "Timeout"})
[docs]class DataProcSubmitTrigger(BaseTrigger):
"""
Check for the state of a previously submitted Dataproc job.
:param dataproc_job_id: The Dataproc job ID to poll. (templated)
:param region: Required. The Cloud Dataproc region in which to handle the request. (templated)
:param project_id: The ID of the google cloud project in which
to create the cluster. (templated)
:param location: (To be deprecated). The Cloud Dataproc region in which to handle the request. (templated)
:param gcp_conn_id: The connection ID to use connecting to Google Cloud Platform.
:param wait_timeout: How many seconds wait for job to be ready.
"""
def __init__(
self,
*,
dataproc_job_id: str,
region: Optional[str] = None,
project_id: Optional[str] = None,
gcp_conn_id: str = "google_cloud_default",
polling_interval: float = 5.0,
) -> None:
super().__init__()
self.project_id = project_id
self.gcp_conn_id = gcp_conn_id
self.dataproc_job_id = dataproc_job_id
self.region = region
self.polling_interval = polling_interval
[docs] def serialize(self) -> Tuple[str, Dict[str, Any]]:
"""Serializes DataProcSubmitTrigger arguments and classpath."""
return (
"astronomer.providers.google.cloud.triggers.dataproc.DataProcSubmitTrigger",
{
"project_id": self.project_id,
"dataproc_job_id": self.dataproc_job_id,
"region": self.region,
"polling_interval": self.polling_interval,
"gcp_conn_id": self.gcp_conn_id,
},
)
[docs] async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override]
"""Simple loop until the job running on Google Cloud DataProc is completed or not"""
try:
hook = DataprocHookAsync(gcp_conn_id=self.gcp_conn_id)
while True:
job_status = await self._get_job_status(hook)
if "status" in job_status and job_status["status"] == "success":
yield TriggerEvent(job_status)
elif "status" in job_status and job_status["status"] == "error":
yield TriggerEvent(job_status)
await asyncio.sleep(self.polling_interval)
except Exception as e:
yield TriggerEvent({"status": "error", "message": str(e)})
return
async def _get_job_status(self, hook: DataprocHookAsync) -> Dict[str, str]:
"""Gets the status of the given job_id from the Google Cloud DataProc"""
job = await hook.get_job(job_id=self.dataproc_job_id, region=self.region, project_id=self.project_id)
state = job.status.state
if state == JobStatus.State.ERROR:
return {"status": "error", "message": "Job Failed", "job_id": self.dataproc_job_id}
elif state in {
JobStatus.State.CANCELLED,
JobStatus.State.CANCEL_PENDING,
JobStatus.State.CANCEL_STARTED,
}:
return {"status": "error", "message": "Job got cancelled", "job_id": self.dataproc_job_id}
elif JobStatus.State.DONE == state:
return {
"status": "success",
"message": "Job completed successfully",
"job_id": self.dataproc_job_id,
}
elif JobStatus.State.ATTEMPT_FAILURE == state:
return {"status": "pending", "message": "Job is in pending state", "job_id": self.dataproc_job_id}
return {"status": "pending", "message": "Job is in pending state", "job_id": self.dataproc_job_id}