Source code for astronomer.providers.microsoft.azure.sensors.wasb

import warnings
from datetime import timedelta
from typing import Any, Dict, List, Optional

from airflow import AirflowException
from airflow.providers.microsoft.azure.sensors.wasb import (
    WasbBlobSensor,
    WasbPrefixSensor,
)

from astronomer.providers.microsoft.azure.triggers.wasb import (
    WasbBlobSensorTrigger,
    WasbPrefixSensorTrigger,
)
from astronomer.providers.utils.typing_compat import Context


[docs]class WasbBlobSensorAsync(WasbBlobSensor): """ Polls asynchronously for the existence of a blob in a WASB container. :param container_name: name of the container in which the blob should be searched for :param blob_name: name of the blob to check existence for :param wasb_conn_id: the connection identifier for connecting to Azure WASB :param poll_interval: polling period in seconds to check for the status :param public_read: whether an anonymous public read access should be used. Default is False """ def __init__( self, *, container_name: str, blob_name: str, wasb_conn_id: str = "wasb_default", public_read: bool = False, poll_interval: float = 5.0, **kwargs: Any, ): self.container_name = container_name self.blob_name = blob_name # TODO: Remove once deprecated if poll_interval: self.poke_interval = poll_interval warnings.warn( "Argument `poll_interval` is deprecated and will be removed " "in a future release. Please use `poke_interval` instead.", DeprecationWarning, stacklevel=2, ) super().__init__(container_name=container_name, blob_name=blob_name, **kwargs) self.wasb_conn_id = wasb_conn_id self.public_read = public_read
[docs] def execute(self, context: Context) -> None: """Defers trigger class to poll for state of the job run until it reaches a failure state or success state""" self.defer( timeout=timedelta(seconds=self.timeout), trigger=WasbBlobSensorTrigger( container_name=self.container_name, blob_name=self.blob_name, wasb_conn_id=self.wasb_conn_id, public_read=self.public_read, poke_interval=self.poke_interval, ), method_name="execute_complete", )
[docs] def execute_complete(self, context: Context, event: Dict[str, str]) -> None: """ Callback for when the trigger fires - returns immediately. Relies on trigger to throw an exception, otherwise it assumes execution was successful. """ if event: if event["status"] == "error": raise AirflowException(event["message"]) self.log.info(event["message"]) else: raise AirflowException("Did not receive valid event from the triggerer")
[docs]class WasbPrefixSensorAsync(WasbPrefixSensor): """ Polls asynchronously for the existence of a blob having the given prefix in a WASB container. :param container_name: name of the container in which the blob should be searched for :param blob_name: name of the blob to check existence for :param include: specifies one or more additional datasets to include in the response. Options include: ``snapshots``, ``metadata``, ``uncommittedblobs``, ``copy`, ``deleted`` :param delimiter: filters objects based on the delimiter (for e.g '.csv') :param wasb_conn_id: the connection identifier for connecting to Azure WASB :param poll_interval: polling period in seconds to check for the status :param public_read: whether an anonymous public read access should be used. Default is False """ def __init__( self, *, container_name: str, prefix: str, include: Optional[List[str]] = None, delimiter: Optional[str] = "/", wasb_conn_id: str = "wasb_default", public_read: bool = False, poll_interval: float = 5.0, **kwargs: Any, ): # TODO: Remove once deprecated if poll_interval: self.poke_interval = poll_interval warnings.warn( "Argument `poll_interval` is deprecated and will be removed " "in a future release. Please use `poke_interval` instead.", DeprecationWarning, stacklevel=2, ) super().__init__(container_name=container_name, prefix=prefix, **kwargs) self.container_name = container_name self.prefix = prefix self.include = include self.delimiter = delimiter self.wasb_conn_id = wasb_conn_id self.public_read = public_read
[docs] def execute(self, context: Context) -> None: """Defers trigger class to poll for state of the job run until it reaches a failure state or success state""" self.defer( timeout=timedelta(seconds=self.timeout), trigger=WasbPrefixSensorTrigger( container_name=self.container_name, prefix=self.prefix, include=self.include, delimiter=self.delimiter, wasb_conn_id=self.wasb_conn_id, public_read=self.public_read, poke_interval=self.poke_interval, ), method_name="execute_complete", )
[docs] def execute_complete(self, context: Context, event: Dict[str, str]) -> None: """ Callback for when the trigger fires - returns immediately. Relies on trigger to throw an exception, otherwise it assumes execution was successful. """ if event: if event["status"] == "error": raise AirflowException(event["message"]) self.log.info(event["message"]) else: raise AirflowException("Did not receive valid event from the triggerer")