Source code for astronomer.providers.snowflake.operators.snowflake

from __future__ import annotations

import typing
from datetime import timedelta
from typing import Any, Callable, List

from airflow.exceptions import AirflowException

try:
    from airflow.providers.snowflake.operators.snowflake import SnowflakeOperator
except ImportError:  # pragma: no cover
    # For apache-airflow-providers-snowflake > 3.3.0
    # currently added type: ignore[no-redef, attr-defined] and pragma: no cover because this import
    # path won't be available in current setup
    from airflow.providers.common.sql.operators.sql import (  # type: ignore[assignment]
        SQLExecuteQueryOperator as SnowflakeOperator,
    )

from astronomer.providers.snowflake.hooks.snowflake import (
    SnowflakeHookAsync,
    fetch_all_snowflake_handler,
)
from astronomer.providers.snowflake.hooks.snowflake_sql_api import (
    SnowflakeSqlApiHookAsync,
)
from astronomer.providers.snowflake.triggers.snowflake_trigger import (
    SnowflakeSqlApiTrigger,
    SnowflakeTrigger,
    get_db_hook,
)
from astronomer.providers.utils.typing_compat import Context


[docs]class SnowflakeOperatorAsync(SnowflakeOperator): """ - SnowflakeOperatorAsync uses the snowflake python connector ``execute_async`` method to submit a database command for asynchronous execution. - Submit multiple queries in parallel without waiting for each query to complete. - Accepts list of queries or multiple queries with ‘;’ semicolon separated string and params. It loops through the queries and execute them in sequence. Uses execute_async method to run the query - Once a query is submitted, it executes the query from one connection and gets the query IDs from the response and passes it to the Triggerer and closes the connection (so that the worker slots can be freed up). - The trigger gets the list of query IDs as input and polls every few seconds to snowflake and checks for the query status based on the query ID from different connection. Where can this operator fit in? - Execute time taking queries which can be executed in parallel - For batch based operation like copy or inserting the data in parallel. Best practices: - Ensure that you know which queries are dependent upon other queries before you run any queries in parallel. Some queries are interdependent and order sensitive, and therefore not suitable for parallelizing. For example, obviously an INSERT statement should not start until after the corresponding to CREATE TABLE statement has finished. - Ensure that you do not run too many queries for the memory that you have available. Running multiple queries in parallel typically consumes more memory, especially if more than one set of results is stored in memory at the same time. - Ensure that transaction control statements (BEGIN, COMMIT, and ROLLBACK) do not execute in parallel with other statements. .. seealso:: - `Snowflake Async Python connector <https://docs.snowflake.com/en/user-guide/python-connector-example.html#label-python-connector-querying-data-asynchronous.>`_ - `Best Practices <https://docs.snowflake.com/en/user-guide/python-connector-example.html#best-practices-for-asynchronous-queries>`_ :param snowflake_conn_id: Reference to Snowflake connection id :param sql: the sql code to be executed. (templated) :param autocommit: if True, each command is automatically committed. (default value: True) :param parameters: (optional) the parameters to render the SQL query with. :param warehouse: name of warehouse (will overwrite any warehouse defined in the connection's extra JSON) :param database: name of database (will overwrite database defined in connection) :param schema: name of schema (will overwrite schema defined in connection) :param role: name of role (will overwrite any role defined in connection's extra JSON) :param authenticator: authenticator for Snowflake. 'snowflake' (default) to use the internal Snowflake authenticator 'externalbrowser' to authenticate using your web browser and Okta, ADFS or any other SAML 2.0-compliant identify provider (IdP) that has been defined for your account 'https://<your_okta_account_name>.okta.com' to authenticate through native Okta. :param session_parameters: You can set session-level parameters at the time you connect to Snowflake :param handler: (optional) the function that will be applied to the cursor (default: fetch_all_handler). :param return_last: (optional) if return the result of only last statement (default: True). :param poll_interval: the interval in seconds to poll the query """ # noqa def __init__( self, *, snowflake_conn_id: str = "snowflake_default", warehouse: str | None = None, database: str | None = None, role: str | None = None, schema: str | None = None, authenticator: str | None = None, session_parameters: dict[str, Any] | None = None, poll_interval: int = 5, handler: Callable[[Any], Any] = fetch_all_snowflake_handler, return_last: bool = True, **kwargs: Any, ) -> None: self.poll_interval = poll_interval self.warehouse = warehouse self.database = database self.role = role self.schema = schema self.authenticator = authenticator self.session_parameters = session_parameters self.snowflake_conn_id = snowflake_conn_id if self.__class__.__base__.__name__ != "SnowflakeOperator": # It's better to do str check of the parent class name because currently SnowflakeOperator # is deprecated and in future OSS SnowflakeOperator may be removed if any( [warehouse, database, role, schema, authenticator, session_parameters] ): # pragma: no cover hook_params = kwargs.pop("hook_params", {}) # pragma: no cover kwargs["hook_params"] = { "warehouse": warehouse, "database": database, "role": role, "schema": schema, "authenticator": authenticator, "session_parameters": session_parameters, **hook_params, } # pragma: no cover super().__init__(conn_id=snowflake_conn_id, **kwargs) # pragma: no cover else: super().__init__(**kwargs) self.handler = handler self.return_last = return_last
[docs] def get_db_hook(self) -> SnowflakeHookAsync: """Get the Snowflake Hook""" return get_db_hook(self.snowflake_conn_id)
[docs] def execute(self, context: Context) -> None: """ Make a sync connection to snowflake and run query in execute_async function in snowflake and close the connection and with the query ids, fetch the status of the query. By deferring the SnowflakeTrigger class pass along with query ids. """ self.log.info("Executing: %s", self.sql) default_query_tag = f"airflow_openlineage_{self.task_id}_{self.dag.dag_id}_{context['ti'].try_number}" query_tag = ( self.parameters.get("query_tag", default_query_tag) if isinstance(self.parameters, dict) else default_query_tag ) session_query_tag = f"ALTER SESSION SET query_tag = '{query_tag}';" if isinstance(self.sql, str): self.sql = "\n".join([session_query_tag, self.sql]) else: session_query_list = [session_query_tag] session_query_list.extend(self.sql) self.sql = session_query_list self.log.info("SQL after adding query tag: %s", self.sql) hook = self.get_db_hook() hook.run(self.sql, parameters=self.parameters) # type: ignore[arg-type] self.query_ids = hook.query_ids if self.do_xcom_push: context["ti"].xcom_push(key="query_ids", value=self.query_ids) self.defer( timeout=self.execution_timeout, trigger=SnowflakeTrigger( task_id=self.task_id, poll_interval=self.poll_interval, query_ids=self.query_ids, snowflake_conn_id=self.snowflake_conn_id, ), method_name="execute_complete", )
[docs] def execute_complete(self, context: Context, event: dict[str, str | list[str]] | None = None) -> Any: """ Callback for when the trigger fires - returns immediately. Relies on trigger to throw an exception, otherwise it assumes execution was successful. """ self.log.info("SQL in execute_complete: %s", self.sql) if event: if "status" in event and event["status"] == "error": msg = "{}: {}".format(event["type"], event["message"]) raise AirflowException(msg) elif "status" in event and event["status"] == "success": hook = self.get_db_hook() qids = typing.cast(List[str], event["query_ids"]) results = hook.check_query_output(qids, self.handler, self.return_last) self.log.info("%s completed successfully.", self.task_id) if self.do_xcom_push: return results else: raise AirflowException("Did not receive valid event from the trigerrer")
[docs]class SnowflakeSqlApiOperatorAsync(SnowflakeOperator): """ Implemented Async Snowflake SQL API Operator to support multiple SQL statements sequentially, which is the behavior of the SnowflakeOperator, the Snowflake SQL API allows submitting multiple SQL statements in a single request. In combination with aiohttp, make post request to submit SQL statements for execution, poll to check the status of the execution of a statement. Fetch query results concurrently. This Operator currently uses key pair authentication, so you need tp provide private key raw content or private key file path in the snowflake connection along with other details .. seealso:: `Snowflake SQL API key pair Authentication <https://docs.snowflake.com/en/developer-guide/sql-api/authenticating.html#label-sql-api-authenticating-key-pair>`_ Where can this operator fit in? - To execute multiple SQL statements in a single request - To execute the SQL statement asynchronously and to execute standard queries and most DDL and DML statements - To develop custom applications and integrations that perform queries - To create provision users and roles, create table, etc. The following commands are not supported: - The PUT command (in Snowflake SQL) - The GET command (in Snowflake SQL) - The CALL command with stored procedures that return a table(stored procedures with the RETURNS TABLE clause). .. seealso:: - `Snowflake SQL API <https://docs.snowflake.com/en/developer-guide/sql-api/intro.html#introduction-to-the-sql-api>`_ - `API Reference <https://docs.snowflake.com/en/developer-guide/sql-api/reference.html#snowflake-sql-api-reference>`_ - `Limitation on snowflake SQL API <https://docs.snowflake.com/en/developer-guide/sql-api/intro.html#limitations-of-the-sql-api>`_ :param snowflake_conn_id: Reference to Snowflake connection id :param sql: the sql code to be executed. (templated) :param autocommit: if True, each command is automatically committed. (default value: True) :param parameters: (optional) the parameters to render the SQL query with. :param warehouse: name of warehouse (will overwrite any warehouse defined in the connection's extra JSON) :param database: name of database (will overwrite database defined in connection) :param schema: name of schema (will overwrite schema defined in connection) :param role: name of role (will overwrite any role defined in connection's extra JSON) :param authenticator: authenticator for Snowflake. 'snowflake' (default) to use the internal Snowflake authenticator 'externalbrowser' to authenticate using your web browser and Okta, ADFS or any other SAML 2.0-compliant identify provider (IdP) that has been defined for your account 'https://<your_okta_account_name>.okta.com' to authenticate through native Okta. :param session_parameters: You can set session-level parameters at the time you connect to Snowflake :param poll_interval: the interval in seconds to poll the query :param statement_count: Number of SQL statement to be executed :param token_life_time: lifetime of the JWT Token :param token_renewal_delta: Renewal time of the JWT Token :param bindings: (Optional) Values of bind variables in the SQL statement. When executing the statement, Snowflake replaces placeholders (? and :name) in the statement with these specified values. """ # noqa LIFETIME = timedelta(minutes=59) # The tokens will have a 59 minutes lifetime RENEWAL_DELTA = timedelta(minutes=54) # Tokens will be renewed after 54 minutes def __init__( self, *, snowflake_conn_id: str = "snowflake_default", warehouse: str | None = None, database: str | None = None, role: str | None = None, schema: str | None = None, authenticator: str | None = None, session_parameters: dict[str, Any] | None = None, poll_interval: int = 5, statement_count: int = 0, token_life_time: timedelta = LIFETIME, token_renewal_delta: timedelta = RENEWAL_DELTA, bindings: dict[str, Any] | None = None, **kwargs: Any, ) -> None: self.warehouse = warehouse self.database = database self.role = role self.schema = schema self.authenticator = authenticator self.session_parameters = session_parameters self.snowflake_conn_id = snowflake_conn_id self.poll_interval = poll_interval self.statement_count = statement_count self.token_life_time = token_life_time self.token_renewal_delta = token_renewal_delta self.bindings = bindings self.execute_async = False if self.__class__.__base__.__name__ != "SnowflakeOperator": # It's better to do str check of the parent class name because currently SnowflakeOperator # is deprecated and in future OSS SnowflakeOperator may be removed if any( [warehouse, database, role, schema, authenticator, session_parameters] ): # pragma: no cover hook_params = kwargs.pop("hook_params", {}) # pragma: no cover kwargs["hook_params"] = { "warehouse": warehouse, "database": database, "role": role, "schema": schema, "authenticator": authenticator, "session_parameters": session_parameters, **hook_params, } super().__init__(conn_id=snowflake_conn_id, **kwargs) # pragma: no cover else: super().__init__(**kwargs)
[docs] def execute(self, context: Context) -> None: """ Make a POST API request to snowflake by using SnowflakeSQL and execute the query to get the ids. By deferring the SnowflakeSqlApiTrigger class passed along with query ids. """ self.log.info("Executing: %s", self.sql) hook = SnowflakeSqlApiHookAsync( snowflake_conn_id=self.snowflake_conn_id, token_life_time=self.token_life_time, token_renewal_delta=self.token_renewal_delta, ) hook.execute_query( self.sql, statement_count=self.statement_count, bindings=self.bindings # type: ignore[arg-type] ) self.query_ids = hook.query_ids self.log.info("List of query ids %s", self.query_ids) if self.do_xcom_push: context["ti"].xcom_push(key="query_ids", value=self.query_ids) self.defer( timeout=self.execution_timeout, trigger=SnowflakeSqlApiTrigger( poll_interval=self.poll_interval, query_ids=self.query_ids, snowflake_conn_id=self.snowflake_conn_id, token_life_time=self.token_life_time, token_renewal_delta=self.token_renewal_delta, ), method_name="execute_complete", )
[docs] def execute_complete(self, context: Context, event: dict[str, str | list[str]] | None = None) -> None: """ Callback for when the trigger fires - returns immediately. Relies on trigger to throw an exception, otherwise it assumes execution was successful. """ if event: if "status" in event and event["status"] == "error": msg = "{}: {}".format(event["status"], event["message"]) raise AirflowException(msg) elif "status" in event and event["status"] == "success": hook = SnowflakeSqlApiHookAsync(snowflake_conn_id=self.snowflake_conn_id) query_ids = typing.cast(List[str], event["statement_query_ids"]) hook.check_query_output(query_ids) self.log.info("%s completed successfully.", self.task_id) else: self.log.info("%s completed successfully.", self.task_id)