Merge pull request #768 from yangtao210/main

优化mongodb配置获取逻辑,移动存储基类位置。集成测试
This commit is contained in:
程序员阿江-Relakkes
2025-11-07 05:44:07 -05:00
committed by GitHub
9 changed files with 408 additions and 85 deletions

View File

@@ -1,8 +1,5 @@
# -*- coding: utf-8 -*-
"""
MongoDB存储基类
提供MongoDB连接管理和通用存储方法
"""
"""MongoDB存储基类提供连接管理和通用存储方法"""
import asyncio
from typing import Dict, List, Optional
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase, AsyncIOMotorCollection
@@ -11,7 +8,7 @@ from tools import utils
class MongoDBConnection:
"""MongoDB连接管理单例"""
"""MongoDB连接管理单例模式)"""
_instance = None
_client: Optional[AsyncIOMotorClient] = None
_db: Optional[AsyncIOMotorDatabase] = None
@@ -23,7 +20,7 @@ class MongoDBConnection:
return cls._instance
async def get_client(self) -> AsyncIOMotorClient:
"""获取MongoDB客户端"""
"""获取客户端"""
if self._client is None:
async with self._lock:
if self._client is None:
@@ -31,7 +28,7 @@ class MongoDBConnection:
return self._client
async def get_db(self) -> AsyncIOMotorDatabase:
"""获取MongoDB数据库"""
"""获取数据库"""
if self._db is None:
async with self._lock:
if self._db is None:
@@ -39,135 +36,93 @@ class MongoDBConnection:
return self._db
async def _connect(self):
"""建立MongoDB连接"""
"""建立连接"""
try:
mongo_config = db_config.mongodb_config
host = mongo_config.get("host", "localhost")
port = mongo_config.get("port", 27017)
user = mongo_config.get("user", "")
password = mongo_config.get("password", "")
db_name = mongo_config.get("db_name", "media_crawler")
host = mongo_config["host"]
port = mongo_config["port"]
user = mongo_config["user"]
password = mongo_config["password"]
db_name = mongo_config["db_name"]
# 构建连接URL
# 构建连接URL(有认证/无认证)
if user and password:
connection_url = f"mongodb://{user}:{password}@{host}:{port}/"
else:
connection_url = f"mongodb://{host}:{port}/"
self._client = AsyncIOMotorClient(connection_url, serverSelectionTimeoutMS=5000)
# 测试连接
await self._client.server_info()
await self._client.server_info() # 测试连接
self._db = self._client[db_name]
utils.logger.info(f"[MongoDBConnection] Successfully connected to MongoDB at {host}:{port}, database: {db_name}")
utils.logger.info(f"[MongoDBConnection] Connected to {host}:{port}/{db_name}")
except Exception as e:
utils.logger.error(f"[MongoDBConnection] Failed to connect to MongoDB: {e}")
utils.logger.error(f"[MongoDBConnection] Connection failed: {e}")
raise
async def close(self):
"""关闭MongoDB连接"""
"""关闭连接"""
if self._client is not None:
self._client.close()
self._client = None
self._db = None
utils.logger.info("[MongoDBConnection] MongoDB connection closed")
utils.logger.info("[MongoDBConnection] Connection closed")
class MongoDBStoreBase:
"""MongoDB存储基类"""
"""MongoDB存储基类提供通用的CRUD操作"""
def __init__(self, collection_prefix: str):
"""
初始化MongoDB存储基类
"""初始化存储基类
Args:
collection_prefix: 集合名称前缀xhs, douyin, bilibili等
collection_prefix: 平台前缀xhs/douyin/bilibili等
"""
self.collection_prefix = collection_prefix
self._connection = MongoDBConnection()
async def get_collection(self, collection_suffix: str) -> AsyncIOMotorCollection:
"""
获取MongoDB集合
Args:
collection_suffix: 集合名称后缀contents, comments, creators
Returns:
MongoDB集合对象
"""
"""获取集合:{prefix}_{suffix}"""
db = await self._connection.get_db()
collection_name = f"{self.collection_prefix}_{collection_suffix}"
return db[collection_name]
async def save_or_update(self, collection_suffix: str, query: Dict, data: Dict) -> bool:
"""
保存或更新数据upsert操作
Args:
collection_suffix: 集合名称后缀
query: 查询条件
data: 要保存的数据
Returns:
是否成功
"""
"""保存或更新数据upsert"""
try:
collection = await self.get_collection(collection_suffix)
result = await collection.update_one(
query,
{"$set": data},
upsert=True
)
await collection.update_one(query, {"$set": data}, upsert=True)
return True
except Exception as e:
utils.logger.error(f"[MongoDBStoreBase.save_or_update] Failed to save data to {self.collection_prefix}_{collection_suffix}: {e}")
utils.logger.error(f"[MongoDBStoreBase] Save failed ({self.collection_prefix}_{collection_suffix}): {e}")
return False
async def find_one(self, collection_suffix: str, query: Dict) -> Optional[Dict]:
"""
查询单条数据
Args:
collection_suffix: 集合名称后缀
query: 查询条件
Returns:
查询结果
"""
"""查询单条数据"""
try:
collection = await self.get_collection(collection_suffix)
result = await collection.find_one(query)
return result
return await collection.find_one(query)
except Exception as e:
utils.logger.error(f"[MongoDBStoreBase.find_one] Failed to query from {self.collection_prefix}_{collection_suffix}: {e}")
utils.logger.error(f"[MongoDBStoreBase] Find one failed ({self.collection_prefix}_{collection_suffix}): {e}")
return None
async def find_many(self, collection_suffix: str, query: Dict, limit: int = 0) -> List[Dict]:
"""
查询多条数据
Args:
collection_suffix: 集合名称后缀
query: 查询条件
limit: 限制返回数量0表示不限制
Returns:
查询结果列表
"""
"""查询多条数据limit=0表示不限制"""
try:
collection = await self.get_collection(collection_suffix)
cursor = collection.find(query)
if limit > 0:
cursor = cursor.limit(limit)
results = await cursor.to_list(length=None)
return results
return await cursor.to_list(length=None)
except Exception as e:
utils.logger.error(f"[MongoDBStoreBase.find_many] Failed to query from {self.collection_prefix}_{collection_suffix}: {e}")
utils.logger.error(f"[MongoDBStoreBase] Find many failed ({self.collection_prefix}_{collection_suffix}): {e}")
return []
async def create_index(self, collection_suffix: str, keys: List[tuple], unique: bool = False):
"""
创建索引
Args:
collection_suffix: 集合名称后缀
keys: 索引键列表例如[("note_id", 1)]
unique: 是否创建唯一索引
"""
"""创建索引keys=[("field", 1)]"""
try:
collection = await self.get_collection(collection_suffix)
await collection.create_index(keys, unique=unique)
utils.logger.info(f"[MongoDBStoreBase.create_index] Created index on {self.collection_prefix}_{collection_suffix}")
utils.logger.info(f"[MongoDBStoreBase] Index created on {self.collection_prefix}_{collection_suffix}")
except Exception as e:
utils.logger.error(f"[MongoDBStoreBase.create_index] Failed to create index: {e}")
utils.logger.error(f"[MongoDBStoreBase] Create index failed: {e}")

View File

@@ -31,7 +31,7 @@ from database.models import BilibiliVideoComment, BilibiliVideo, BilibiliUpInfo,
from tools.async_file_writer import AsyncFileWriter
from tools import utils, words
from var import crawler_type_var
from store.mongodb_store_base import MongoDBStoreBase
from database.mongodb_store_base import MongoDBStoreBase
class BiliCsvStoreImplement(AbstractStore):

View File

@@ -28,7 +28,7 @@ from database.models import DouyinAweme, DouyinAwemeComment, DyCreator
from tools import utils, words
from tools.async_file_writer import AsyncFileWriter
from var import crawler_type_var
from store.mongodb_store_base import MongoDBStoreBase
from database.mongodb_store_base import MongoDBStoreBase
class DouyinCsvStoreImplement(AbstractStore):

View File

@@ -30,7 +30,7 @@ from database.db_session import get_session
from database.models import KuaishouVideo, KuaishouVideoComment
from tools import utils, words
from var import crawler_type_var
from store.mongodb_store_base import MongoDBStoreBase
from database.mongodb_store_base import MongoDBStoreBase
def calculate_number_of_files(file_store_path: str) -> int:

View File

@@ -31,7 +31,7 @@ from tools import utils, words
from database.db_session import get_session
from var import crawler_type_var
from tools.async_file_writer import AsyncFileWriter
from store.mongodb_store_base import MongoDBStoreBase
from database.mongodb_store_base import MongoDBStoreBase
def calculate_number_of_files(file_store_path: str) -> int:

View File

@@ -31,7 +31,7 @@ from tools import utils, words
from tools.async_file_writer import AsyncFileWriter
from database.db_session import get_session
from var import crawler_type_var
from store.mongodb_store_base import MongoDBStoreBase
from database.mongodb_store_base import MongoDBStoreBase
def calculate_number_of_files(file_store_path: str) -> int:

View File

@@ -18,7 +18,7 @@ from database.models import XhsNote, XhsNoteComment, XhsCreator
from tools.async_file_writer import AsyncFileWriter
from tools.time_util import get_current_timestamp
from var import crawler_type_var
from store.mongodb_store_base import MongoDBStoreBase
from database.mongodb_store_base import MongoDBStoreBase
from tools import utils
class XhsCsvStoreImplement(AbstractStore):

View File

@@ -31,7 +31,7 @@ from database.models import ZhihuContent, ZhihuComment, ZhihuCreator
from tools import utils, words
from var import crawler_type_var
from tools.async_file_writer import AsyncFileWriter
from store.mongodb_store_base import MongoDBStoreBase
from database.mongodb_store_base import MongoDBStoreBase
def calculate_number_of_files(file_store_path: str) -> int:
"""计算数据保存文件的前部分排序数字,支持每次运行代码不写到同一个文件中

