Source code for astronomer.providers.amazon.aws.hooks.redshift_data
from io import StringIO
from typing import Any, Dict, Iterable, List, Tuple, Union
import botocore.exceptions
from airflow.exceptions import AirflowException
from airflow.models.param import ParamsDict
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from snowflake.connector.util_text import split_statements
[docs]class RedshiftDataHook(AwsBaseHook):
"""
RedshiftDataHook inherits from AwsBaseHook to connect with AWS redshift
by using boto3 client_type as redshift-data we can interact with redshift cluster database and execute the query
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is None or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param verify: Whether or not to verify SSL certificates.
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param client_type: boto3.client client_type. Eg 's3', 'emr' etc
:param resource_type: boto3.resource resource_type. Eg 'dynamodb' etc
:param config: Configuration for botocore client.
(https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html)
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
client_type: str = "redshift-data"
kwargs["client_type"] = "redshift-data"
kwargs["resource_type"] = "redshift-data"
super().__init__(*args, **kwargs)
self.client_type = client_type
[docs] def get_conn_params(self) -> Dict[str, Union[str, int]]:
"""Helper method to retrieve connection args"""
if not self.aws_conn_id:
raise AirflowException("Required connection details is missing !")
connection_object = self.get_connection(self.aws_conn_id)
extra_config = connection_object.extra_dejson
conn_params: Dict[str, Union[str, int]] = {}
if "db_user" in extra_config:
conn_params["db_user"] = extra_config.get("db_user", None)
else:
raise AirflowException("Required db user is missing !")
if "database" in extra_config:
conn_params["database"] = extra_config.get("database", None)
elif connection_object.schema:
conn_params["database"] = connection_object.schema
else:
raise AirflowException("Required Database name is missing !")
if "access_key_id" in extra_config or "aws_access_key_id" in extra_config:
conn_params["aws_access_key_id"] = (
extra_config["access_key_id"]
if "access_key_id" in extra_config
else extra_config["aws_access_key_id"]
)
conn_params["aws_secret_access_key"] = (
extra_config["secret_access_key"]
if "secret_access_key" in extra_config
else extra_config["aws_secret_access_key"]
)
else:
raise AirflowException("Required access_key_id, aws_secret_access_key")
if "region" in extra_config or "region_name" in extra_config:
self.log.info("Retrieving region_name from Connection.extra_config['region_name']")
conn_params["region_name"] = (
extra_config["region"] if "region" in extra_config else extra_config["region_name"]
)
else:
raise AirflowException("Required Region name is missing !")
if "cluster_identifier" in extra_config:
self.log.info("Retrieving cluster_identifier from Connection.extra_config['cluster_identifier']")
conn_params["cluster_identifier"] = extra_config["cluster_identifier"]
else:
raise AirflowException("Required Cluster identifier is missing !")
return conn_params
[docs] def execute_query(
self, sql: Union[Dict[Any, Any], Iterable[Any]], params: Union[ParamsDict, Dict[Any, Any]]
) -> Tuple[List[str], Dict[str, str]]:
"""
Runs an SQL statement, which can be data manipulation language (DML)
or data definition language (DDL)
:param sql: list of query ids
"""
if not sql:
raise AirflowException("SQL query is None.")
try:
if isinstance(sql, str):
split_statements_tuple = split_statements(StringIO(sql))
sql = [sql_string for sql_string, _ in split_statements_tuple if sql_string]
client = self.get_conn()
conn_params = self.get_conn_params()
query_ids: List[str] = []
for sql_statement in sql:
self.log.info("Executing statement: %s", sql_statement)
response = client.execute_statement(
Database=conn_params["database"],
ClusterIdentifier=conn_params["cluster_identifier"],
DbUser=conn_params["db_user"],
Sql=sql_statement,
WithEvent=True,
)
query_ids.append(response["Id"])
return query_ids, {"status": "success", "message": "success"}
except botocore.exceptions.ClientError as error:
return [], {"status": "error", "message": str(error)}