Source code for astronomer.providers.google.cloud.triggers.dataproc

import asyncio
import logging
from typing import Any, AsyncIterator, Dict, Optional, Tuple

from airflow.triggers.base import BaseTrigger, TriggerEvent
from google.cloud.dataproc_v1.types import JobStatus

from astronomer.providers.google.cloud.hooks.dataproc import DataprocHookAsync

log = logging.getLogger(__name__)


[docs]class DataProcSubmitTrigger(BaseTrigger): """ Check for the state of a previously submitted Dataproc job. :param dataproc_job_id: The Dataproc job ID to poll. (templated) :param region: Required. The Cloud Dataproc region in which to handle the request. (templated) :param project_id: The ID of the google cloud project in which to create the cluster. (templated) :param location: (To be deprecated). The Cloud Dataproc region in which to handle the request. (templated) :param gcp_conn_id: The connection ID to use connecting to Google Cloud Platform. :param wait_timeout: How many seconds wait for job to be ready. """ def __init__( self, *, dataproc_job_id: str, region: Optional[str] = None, project_id: Optional[str] = None, gcp_conn_id: str = "google_cloud_default", polling_interval: float = 5.0, **kwargs: Any, ) -> None: super().__init__(**kwargs) self.project_id = project_id self.gcp_conn_id = gcp_conn_id self.dataproc_job_id = dataproc_job_id self.region = region self.polling_interval = polling_interval
[docs] def serialize(self) -> Tuple[str, Dict[str, Any]]: """Serializes DataProcSubmitTrigger arguments and classpath.""" return ( "astronomer.providers.google.cloud.triggers.dataproc.DataProcSubmitTrigger", { "project_id": self.project_id, "dataproc_job_id": self.dataproc_job_id, "region": self.region, "polling_interval": self.polling_interval, "gcp_conn_id": self.gcp_conn_id, }, )
[docs] async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override] """Simple loop until the job running on Google Cloud DataProc is completed or not""" try: hook = DataprocHookAsync(gcp_conn_id=self.gcp_conn_id) while True: job_status = await self._get_job_status(hook) if "status" in job_status and job_status["status"] == "success": yield TriggerEvent(job_status) elif "status" in job_status and job_status["status"] == "error": yield TriggerEvent(job_status) await asyncio.sleep(self.polling_interval) except Exception as e: yield TriggerEvent({"status": "error", "message": str(e)}) return
async def _get_job_status(self, hook: DataprocHookAsync) -> Dict[str, str]: """Gets the status of the given job_id from the Google Cloud DataProc""" job = await hook.get_job(job_id=self.dataproc_job_id, region=self.region, project_id=self.project_id) state = job.status.state if state == JobStatus.State.ERROR: return {"status": "error", "message": "Job Failed"} elif state in { JobStatus.State.CANCELLED, JobStatus.State.CANCEL_PENDING, JobStatus.State.CANCEL_STARTED, }: return {"status": "error", "message": "Job got cancelled"} elif JobStatus.State.DONE == state: return {"status": "success", "message": "Job completed successfully"} elif JobStatus.State.ATTEMPT_FAILURE == state: return {"status": "pending", "message": "Job is in pending state"} return {"status": "pending", "message": "Job is in pending state"}