View File

@@ -0,0 +1,368 @@
# -*- coding: utf-8 -*-
import asyncio
import unittest
import sys
import os
from datetime import datetime
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from database.mongodb_store_base import MongoDBConnection, MongoDBStoreBase
from store.xhs._store_impl import XhsMongoStoreImplement
from store.douyin._store_impl import DouyinMongoStoreImplement
from config import db_config
class TestMongoDBRealConnection(unittest.TestCase):
@classmethod
def setUpClass(cls):
try:
conn = MongoDBConnection()
asyncio.run(conn._connect())
cls.mongodb_available = True
print("\n✓ MongoDB连接成功")
except Exception as e:
cls.mongodb_available = False
print(f"\n✗ MongoDB连接失败: {e}")
def setUp(self):
if not self.mongodb_available:
self.skipTest("MongoDB不可用")
MongoDBConnection._instance = None
MongoDBConnection._client = None
MongoDBConnection._db = None
def tearDown(self):
if self.mongodb_available:
conn = MongoDBConnection()
asyncio.run(conn.close())
@classmethod
def tearDownClass(cls):
if cls.mongodb_available:
async def cleanup():
conn = MongoDBConnection()
db = await conn.get_db()
test_collections = [
"test_xhs_contents",
"test_xhs_comments",
"test_xhs_creators",
"test_douyin_contents",
"test_douyin_comments",
"test_douyin_creators"
]
for collection_name in test_collections:
try:
await db[collection_name].drop()
except:
pass
await conn.close()
try:
asyncio.run(cleanup())
print("\n✓ 测试数据清理完成")
except Exception as e:
print(f"\n✗ 清理测试数据时出错: {e}")
def test_real_connection(self):
async def test():
conn = MongoDBConnection()
client = await conn.get_client()
db = await conn.get_db()
self.assertIsNotNone(client)
self.assertIsNotNone(db)
result = await db.command("ping")
self.assertEqual(result.get("ok"), 1.0)
asyncio.run(test())
def test_real_save_and_query(self):
async def test():
store = MongoDBStoreBase(collection_prefix="test_xhs")
test_data = {
"note_id": "test_note_001",
"title": "测试笔记",
"content": "这是一条测试内容",
"created_at": datetime.now().isoformat()
}
result = await store.save_or_update(
"contents",
{"note_id": "test_note_001"},
test_data
)
self.assertTrue(result)
found = await store.find_one(
"contents",
{"note_id": "test_note_001"}
)
self.assertIsNotNone(found)
self.assertEqual(found["note_id"], "test_note_001")
self.assertEqual(found["title"], "测试笔记")
asyncio.run(test())
def test_real_update(self):
async def test():
store = MongoDBStoreBase(collection_prefix="test_xhs")
initial_data = {
"note_id": "test_note_002",
"title": "初始标题",
"likes": 10
}
await store.save_or_update(
"contents",
{"note_id": "test_note_002"},
initial_data
)
updated_data = {
"note_id": "test_note_002",
"title": "更新后的标题",
"likes": 100
}
await store.save_or_update(
"contents",
{"note_id": "test_note_002"},
updated_data
)
found = await store.find_one(
"contents",
{"note_id": "test_note_002"}
)
self.assertEqual(found["title"], "更新后的标题")
self.assertEqual(found["likes"], 100)
asyncio.run(test())
def test_real_find_many(self):
async def test():
store = MongoDBStoreBase(collection_prefix="test_xhs")
test_user_id = "test_user_123"
for i in range(5):
data = {
"note_id": f"test_note_{i:03d}",
"user_id": test_user_id,
"title": f"测试笔记{i}",
"likes": i * 10
}
await store.save_or_update(
"contents",
{"note_id": data["note_id"]},
data
)
results = await store.find_many(
"contents",
{"user_id": test_user_id}
)
self.assertGreaterEqual(len(results), 5)
limited_results = await store.find_many(
"contents",
{"user_id": test_user_id},
limit=3
)
self.assertEqual(len(limited_results), 3)
asyncio.run(test())
def test_real_create_index(self):
async def test():
store = MongoDBStoreBase(collection_prefix="test_xhs")
await store.create_index(
"contents",
[("note_id", 1)],
unique=True
)
collection = await store.get_collection("contents")
indexes = await collection.index_information()
self.assertIn("note_id_1", indexes)
asyncio.run(test())
def test_xhs_store_implementation(self):
async def test():
store = XhsMongoStoreImplement()
note_data = {
"note_id": "xhs_test_001",
"user_id": "user_001",
"nickname": "测试用户",
"title": "小红书测试笔记",
"desc": "这是一条测试笔记",
"type": "normal",
"liked_count": "100",
"collected_count": "50",
"comment_count": "20"
}
await store.store_content(note_data)
comment_data = {
"comment_id": "comment_001",
"note_id": "xhs_test_001",
"user_id": "user_002",
"nickname": "评论用户",
"content": "这是一条测试评论",
"like_count": "10"
}
await store.store_comment(comment_data)
creator_data = {
"user_id": "user_001",
"nickname": "测试创作者",
"desc": "这是一个测试创作者",
"fans": "1000",
"follows": "100"
}
await store.store_creator(creator_data)
mongo_store = store.mongo_store
note = await mongo_store.find_one("contents", {"note_id": "xhs_test_001"})
self.assertIsNotNone(note)
self.assertEqual(note["title"], "小红书测试笔记")
comment = await mongo_store.find_one("comments", {"comment_id": "comment_001"})
self.assertIsNotNone(comment)
self.assertEqual(comment["content"], "这是一条测试评论")
creator = await mongo_store.find_one("creators", {"user_id": "user_001"})
self.assertIsNotNone(creator)
self.assertEqual(creator["nickname"], "测试创作者")
asyncio.run(test())
def test_douyin_store_implementation(self):
async def test():
store = DouyinMongoStoreImplement()
video_data = {
"aweme_id": "dy_test_001",
"user_id": "user_001",
"nickname": "测试用户",
"title": "抖音测试视频",
"desc": "这是一条测试视频",
"liked_count": "1000",
"comment_count": "100"
}
await store.store_content(video_data)
comment_data = {
"comment_id": "dy_comment_001",
"aweme_id": "dy_test_001",
"user_id": "user_002",
"nickname": "评论用户",
"content": "这是一条测试评论"
}
await store.store_comment(comment_data)
creator_data = {
"user_id": "user_001",
"nickname": "测试创作者",
"desc": "这是一个测试创作者"
}
await store.store_creator(creator_data)
mongo_store = store.mongo_store
video = await mongo_store.find_one("contents", {"aweme_id": "dy_test_001"})
self.assertIsNotNone(video)
self.assertEqual(video["title"], "抖音测试视频")
comment = await mongo_store.find_one("comments", {"comment_id": "dy_comment_001"})
self.assertIsNotNone(comment)
creator = await mongo_store.find_one("creators", {"user_id": "user_001"})
self.assertIsNotNone(creator)
asyncio.run(test())
def test_concurrent_operations(self):
async def test():
store = MongoDBStoreBase(collection_prefix="test_xhs")
tasks = []
for i in range(10):
data = {
"note_id": f"concurrent_note_{i:03d}",
"title": f"并发测试笔记{i}",
"content": f"内容{i}"
}
task = store.save_or_update(
"contents",
{"note_id": data["note_id"]},
data
)
tasks.append(task)
results = await asyncio.gather(*tasks)
self.assertTrue(all(results))
for i in range(10):
found = await store.find_one(
"contents",
{"note_id": f"concurrent_note_{i:03d}"}
)
self.assertIsNotNone(found)
asyncio.run(test())
def run_integration_tests():
loader = unittest.TestLoader()
suite = unittest.TestSuite()
suite.addTests(loader.loadTestsFromTestCase(TestMongoDBRealConnection))
runner = unittest.TextTestRunner(verbosity=2)
result = runner.run(suite)
return result
if __name__ == "__main__":
print("="*70)
print("MongoDB存储集成测试")
print("="*70)
print(f"MongoDB配置:")
print(f" Host: {db_config.MONGODB_HOST}")
print(f" Port: {db_config.MONGODB_PORT}")
print(f" Database: {db_config.MONGODB_DB_NAME}")
print("="*70)
result = run_integration_tests()
print("\n" + "="*70)
print("测试统计:")
print(f"总测试数: {result.testsRun}")
print(f"成功: {result.testsRun - len(result.failures) - len(result.errors)}")
print(f"失败: {len(result.failures)}")
print(f"错误: {len(result.errors)}")
print(f"跳过: {len(result.skipped)}")
print("="*70)
sys.exit(0 if result.wasSuccessful() else 1)