import warnings
from abc import ABC
from typing import Any, Optional, Sequence, Tuple, Union
from airflow.providers.google.common.consts import CLIENT_INFO
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
from google.api_core import gapic_v1
from google.api_core.client_options import ClientOptions
from google.api_core.retry import Retry
from google.cloud.dataproc_v1 import (
ClusterControllerAsyncClient,
Job,
JobControllerAsyncClient,
)
from google.cloud.dataproc_v1.types import clusters
JobType = Union[Job, Any]
[docs]class DataprocHookAsync(GoogleBaseHook, ABC):
"""Async Hook for Google Cloud Dataproc APIs"""
[docs] def get_cluster_client(
self, region: Optional[str] = None, location: Optional[str] = None
) -> ClusterControllerAsyncClient:
"""
Get async cluster controller client for GCP Dataproc.
:param region: The Cloud Dataproc region in which to handle the request.
:param location: (To be deprecated). The Cloud Dataproc region in which to handle the request.
"""
client_options, region = self._get_client_options_and_region(region=region, location=location)
return ClusterControllerAsyncClient(
credentials=self._get_credentials(), client_info=CLIENT_INFO, client_options=client_options
)
[docs] def get_job_client(
self, region: Optional[str] = None, location: Optional[str] = None
) -> JobControllerAsyncClient:
"""
Get async job controller for GCP Dataproc.
:param region: The Cloud Dataproc region in which to handle the request.
:param location: (To be deprecated). The Cloud Dataproc region in which to handle the request.
"""
client_options, region = self._get_client_options_and_region(region=region, location=location)
return JobControllerAsyncClient(
credentials=self._get_credentials(), client_info=CLIENT_INFO, client_options=client_options
)
[docs] async def get_cluster(
self,
region: str,
cluster_name: str,
project_id: str,
retry: Union[Retry, gapic_v1.method._MethodDefault] = gapic_v1.method.DEFAULT,
metadata: Sequence[Tuple[str, str]] = (),
) -> clusters.Cluster:
"""
Get a cluster details from GCP using `ClusterControllerAsyncClient`
:param region: The Cloud Dataproc region in which to handle the request
:param cluster_name: The name of the cluster
:param project_id: The ID of the Google Cloud project the cluster belongs to
:param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
retried
:param metadata: Additional metadata that is provided to the method
"""
client = self.get_cluster_client(region=region)
return await client.get_cluster(
region=region,
cluster_name=cluster_name,
project_id=project_id,
retry=retry,
metadata=metadata,
)
[docs] @GoogleBaseHook.fallback_to_default_project_id
async def get_job(
self,
job_id: str,
project_id: str,
timeout: float = 5,
region: Optional[str] = None,
location: Optional[str] = None,
retry: Union[Retry, gapic_v1.method._MethodDefault] = gapic_v1.method.DEFAULT,
metadata: Sequence[Tuple[str, str]] = (),
) -> JobType:
"""
Gets the resource representation for a job using `JobControllerAsyncClient`.
:param job_id: Id of the Dataproc job
:param project_id: Required. The ID of the Google Cloud project the cluster belongs to.
:param region: Required. The Cloud Dataproc region in which to handle the request.
:param location: (To be deprecated). The Cloud Dataproc region in which to handle the request.
:param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
retried.
:param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
``retry`` is specified, the timeout applies to each individual attempt.
:param metadata: Additional metadata that is provided to the method.
"""
client = self.get_job_client(region=region, location=location)
job = await client.get_job(
request={"project_id": project_id, "region": region, "job_id": job_id},
retry=retry,
timeout=timeout,
metadata=metadata,
)
return job
def _get_client_options_and_region(
self, region: Optional[str] = None, location: Optional[str] = None
) -> Tuple[ClientOptions, Optional[str]]:
"""
Checks for location if present or not and creates a client options using the provided region/location
:param region: Required. The Cloud Dataproc region in which to handle the request.
:param location: (To be deprecated). The Cloud Dataproc region in which to handle the request.
"""
if location is not None:
warnings.warn(
"Parameter `location` will be deprecated. "
"Please provide value through `region` parameter instead.",
DeprecationWarning,
stacklevel=2,
)
region = location
client_options = ClientOptions()
if region and region != "global":
client_options = ClientOptions(api_endpoint=f"{region}-dataproc.googleapis.com:443")
return client_options, region