# -*- coding: utf-8 -*- # Copyright (c) 2025 relakkes@gmail.com # # This file is part of MediaCrawler project. # Repository: https://github.com/NanmiCoder/MediaCrawler/blob/main/database/db_session.py # GitHub: https://github.com/NanmiCoder # Licensed under NON-COMMERCIAL LEARNING LICENSE 1.1 # # 声明:本代码仅供学习和研究目的使用。使用者应遵守以下原则: # 1. 不得用于任何商业用途。 # 2. 使用时应遵守目标平台的使用条款和robots.txt规则。 # 3. 不得进行大规模爬取或对平台造成运营干扰。 # 4. 应合理控制请求频率,避免给目标平台带来不必要的负担。 # 5. 不得用于任何非法或不当的用途。 # # 详细许可条款请参阅项目根目录下的LICENSE文件。 # 使用本代码即表示您同意遵守上述原则和LICENSE中的所有条款。 from sqlalchemy import text from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession from sqlalchemy.orm import sessionmaker from contextlib import asynccontextmanager from .models import Base import config from config.db_config import mysql_db_config, sqlite_db_config # Keep a cache of engines _engines = {} async def create_database_if_not_exists(db_type: str): if db_type == "mysql" or db_type == "db": # Connect to the server without a database server_url = f"mysql+asyncmy://{mysql_db_config['user']}:{mysql_db_config['password']}@{mysql_db_config['host']}:{mysql_db_config['port']}" engine = create_async_engine(server_url, echo=False) async with engine.connect() as conn: await conn.execute(text(f"CREATE DATABASE IF NOT EXISTS {mysql_db_config['db_name']}")) await engine.dispose() def get_async_engine(db_type: str = None): if db_type is None: db_type = config.SAVE_DATA_OPTION if db_type in _engines: return _engines[db_type] if db_type in ["json", "csv"]: return None if db_type == "sqlite": db_url = f"sqlite+aiosqlite:///{sqlite_db_config['db_path']}" elif db_type == "mysql" or db_type == "db": db_url = f"mysql+asyncmy://{mysql_db_config['user']}:{mysql_db_config['password']}@{mysql_db_config['host']}:{mysql_db_config['port']}/{mysql_db_config['db_name']}" else: raise ValueError(f"Unsupported database type: {db_type}") engine = create_async_engine(db_url, echo=False) _engines[db_type] = engine return engine async def create_tables(db_type: str = None): if db_type is None: db_type = config.SAVE_DATA_OPTION await create_database_if_not_exists(db_type) engine = get_async_engine(db_type) if engine: async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) @asynccontextmanager async def get_session() -> AsyncSession: engine = get_async_engine(config.SAVE_DATA_OPTION) if not engine: yield None return AsyncSessionFactory = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) session = AsyncSessionFactory() try: yield session await session.commit() except Exception as e: await session.rollback() raise e finally: await session.close()