Source code for astronomer.providers.amazon.aws.hooks.redshift_data

import asyncio
from io import StringIO
from typing import Any, Dict, Iterable, List, Tuple, Union

import botocore.exceptions
from airflow.exceptions import AirflowException
from airflow.models.param import ParamsDict
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from asgiref.sync import sync_to_async
from snowflake.connector.util_text import split_statements


[docs]class RedshiftDataHook(AwsBaseHook): """ RedshiftDataHook inherits from AwsBaseHook to connect with AWS redshift by using boto3 client_type as redshift-data we can interact with redshift cluster database and execute the query :param aws_conn_id: The Airflow connection used for AWS credentials. If this is None or empty then the default boto3 behaviour is used. If running Airflow in a distributed manner and aws_conn_id is None or empty, then default boto3 configuration would be used (and must be maintained on each worker node). :param verify: Whether or not to verify SSL certificates. https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. :param client_type: boto3.client client_type. Eg 's3', 'emr' etc :param resource_type: boto3.resource resource_type. Eg 'dynamodb' etc :param config: Configuration for botocore client. (https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html) :param poll_interval: polling period in seconds to check for the status """ def __init__(self, *args: Any, poll_interval: int = 0, **kwargs: Any) -> None: aws_connection_type: str = "redshift-data" try: # for apache-airflow-providers-amazon>=3.0.0 kwargs["client_type"] = aws_connection_type kwargs["resource_type"] = aws_connection_type super().__init__(*args, **kwargs) except ValueError: # for apache-airflow-providers-amazon>=4.1.0 kwargs["client_type"] = aws_connection_type super().__init__(*args, **kwargs) self.client_type = aws_connection_type self.poll_interval = poll_interval
[docs] def get_conn_params(self) -> Dict[str, Union[str, int]]: """Helper method to retrieve connection args""" if not self.aws_conn_id: raise AirflowException("Required connection details is missing !") connection_object = self.get_connection(self.aws_conn_id) extra_config = connection_object.extra_dejson conn_params: Dict[str, Union[str, int]] = {} if "db_user" in extra_config: conn_params["db_user"] = extra_config.get("db_user", None) else: raise AirflowException("Required db user is missing !") if "database" in extra_config: conn_params["database"] = extra_config.get("database", None) elif connection_object.schema: conn_params["database"] = connection_object.schema else: raise AirflowException("Required Database name is missing !") if "access_key_id" in extra_config or "aws_access_key_id" in extra_config: conn_params["aws_access_key_id"] = ( extra_config["access_key_id"] if "access_key_id" in extra_config else extra_config["aws_access_key_id"] ) conn_params["aws_secret_access_key"] = ( extra_config["secret_access_key"] if "secret_access_key" in extra_config else extra_config["aws_secret_access_key"] ) elif connection_object.login: conn_params["aws_access_key_id"] = connection_object.login conn_params["aws_secret_access_key"] = connection_object.password else: raise AirflowException("Required access_key_id, aws_secret_access_key") if "region" in extra_config or "region_name" in extra_config: self.log.info("Retrieving region_name from Connection.extra_config['region_name']") conn_params["region_name"] = ( extra_config["region"] if "region" in extra_config else extra_config["region_name"] ) else: raise AirflowException("Required Region name is missing !") if "aws_session_token" in extra_config: self.log.info( "session token retrieved from extra, please note you are responsible for renewing these.", ) conn_params["aws_session_token"] = extra_config["aws_session_token"] if "cluster_identifier" in extra_config: self.log.info("Retrieving cluster_identifier from Connection.extra_config['cluster_identifier']") conn_params["cluster_identifier"] = extra_config["cluster_identifier"] else: raise AirflowException("Required Cluster identifier is missing !") return conn_params
[docs] def execute_query( self, sql: Union[Dict[Any, Any], Iterable[Any]], params: Union[ParamsDict, Dict[Any, Any]] ) -> Tuple[List[str], Dict[str, str]]: """ Runs an SQL statement, which can be data manipulation language (DML) or data definition language (DDL) :param sql: list of query ids """ if not sql: raise AirflowException("SQL query is None.") try: if isinstance(sql, str): split_statements_tuple = split_statements(StringIO(sql)) sql = [sql_string for sql_string, _ in split_statements_tuple if sql_string] try: # for apache-airflow-providers-amazon>=3.0.0 client = self.get_conn() except ValueError: # for apache-airflow-providers-amazon>=4.1.0 self.resource_type = None client = self.get_conn() conn_params = self.get_conn_params() query_ids: List[str] = [] for sql_statement in sql: self.log.info("Executing statement: %s", sql_statement) response = client.execute_statement( Database=conn_params["database"], ClusterIdentifier=conn_params["cluster_identifier"], DbUser=conn_params["db_user"], Sql=sql_statement, WithEvent=True, ) query_ids.append(response["Id"]) return query_ids, {"status": "success", "message": "success"} except botocore.exceptions.ClientError as error: return [], {"status": "error", "message": str(error)}
[docs] async def get_query_status(self, query_ids: List[str]) -> Dict[str, Union[str, List[str]]]: """ Async function to get the Query status by query Ids. The function takes list of query_ids, makes async connection to redshift data to get the query status by query id and returns the query status. In case of success, it returns a list of query IDs of the queries that have a status `FINISHED`. In the case of partial failure meaning if any of queries fail or is aborted by the user we return an error as a whole. :param query_ids: list of query ids """ try: try: # for apache-airflow-providers-amazon>=3.0.0 client = await sync_to_async(self.get_conn)() except ValueError: # for apache-airflow-providers-amazon>=4.1.0 self.resource_type = None client = await sync_to_async(self.get_conn)() completed_ids: List[str] = [] for query_id in query_ids: while await self.is_still_running(query_id): await asyncio.sleep(self.poll_interval) res = client.describe_statement(Id=query_id) if res["Status"] == "FINISHED": completed_ids.append(query_id) elif res["Status"] == "FAILED": msg = "Error: " + res["QueryString"] + " query Failed due to, " + res["Error"] return {"status": "error", "message": msg, "query_id": query_id, "type": res["Status"]} elif res["Status"] == "ABORTED": return { "status": "error", "message": "The query run was stopped by the user.", "query_id": query_id, "type": res["Status"], } return {"status": "success", "completed_ids": completed_ids} except botocore.exceptions.ClientError as error: return {"status": "error", "message": str(error), "type": "ERROR"}
[docs] async def is_still_running(self, qid: str) -> Union[bool, Dict[str, str]]: """ Async function to check whether the query is still running to return True or in "PICKED", "STARTED" or "SUBMITTED" state to return False. """ try: try: # for apache-airflow-providers-amazon>=3.0.0 client = await sync_to_async(self.get_conn)() except ValueError: # for apache-airflow-providers-amazon>=4.1.0 self.resource_type = None client = await sync_to_async(self.get_conn)() desc = client.describe_statement(Id=qid) if desc["Status"] in ["PICKED", "STARTED", "SUBMITTED"]: return True return False except botocore.exceptions.ClientError as error: return {"status": "error", "message": str(error), "type": "ERROR"}