diff --git a/pkg/storage/mgr.py b/pkg/storage/mgr.py index 8d52e465..2f263f15 100644 --- a/pkg/storage/mgr.py +++ b/pkg/storage/mgr.py @@ -3,11 +3,11 @@ from __future__ import annotations from ..core import app from . import provider -from .providers import localstorage +from .providers import localstorage, s3storage class StorageMgr: - """存储管理器""" + """Storage manager""" ap: app.Application @@ -15,7 +15,16 @@ class StorageMgr: def __init__(self, ap: app.Application): self.ap = ap - self.storage_provider = localstorage.LocalStorageProvider(ap) async def initialize(self): + storage_config = self.ap.instance_config.data.get('storage', {}) + storage_type = storage_config.get('use', 'local') + + if storage_type == 's3': + self.storage_provider = s3storage.S3StorageProvider(self.ap) + self.ap.logger.info('Initialized S3 storage backend.') + else: + self.storage_provider = localstorage.LocalStorageProvider(self.ap) + self.ap.logger.info('Initialized local storage backend.') + await self.storage_provider.initialize() diff --git a/pkg/storage/providers/s3storage.py b/pkg/storage/providers/s3storage.py new file mode 100644 index 00000000..ed4fc443 --- /dev/null +++ b/pkg/storage/providers/s3storage.py @@ -0,0 +1,145 @@ +from __future__ import annotations + +import boto3 +from botocore.exceptions import ClientError + +from ...core import app +from .. import provider + + +class S3StorageProvider(provider.StorageProvider): + """S3 object storage provider""" + + def __init__(self, ap: app.Application): + super().__init__(ap) + self.s3_client = None + self.bucket_name = None + + async def initialize(self): + """Initialize S3 client with configuration from config.yaml""" + storage_config = self.ap.instance_config.data.get('storage', {}) + s3_config = storage_config.get('s3', {}) + + # Get S3 configuration + endpoint_url = s3_config.get('endpoint_url', '') + access_key_id = s3_config.get('access_key_id', '') + secret_access_key = s3_config.get('secret_access_key', '') + region_name = s3_config.get('region', 'us-east-1') + self.bucket_name = s3_config.get('bucket', 'langbot-storage') + + # Initialize S3 client + session = boto3.session.Session() + self.s3_client = session.client( + service_name='s3', + region_name=region_name, + endpoint_url=endpoint_url if endpoint_url else None, + aws_access_key_id=access_key_id, + aws_secret_access_key=secret_access_key, + ) + + # Ensure bucket exists + try: + self.s3_client.head_bucket(Bucket=self.bucket_name) + except ClientError as e: + error_code = e.response['Error']['Code'] + if error_code == '404': + # Bucket doesn't exist, create it + try: + self.s3_client.create_bucket(Bucket=self.bucket_name) + self.ap.logger.info(f'Created S3 bucket: {self.bucket_name}') + except Exception as create_error: + self.ap.logger.error(f'Failed to create S3 bucket: {create_error}') + raise + else: + self.ap.logger.error(f'Failed to access S3 bucket: {e}') + raise + + async def save( + self, + key: str, + value: bytes, + ): + """Save bytes to S3""" + try: + self.s3_client.put_object( + Bucket=self.bucket_name, + Key=key, + Body=value, + ) + except Exception as e: + self.ap.logger.error(f'Failed to save to S3: {e}') + raise + + async def load( + self, + key: str, + ) -> bytes: + """Load bytes from S3""" + try: + response = self.s3_client.get_object( + Bucket=self.bucket_name, + Key=key, + ) + return response['Body'].read() + except Exception as e: + self.ap.logger.error(f'Failed to load from S3: {e}') + raise + + async def exists( + self, + key: str, + ) -> bool: + """Check if object exists in S3""" + try: + self.s3_client.head_object( + Bucket=self.bucket_name, + Key=key, + ) + return True + except ClientError as e: + if e.response['Error']['Code'] == '404': + return False + else: + self.ap.logger.error(f'Failed to check existence in S3: {e}') + raise + + async def delete( + self, + key: str, + ): + """Delete object from S3""" + try: + self.s3_client.delete_object( + Bucket=self.bucket_name, + Key=key, + ) + except Exception as e: + self.ap.logger.error(f'Failed to delete from S3: {e}') + raise + + async def delete_dir_recursive( + self, + dir_path: str, + ): + """Delete all objects with the given prefix (directory)""" + try: + # Ensure dir_path ends with / + if not dir_path.endswith('/'): + dir_path = dir_path + '/' + + # List all objects with the prefix + paginator = self.s3_client.get_paginator('list_objects_v2') + pages = paginator.paginate(Bucket=self.bucket_name, Prefix=dir_path) + + # Delete all objects + for page in pages: + if 'Contents' in page: + objects_to_delete = [{'Key': obj['Key']} for obj in page['Contents']] + if objects_to_delete: + self.s3_client.delete_objects( + Bucket=self.bucket_name, + Delete={'Objects': objects_to_delete}, + ) + except Exception as e: + self.ap.logger.error(f'Failed to delete directory from S3: {e}') + raise diff --git a/pyproject.toml b/pyproject.toml index 6073ed28..6a8d79f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,6 +67,7 @@ dependencies = [ "asyncpg>=0.30.0", "line-bot-sdk>=3.19.0", "tboxsdk>=0.0.10", + "boto3>=1.35.0", ] keywords = [ "bot", diff --git a/templates/config.yaml b/templates/config.yaml index 366ee782..28c4d57b 100644 --- a/templates/config.yaml +++ b/templates/config.yaml @@ -35,6 +35,14 @@ vdb: host: localhost port: 6333 api_key: '' +storage: + use: local + s3: + endpoint_url: '' + access_key_id: '' + secret_access_key: '' + region: 'us-east-1' + bucket: 'langbot-storage' plugin: enable: true runtime_ws_url: 'ws://langbot_plugin_runtime:5400/control/ws' diff --git a/tests/unit_tests/storage/__init__.py b/tests/unit_tests/storage/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit_tests/storage/test_storage_provider_selection.py b/tests/unit_tests/storage/test_storage_provider_selection.py new file mode 100644 index 00000000..9f87f10a --- /dev/null +++ b/tests/unit_tests/storage/test_storage_provider_selection.py @@ -0,0 +1,100 @@ +""" +Tests for storage manager and provider selection +""" + +import pytest +from unittest.mock import Mock, AsyncMock, patch +from pkg.storage.mgr import StorageMgr +from pkg.storage.providers.localstorage import LocalStorageProvider +from pkg.storage.providers.s3storage import S3StorageProvider + + +class TestStorageProviderSelection: + """Test storage provider selection based on configuration""" + + @pytest.mark.asyncio + async def test_default_to_local_storage(self): + """Test that local storage is used by default when no config is provided""" + # Mock application + mock_app = Mock() + mock_app.instance_config = Mock() + mock_app.instance_config.data = {} + mock_app.logger = Mock() + + storage_mgr = StorageMgr(mock_app) + + with patch.object(LocalStorageProvider, 'initialize', new_callable=AsyncMock) as mock_init: + await storage_mgr.initialize() + assert isinstance(storage_mgr.storage_provider, LocalStorageProvider) + mock_init.assert_called_once() + + @pytest.mark.asyncio + async def test_explicit_local_storage(self): + """Test that local storage is used when explicitly configured""" + # Mock application + mock_app = Mock() + mock_app.instance_config = Mock() + mock_app.instance_config.data = { + 'storage': { + 'use': 'local' + } + } + mock_app.logger = Mock() + + storage_mgr = StorageMgr(mock_app) + + with patch.object(LocalStorageProvider, 'initialize', new_callable=AsyncMock) as mock_init: + await storage_mgr.initialize() + assert isinstance(storage_mgr.storage_provider, LocalStorageProvider) + mock_init.assert_called_once() + + @pytest.mark.asyncio + async def test_s3_storage_provider_selection(self): + """Test that S3 storage is used when configured""" + # Mock application + mock_app = Mock() + mock_app.instance_config = Mock() + mock_app.instance_config.data = { + 'storage': { + 'use': 's3', + 's3': { + 'endpoint_url': 'https://s3.amazonaws.com', + 'access_key_id': 'test_key', + 'secret_access_key': 'test_secret', + 'region': 'us-east-1', + 'bucket': 'test-bucket' + } + } + } + mock_app.logger = Mock() + + storage_mgr = StorageMgr(mock_app) + + with patch.object(S3StorageProvider, 'initialize', new_callable=AsyncMock) as mock_init: + await storage_mgr.initialize() + assert isinstance(storage_mgr.storage_provider, S3StorageProvider) + mock_init.assert_called_once() + + @pytest.mark.asyncio + async def test_invalid_storage_type_defaults_to_local(self): + """Test that invalid storage type defaults to local storage""" + # Mock application + mock_app = Mock() + mock_app.instance_config = Mock() + mock_app.instance_config.data = { + 'storage': { + 'use': 'invalid_type' + } + } + mock_app.logger = Mock() + + storage_mgr = StorageMgr(mock_app) + + with patch.object(LocalStorageProvider, 'initialize', new_callable=AsyncMock) as mock_init: + await storage_mgr.initialize() + assert isinstance(storage_mgr.storage_provider, LocalStorageProvider) + mock_init.assert_called_once() + + +if __name__ == '__main__': + pytest.main([__file__, '-v'])