Source code for astronomer.providers.amazon.aws.triggers.s3
from __future__ import annotations
import asyncio
import warnings
from typing import Any, AsyncIterator
from airflow.triggers.base import BaseTrigger, TriggerEvent
from astronomer.providers.amazon.aws.hooks.s3 import S3HookAsync
[docs]
class S3KeyTrigger(BaseTrigger):
"""
S3KeyTrigger is fired as deferred class with params to run the task in trigger worker
This class is deprecated and will be removed in 2.0.0.
Use :class: `~airflow.providers.amazon.aws.triggers.s3.S3KeyTrigger` instead
:param bucket_name: Name of the S3 bucket. Only needed when ``bucket_key``
is not provided as a full s3:// url.
:param bucket_key: The key being waited on. Supports full s3:// style url
or relative path from root level. When it's specified as a full s3://
url, please leave bucket_name as `None`.
:param wildcard_match: whether the bucket_key should be interpreted as a
Unix wildcard pattern
:param use_regex: whether to use regex to check bucket
:param aws_conn_id: reference to the s3 connection
:param hook_params: params for hook its optional
:param soft_fail: Set to true to mark the task as SKIPPED on failure
"""
def __init__(
self,
bucket_name: str,
bucket_key: list[str],
wildcard_match: bool = False,
use_regex: bool = False,
aws_conn_id: str = "aws_default",
poke_interval: float = 5.0,
soft_fail: bool = False,
should_check_fn: bool = False,
**hook_params: Any,
):
warnings.warn(
(
"This module is deprecated and will be removed in 2.0.0."
"Please use `airflow.providers.amazon.aws.triggers.s3.S3KeyTrigger`"
),
DeprecationWarning,
stacklevel=2,
)
super().__init__()
self.bucket_name = bucket_name
self.bucket_key = bucket_key
self.wildcard_match = wildcard_match
self.use_regex = use_regex
self.aws_conn_id = aws_conn_id
self.hook_params = hook_params
self.poke_interval = poke_interval
self.soft_fail = soft_fail
self.should_check_fn = should_check_fn
[docs]
def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serialize S3KeyTrigger arguments and classpath."""
return (
"astronomer.providers.amazon.aws.triggers.s3.S3KeyTrigger",
{
"bucket_name": self.bucket_name,
"bucket_key": self.bucket_key,
"wildcard_match": self.wildcard_match,
"use_regex": self.use_regex,
"aws_conn_id": self.aws_conn_id,
"hook_params": self.hook_params,
"poke_interval": self.poke_interval,
"soft_fail": self.soft_fail,
"should_check_fn": self.should_check_fn,
},
)
[docs]
async def run(self) -> AsyncIterator[TriggerEvent]:
"""Make an asynchronous connection using S3HookAsync."""
try:
hook = self._get_async_hook()
async with await hook.get_client_async() as client:
while True:
if await hook.check_key(
client, self.bucket_name, self.bucket_key, self.wildcard_match, self.use_regex
):
if self.should_check_fn:
s3_objects = await hook.get_files(
client, self.bucket_name, self.bucket_key, self.wildcard_match
)
await asyncio.sleep(self.poke_interval)
files = [{"Size": s3_object["Size"]} for s3_object in s3_objects]
yield TriggerEvent({"status": "running", "files": files})
else:
yield TriggerEvent({"status": "success"})
self.log.info("Sleeping for %s seconds", self.poke_interval)
await asyncio.sleep(self.poke_interval)
except Exception as e:
yield TriggerEvent({"status": "error", "message": str(e), "soft_fail": self.soft_fail})
def _get_async_hook(self) -> S3HookAsync:
return S3HookAsync(aws_conn_id=self.aws_conn_id, verify=self.hook_params.get("verify"))