mirror of
https://github.com/NanmiCoder/MediaCrawler.git
synced 2025-11-25 03:15:17 +08:00
89 lines
3.2 KiB
Python
89 lines
3.2 KiB
Python
# -*- coding: utf-8 -*-
|
|
# Copyright (c) 2025 relakkes@gmail.com
|
|
#
|
|
# This file is part of MediaCrawler project.
|
|
# Repository: https://github.com/NanmiCoder/MediaCrawler/blob/main/database/db_session.py
|
|
# GitHub: https://github.com/NanmiCoder
|
|
# Licensed under NON-COMMERCIAL LEARNING LICENSE 1.1
|
|
#
|
|
# 声明:本代码仅供学习和研究目的使用。使用者应遵守以下原则:
|
|
# 1. 不得用于任何商业用途。
|
|
# 2. 使用时应遵守目标平台的使用条款和robots.txt规则。
|
|
# 3. 不得进行大规模爬取或对平台造成运营干扰。
|
|
# 4. 应合理控制请求频率,避免给目标平台带来不必要的负担。
|
|
# 5. 不得用于任何非法或不当的用途。
|
|
#
|
|
# 详细许可条款请参阅项目根目录下的LICENSE文件。
|
|
# 使用本代码即表示您同意遵守上述原则和LICENSE中的所有条款。
|
|
|
|
from sqlalchemy import text
|
|
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
|
from sqlalchemy.orm import sessionmaker
|
|
from contextlib import asynccontextmanager
|
|
from .models import Base
|
|
import config
|
|
from config.db_config import mysql_db_config, sqlite_db_config
|
|
|
|
# Keep a cache of engines
|
|
_engines = {}
|
|
|
|
|
|
async def create_database_if_not_exists(db_type: str):
|
|
if db_type == "mysql" or db_type == "db":
|
|
# Connect to the server without a database
|
|
server_url = f"mysql+asyncmy://{mysql_db_config['user']}:{mysql_db_config['password']}@{mysql_db_config['host']}:{mysql_db_config['port']}"
|
|
engine = create_async_engine(server_url, echo=False)
|
|
async with engine.connect() as conn:
|
|
await conn.execute(text(f"CREATE DATABASE IF NOT EXISTS {mysql_db_config['db_name']}"))
|
|
await engine.dispose()
|
|
|
|
|
|
def get_async_engine(db_type: str = None):
|
|
if db_type is None:
|
|
db_type = config.SAVE_DATA_OPTION
|
|
|
|
if db_type in _engines:
|
|
return _engines[db_type]
|
|
|
|
if db_type in ["json", "csv"]:
|
|
return None
|
|
|
|
if db_type == "sqlite":
|
|
db_url = f"sqlite+aiosqlite:///{sqlite_db_config['db_path']}"
|
|
elif db_type == "mysql" or db_type == "db":
|
|
db_url = f"mysql+asyncmy://{mysql_db_config['user']}:{mysql_db_config['password']}@{mysql_db_config['host']}:{mysql_db_config['port']}/{mysql_db_config['db_name']}"
|
|
else:
|
|
raise ValueError(f"Unsupported database type: {db_type}")
|
|
|
|
engine = create_async_engine(db_url, echo=False)
|
|
_engines[db_type] = engine
|
|
return engine
|
|
|
|
|
|
async def create_tables(db_type: str = None):
|
|
if db_type is None:
|
|
db_type = config.SAVE_DATA_OPTION
|
|
await create_database_if_not_exists(db_type)
|
|
engine = get_async_engine(db_type)
|
|
if engine:
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
|
|
|
|
@asynccontextmanager
|
|
async def get_session() -> AsyncSession:
|
|
engine = get_async_engine(config.SAVE_DATA_OPTION)
|
|
if not engine:
|
|
yield None
|
|
return
|
|
AsyncSessionFactory = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
|
session = AsyncSessionFactory()
|
|
try:
|
|
yield session
|
|
await session.commit()
|
|
except Exception as e:
|
|
await session.rollback()
|
|
raise e
|
|
finally:
|
|
await session.close()
|