-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: support for thumbor-aws result storage
- Loading branch information
Showing
6 changed files
with
488 additions
and
3 deletions.
There are no files selected for viewing
108 changes: 108 additions & 0 deletions
108
src/thumbor_video_engine/result_storages/aws_storage.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
from datetime import datetime, timezone | ||
from hashlib import sha1 | ||
from os.path import join | ||
from deprecated import deprecated | ||
|
||
from thumbor.engines import BaseEngine | ||
from thumbor.result_storages import ResultStorageResult | ||
from thumbor.utils import logger | ||
import thumbor_aws.result_storage | ||
from thumbor_aws.utils import normalize_path | ||
from .base import BaseStorage | ||
|
||
|
||
class Storage(BaseStorage, thumbor_aws.result_storage.Storage): | ||
def _normalize_path(self, path): | ||
path = normalize_path("", path).lstrip("/") | ||
segments = [path] | ||
|
||
root_path = getattr( | ||
self.context.config, "TC_AWS_RESULT_STORAGE_ROOT_PATH", None | ||
) | ||
|
||
if root_path: | ||
segments.insert(0, root_path) | ||
auto_component = self.get_auto_path_component() | ||
if auto_component: | ||
segments.append(auto_component) | ||
|
||
should_randomize_key = getattr( | ||
self.context.config, "TC_AWS_RANDOMIZE_KEYS", False | ||
) | ||
if should_randomize_key: | ||
segments.insert(0, self._generate_digest(segments)) | ||
|
||
normalized_path = ( | ||
join(segments[0], *segments[1:]).lstrip("/") | ||
if len(segments) > 1 | ||
else segments[0] | ||
) | ||
if normalized_path.endswith("/"): | ||
normalized_path += self.context.config.TC_AWS_ROOT_IMAGE_NAME | ||
|
||
return normalized_path | ||
|
||
def _generate_digest(self, segments): | ||
return sha1(".".join(segments).encode("utf-8")).hexdigest() | ||
|
||
async def put(self, image_bytes: bytes) -> str: | ||
file_abspath = self._normalize_path(self.context.request.url) | ||
logger.debug("[RESULT_STORAGE] putting at %s", file_abspath) | ||
content_type = BaseEngine.get_mimetype(image_bytes) | ||
response = await self.upload( | ||
file_abspath, | ||
image_bytes, | ||
content_type, | ||
self.context.config.AWS_DEFAULT_LOCATION, | ||
) | ||
logger.info("[RESULT_STORAGE] Image uploaded successfully to %s", file_abspath) | ||
return response | ||
|
||
async def get(self) -> ResultStorageResult: | ||
path = self.context.request.url | ||
file_abspath = self._normalize_path(path) | ||
|
||
logger.debug("[RESULT_STORAGE] getting from %s", file_abspath) | ||
|
||
exists = await self.object_exists(file_abspath) | ||
if not exists: | ||
logger.debug("[RESULT_STORAGE] image not found at %s", file_abspath) | ||
return None | ||
|
||
status, body, last_modified = await self.get_data( | ||
self.bucket_name, file_abspath | ||
) | ||
|
||
if status != 200 or self._is_expired(last_modified): | ||
logger.debug( | ||
"[RESULT_STORAGE] cached image has expired (status %s)", status | ||
) | ||
return None | ||
|
||
logger.info( | ||
"[RESULT_STORAGE] Image retrieved successfully at %s.", | ||
file_abspath, | ||
) | ||
|
||
return ResultStorageResult( | ||
buffer=body, | ||
metadata={ | ||
"LastModified": last_modified.replace(tzinfo=timezone.utc), | ||
"ContentLength": len(body), | ||
"ContentType": BaseEngine.get_mimetype(body), | ||
}, | ||
) | ||
|
||
@deprecated(version="7.0.0", reason="Use result's last_modified instead") | ||
async def last_updated( # pylint: disable=invalid-overridden-method | ||
self, | ||
) -> datetime: | ||
path = self.context.request.url | ||
file_abspath = self._normalize_path(path) | ||
logger.debug("[RESULT_STORAGE] getting from %s", file_abspath) | ||
|
||
response = await self.get_object_metadata(file_abspath) | ||
return datetime.strptime( | ||
response["ResponseMetadata"]["HTTPHeaders"]["last-modified"], | ||
"%a, %d %b %Y %H:%M:%S %Z", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import pytest | ||
import pytest_asyncio | ||
import aiobotocore.session | ||
from aiobotocore.config import AioConfig | ||
|
||
from tests.moto_server import MotoService | ||
|
||
|
||
@pytest_asyncio.fixture | ||
async def s3_server(monkeypatch, event_loop): | ||
monkeypatch.setenv("TEST_SERVER_MODE", "true") | ||
monkeypatch.setenv("AWS_SHARED_CREDENTIALS_FILE", "") | ||
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "test-key") | ||
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "test-secret-key") | ||
monkeypatch.setenv("AWS_SESSION_TOKEN", "test-session-token") | ||
async with MotoService("s3", ssl=False) as svc: | ||
yield svc.endpoint_url | ||
|
||
|
||
@pytest.fixture | ||
def session(event_loop): | ||
return aiobotocore.session.AioSession() | ||
|
||
|
||
@pytest_asyncio.fixture | ||
async def s3_client( | ||
session, | ||
s3_server, | ||
): | ||
# This depends on mock_attributes because we may want to test event listeners. | ||
# See the documentation of `mock_attributes` for details. | ||
read_timeout = connect_timeout = 5 | ||
region = "us-east-1" | ||
|
||
async with session.create_client( | ||
"s3", | ||
region_name=region, | ||
config=AioConfig( | ||
region_name=region, | ||
signature_version="s3", | ||
read_timeout=read_timeout, | ||
connect_timeout=connect_timeout, | ||
), | ||
verify=False, | ||
endpoint_url=s3_server, | ||
aws_secret_access_key="xxx", | ||
aws_access_key_id="xxx", | ||
) as client: | ||
yield client |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
import asyncio | ||
import functools | ||
import logging | ||
import socket | ||
import threading | ||
import time | ||
|
||
# Third Party | ||
import aiohttp | ||
import moto.server | ||
import werkzeug.serving | ||
|
||
host = "127.0.0.1" | ||
|
||
_CONNECT_TIMEOUT = 10 | ||
|
||
|
||
def get_free_tcp_port(release_socket: bool = False): | ||
sckt = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | ||
sckt.bind((host, 0)) | ||
addr, port = sckt.getsockname() | ||
if release_socket: | ||
sckt.close() | ||
return port | ||
|
||
return sckt, port | ||
|
||
|
||
class MotoService: | ||
"""Will Create MotoService. | ||
Service is ref-counted so there will only be one per process. Real Service will | ||
be returned by `__aenter__`.""" | ||
|
||
_services = dict() # {name: instance} | ||
|
||
def __init__(self, service_name: str, port: int = None, ssl: bool = False): | ||
self._service_name = service_name | ||
|
||
if port: | ||
self._socket = None | ||
self._port = port | ||
else: | ||
self._socket, self._port = get_free_tcp_port() | ||
|
||
self._thread = None | ||
self._logger = logging.getLogger("MotoService") | ||
self._refcount = None | ||
self._ip_address = host | ||
self._server = None | ||
self._ssl_ctx = werkzeug.serving.generate_adhoc_ssl_context() if ssl else None | ||
self._schema = "http" if not self._ssl_ctx else "https" | ||
|
||
@property | ||
def endpoint_url(self): | ||
return f"{self._schema}://{self._ip_address}:{self._port}" | ||
|
||
def __call__(self, func): | ||
async def wrapper(*args, **kwargs): | ||
await self._start() | ||
try: | ||
result = await func(*args, **kwargs) | ||
finally: | ||
await self._stop() | ||
return result | ||
|
||
functools.update_wrapper(wrapper, func) | ||
wrapper.__wrapped__ = func | ||
return wrapper | ||
|
||
async def __aenter__(self): | ||
svc = self._services.get(self._service_name) | ||
if svc is None: | ||
self._services[self._service_name] = self | ||
self._refcount = 1 | ||
await self._start() | ||
return self | ||
else: | ||
svc._refcount += 1 | ||
return svc | ||
|
||
async def __aexit__(self, exc_type, exc_val, exc_tb): | ||
self._refcount -= 1 | ||
|
||
if self._socket: | ||
self._socket.close() | ||
self._socket = None | ||
|
||
if self._refcount == 0: | ||
del self._services[self._service_name] | ||
await self._stop() | ||
|
||
def _server_entry(self): | ||
self._main_app = moto.server.DomainDispatcherApplication( | ||
moto.server.create_backend_app, service=self._service_name | ||
) | ||
self._main_app.debug = True | ||
|
||
if self._socket: | ||
self._socket.close() # release right before we use it | ||
self._socket = None | ||
|
||
self._server = werkzeug.serving.make_server( | ||
self._ip_address, | ||
self._port, | ||
self._main_app, | ||
True, | ||
ssl_context=self._ssl_ctx, | ||
) | ||
self._server.serve_forever() | ||
|
||
async def _start(self): | ||
self._thread = threading.Thread(target=self._server_entry, daemon=True) | ||
self._thread.start() | ||
|
||
async with aiohttp.ClientSession() as session: | ||
start = time.time() | ||
|
||
while time.time() - start < 10: | ||
if not self._thread.is_alive(): | ||
break | ||
|
||
try: | ||
# we need to bypass the proxies due to monkeypatches | ||
async with session.get( | ||
self.endpoint_url + "/static", | ||
timeout=_CONNECT_TIMEOUT, | ||
verify_ssl=False, | ||
): | ||
pass | ||
break | ||
except (asyncio.TimeoutError, aiohttp.ClientConnectionError): | ||
await asyncio.sleep(0.5) | ||
else: | ||
await self._stop() # pytest.fail doesn't call stop_process | ||
raise Exception(f"Can not start service: {self._service_name}") | ||
|
||
async def _stop(self): | ||
if self._server: | ||
self._server.shutdown() | ||
|
||
self._thread.join() |
Oops, something went wrong.