Files
MediaCrawler/database/db_session.py
persist-1 b04f5bcd6f feat(database): 优化数据库模型和连接管理
- 添加cryptography依赖用于加密功能
- 重构数据库模型字段类型和约束
- 增加数据库自动创建功能
- 改进数据库连接管理和错误处理
- 更新相关依赖文件(pyproject.toml, requirements.txt)
2025-09-06 06:08:28 +08:00

70 lines
2.3 KiB
Python

from sqlalchemy import text
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker
from contextlib import asynccontextmanager
from .models import Base
import config
from config.db_config import mysql_db_config, sqlite_db_config
# Keep a cache of engines
_engines = {}
async def create_database_if_not_exists(db_type: str):
if db_type == "mysql" or db_type == "db":
# Connect to the server without a database
server_url = f"mysql+asyncmy://{mysql_db_config['user']}:{mysql_db_config['password']}@{mysql_db_config['host']}:{mysql_db_config['port']}"
engine = create_async_engine(server_url, echo=False)
async with engine.connect() as conn:
await conn.execute(text(f"CREATE DATABASE IF NOT EXISTS {mysql_db_config['db_name']}"))
await engine.dispose()
def get_async_engine(db_type: str = None):
if db_type is None:
db_type = config.SAVE_DATA_OPTION
if db_type in _engines:
return _engines[db_type]
if db_type in ["json", "csv"]:
return None
if db_type == "sqlite":
db_url = f"sqlite+aiosqlite:///{sqlite_db_config['db_path']}"
elif db_type == "mysql" or db_type == "db":
db_url = f"mysql+asyncmy://{mysql_db_config['user']}:{mysql_db_config['password']}@{mysql_db_config['host']}:{mysql_db_config['port']}/{mysql_db_config['db_name']}"
else:
raise ValueError(f"Unsupported database type: {db_type}")
engine = create_async_engine(db_url, echo=False)
_engines[db_type] = engine
return engine
async def create_tables(db_type: str = None):
if db_type is None:
db_type = config.SAVE_DATA_OPTION
await create_database_if_not_exists(db_type)
engine = get_async_engine(db_type)
if engine:
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
@asynccontextmanager
async def get_session() -> AsyncSession:
engine = get_async_engine(config.SAVE_DATA_OPTION)
if not engine:
yield None
return
AsyncSessionFactory = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
session = AsyncSessionFactory()
try:
yield session
await session.commit()
except Exception as e:
await session.rollback()
raise e
finally:
await session.close()