Skip to content

Commit

Permalink
feat: support for thumbor-aws result storage
Browse files Browse the repository at this point in the history
  • Loading branch information
fdintino committed Jan 4, 2024
1 parent 52da121 commit 7961510
Show file tree
Hide file tree
Showing 6 changed files with 488 additions and 3 deletions.
108 changes: 108 additions & 0 deletions src/thumbor_video_engine/result_storages/aws_storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from datetime import datetime, timezone
from hashlib import sha1
from os.path import join
from deprecated import deprecated

from thumbor.engines import BaseEngine
from thumbor.result_storages import ResultStorageResult
from thumbor.utils import logger
import thumbor_aws.result_storage
from thumbor_aws.utils import normalize_path
from .base import BaseStorage


class Storage(BaseStorage, thumbor_aws.result_storage.Storage):
def _normalize_path(self, path):
path = normalize_path("", path).lstrip("/")
segments = [path]

root_path = getattr(
self.context.config, "TC_AWS_RESULT_STORAGE_ROOT_PATH", None
)

if root_path:
segments.insert(0, root_path)
auto_component = self.get_auto_path_component()
if auto_component:
segments.append(auto_component)

should_randomize_key = getattr(
self.context.config, "TC_AWS_RANDOMIZE_KEYS", False
)
if should_randomize_key:
segments.insert(0, self._generate_digest(segments))

normalized_path = (
join(segments[0], *segments[1:]).lstrip("/")
if len(segments) > 1
else segments[0]
)
if normalized_path.endswith("/"):
normalized_path += self.context.config.TC_AWS_ROOT_IMAGE_NAME

return normalized_path

def _generate_digest(self, segments):
return sha1(".".join(segments).encode("utf-8")).hexdigest()

async def put(self, image_bytes: bytes) -> str:
file_abspath = self._normalize_path(self.context.request.url)
logger.debug("[RESULT_STORAGE] putting at %s", file_abspath)
content_type = BaseEngine.get_mimetype(image_bytes)
response = await self.upload(
file_abspath,
image_bytes,
content_type,
self.context.config.AWS_DEFAULT_LOCATION,
)
logger.info("[RESULT_STORAGE] Image uploaded successfully to %s", file_abspath)
return response

async def get(self) -> ResultStorageResult:
path = self.context.request.url
file_abspath = self._normalize_path(path)

logger.debug("[RESULT_STORAGE] getting from %s", file_abspath)

exists = await self.object_exists(file_abspath)
if not exists:
logger.debug("[RESULT_STORAGE] image not found at %s", file_abspath)
return None

status, body, last_modified = await self.get_data(
self.bucket_name, file_abspath
)

if status != 200 or self._is_expired(last_modified):
logger.debug(
"[RESULT_STORAGE] cached image has expired (status %s)", status
)
return None

logger.info(
"[RESULT_STORAGE] Image retrieved successfully at %s.",
file_abspath,
)

return ResultStorageResult(
buffer=body,
metadata={
"LastModified": last_modified.replace(tzinfo=timezone.utc),
"ContentLength": len(body),
"ContentType": BaseEngine.get_mimetype(body),
},
)

@deprecated(version="7.0.0", reason="Use result's last_modified instead")
async def last_updated( # pylint: disable=invalid-overridden-method
self,
) -> datetime:
path = self.context.request.url
file_abspath = self._normalize_path(path)
logger.debug("[RESULT_STORAGE] getting from %s", file_abspath)

response = await self.get_object_metadata(file_abspath)
return datetime.strptime(
response["ResponseMetadata"]["HTTPHeaders"]["last-modified"],
"%a, %d %b %Y %H:%M:%S %Z",
)
12 changes: 12 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,18 @@
except ImportError:
from thumbor.utils import which

try:
from tests.mock_aio_server import s3_server, s3_client, session # noqa
except: # noqa

@pytest.fixture
def s3_server():
yield "http://does.not.exist"

@pytest.fixture
def s3_client():
return None


CURR_DIR = os.path.abspath(os.path.dirname(__file__))

Expand Down
49 changes: 49 additions & 0 deletions tests/mock_aio_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import pytest
import pytest_asyncio
import aiobotocore.session
from aiobotocore.config import AioConfig

from tests.moto_server import MotoService


@pytest_asyncio.fixture
async def s3_server(monkeypatch, event_loop):
monkeypatch.setenv("TEST_SERVER_MODE", "true")
monkeypatch.setenv("AWS_SHARED_CREDENTIALS_FILE", "")
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "test-key")
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "test-secret-key")
monkeypatch.setenv("AWS_SESSION_TOKEN", "test-session-token")
async with MotoService("s3", ssl=False) as svc:
yield svc.endpoint_url


@pytest.fixture
def session(event_loop):
return aiobotocore.session.AioSession()


@pytest_asyncio.fixture
async def s3_client(
session,
s3_server,
):
# This depends on mock_attributes because we may want to test event listeners.
# See the documentation of `mock_attributes` for details.
read_timeout = connect_timeout = 5
region = "us-east-1"

async with session.create_client(
"s3",
region_name=region,
config=AioConfig(
region_name=region,
signature_version="s3",
read_timeout=read_timeout,
connect_timeout=connect_timeout,
),
verify=False,
endpoint_url=s3_server,
aws_secret_access_key="xxx",
aws_access_key_id="xxx",
) as client:
yield client
141 changes: 141 additions & 0 deletions tests/moto_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import asyncio
import functools
import logging
import socket
import threading
import time

# Third Party
import aiohttp
import moto.server
import werkzeug.serving

host = "127.0.0.1"

_CONNECT_TIMEOUT = 10


def get_free_tcp_port(release_socket: bool = False):
sckt = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sckt.bind((host, 0))
addr, port = sckt.getsockname()
if release_socket:
sckt.close()
return port

return sckt, port


class MotoService:
"""Will Create MotoService.
Service is ref-counted so there will only be one per process. Real Service will
be returned by `__aenter__`."""

_services = dict() # {name: instance}

def __init__(self, service_name: str, port: int = None, ssl: bool = False):
self._service_name = service_name

if port:
self._socket = None
self._port = port
else:
self._socket, self._port = get_free_tcp_port()

self._thread = None
self._logger = logging.getLogger("MotoService")
self._refcount = None
self._ip_address = host
self._server = None
self._ssl_ctx = werkzeug.serving.generate_adhoc_ssl_context() if ssl else None
self._schema = "http" if not self._ssl_ctx else "https"

@property
def endpoint_url(self):
return f"{self._schema}://{self._ip_address}:{self._port}"

def __call__(self, func):
async def wrapper(*args, **kwargs):
await self._start()
try:
result = await func(*args, **kwargs)
finally:
await self._stop()
return result

functools.update_wrapper(wrapper, func)
wrapper.__wrapped__ = func
return wrapper

async def __aenter__(self):
svc = self._services.get(self._service_name)
if svc is None:
self._services[self._service_name] = self
self._refcount = 1
await self._start()
return self
else:
svc._refcount += 1
return svc

async def __aexit__(self, exc_type, exc_val, exc_tb):
self._refcount -= 1

if self._socket:
self._socket.close()
self._socket = None

if self._refcount == 0:
del self._services[self._service_name]
await self._stop()

def _server_entry(self):
self._main_app = moto.server.DomainDispatcherApplication(
moto.server.create_backend_app, service=self._service_name
)
self._main_app.debug = True

if self._socket:
self._socket.close() # release right before we use it
self._socket = None

self._server = werkzeug.serving.make_server(
self._ip_address,
self._port,
self._main_app,
True,
ssl_context=self._ssl_ctx,
)
self._server.serve_forever()

async def _start(self):
self._thread = threading.Thread(target=self._server_entry, daemon=True)
self._thread.start()

async with aiohttp.ClientSession() as session:
start = time.time()

while time.time() - start < 10:
if not self._thread.is_alive():
break

try:
# we need to bypass the proxies due to monkeypatches
async with session.get(
self.endpoint_url + "/static",
timeout=_CONNECT_TIMEOUT,
verify_ssl=False,
):
pass
break
except (asyncio.TimeoutError, aiohttp.ClientConnectionError):
await asyncio.sleep(0.5)
else:
await self._stop() # pytest.fail doesn't call stop_process
raise Exception(f"Can not start service: {self._service_name}")

async def _stop(self):
if self._server:
self._server.shutdown()

self._thread.join()
Loading

0 comments on commit 7961510

Please sign in to comment.