diff --git a/.env.example b/.env.example index 6d986ae..6879932 100644 --- a/.env.example +++ b/.env.example @@ -112,73 +112,79 @@ ENABLE_GRAPH_RAG=false # COS_ENABLE_OLD_DOMAIN=true 表示启用旧的域名格式,默认为 true COS_ENABLE_OLD_DOMAIN=true -# 初始化默认租户与知识库 -# 租户ID,通常是一个字符串 -INIT_TEST_TENANT_ID=1 -# 知识库ID,通常是一个字符串 -INIT_TEST_KNOWLEDGE_BASE_ID=kb-00000001 +############################################################## -# LLM Model -# 使用的LLM模型名称 -# 默认使用 Ollama 的 Qwen3 8B 模型,ollama 会自动处理模型下载和加载 -# 如果需要使用其他模型,请替换为实际的模型名称 -INIT_LLM_MODEL_NAME=qwen3:8b +###### 注意: 以下配置不再生效,已在Web“配置初始化”阶段完成 ######### -# LLM模型的访问地址 -# 支持第三方模型服务的URL -# 如果使用 Ollama 的本地服务,可以留空,ollama 会自动处理 -# INIT_LLM_MODEL_BASE_URL=your_llm_model_base_url -# LLM模型的API密钥,如果需要身份验证,可以设置 -# 支持第三方模型服务的API密钥 -# 如果使用 Ollama 的本地服务,可以留空,ollama 会自动处理 -# INIT_LLM_MODEL_API_KEY=your_llm_model_api_key +# # 初始化默认租户与知识库 +# # 租户ID,通常是一个字符串 +# INIT_TEST_TENANT_ID=1 -# Embedding Model -# 使用的Embedding模型名称 -# 默认使用 nomic-embed-text 模型,支持文本嵌入 -# 如果需要使用其他模型,请替换为实际的模型名称 -INIT_EMBEDDING_MODEL_NAME=nomic-embed-text +# # 知识库ID,通常是一个字符串 +# INIT_TEST_KNOWLEDGE_BASE_ID=kb-00000001 -# Embedding模型向量维度 -INIT_EMBEDDING_MODEL_DIMENSION=768 +# # LLM Model +# # 使用的LLM模型名称 +# # 默认使用 Ollama 的 Qwen3 8B 模型,ollama 会自动处理模型下载和加载 +# # 如果需要使用其他模型,请替换为实际的模型名称 +# INIT_LLM_MODEL_NAME=qwen3:8b -# Embedding模型的ID,通常是一个字符串 -INIT_EMBEDDING_MODEL_ID=builtin:nomic-embed-text:768 +# # LLM模型的访问地址 +# # 支持第三方模型服务的URL +# # 如果使用 Ollama 的本地服务,可以留空,ollama 会自动处理 +# # INIT_LLM_MODEL_BASE_URL=your_llm_model_base_url -# Embedding模型的访问地址 -# 支持第三方模型服务的URL -# 如果使用 Ollama 的本地服务,可以留空,ollama 会自动处理 -# INIT_EMBEDDING_MODEL_BASE_URL=your_embedding_model_base_url +# # LLM模型的API密钥,如果需要身份验证,可以设置 +# # 支持第三方模型服务的API密钥 +# # 如果使用 Ollama 的本地服务,可以留空,ollama 会自动处理 +# # INIT_LLM_MODEL_API_KEY=your_llm_model_api_key -# Embedding模型的API密钥,如果需要身份验证,可以设置 -# 支持第三方模型服务的API密钥 -# 如果使用 Ollama 的本地服务,可以留空,ollama 会自动处理 -# INIT_EMBEDDING_MODEL_API_KEY=your_embedding_model_api_key +# # Embedding Model +# # 使用的Embedding模型名称 +# # 默认使用 nomic-embed-text 模型,支持文本嵌入 +# # 如果需要使用其他模型,请替换为实际的模型名称 +# INIT_EMBEDDING_MODEL_NAME=nomic-embed-text -# Rerank Model(可选) -# 对于rag来说,使用Rerank模型对提升文档搜索的准确度有着重要作用 -# 目前 ollama 暂不支持运行 Rerank 模型 -# 使用的Rerank模型名称 -# INIT_RERANK_MODEL_NAME=your_rerank_model_name +# # Embedding模型向量维度 +# INIT_EMBEDDING_MODEL_DIMENSION=768 -# Rerank模型的访问地址 -# 支持第三方模型服务的URL -# INIT_RERANK_MODEL_BASE_URL=your_rerank_model_base_url +# # Embedding模型的ID,通常是一个字符串 +# INIT_EMBEDDING_MODEL_ID=builtin:nomic-embed-text:768 -# Rerank模型的API密钥,如果需要身份验证,可以设置 -# 支持第三方模型服务的API密钥 -# INIT_RERANK_MODEL_API_KEY=your_rerank_model_api_key +# # Embedding模型的访问地址 +# # 支持第三方模型服务的URL +# # 如果使用 Ollama 的本地服务,可以留空,ollama 会自动处理 +# # INIT_EMBEDDING_MODEL_BASE_URL=your_embedding_model_base_url -# VLM_MODEL_NAME 使用的多模态模型名称 -# 用于解析图片数据 -# VLM_MODEL_NAME=your_vlm_model_name +# # Embedding模型的API密钥,如果需要身份验证,可以设置 +# # 支持第三方模型服务的API密钥 +# # 如果使用 Ollama 的本地服务,可以留空,ollama 会自动处理 +# # INIT_EMBEDDING_MODEL_API_KEY=your_embedding_model_api_key -# VLM_MODEL_BASE_URL 使用的多模态模型访问地址 -# 支持第三方模型服务的URL -# VLM_MODEL_BASE_URL=your_vlm_model_base_url +# # Rerank Model(可选) +# # 对于rag来说,使用Rerank模型对提升文档搜索的准确度有着重要作用 +# # 目前 ollama 暂不支持运行 Rerank 模型 +# # 使用的Rerank模型名称 +# # INIT_RERANK_MODEL_NAME=your_rerank_model_name -# VLM_MODEL_API_KEY 使用的多模态模型API密钥 -# 支持第三方模型服务的API密钥 -# VLM_MODEL_API_KEY=your_vlm_model_api_key \ No newline at end of file +# # Rerank模型的访问地址 +# # 支持第三方模型服务的URL +# # INIT_RERANK_MODEL_BASE_URL=your_rerank_model_base_url + +# # Rerank模型的API密钥,如果需要身份验证,可以设置 +# # 支持第三方模型服务的API密钥 +# # INIT_RERANK_MODEL_API_KEY=your_rerank_model_api_key + +# # VLM_MODEL_NAME 使用的多模态模型名称 +# # 用于解析图片数据 +# # VLM_MODEL_NAME=your_vlm_model_name + +# # VLM_MODEL_BASE_URL 使用的多模态模型访问地址 +# # 支持第三方模型服务的URL +# # VLM_MODEL_BASE_URL=your_vlm_model_base_url + +# # VLM_MODEL_API_KEY 使用的多模态模型API密钥 +# # 支持第三方模型服务的API密钥 +# # VLM_MODEL_API_KEY=your_vlm_model_api_key \ No newline at end of file diff --git a/README.md b/README.md index 5da87a8..c37ae1e 100644 --- a/README.md +++ b/README.md @@ -140,6 +140,38 @@ WeKnora 作为[微信对话开放平台](https://chatbot.weixin.qq.com)的核心 - **高效问题管理**:支持高频问题的独立分类管理,提供丰富的数据工具,确保回答精准可靠且易于维护 - **微信生态覆盖**:通过微信对话开放平台,WeKnora 的智能问答能力可无缝集成到公众号、小程序等微信场景中,提升用户交互体验 +## 🔧 初始化配置引导 + +为了方便用户快速配置各类模型,降低试错成本,我们改进了原来的配置文件初始化方式,增加了Web UI界面进行各种模型的配置。在使用之前,请确保代码更新到最新版本。具体使用步骤如下: +如果是第一次使用本项目,可跳过①②步骤,直接进入③④步骤。 + +### ① 关闭服务 + +```bash +./scripts/start_all.sh --stop +``` + +### ② 清空原有数据表(建议在没有重要数据的情况下使用) + +```bash +make clean-db +``` + +### ③ 编译并启动服务 + +```bash +./scripts/start_all.sh +``` + +### ④ 访问Web UI + +http://localhost + +首次访问会自动跳转到初始化配置页面,配置完成后会自动跳转到知识库页面。请按照页面提示信息完成模型的配置。 + +![配置页面](./docs/images/config.png) + + ## 📱 功能展示 ### Web UI 界面 diff --git a/cmd/server/main.go b/cmd/server/main.go index 403f889..b253f35 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -14,7 +14,6 @@ import ( "github.com/gin-gonic/gin" - "github.com/Tencent/WeKnora/internal/application/service" "github.com/Tencent/WeKnora/internal/config" "github.com/Tencent/WeKnora/internal/container" "github.com/Tencent/WeKnora/internal/runtime" @@ -42,7 +41,6 @@ func main() { cfg *config.Config, router *gin.Engine, tracer *tracing.Tracer, - testDataService *service.TestDataService, resourceCleaner interfaces.ResourceCleaner, ) error { // Create context for resource cleanup @@ -58,13 +56,6 @@ func main() { return tracer.Cleanup(cleanupCtx) }) - // Initialize test data - if testDataService != nil { - if err := testDataService.InitializeTestData(context.Background()); err != nil { - log.Printf("Failed to initialize test data: %v", err) - } - } - // Create HTTP server server := &http.Server{ Addr: fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port), diff --git a/config/config.yaml b/config/config.yaml index f88b3f0..d739a83 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -151,26 +151,27 @@ conversation: Output: generate_summary_prompt: | - 你是一个总结助手,你的任务是总结文章或者片段内容。 + 你是一个精准的文章总结专家。你的任务是提取并总结用户提供的文章或片段的核心内容。 - ## 准则要求 - - 总结结果长度不能超过100个字 - - 保持客观,不添加个人观点或评价 - - 使用第三人称陈述语气 - - 不要基于任何先验知识回答用户的问题,只基于文章内容生成摘要 - - 直接输出总结结果,不要有任何前缀或解释 - - 使用中文输出总结结果 - - 不能输出“无法生成”、“无法总结”等字眼 + ## 核心要求 + - 总结结果长度为50-100个字,根据内容复杂度灵活调整 + - 完全基于提供的文章内容生成总结,不添加任何未在文章中出现的信息 + - 确保总结包含文章的关键信息点和主要结论 + - 即使文章内容较复杂或专业,也必须尝试提取核心要点进行总结 + - 直接输出总结结果,不包含任何引言、前缀或解释 - ## Few-shot示例 + ## 格式与风格 + - 使用客观、中立的第三人称陈述语气 + - 使用清晰简洁的中文表达 + - 保持逻辑连贯性,确保句与句之间有合理过渡 + - 避免重复使用相同的表达方式或句式结构 - 用户给出的文章内容: - 随着5G技术的快速发展,各行各业正经历数字化转型。5G网络凭借高速率、低延迟和大连接的特性,正在推动智慧城市、工业互联网和远程医疗等领域的创新。专家预测,到2025年,5G将为全球经济贡献约2.2万亿美元。然而,5G建设也面临基础设施投入大、覆盖不均等挑战。 - - 文章总结: - 5G技术凭借高速率、低延迟和大连接特性推动各行业数字化转型,促进智慧城市、工业互联网和远程医疗创新。预计2025年将贡献约2.2万亿美元经济价值,但仍面临基础设施投入大和覆盖不均等挑战。 - - ## 用户给出的文章内容是: + ## 注意事项 + - 绝对不输出"无法生成"、"无法总结"、"内容不足"等拒绝回应的词语 + - 不要照抄或参考示例中的任何内容,确保总结完全基于用户提供的新文章 + - 对于任何文本都尽最大努力提取重点并总结,无论长度或复杂度 + + ## 以下是用户给出的文章相关信息: generate_session_title_prompt: | 你是一个专业的会话标题生成助手,你的任务是为用户提问创建简洁、精准且具描述性的标题。 diff --git a/docker-compose.yml b/docker-compose.yml index 9521ae6..420582b 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -87,6 +87,8 @@ services: build: context: . dockerfile: docker/Dockerfile.docreader + args: + - PLATFORM=${PLATFORM:-linux/amd64} container_name: WeKnora-docreader ports: - "50051:50051" @@ -104,6 +106,8 @@ services: networks: - WeKnora-network restart: unless-stopped + extra_hosts: + - "host.docker.internal:host-gateway" jaeger: image: jaegertracing/all-in-one:latest diff --git a/docker/Dockerfile.docreader b/docker/Dockerfile.docreader index 3cec8a2..b49ad55 100644 --- a/docker/Dockerfile.docreader +++ b/docker/Dockerfile.docreader @@ -93,14 +93,29 @@ RUN apt-get update && apt-get install -y \ # 下载并安装最新版本的 LibreOffice 25.2.4 RUN mkdir -p /tmp/libreoffice && \ cd /tmp/libreoffice && \ - wget -q https://mirrors.tuna.tsinghua.edu.cn/libreoffice/libreoffice/stable/25.2.4/deb/x86_64/LibreOffice_25.2.4_Linux_x86-64_deb.tar.gz && \ - tar -xzf LibreOffice_25.2.4_Linux_x86-64_deb.tar.gz && \ - cd LibreOffice_25.2.4*_Linux_x86-64_deb/DEBS && \ - dpkg -i *.deb && \ + if [ "$(uname -m)" = "x86_64" ]; then \ + wget -q https://mirrors.tuna.tsinghua.edu.cn/libreoffice/libreoffice/stable/25.2.4/deb/x86_64/LibreOffice_25.2.4_Linux_x86-64_deb.tar.gz && \ + tar -xzf LibreOffice_25.2.4_Linux_x86-64_deb.tar.gz && \ + cd LibreOffice_25.2.4*_Linux_x86-64_deb/DEBS && \ + dpkg -i *.deb; \ + elif [ "$(uname -m)" = "aarch64" ] || [ "$(uname -m)" = "arm64" ]; then \ + wget -q https://mirrors.tuna.tsinghua.edu.cn/libreoffice/libreoffice/testing/25.8.0/deb/aarch64/LibreOffice_25.8.0.2_Linux_aarch64_deb.tar.gz && \ + tar -xzf LibreOffice_25.8.0.2_Linux_aarch64_deb.tar.gz && \ + cd LibreOffice_25.8.0*_Linux_aarch64_deb/DEBS && \ + dpkg -i *.deb; \ + else \ + echo "Unsupported architecture: $(uname -m)" && exit 1; \ + fi && \ cd / && \ rm -rf /tmp/libreoffice # 设置 LibreOffice 环境变量 +RUN if [ "$(uname -m)" = "x86_64" ]; then \ + echo 'export LIBREOFFICE_PATH=/opt/libreoffice25.2/program/soffice' >> /etc/environment; \ + elif [ "$(uname -m)" = "aarch64" ] || [ "$(uname -m)" = "arm64" ]; then \ + echo 'export LIBREOFFICE_PATH=/opt/libreoffice25.2/program/soffice' >> /etc/environment; \ + fi + ENV LIBREOFFICE_PATH=/opt/libreoffice25.2/program/soffice # 从构建阶段复制已安装的依赖和生成的代码 diff --git a/docs/images/config.png b/docs/images/config.png new file mode 100644 index 0000000..9c1580a Binary files /dev/null and b/docs/images/config.png differ diff --git a/frontend/src/api/initialization/index.ts b/frontend/src/api/initialization/index.ts new file mode 100644 index 0000000..875179b --- /dev/null +++ b/frontend/src/api/initialization/index.ts @@ -0,0 +1,278 @@ +import { get, post } from '../../utils/request'; + +// 初始化配置数据类型 +export interface InitializationConfig { + llm: { + source: string; + modelName: string; + baseUrl?: string; + apiKey?: string; + }; + embedding: { + source: string; + modelName: string; + baseUrl?: string; + apiKey?: string; + dimension?: number; // 添加embedding维度字段 + }; + rerank: { + modelName: string; + baseUrl: string; + apiKey?: string; + }; + multimodal: { + enabled: boolean; + vlm?: { + modelName: string; + baseUrl: string; + apiKey?: string; + interfaceType?: string; // "ollama" or "openai" + }; + cos?: { + secretId: string; + secretKey: string; + region: string; + bucketName: string; + appId: string; + pathPrefix?: string; + }; + }; + documentSplitting: { + chunkSize: number; + chunkOverlap: number; + separators: string[]; + }; +} + +// 下载任务状态类型 +export interface DownloadTask { + id: string; + modelName: string; + status: 'pending' | 'downloading' | 'completed' | 'failed'; + progress: number; + message: string; + startTime: string; + endTime?: string; +} + +// 系统初始化状态检查 +export function checkInitializationStatus(): Promise<{ initialized: boolean }> { + return new Promise((resolve, reject) => { + get('/api/v1/initialization/status') + .then((response: any) => { + resolve(response.data || { initialized: false }); + }) + .catch((error: any) => { + console.warn('检查初始化状态失败,假设需要初始化:', error); + resolve({ initialized: false }); + }); + }); +} + +// 执行系统初始化 +export function initializeSystem(config: InitializationConfig): Promise { + return new Promise((resolve, reject) => { + console.log('开始系统初始化...', config); + post('/api/v1/initialization/initialize', config) + .then((response: any) => { + console.log('系统初始化完成', response); + // 设置本地初始化状态标记 + localStorage.setItem('system_initialized', 'true'); + resolve(response); + }) + .catch((error: any) => { + console.error('系统初始化失败:', error); + reject(error); + }); + }); +} + +// 检查Ollama服务状态 +export function checkOllamaStatus(): Promise<{ available: boolean; version?: string; error?: string }> { + return new Promise((resolve, reject) => { + get('/api/v1/initialization/ollama/status') + .then((response: any) => { + resolve(response.data || { available: false }); + }) + .catch((error: any) => { + console.error('检查Ollama状态失败:', error); + resolve({ available: false, error: error.message || '检查失败' }); + }); + }); +} + +// 检查Ollama模型状态 +export function checkOllamaModels(models: string[]): Promise<{ models: Record }> { + return new Promise((resolve, reject) => { + post('/api/v1/initialization/ollama/models/check', { models }) + .then((response: any) => { + resolve(response.data || { models: {} }); + }) + .catch((error: any) => { + console.error('检查Ollama模型状态失败:', error); + reject(error); + }); + }); +} + +// 启动Ollama模型下载(异步) +export function downloadOllamaModel(modelName: string): Promise<{ taskId: string; modelName: string; status: string; progress: number }> { + return new Promise((resolve, reject) => { + post('/api/v1/initialization/ollama/models/download', { modelName }) + .then((response: any) => { + resolve(response.data || { taskId: '', modelName, status: 'failed', progress: 0 }); + }) + .catch((error: any) => { + console.error('启动Ollama模型下载失败:', error); + reject(error); + }); + }); +} + +// 查询下载进度 +export function getDownloadProgress(taskId: string): Promise { + return new Promise((resolve, reject) => { + get(`/api/v1/initialization/ollama/download/progress/${taskId}`) + .then((response: any) => { + resolve(response.data); + }) + .catch((error: any) => { + console.error('查询下载进度失败:', error); + reject(error); + }); + }); +} + +// 获取所有下载任务 +export function listDownloadTasks(): Promise { + return new Promise((resolve, reject) => { + get('/api/v1/initialization/ollama/download/tasks') + .then((response: any) => { + resolve(response.data || []); + }) + .catch((error: any) => { + console.error('获取下载任务列表失败:', error); + reject(error); + }); + }); +} + +// 获取当前系统配置 +export function getCurrentConfig(): Promise { + return new Promise((resolve, reject) => { + get('/api/v1/initialization/config') + .then((response: any) => { + resolve(response.data || {}); + }) + .catch((error: any) => { + console.error('获取当前配置失败:', error); + reject(error); + }); + }); +} + +// 检查远程API模型 +export function checkRemoteModel(modelConfig: { + modelName: string; + baseUrl: string; + apiKey?: string; +}): Promise<{ + available: boolean; + message?: string; +}> { + return new Promise((resolve, reject) => { + post('/api/v1/initialization/remote/check', modelConfig) + .then((response: any) => { + resolve(response.data || {}); + }) + .catch((error: any) => { + console.error('检查远程模型失败:', error); + reject(error); + }); + }); +} + +export function checkRerankModel(modelConfig: { + modelName: string; + baseUrl: string; + apiKey?: string; +}): Promise<{ + available: boolean; + message?: string; +}> { + return new Promise((resolve, reject) => { + post('/api/v1/initialization/rerank/check', modelConfig) + .then((response: any) => { + resolve(response.data || {}); + }) + .catch((error: any) => { + console.error('检查Rerank模型失败:', error); + reject(error); + }); + }); +} + +export function testMultimodalFunction(testData: { + image: File; + vlm_model: string; + vlm_base_url: string; + vlm_api_key?: string; + vlm_interface_type?: string; + cos_secret_id: string; + cos_secret_key: string; + cos_region: string; + cos_bucket_name: string; + cos_app_id: string; + cos_path_prefix?: string; + chunk_size: number; + chunk_overlap: number; + separators: string[]; +}): Promise<{ + success: boolean; + caption?: string; + ocr?: string; + processing_time?: number; + message?: string; +}> { + return new Promise((resolve, reject) => { + const formData = new FormData(); + formData.append('image', testData.image); + formData.append('vlm_model', testData.vlm_model); + formData.append('vlm_base_url', testData.vlm_base_url); + if (testData.vlm_api_key) { + formData.append('vlm_api_key', testData.vlm_api_key); + } + if (testData.vlm_interface_type) { + formData.append('vlm_interface_type', testData.vlm_interface_type); + } + formData.append('cos_secret_id', testData.cos_secret_id); + formData.append('cos_secret_key', testData.cos_secret_key); + formData.append('cos_region', testData.cos_region); + formData.append('cos_bucket_name', testData.cos_bucket_name); + formData.append('cos_app_id', testData.cos_app_id); + if (testData.cos_path_prefix) { + formData.append('cos_path_prefix', testData.cos_path_prefix); + } + formData.append('chunk_size', testData.chunk_size.toString()); + formData.append('chunk_overlap', testData.chunk_overlap.toString()); + formData.append('separators', JSON.stringify(testData.separators)); + + // 使用原生fetch因为需要发送FormData + fetch('/api/v1/initialization/multimodal/test', { + method: 'POST', + body: formData + }) + .then(response => response.json()) + .then((data: any) => { + if (data.success) { + resolve(data.data || {}); + } else { + resolve({ success: false, message: data.message || '测试失败' }); + } + }) + .catch((error: any) => { + console.error('多模态测试失败:', error); + reject(error); + }); + }); +} \ No newline at end of file diff --git a/frontend/src/components/menu.vue b/frontend/src/components/menu.vue index b656c04..cc1d7f1 100644 --- a/frontend/src/components/menu.vue +++ b/frontend/src/components/menu.vue @@ -173,7 +173,12 @@ const getIcon = (path) => { getIcon(route.name) const gotopage = (path) => { pathPrefix.value = path; - router.push(`/platform/${path}`); + // 如果是系统设置,跳转到初始化配置页面 + if (path === 'settings') { + router.push('/initialization'); + } else { + router.push(`/platform/${path}`); + } getIcon(path) } diff --git a/frontend/src/hooks/useKnowledgeBase.ts b/frontend/src/hooks/useKnowledgeBase.ts index 08ce546..63ba904 100644 --- a/frontend/src/hooks/useKnowledgeBase.ts +++ b/frontend/src/hooks/useKnowledgeBase.ts @@ -23,7 +23,7 @@ export default function () { }); const getKnowled = (query = { page: 1, page_size: 35 }) => { getKnowledgeBase(query) - .then((result: object) => { + .then((result: any) => { let { data, total: totalResult } = result; let cardList_ = data.map((item) => { item["file_name"] = item.file_name.substring( @@ -50,7 +50,7 @@ export default function () { cardList.value[index].isMore = false; moreIndex.value = -1; delKnowledgeDetails(item.id) - .then((result) => { + .then((result: any) => { if (result.success) { MessagePlugin.info("知识删除成功!"); getKnowled(); @@ -76,17 +76,29 @@ export default function () { return; } uploadKnowledgeBase({ file }) - .then((result) => { + .then((result: any) => { if (result.success) { MessagePlugin.info("上传成功!"); getKnowled(); } else { - MessagePlugin.error("上传失败!"); + // 检查错误码,如果是重复文件则显示特定提示 + if (result.code === 'duplicate_file') { + MessagePlugin.error("文件已存在"); + } else { + MessagePlugin.error(result.message || (result.error && result.error.message) || "上传失败!"); + } } uploadInput.value.value = ""; }) - .catch((err) => { - MessagePlugin.error("上传失败!"); + .catch((err: any) => { + // 检查错误响应中的错误码 + if (err.code === 'duplicate_file') { + MessagePlugin.error("文件已存在"); + } else if (err.message) { + MessagePlugin.error(err.message); + } else { + MessagePlugin.error("上传失败!"); + } uploadInput.value.value = ""; }); } else { @@ -101,7 +113,7 @@ export default function () { id: "", }); getKnowledgeDetails(item.id) - .then((result) => { + .then((result: any) => { if (result.success && result.data) { let { data } = result; Object.assign(details, { @@ -116,7 +128,7 @@ export default function () { }; const getfDetails = (id, page) => { getKnowledgeDetailsCon(id, page) - .then((result) => { + .then((result: any) => { if (result.success && result.data) { let { data, total: totalResult } = result; if (page == 1) { diff --git a/frontend/src/router/index.ts b/frontend/src/router/index.ts index 2e37e21..6e1d7b8 100644 --- a/frontend/src/router/index.ts +++ b/frontend/src/router/index.ts @@ -1,4 +1,5 @@ import { createRouter, createWebHistory } from 'vue-router' +import { checkInitializationStatus } from '@/api/initialization' const router = createRouter({ history: createWebHistory(import.meta.env.BASE_URL), @@ -7,40 +8,82 @@ const router = createRouter({ path: "/", redirect: "/platform", }, + { + path: "/initialization", + name: "initialization", + component: () => import("../views/initialization/InitializationConfig.vue"), + meta: { requiresInit: false } // 初始化页面不需要检查初始化状态 + }, { path: "/knowledgeBase", name: "home", component: () => import("../views/knowledge/KnowledgeBase.vue"), + meta: { requiresInit: true } }, { path: "/platform", name: "Platform", redirect: "/platform/knowledgeBase", component: () => import("../views/platform/index.vue"), + meta: { requiresInit: true }, children: [ { path: "knowledgeBase", name: "knowledgeBase", component: () => import("../views/knowledge/KnowledgeBase.vue"), + meta: { requiresInit: true } }, { path: "creatChat", name: "creatChat", component: () => import("../views/creatChat/creatChat.vue"), + meta: { requiresInit: true } }, { path: "chat/:chatid", name: "chat", component: () => import("../views/chat/index.vue"), + meta: { requiresInit: true } }, { path: "settings", name: "settings", component: () => import("../views/settings/Settings.vue"), + meta: { requiresInit: true } }, ], }, ], }); +// 路由守卫:检查系统初始化状态 +router.beforeEach(async (to, from, next) => { + // 如果访问的是初始化页面,直接放行 + if (to.meta.requiresInit === false) { + next(); + return; + } + +1 + + try { + // 检查系统是否已初始化 + const { initialized } = await checkInitializationStatus(); + + if (initialized) { + // 系统已初始化,记录到本地存储并正常跳转 + localStorage.setItem('system_initialized', 'true'); + next(); + } else { + // 系统未初始化,跳转到初始化页面 + console.log('系统未初始化,跳转到初始化页面'); + next('/initialization'); + } + } catch (error) { + console.error('检查初始化状态失败:', error); + // 如果检查失败,默认认为需要初始化 + next('/initialization'); + } +}); + export default router diff --git a/frontend/src/views/initialization/InitializationConfig.vue b/frontend/src/views/initialization/InitializationConfig.vue new file mode 100644 index 0000000..4a2a71e --- /dev/null +++ b/frontend/src/views/initialization/InitializationConfig.vue @@ -0,0 +1,2626 @@ + + + + + \ No newline at end of file diff --git a/go.mod b/go.mod index ef5699a..d9206bd 100644 --- a/go.mod +++ b/go.mod @@ -13,7 +13,7 @@ require ( github.com/google/uuid v1.6.0 github.com/hibiken/asynq v0.25.1 github.com/minio/minio-go/v7 v7.0.90 - github.com/ollama/ollama v0.9.6 + github.com/ollama/ollama v0.11.4 github.com/panjf2000/ants/v2 v2.11.2 github.com/parquet-go/parquet-go v0.25.0 github.com/pgvector/pgvector-go v0.3.0 diff --git a/go.sum b/go.sum index b0c93ea..96e533f 100644 --- a/go.sum +++ b/go.sum @@ -141,8 +141,8 @@ github.com/mozillazg/go-httpheader v0.2.1 h1:geV7TrjbL8KXSyvghnFm+NyTux/hxwueTSr github.com/mozillazg/go-httpheader v0.2.1/go.mod h1:jJ8xECTlalr6ValeXYdOF8fFUISeBAdw6E61aqQma60= github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= -github.com/ollama/ollama v0.9.6 h1:HZNJmB52pMt6zLkGkkheBuXBXM5478eiSAj7GR75AMc= -github.com/ollama/ollama v0.9.6/go.mod h1:zLwx3iZ3AI4Rc/egsrx3u1w4RU2MHQ/Ylxse48jvyt4= +github.com/ollama/ollama v0.11.4 h1:6xLYLEPTKtw6N20qQecyEL/rrBktPO4o5U05cnvkSmI= +github.com/ollama/ollama v0.11.4/go.mod h1:9+1//yWPsDE2u+l1a5mpaKrYw4VdnSsRU3ioq5BvMms= github.com/panjf2000/ants/v2 v2.11.2 h1:AVGpMSePxUNpcLaBO34xuIgM1ZdKOiGnpxLXixLi5Jo= github.com/panjf2000/ants/v2 v2.11.2/go.mod h1:8u92CYMUc6gyvTIw8Ru7Mt7+/ESnJahz5EVtqfrilek= github.com/parquet-go/parquet-go v0.25.0 h1:GwKy11MuF+al/lV6nUsFw8w8HCiPOSAx1/y8yFxjH5c= diff --git a/internal/application/service/knowledge.go b/internal/application/service/knowledge.go index b093cf5..038c40e 100644 --- a/internal/application/service/knowledge.go +++ b/internal/application/service/knowledge.go @@ -19,6 +19,7 @@ import ( "github.com/Tencent/WeKnora/internal/application/service/retriever" "github.com/Tencent/WeKnora/internal/config" + werrors "github.com/Tencent/WeKnora/internal/errors" "github.com/Tencent/WeKnora/internal/logger" "github.com/Tencent/WeKnora/internal/models/chat" "github.com/Tencent/WeKnora/internal/models/utils" @@ -105,6 +106,28 @@ func (s *knowledgeService) CreateKnowledgeFromFile(ctx context.Context, return nil, err } + // 检查多模态配置完整性 - 只在图片文件时校验 + // 检查是否为图片文件 + if !IsImageType(getFileType(file.Filename)) { + logger.Info(ctx, "Non-image file with multimodal enabled, skipping COS/VLM validation") + } else { + // 检查COS配置 + if kb.COSConfig.SecretID == "" || kb.COSConfig.SecretKey == "" || + kb.COSConfig.Region == "" || kb.COSConfig.BucketName == "" || + kb.COSConfig.AppID == "" { + logger.Error(ctx, "COS configuration incomplete for image multimodal processing") + return nil, werrors.NewBadRequestError("上传图片文件需要完整的COS配置信息") + } + + // 检查VLM配置 + if kb.VLMConfig.ModelName == "" || kb.VLMConfig.BaseURL == "" { + logger.Error(ctx, "VLM configuration incomplete for image multimodal processing") + return nil, werrors.NewBadRequestError("上传图片文件需要完整的VLM配置信息") + } + + logger.Info(ctx, "Image multimodal configuration validation passed") + } + // Validate file type logger.Infof(ctx, "Checking file type: %s", file.Filename) if !isValidFileType(file.Filename) { @@ -647,6 +670,20 @@ func (s *knowledgeService) processDocument(ctx context.Context, ChunkOverlap: int32(kb.ChunkingConfig.ChunkOverlap), Separators: kb.ChunkingConfig.Separators, EnableMultimodal: enableMultimodel, + CosConfig: &proto.COSConfig{ + SecretId: kb.COSConfig.SecretID, + SecretKey: kb.COSConfig.SecretKey, + Region: kb.COSConfig.Region, + BucketName: kb.COSConfig.BucketName, + AppId: kb.COSConfig.AppID, + PathPrefix: kb.COSConfig.PathPrefix, + }, + VlmConfig: &proto.VLMConfig{ + ModelName: kb.VLMConfig.ModelName, + BaseUrl: kb.VLMConfig.BaseURL, + ApiKey: kb.VLMConfig.APIKey, + InterfaceType: kb.VLMConfig.InterfaceType, + }, }, RequestId: ctx.Value(types.RequestIDContextKey).(string), }) @@ -687,6 +724,20 @@ func (s *knowledgeService) processDocumentFromURL(ctx context.Context, ChunkOverlap: int32(kb.ChunkingConfig.ChunkOverlap), Separators: kb.ChunkingConfig.Separators, EnableMultimodal: enableMultimodel, + CosConfig: &proto.COSConfig{ + SecretId: kb.COSConfig.SecretID, + SecretKey: kb.COSConfig.SecretKey, + Region: kb.COSConfig.Region, + BucketName: kb.COSConfig.BucketName, + AppId: kb.COSConfig.AppID, + PathPrefix: kb.COSConfig.PathPrefix, + }, + VlmConfig: &proto.VLMConfig{ + ModelName: kb.VLMConfig.ModelName, + BaseUrl: kb.VLMConfig.BaseURL, + ApiKey: kb.VLMConfig.APIKey, + InterfaceType: kb.VLMConfig.InterfaceType, + }, }, RequestId: ctx.Value(types.RequestIDContextKey).(string), }) @@ -1145,7 +1196,7 @@ func (s *knowledgeService) getSummary(ctx context.Context, chunkContents = chunkContents + imageAnnotations } - if len(chunkContents) < 30 { + if len(chunkContents) < 300 { return chunkContents, nil } diff --git a/internal/application/service/knowledgebase.go b/internal/application/service/knowledgebase.go index 326290a..1710d31 100644 --- a/internal/application/service/knowledgebase.go +++ b/internal/application/service/knowledgebase.go @@ -47,7 +47,9 @@ func (s *knowledgeBaseService) CreateKnowledgeBase(ctx context.Context, kb *types.KnowledgeBase, ) (*types.KnowledgeBase, error) { // Generate UUID and set creation timestamps - kb.ID = uuid.New().String() + if kb.ID == "" { + kb.ID = uuid.New().String() + } kb.CreatedAt = time.Now() kb.TenantID = ctx.Value(types.TenantIDContextKey).(uint) kb.UpdatedAt = time.Now() @@ -245,6 +247,9 @@ func (s *knowledgeBaseService) CopyKnowledgeBase(ctx context.Context, ImageProcessingConfig: sourceKB.ImageProcessingConfig, EmbeddingModelID: sourceKB.EmbeddingModelID, SummaryModelID: sourceKB.SummaryModelID, + RerankModelID: sourceKB.RerankModelID, + VLMModelID: sourceKB.VLMModelID, + COSConfig: sourceKB.COSConfig, } if err := s.repo.CreateKnowledgeBase(ctx, targetKB); err != nil { return nil, nil, err diff --git a/internal/application/service/test_data.go b/internal/application/service/test_data.go index b0679b3..49e098d 100644 --- a/internal/application/service/test_data.go +++ b/internal/application/service/test_data.go @@ -144,6 +144,7 @@ func (s *TestDataService) initKnowledgeBase(ctx context.Context) error { }, EmbeddingModelID: s.EmbedModel.GetModelID(), SummaryModelID: s.LLMModel.GetModelID(), + RerankModelID: s.RerankModel.GetModelID(), } // 初始化测试知识库 diff --git a/internal/container/container.go b/internal/container/container.go index 6808760..614ff7b 100644 --- a/internal/container/container.go +++ b/internal/container/container.go @@ -115,6 +115,7 @@ func BuildContainer(container *dig.Container) *dig.Container { must(container.Provide(handler.NewTestDataHandler)) must(container.Provide(handler.NewModelHandler)) must(container.Provide(handler.NewEvaluationHandler)) + must(container.Provide(handler.NewInitializationHandler)) // Router configuration must(container.Provide(router.NewRouter)) diff --git a/internal/handler/initialization.go b/internal/handler/initialization.go new file mode 100644 index 0000000..3e79478 --- /dev/null +++ b/internal/handler/initialization.go @@ -0,0 +1,1718 @@ +package handler + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "strings" + "sync" + "time" + + "strconv" + + "github.com/Tencent/WeKnora/internal/config" + "github.com/Tencent/WeKnora/internal/errors" + "github.com/Tencent/WeKnora/internal/logger" + "github.com/Tencent/WeKnora/internal/models/utils/ollama" + "github.com/Tencent/WeKnora/internal/types" + "github.com/Tencent/WeKnora/internal/types/interfaces" + "github.com/Tencent/WeKnora/services/docreader/src/client" + "github.com/Tencent/WeKnora/services/docreader/src/proto" + "github.com/gin-gonic/gin" + "github.com/google/uuid" + "github.com/ollama/ollama/api" +) + +// DownloadTask 下载任务信息 +type DownloadTask struct { + ID string `json:"id"` + ModelName string `json:"modelName"` + Status string `json:"status"` // pending, downloading, completed, failed + Progress float64 `json:"progress"` + Message string `json:"message"` + StartTime time.Time `json:"startTime"` + EndTime *time.Time `json:"endTime,omitempty"` +} + +// 全局下载任务管理器 +var ( + downloadTasks = make(map[string]*DownloadTask) + tasksMutex sync.RWMutex +) + +// InitializationHandler 初始化处理器 +type InitializationHandler struct { + config *config.Config + tenantService interfaces.TenantService + modelService interfaces.ModelService + kbService interfaces.KnowledgeBaseService + kbRepository interfaces.KnowledgeBaseRepository + knowledgeService interfaces.KnowledgeService + ollamaService *ollama.OllamaService + docReaderClient *client.Client +} + +// NewInitializationHandler 创建初始化处理器 +func NewInitializationHandler( + config *config.Config, + tenantService interfaces.TenantService, + modelService interfaces.ModelService, + kbService interfaces.KnowledgeBaseService, + kbRepository interfaces.KnowledgeBaseRepository, + knowledgeService interfaces.KnowledgeService, + ollamaService *ollama.OllamaService, + docReaderClient *client.Client, +) *InitializationHandler { + return &InitializationHandler{ + config: config, + tenantService: tenantService, + modelService: modelService, + kbService: kbService, + kbRepository: kbRepository, + knowledgeService: knowledgeService, + ollamaService: ollamaService, + docReaderClient: docReaderClient, + } +} + +// InitializationRequest 初始化请求结构 +type InitializationRequest struct { + LLM struct { + Source string `json:"source" binding:"required"` + ModelName string `json:"modelName" binding:"required"` + BaseURL string `json:"baseUrl"` + APIKey string `json:"apiKey"` + } `json:"llm" binding:"required"` + + Embedding struct { + Source string `json:"source" binding:"required"` + ModelName string `json:"modelName" binding:"required"` + BaseURL string `json:"baseUrl"` + APIKey string `json:"apiKey"` + Dimension int `json:"dimension"` // 添加embedding维度字段 + } `json:"embedding" binding:"required"` + + Rerank struct { + Enabled bool `json:"enabled"` + ModelName string `json:"modelName"` + BaseURL string `json:"baseUrl"` + APIKey string `json:"apiKey"` + } `json:"rerank"` + + Multimodal struct { + Enabled bool `json:"enabled"` + VLM *struct { + ModelName string `json:"modelName"` + BaseURL string `json:"baseUrl"` + APIKey string `json:"apiKey"` + InterfaceType string `json:"interfaceType"` // "ollama" or "openai" + } `json:"vlm,omitempty"` + COS *struct { + SecretID string `json:"secretId"` + SecretKey string `json:"secretKey"` + Region string `json:"region"` + BucketName string `json:"bucketName"` + AppID string `json:"appId"` + PathPrefix string `json:"pathPrefix"` + } `json:"cos,omitempty"` + } `json:"multimodal"` + + DocumentSplitting struct { + ChunkSize int `json:"chunkSize" binding:"required,min=100,max=10000"` + ChunkOverlap int `json:"chunkOverlap" binding:"required,min=0"` + Separators []string `json:"separators" binding:"required,min=1"` + } `json:"documentSplitting" binding:"required"` +} + +// CheckStatus 检查系统初始化状态 +func (h *InitializationHandler) CheckStatus(c *gin.Context) { + ctx := c.Request.Context() + logger.Info(ctx, "Checking system initialization status") + + // 检查是否存在租户 + tenant, err := h.tenantService.GetTenantByID(ctx, types.InitDefaultTenantID) + if err != nil { + logger.ErrorWithFields(ctx, err, nil) + c.JSON(http.StatusOK, gin.H{ + "success": true, + "data": gin.H{ + "initialized": false, + }, + }) + return + } + + // 如果没有租户,说明系统未初始化 + if tenant == nil { + logger.Info(ctx, "No tenants found, system not initialized") + c.JSON(http.StatusOK, gin.H{ + "success": true, + "data": gin.H{ + "initialized": false, + }, + }) + return + } + ctx = context.WithValue(ctx, types.TenantIDContextKey, types.InitDefaultTenantID) + + // 检查是否存在模型 + models, err := h.modelService.ListModels(ctx) + if err != nil || len(models) == 0 { + logger.Info(ctx, "No models found, system not initialized") + c.JSON(http.StatusOK, gin.H{ + "success": true, + "data": gin.H{ + "initialized": false, + }, + }) + return + } + + logger.Info(ctx, "System is already initialized") + c.JSON(http.StatusOK, gin.H{ + "success": true, + "data": gin.H{ + "initialized": true, + }, + }) +} + +// Initialize 执行系统初始化 +func (h *InitializationHandler) Initialize(c *gin.Context) { + ctx := c.Request.Context() + + logger.Info(ctx, "Starting system initialization") + + var req InitializationRequest + if err := c.ShouldBindJSON(&req); err != nil { + logger.Error(ctx, "Failed to parse initialization request", err) + c.Error(errors.NewBadRequestError(err.Error())) + return + } + + // 验证多模态配置 + if req.Multimodal.Enabled { + if req.Multimodal.VLM == nil || req.Multimodal.COS == nil { + logger.Error(ctx, "Multimodal enabled but missing VLM or COS configuration") + c.Error(errors.NewBadRequestError("启用多模态时需要配置VLM和COS信息")) + return + } + if req.Multimodal.VLM.InterfaceType == "ollama" { + req.Multimodal.VLM.BaseURL = os.Getenv("OLLAMA_BASE_URL") + "/v1" + } + if req.Multimodal.VLM.ModelName == "" || req.Multimodal.VLM.BaseURL == "" { + logger.Error(ctx, "VLM configuration incomplete") + c.Error(errors.NewBadRequestError("VLM配置不完整")) + return + } + if req.Multimodal.COS.SecretID == "" || req.Multimodal.COS.SecretKey == "" || + req.Multimodal.COS.Region == "" || req.Multimodal.COS.BucketName == "" || + req.Multimodal.COS.AppID == "" { + logger.Error(ctx, "COS configuration incomplete") + c.Error(errors.NewBadRequestError("COS配置不完整")) + return + } + } + + // 验证Rerank配置(如果启用) + if req.Rerank.Enabled { + if req.Rerank.ModelName == "" || req.Rerank.BaseURL == "" { + logger.Error(ctx, "Rerank configuration incomplete") + c.Error(errors.NewBadRequestError("启用Rerank时需要配置模型名称和Base URL")) + return + } + } + + // 验证文档分割配置 + if req.DocumentSplitting.ChunkOverlap >= req.DocumentSplitting.ChunkSize { + logger.Error(ctx, "Chunk overlap must be less than chunk size") + c.Error(errors.NewBadRequestError("分块重叠大小必须小于分块大小")) + return + } + if len(req.DocumentSplitting.Separators) == 0 { + logger.Error(ctx, "Document separators cannot be empty") + c.Error(errors.NewBadRequestError("文档分隔符不能为空")) + return + } + var err error + // 1. 处理租户 - 检查是否存在,不存在则创建 + tenant, _ := h.tenantService.GetTenantByID(ctx, types.InitDefaultTenantID) + if tenant == nil { + logger.Info(ctx, "Tenant not found, creating tenant") + // 创建默认租户 + tenant = &types.Tenant{ + ID: types.InitDefaultTenantID, + Name: "Default Tenant", + Description: "System Default Tenant", + RetrieverEngines: types.RetrieverEngines{ + Engines: []types.RetrieverEngineParams{ + { + RetrieverType: types.KeywordsRetrieverType, + RetrieverEngineType: types.PostgresRetrieverEngineType, + }, + { + RetrieverType: types.VectorRetrieverType, + RetrieverEngineType: types.PostgresRetrieverEngineType, + }, + }, + }, + } + logger.Info(ctx, "Creating default tenant") + tenant, err = h.tenantService.CreateTenant(ctx, tenant) + if err != nil { + logger.ErrorWithFields(ctx, err, nil) + c.Error(errors.NewInternalServerError("创建租户失败: " + err.Error())) + return + } + } else { + logger.Info(ctx, "Tenant exists, updating if needed") + // 更新租户信息(如果需要) + updated := false + if tenant.Name != "Default Tenant" { + tenant.Name = "Default Tenant" + updated = true + } + if tenant.Description != "System Default Tenant" { + tenant.Description = "System Default Tenant" + updated = true + } + + if updated { + _, err = h.tenantService.UpdateTenant(ctx, tenant) + if err != nil { + logger.ErrorWithFields(ctx, err, nil) + c.Error(errors.NewInternalServerError("更新租户失败: " + err.Error())) + return + } + logger.Info(ctx, "Tenant updated successfully") + } + } + + // 创建带有租户ID的新上下文 + newCtx := context.WithValue(ctx, types.TenantIDContextKey, types.InitDefaultTenantID) + + // 2. 处理模型 - 检查现有模型并更新或创建 + existingModels, err := h.modelService.ListModels(newCtx) + if err != nil { + logger.ErrorWithFields(ctx, err, nil) + // 如果获取失败,继续执行创建流程 + existingModels = []*types.Model{} + } + + // 构建模型映射,按类型分组 + modelMap := make(map[types.ModelType]*types.Model) + for _, model := range existingModels { + modelMap[model.Type] = model + } + + // 要处理的模型配置 + modelsToProcess := []struct { + modelType types.ModelType + name string + source types.ModelSource + description string + baseURL string + apiKey string + dimension int + }{ + { + modelType: types.ModelTypeKnowledgeQA, + name: req.LLM.ModelName, + source: types.ModelSource(req.LLM.Source), + description: "LLM Model for Knowledge QA", + baseURL: req.LLM.BaseURL, + apiKey: req.LLM.APIKey, + }, + { + modelType: types.ModelTypeEmbedding, + name: req.Embedding.ModelName, + source: types.ModelSource(req.Embedding.Source), + description: "Embedding Model", + baseURL: req.Embedding.BaseURL, + apiKey: req.Embedding.APIKey, + dimension: req.Embedding.Dimension, + }, + } + + // 如果启用Rerank,添加Rerank模型 + if req.Rerank.Enabled { + modelsToProcess = append(modelsToProcess, struct { + modelType types.ModelType + name string + source types.ModelSource + description string + baseURL string + apiKey string + dimension int + }{ + modelType: types.ModelTypeRerank, + name: req.Rerank.ModelName, + source: types.ModelSourceRemote, + description: "Rerank Model", + baseURL: req.Rerank.BaseURL, + apiKey: req.Rerank.APIKey, + }) + } + + // 如果启用多模态,添加VLM模型 + if req.Multimodal.Enabled && req.Multimodal.VLM != nil { + modelsToProcess = append(modelsToProcess, struct { + modelType types.ModelType + name string + source types.ModelSource + description string + baseURL string + apiKey string + dimension int + }{ + modelType: types.ModelTypeVLLM, + name: req.Multimodal.VLM.ModelName, + source: types.ModelSourceRemote, + description: "Vision Language Model", + baseURL: req.Multimodal.VLM.BaseURL, + apiKey: req.Multimodal.VLM.APIKey, + }) + } + + // 处理每个模型 + var processedModels []*types.Model + for _, modelConfig := range modelsToProcess { + existingModel, exists := modelMap[modelConfig.modelType] + + if exists { + // 更新现有模型 + logger.Infof(ctx, "Updating existing model: %s (%s)", + modelConfig.name, modelConfig.modelType, + ) + existingModel.Name = modelConfig.name + existingModel.Source = modelConfig.source + existingModel.Description = modelConfig.description + existingModel.Parameters = types.ModelParameters{ + BaseURL: modelConfig.baseURL, + APIKey: modelConfig.apiKey, + EmbeddingParameters: types.EmbeddingParameters{ + Dimension: modelConfig.dimension, + }, + } + existingModel.IsDefault = true + existingModel.Status = types.ModelStatusActive + + err := h.modelService.UpdateModel(newCtx, existingModel) + if err != nil { + logger.ErrorWithFields(ctx, err, map[string]interface{}{ + "model_name": modelConfig.name, + "model_type": modelConfig.modelType, + }) + c.Error(errors.NewInternalServerError("更新模型失败: " + err.Error())) + return + } + processedModels = append(processedModels, existingModel) + } else { + // 创建新模型 + logger.Infof(ctx, + "Creating new model: %s (%s)", + modelConfig.name, modelConfig.modelType, + ) + newModel := &types.Model{ + TenantID: types.InitDefaultTenantID, + Name: modelConfig.name, + Type: modelConfig.modelType, + Source: modelConfig.source, + Description: modelConfig.description, + Parameters: types.ModelParameters{ + BaseURL: modelConfig.baseURL, + APIKey: modelConfig.apiKey, + EmbeddingParameters: types.EmbeddingParameters{ + Dimension: modelConfig.dimension, + }, + }, + IsDefault: true, + Status: types.ModelStatusActive, + } + + err := h.modelService.CreateModel(newCtx, newModel) + if err != nil { + logger.ErrorWithFields(ctx, err, map[string]interface{}{ + "model_name": modelConfig.name, + "model_type": modelConfig.modelType, + }) + c.Error(errors.NewInternalServerError("创建模型失败: " + err.Error())) + return + } + processedModels = append(processedModels, newModel) + } + } + + // 删除不需要的VLM模型(如果多模态被禁用) + if !req.Multimodal.Enabled { + if existingVLM, exists := modelMap[types.ModelTypeVLLM]; exists { + logger.Info(ctx, "Deleting VLM model as multimodal is disabled") + err := h.modelService.DeleteModel(newCtx, existingVLM.ID) + if err != nil { + logger.ErrorWithFields(ctx, err, map[string]interface{}{ + "model_id": existingVLM.ID, + }) + // 记录错误但不阻止流程 + logger.Warn(ctx, "Failed to delete VLM model, but continuing") + } + } + } + + // 删除不需要的Rerank模型(如果Rerank被禁用) + if !req.Rerank.Enabled { + if existingRerank, exists := modelMap[types.ModelTypeRerank]; exists { + logger.Info(ctx, "Deleting Rerank model as rerank is disabled") + err := h.modelService.DeleteModel(newCtx, existingRerank.ID) + if err != nil { + logger.ErrorWithFields(ctx, err, map[string]interface{}{ + "model_id": existingRerank.ID, + }) + // 记录错误但不阻止流程 + logger.Warn(ctx, "Failed to delete Rerank model, but continuing") + } + } + } + + // 3. 处理知识库 - 检查是否存在,不存在则创建,存在则更新 + kb, err := h.kbService.GetKnowledgeBaseByID(newCtx, types.InitDefaultKnowledgeBaseID) + + // 找到embedding模型ID和LLM模型ID + var embeddingModelID, llmModelID, rerankModelID, vlmModelID string + for _, model := range processedModels { + if model.Type == types.ModelTypeEmbedding { + embeddingModelID = model.ID + } + if model.Type == types.ModelTypeKnowledgeQA { + llmModelID = model.ID + } + if model.Type == types.ModelTypeRerank && req.Rerank.Enabled { + rerankModelID = model.ID + } + if model.Type == types.ModelTypeVLLM { + vlmModelID = model.ID + } + } + + if kb == nil { + // 创建新知识库 + logger.Info(ctx, "Creating default knowledge base") + kb = &types.KnowledgeBase{ + ID: types.InitDefaultKnowledgeBaseID, + Name: "Default Knowledge Base", + Description: "System Default Knowledge Base", + TenantID: types.InitDefaultTenantID, + ChunkingConfig: types.ChunkingConfig{ + ChunkSize: req.DocumentSplitting.ChunkSize, + ChunkOverlap: req.DocumentSplitting.ChunkOverlap, + Separators: req.DocumentSplitting.Separators, + EnableMultimodal: req.Multimodal.Enabled, + }, + EmbeddingModelID: embeddingModelID, + SummaryModelID: llmModelID, + RerankModelID: rerankModelID, + VLMModelID: vlmModelID, + VLMConfig: types.VLMConfig{ + ModelName: req.Multimodal.VLM.ModelName, + BaseURL: req.Multimodal.VLM.BaseURL, + APIKey: req.Multimodal.VLM.APIKey, + InterfaceType: req.Multimodal.VLM.InterfaceType, + }, + COSConfig: types.COSConfig{ + SecretID: req.Multimodal.COS.SecretID, + SecretKey: req.Multimodal.COS.SecretKey, + Region: req.Multimodal.COS.Region, + BucketName: req.Multimodal.COS.BucketName, + AppID: req.Multimodal.COS.AppID, + PathPrefix: req.Multimodal.COS.PathPrefix, + }, + } + + _, err = h.kbService.CreateKnowledgeBase(newCtx, kb) + if err != nil { + logger.ErrorWithFields(ctx, err, nil) + c.Error(errors.NewInternalServerError("创建知识库失败: " + err.Error())) + return + } + } else { + // 更新现有知识库 + logger.Info(ctx, "Updating existing knowledge base") + + // 检查是否有文件,如果有文件则不允许修改Embedding模型 + knowledgeList, err := h.knowledgeService.ListKnowledgeByKnowledgeBaseID( + newCtx, types.InitDefaultKnowledgeBaseID, + ) + hasFiles := err == nil && len(knowledgeList) > 0 + + // 先更新模型ID(直接在对象上) + kb.SummaryModelID = llmModelID + if req.Rerank.Enabled { + kb.RerankModelID = rerankModelID + } else { + kb.RerankModelID = "" // 清空Rerank模型ID + } + if req.Multimodal.Enabled { + kb.VLMModelID = vlmModelID + // 更新VLM配置 + kb.VLMConfig = types.VLMConfig{ + ModelName: req.Multimodal.VLM.ModelName, + BaseURL: req.Multimodal.VLM.BaseURL, + APIKey: req.Multimodal.VLM.APIKey, + InterfaceType: req.Multimodal.VLM.InterfaceType, + } + kb.COSConfig = types.COSConfig{ + SecretID: req.Multimodal.COS.SecretID, + SecretKey: req.Multimodal.COS.SecretKey, + Region: req.Multimodal.COS.Region, + BucketName: req.Multimodal.COS.BucketName, + AppID: req.Multimodal.COS.AppID, + PathPrefix: req.Multimodal.COS.PathPrefix, + } + } else { + kb.VLMModelID = "" // 清空VLM模型ID + // 清空VLM配置 + kb.VLMConfig = types.VLMConfig{} + kb.COSConfig = types.COSConfig{} + } + if !hasFiles { + kb.EmbeddingModelID = embeddingModelID + } + kb.ChunkingConfig = types.ChunkingConfig{ + ChunkSize: req.DocumentSplitting.ChunkSize, + ChunkOverlap: req.DocumentSplitting.ChunkOverlap, + Separators: req.DocumentSplitting.Separators, + EnableMultimodal: req.Multimodal.Enabled, + } + + // 更新基本信息和配置 + err = h.kbRepository.UpdateKnowledgeBase(newCtx, kb) + if err != nil { + logger.ErrorWithFields(ctx, err, nil) + c.Error(errors.NewInternalServerError("更新知识库配置失败: " + err.Error())) + return + } + + // 如果需要更新模型ID,使用repository直接更新 + if !hasFiles || kb.SummaryModelID != llmModelID { + // 刷新知识库对象以获取最新信息 + kb, err = h.kbService.GetKnowledgeBaseByID(newCtx, types.InitDefaultKnowledgeBaseID) + if err != nil { + logger.ErrorWithFields(ctx, err, nil) + c.Error(errors.NewInternalServerError("获取更新后的知识库失败: " + err.Error())) + return + } + + // 更新模型ID + kb.SummaryModelID = llmModelID + if req.Rerank.Enabled { + kb.RerankModelID = rerankModelID + } else { + kb.RerankModelID = "" // 清空Rerank模型ID + } + + // 使用repository直接更新模型ID + err = h.kbRepository.UpdateKnowledgeBase(newCtx, kb) + if err != nil { + logger.ErrorWithFields(ctx, err, nil) + c.Error(errors.NewInternalServerError("更新知识库模型ID失败: " + err.Error())) + return + } + + logger.Info(ctx, "Model IDs updated successfully") + } + } + + logger.Info(ctx, "System initialization completed successfully") + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "系统初始化成功", + "data": gin.H{ + "tenant": tenant, + "models": processedModels, + "knowledge_base": kb, + }, + }) +} + +// CheckOllamaStatus 检查Ollama服务状态 +func (h *InitializationHandler) CheckOllamaStatus(c *gin.Context) { + ctx := c.Request.Context() + + logger.Info(ctx, "Checking Ollama service status") + + // 检查Ollama服务是否可用 + err := h.ollamaService.StartService(ctx) + if err != nil { + logger.ErrorWithFields(ctx, err, nil) + c.JSON(http.StatusOK, gin.H{ + "success": true, + "data": gin.H{ + "available": false, + "error": err.Error(), + }, + }) + return + } + + version, err := h.ollamaService.GetVersion(ctx) + if err != nil { + logger.ErrorWithFields(ctx, err, nil) + version = "unknown" + } + + logger.Info(ctx, "Ollama service is available") + c.JSON(http.StatusOK, gin.H{ + "success": true, + "data": gin.H{ + "available": h.ollamaService.IsAvailable(), + "version": version, + }, + }) +} + +// CheckOllamaModels 检查Ollama模型状态 +func (h *InitializationHandler) CheckOllamaModels(c *gin.Context) { + ctx := c.Request.Context() + + logger.Info(ctx, "Checking Ollama models status") + + var req struct { + Models []string `json:"models" binding:"required"` + } + + if err := c.ShouldBindJSON(&req); err != nil { + logger.Error(ctx, "Failed to parse models check request", err) + c.Error(errors.NewBadRequestError(err.Error())) + return + } + + // 检查Ollama服务是否可用 + if !h.ollamaService.IsAvailable() { + err := h.ollamaService.StartService(ctx) + if err != nil { + logger.ErrorWithFields(ctx, err, nil) + c.Error(errors.NewInternalServerError("Ollama服务不可用: " + err.Error())) + return + } + } + + modelStatus := make(map[string]bool) + + // 检查每个模型是否存在 + for _, modelName := range req.Models { + available, err := h.ollamaService.IsModelAvailable(ctx, modelName) + if err != nil { + logger.ErrorWithFields(ctx, err, map[string]interface{}{ + "model_name": modelName, + }) + modelStatus[modelName] = false + } else { + modelStatus[modelName] = available + } + + logger.Infof(ctx, "Model %s availability: %v", modelName, modelStatus[modelName]) + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "data": gin.H{ + "models": modelStatus, + }, + }) +} + +// DownloadOllamaModel 异步下载Ollama模型 +func (h *InitializationHandler) DownloadOllamaModel(c *gin.Context) { + ctx := c.Request.Context() + + logger.Info(ctx, "Starting async Ollama model download") + + var req struct { + ModelName string `json:"modelName" binding:"required"` + } + + if err := c.ShouldBindJSON(&req); err != nil { + logger.Error(ctx, "Failed to parse model download request", err) + c.Error(errors.NewBadRequestError(err.Error())) + return + } + + // 检查Ollama服务是否可用 + if !h.ollamaService.IsAvailable() { + err := h.ollamaService.StartService(ctx) + if err != nil { + logger.ErrorWithFields(ctx, err, nil) + c.Error(errors.NewInternalServerError("Ollama服务不可用: " + err.Error())) + return + } + } + + // 检查模型是否已存在 + available, err := h.ollamaService.IsModelAvailable(ctx, req.ModelName) + if err != nil { + logger.ErrorWithFields(ctx, err, map[string]interface{}{ + "model_name": req.ModelName, + }) + c.Error(errors.NewInternalServerError("检查模型状态失败: " + err.Error())) + return + } + + if available { + logger.Infof(ctx, "Model %s already exists", req.ModelName) + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "模型已存在", + "data": gin.H{ + "modelName": req.ModelName, + "status": "completed", + "progress": 100.0, + }, + }) + return + } + + // 检查是否已有相同模型的下载任务 + tasksMutex.RLock() + for _, task := range downloadTasks { + if task.ModelName == req.ModelName && (task.Status == "pending" || task.Status == "downloading") { + tasksMutex.RUnlock() + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "模型下载任务已存在", + "data": gin.H{ + "taskId": task.ID, + "modelName": task.ModelName, + "status": task.Status, + "progress": task.Progress, + }, + }) + return + } + } + tasksMutex.RUnlock() + + // 创建下载任务 + taskID := uuid.New().String() + task := &DownloadTask{ + ID: taskID, + ModelName: req.ModelName, + Status: "pending", + Progress: 0.0, + Message: "准备下载", + StartTime: time.Now(), + } + + tasksMutex.Lock() + downloadTasks[taskID] = task + tasksMutex.Unlock() + + // 启动异步下载 + newCtx, cancel := context.WithTimeout(context.Background(), 12*time.Hour) + go func() { + defer cancel() + h.downloadModelAsync(newCtx, taskID, req.ModelName) + }() + + logger.Infof(ctx, "Created download task for model: %s, task ID: %s", req.ModelName, taskID) + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "模型下载任务已创建", + "data": gin.H{ + "taskId": taskID, + "modelName": req.ModelName, + "status": "pending", + "progress": 0.0, + }, + }) +} + +// GetDownloadProgress 获取下载进度 +func (h *InitializationHandler) GetDownloadProgress(c *gin.Context) { + taskID := c.Param("taskId") + + if taskID == "" { + c.Error(errors.NewBadRequestError("任务ID不能为空")) + return + } + + tasksMutex.RLock() + task, exists := downloadTasks[taskID] + tasksMutex.RUnlock() + + if !exists { + c.Error(errors.NewNotFoundError("下载任务不存在")) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "data": task, + }) +} + +// ListDownloadTasks 列出所有下载任务 +func (h *InitializationHandler) ListDownloadTasks(c *gin.Context) { + tasksMutex.RLock() + tasks := make([]*DownloadTask, 0, len(downloadTasks)) + for _, task := range downloadTasks { + tasks = append(tasks, task) + } + tasksMutex.RUnlock() + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "data": tasks, + }) +} + +// downloadModelAsync 异步下载模型 +func (h *InitializationHandler) downloadModelAsync(ctx context.Context, + taskID, modelName string, +) { + logger.Infof(ctx, "Starting async download for model: %s, task: %s", modelName, taskID) + + // 更新任务状态为下载中 + h.updateTaskStatus(taskID, "downloading", 0.0, "开始下载模型") + + // 执行下载,带进度回调 + err := h.pullModelWithProgress(ctx, modelName, func(progress float64, message string) { + h.updateTaskStatus(taskID, "downloading", progress, message) + }) + + if err != nil { + logger.ErrorWithFields(ctx, err, map[string]interface{}{ + "model_name": modelName, + "task_id": taskID, + }) + h.updateTaskStatus(taskID, "failed", 0.0, fmt.Sprintf("下载失败: %v", err)) + return + } + + // 下载成功 + logger.Infof(ctx, "Model %s downloaded successfully, task: %s", modelName, taskID) + h.updateTaskStatus(taskID, "completed", 100.0, "下载完成") +} + +// pullModelWithProgress 下载模型并提供进度回调 +func (h *InitializationHandler) pullModelWithProgress(ctx context.Context, + modelName string, + progressCallback func(float64, string), +) error { + // 检查服务是否可用 + if err := h.ollamaService.StartService(ctx); err != nil { + logger.ErrorWithFields(ctx, err, nil) + return err + } + + // 检查模型是否已存在 + available, err := h.ollamaService.IsModelAvailable(ctx, modelName) + if err != nil { + logger.ErrorWithFields(ctx, err, map[string]interface{}{ + "model_name": modelName, + }) + return err + } + if available { + progressCallback(100.0, "模型已存在") + return nil + } + + logger.GetLogger(ctx).Infof("Pulling model %s...", modelName) + + // 创建下载请求 + pullReq := &api.PullRequest{ + Name: modelName, + } + + // 使用Ollama客户端的Pull方法,带进度回调 + err = h.ollamaService.GetClient().Pull(ctx, pullReq, func(progress api.ProgressResponse) error { + var progressPercent float64 = 0.0 + var message string = "下载中" + + if progress.Total > 0 && progress.Completed > 0 { + progressPercent = float64(progress.Completed) / float64(progress.Total) * 100 + message = fmt.Sprintf("下载中: %.1f%% (%s)", progressPercent, progress.Status) + } else if progress.Status != "" { + message = progress.Status + } + + // 调用进度回调 + progressCallback(progressPercent, message) + + logger.Infof(ctx, + "Download progress for %s: %.2f%% - %s", + modelName, progressPercent, message, + ) + return nil + }) + + if err != nil { + return fmt.Errorf("failed to pull model: %w", err) + } + + return nil +} + +// updateTaskStatus 更新任务状态 +func (h *InitializationHandler) updateTaskStatus( + taskID, status string, progress float64, message string, +) { + tasksMutex.Lock() + defer tasksMutex.Unlock() + + if task, exists := downloadTasks[taskID]; exists { + task.Status = status + task.Progress = progress + task.Message = message + + if status == "completed" || status == "failed" { + now := time.Now() + task.EndTime = &now + } + } +} + +// 清理过期任务 (可以在后台定期执行) +func (h *InitializationHandler) cleanupExpiredTasks() { + tasksMutex.Lock() + defer tasksMutex.Unlock() + + cutoff := time.Now().Add(-24 * time.Hour) // 保留24小时内的任务 + + for id, task := range downloadTasks { + if task.EndTime != nil && task.EndTime.Before(cutoff) { + delete(downloadTasks, id) + } + } +} + +// GetCurrentConfig 获取当前系统配置信息 +func (h *InitializationHandler) GetCurrentConfig(c *gin.Context) { + ctx := c.Request.Context() + + logger.Info(ctx, "Getting current system configuration") + + // 设置租户上下文 + newCtx := context.WithValue(ctx, types.TenantIDContextKey, types.InitDefaultTenantID) + + // 获取模型信息 + models, err := h.modelService.ListModels(newCtx) + if err != nil { + logger.ErrorWithFields(ctx, err, nil) + c.Error(errors.NewInternalServerError("获取模型列表失败: " + err.Error())) + return + } + + // 获取知识库信息 + kb, err := h.kbService.GetKnowledgeBaseByID(newCtx, types.InitDefaultKnowledgeBaseID) + if err != nil { + logger.ErrorWithFields(ctx, err, nil) + c.Error(errors.NewInternalServerError("获取知识库信息失败: " + err.Error())) + return + } + + // 检查知识库是否有文件 + knowledgeList, err := h.knowledgeService.ListPagedKnowledgeByKnowledgeBaseID(newCtx, + types.InitDefaultKnowledgeBaseID, &types.Pagination{ + Page: 1, + PageSize: 1, + }) + hasFiles := false + if err == nil && knowledgeList != nil && knowledgeList.Total > 0 { + hasFiles = true + } + + // 构建配置响应 + config := buildConfigResponse(models, kb, hasFiles) + + logger.Info(ctx, "Current system configuration retrieved successfully") + c.JSON(http.StatusOK, gin.H{ + "success": true, + "data": config, + }) +} + +// buildConfigResponse 构建配置响应数据 +func buildConfigResponse(models []*types.Model, + kb *types.KnowledgeBase, hasFiles bool, +) map[string]interface{} { + config := map[string]interface{}{ + "hasFiles": hasFiles, + } + + // 按类型分组模型 + for _, model := range models { + switch model.Type { + case types.ModelTypeKnowledgeQA: + config["llm"] = map[string]interface{}{ + "source": string(model.Source), + "modelName": model.Name, + "baseUrl": model.Parameters.BaseURL, + "apiKey": model.Parameters.APIKey, + } + case types.ModelTypeEmbedding: + config["embedding"] = map[string]interface{}{ + "source": string(model.Source), + "modelName": model.Name, + "baseUrl": model.Parameters.BaseURL, + "apiKey": model.Parameters.APIKey, + "dimension": model.Parameters.EmbeddingParameters.Dimension, + } + case types.ModelTypeRerank: + config["rerank"] = map[string]interface{}{ + "enabled": true, + "modelName": model.Name, + "baseUrl": model.Parameters.BaseURL, + "apiKey": model.Parameters.APIKey, + } + case types.ModelTypeVLLM: + if config["multimodal"] == nil { + config["multimodal"] = map[string]interface{}{ + "enabled": true, + } + } + multimodal := config["multimodal"].(map[string]interface{}) + multimodal["vlm"] = map[string]interface{}{ + "modelName": model.Name, + "baseUrl": model.Parameters.BaseURL, + "apiKey": model.Parameters.APIKey, + "interfaceType": kb.VLMConfig.InterfaceType, + } + } + } + + // 如果没有VLM模型,设置multimodal为disabled + if config["multimodal"] == nil { + config["multimodal"] = map[string]interface{}{ + "enabled": false, + } + } + + // 如果没有Rerank模型,设置rerank为disabled + if config["rerank"] == nil { + config["rerank"] = map[string]interface{}{ + "enabled": false, + "modelName": "", + "baseUrl": "", + "apiKey": "", + } + } + + // 添加知识库的文档分割配置 + if kb != nil { + config["documentSplitting"] = map[string]interface{}{ + "chunkSize": kb.ChunkingConfig.ChunkSize, + "chunkOverlap": kb.ChunkingConfig.ChunkOverlap, + "separators": kb.ChunkingConfig.Separators, + } + + // 添加多模态的COS配置信息 + if kb.COSConfig.SecretID != "" { + if config["multimodal"] == nil { + config["multimodal"] = map[string]interface{}{ + "enabled": true, + } + } + multimodal := config["multimodal"].(map[string]interface{}) + multimodal["cos"] = map[string]interface{}{ + "secretId": kb.COSConfig.SecretID, + "secretKey": kb.COSConfig.SecretKey, + "region": kb.COSConfig.Region, + "bucketName": kb.COSConfig.BucketName, + "appId": kb.COSConfig.AppID, + "pathPrefix": kb.COSConfig.PathPrefix, + } + } + } + + return config +} + +// RemoteModelCheckRequest 远程模型检查请求结构 +type RemoteModelCheckRequest struct { + ModelName string `json:"modelName" binding:"required"` + BaseURL string `json:"baseUrl" binding:"required"` + APIKey string `json:"apiKey"` +} + +// CheckRemoteModel 检查远程API模型连接 +func (h *InitializationHandler) CheckRemoteModel(c *gin.Context) { + ctx := c.Request.Context() + + logger.Info(ctx, "Checking remote model connection") + + var req RemoteModelCheckRequest + if err := c.ShouldBindJSON(&req); err != nil { + logger.Error(ctx, "Failed to parse remote model check request", err) + c.Error(errors.NewBadRequestError(err.Error())) + return + } + + // 验证请求参数 + if req.ModelName == "" || req.BaseURL == "" { + logger.Error(ctx, "Model name and base URL are required") + c.Error(errors.NewBadRequestError("模型名称和Base URL不能为空")) + return + } + + // 创建模型配置进行测试 + modelConfig := &types.Model{ + Name: req.ModelName, + Source: "remote", + Parameters: types.ModelParameters{ + BaseURL: req.BaseURL, + APIKey: req.APIKey, + }, + Type: "llm", // 默认类型,实际检查时不区分具体类型 + } + + // 检查远程模型连接 + available, message := h.checkRemoteModelConnection(ctx, modelConfig) + + logger.Info(ctx, + fmt.Sprintf( + "Remote model check completed: modelName=%s, baseUrl=%s, available=%v, message=%s", + req.ModelName, req.BaseURL, available, message, + ), + ) + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "data": gin.H{ + "available": available, + "message": message, + }, + }) +} + +// checkRemoteModelConnection 检查远程模型连接的内部方法 +func (h *InitializationHandler) checkRemoteModelConnection(ctx context.Context, + model *types.Model, +) (bool, string) { + // 这里需要根据实际情况实现远程API的连接检查 + // 可以发送一个简单的请求来验证连接和认证 + + client := &http.Client{ + Timeout: 10 * time.Second, + } + + // 根据不同的API类型构造测试请求 + testEndpoint := "" + if model.Parameters.BaseURL != "" { + // 常见的API端点模式 + if strings.Contains(model.Parameters.BaseURL, "openai.com") || + strings.Contains(model.Parameters.BaseURL, "api.openai.com") { + testEndpoint = model.Parameters.BaseURL + "/models" + } else if strings.Contains(model.Parameters.BaseURL, "v1") { + testEndpoint = model.Parameters.BaseURL + "/models" + } else { + testEndpoint = model.Parameters.BaseURL + "/v1/models" + } + } + + req, err := http.NewRequestWithContext(ctx, "GET", testEndpoint, nil) + if err != nil { + return false, fmt.Sprintf("创建请求失败: %v", err) + } + + // 添加认证头 + if model.Parameters.APIKey != "" { + req.Header.Set("Authorization", "Bearer "+model.Parameters.APIKey) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := client.Do(req) + if err != nil { + return false, fmt.Sprintf("连接失败: %v", err) + } + defer resp.Body.Close() + + // 检查响应状态 + if resp.StatusCode >= 200 && resp.StatusCode < 300 { + // 连接成功,现在检查模型是否存在 + return h.checkModelExistence(ctx, resp, model.Name) + } else if resp.StatusCode == 401 { + return false, "认证失败,请检查API Key" + } else if resp.StatusCode == 403 { + return false, "权限不足,请检查API Key权限" + } else if resp.StatusCode == 404 { + return false, "API端点不存在,请检查Base URL" + } else { + return false, fmt.Sprintf("API返回错误状态: %d", resp.StatusCode) + } +} + +// checkModelExistence 检查指定模型是否在模型列表中存在 +func (h *InitializationHandler) checkModelExistence(ctx context.Context, + resp *http.Response, modelName string) (bool, string) { + body, err := io.ReadAll(resp.Body) + if err != nil { + return true, "连接正常,但无法验证模型列表" + } + + var modelsResp struct { + Data []struct { + ID string `json:"id"` + Object string `json:"object"` + } `json:"data"` + Object string `json:"object"` + } + + // 尝试解析模型列表响应 + if err := json.Unmarshal(body, &modelsResp); err != nil { + // 如果无法解析,可能是非标准API,只要连接成功就认为可用 + return true, "连接正常" + } + + // 检查模型是否在列表中 + for _, model := range modelsResp.Data { + if model.ID == modelName { + return true, "连接正常,模型存在" + } + } + + // 模型不在列表中,返回可用的模型建议 + if len(modelsResp.Data) > 0 { + availableModels := make([]string, 0, min(3, len(modelsResp.Data))) + for i, model := range modelsResp.Data { + if i >= 3 { + break + } + availableModels = append(availableModels, model.ID) + } + return false, fmt.Sprintf("模型 '%s' 不存在,可用模型: %s", modelName, strings.Join(availableModels, ", ")) + } + + return false, fmt.Sprintf("模型 '%s' 不存在", modelName) +} + +// min returns the minimum of two integers +func min(a, b int) int { + if a < b { + return a + } + return b +} + +// checkRerankModelConnection 检查Rerank模型连接和功能的内部方法 +func (h *InitializationHandler) checkRerankModelConnection(ctx context.Context, + modelName, baseURL, apiKey string) (bool, string) { + client := &http.Client{ + Timeout: 15 * time.Second, + } + + // 构造重排API端点 + rerankEndpoint := "" + if strings.Contains(baseURL, "v1") { + // 去除v1 + baseURL = strings.Replace(baseURL, "/v1", "", 1) + rerankEndpoint = baseURL + "/rerank" + } else { + rerankEndpoint = baseURL + "/rerank" + } + + // Mock测试数据 + testQuery := "什么是人工智能?" + testPassages := []string{ + "机器学习是人工智能的一个子领域,专注于算法和统计模型,使计算机系统能够通过经验自动改进。", + "深度学习是机器学习的一个子集,使用人工神经网络来模拟人脑的工作方式。", + } + + // 构造重排请求 + rerankRequest := map[string]interface{}{ + "model": modelName, + "query": testQuery, + "documents": testPassages, + "truncate_prompt_tokens": 512, + } + + jsonData, err := json.Marshal(rerankRequest) + if err != nil { + return false, fmt.Sprintf("构造请求失败: %v", err) + } + + logger.Infof(ctx, "Rerank request: %s, modelName=%s, baseURL=%s, apiKey=%s", + string(jsonData), modelName, baseURL, apiKey) + + req, err := http.NewRequestWithContext( + ctx, "POST", rerankEndpoint, strings.NewReader(string(jsonData)), + ) + if err != nil { + return false, fmt.Sprintf("创建请求失败: %v", err) + } + + // 添加认证头 + if apiKey != "" { + req.Header.Set("Authorization", "Bearer "+apiKey) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := client.Do(req) + if err != nil { + return false, fmt.Sprintf("连接失败: %v", err) + } + defer resp.Body.Close() + + // 读取响应 + body, err := io.ReadAll(resp.Body) + if err != nil { + return false, fmt.Sprintf("读取响应失败: %v", err) + } + + // 检查响应状态 + if resp.StatusCode >= 200 && resp.StatusCode < 300 { + // 尝试解析重排响应 + var rerankResp struct { + Results []struct { + Index int `json:"index"` + Document string `json:"document"` + RelevanceScore float64 `json:"relevance_score"` + } `json:"results"` + } + + if err := json.Unmarshal(body, &rerankResp); err != nil { + // 如果无法解析标准重排响应,检查是否有其他格式 + return true, "连接正常,但响应格式非标准" + } + + // 检查是否返回了重排结果 + if len(rerankResp.Results) > 0 { + return true, fmt.Sprintf("重排功能正常,返回%d个结果", len(rerankResp.Results)) + } else { + return false, "重排接口连接成功,但未返回重排结果" + } + } else if resp.StatusCode == 401 { + return false, "认证失败,请检查API Key" + } else if resp.StatusCode == 403 { + return false, "权限不足,请检查API Key权限" + } else if resp.StatusCode == 404 { + return false, "重排API端点不存在,请检查Base URL" + } else if resp.StatusCode == 422 { + return false, fmt.Sprintf("请求参数错误: %s", string(body)) + } else { + return false, fmt.Sprintf("API返回错误状态: %d, 响应: %s", resp.StatusCode, string(body)) + } +} + +// CheckRerankModel 检查Rerank模型连接和功能 +func (h *InitializationHandler) CheckRerankModel(c *gin.Context) { + ctx := c.Request.Context() + + logger.Info(ctx, "Checking rerank model connection and functionality") + + var req struct { + ModelName string `json:"modelName" binding:"required"` + BaseURL string `json:"baseUrl" binding:"required"` + APIKey string `json:"apiKey"` + } + + if err := c.ShouldBindJSON(&req); err != nil { + logger.Error(ctx, "Failed to parse rerank model check request", err) + c.Error(errors.NewBadRequestError(err.Error())) + return + } + + // 验证请求参数 + if req.ModelName == "" || req.BaseURL == "" { + logger.Error(ctx, "Model name and base URL are required") + c.Error(errors.NewBadRequestError("模型名称和Base URL不能为空")) + return + } + + // 检查Rerank模型连接和功能 + available, message := h.checkRerankModelConnection( + ctx, req.ModelName, req.BaseURL, req.APIKey, + ) + + logger.Info(ctx, + fmt.Sprintf("Rerank model check completed: modelName=%s, baseUrl=%s, available=%v, message=%s", + req.ModelName, req.BaseURL, available, message, + ), + ) + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "data": gin.H{ + "available": available, + "message": message, + }, + }) +} + +// TestMultimodalFunction 测试多模态功能 +func (h *InitializationHandler) TestMultimodalFunction(c *gin.Context) { + ctx := c.Request.Context() + + logger.Info(ctx, "Testing multimodal functionality") + + // 解析表单数据 + vlmModel := c.PostForm("vlm_model") + vlmBaseURL := c.PostForm("vlm_base_url") + vlmAPIKey := c.PostForm("vlm_api_key") + vlmInterfaceType := c.PostForm("vlm_interface_type") + + if vlmInterfaceType == "ollama" { + vlmBaseURL = os.Getenv("OLLAMA_BASE_URL") + "/v1" + } + + // 如果没有提供VLM配置,尝试从KnowledgeBase获取 + if vlmModel == "" || vlmBaseURL == "" { + logger.Info(ctx, "VLM configuration not provided, trying to get from KnowledgeBase") + + // 获取默认知识库 + kb, err := h.kbService.GetKnowledgeBaseByID(ctx, types.InitDefaultKnowledgeBaseID) + if err != nil { + logger.Error(ctx, "Failed to get KnowledgeBase", err) + c.Error(errors.NewBadRequestError("获取知识库配置失败")) + return + } + + // 使用知识库中的VLM配置 + if kb.VLMConfig.ModelName != "" && kb.VLMConfig.BaseURL != "" { + vlmModel = kb.VLMConfig.ModelName + vlmBaseURL = kb.VLMConfig.BaseURL + vlmAPIKey = kb.VLMConfig.APIKey + vlmInterfaceType = kb.VLMConfig.InterfaceType + logger.Infof(ctx, "Using VLM config from KnowledgeBase: Model=%s, URL=%s, Type=%s", + vlmModel, vlmBaseURL, vlmInterfaceType) + } else { + logger.Error(ctx, "VLM configuration not found in KnowledgeBase") + c.Error(errors.NewBadRequestError("知识库中未找到VLM配置信息")) + return + } + } + + // COS配置 + cosSecretID := c.PostForm("cos_secret_id") + cosSecretKey := c.PostForm("cos_secret_key") + cosRegion := c.PostForm("cos_region") + cosBucketName := c.PostForm("cos_bucket_name") + cosAppID := c.PostForm("cos_app_id") + cosPathPrefix := c.PostForm("cos_path_prefix") + + // 如果没有提供COS配置,尝试从KnowledgeBase获取 + if cosSecretID == "" || cosSecretKey == "" || + cosRegion == "" || cosBucketName == "" || cosAppID == "" { + logger.Info(ctx, "COS configuration not provided, trying to get from KnowledgeBase") + + // 获取默认知识库 + kb, err := h.kbService.GetKnowledgeBaseByID(ctx, types.InitDefaultKnowledgeBaseID) + if err != nil { + logger.Error(ctx, "Failed to get KnowledgeBase", err) + c.Error(errors.NewBadRequestError("获取知识库配置失败")) + return + } + + // 使用知识库中的COS配置 + if kb.COSConfig.SecretID != "" && kb.COSConfig.SecretKey != "" { + cosSecretID = kb.COSConfig.SecretID + cosSecretKey = kb.COSConfig.SecretKey + cosRegion = kb.COSConfig.Region + cosBucketName = kb.COSConfig.BucketName + cosAppID = kb.COSConfig.AppID + cosPathPrefix = kb.COSConfig.PathPrefix + logger.Infof(ctx, "Using COS config from KnowledgeBase: Region=%s, Bucket=%s, App=%s", + cosRegion, cosBucketName, cosAppID) + } else { + logger.Error(ctx, "COS configuration not found in KnowledgeBase") + c.Error(errors.NewBadRequestError("知识库中未找到COS配置信息")) + return + } + } + + // 文档分割配置 + chunkSizeStr := c.PostForm("chunk_size") + chunkOverlapStr := c.PostForm("chunk_overlap") + separatorsStr := c.PostForm("separators") + + if vlmModel == "" || vlmBaseURL == "" { + logger.Error(ctx, "VLM model name and base URL are required") + c.Error(errors.NewBadRequestError("VLM模型名称和Base URL不能为空")) + return + } + + if cosSecretID == "" || cosSecretKey == "" || + cosRegion == "" || cosBucketName == "" || cosAppID == "" { + logger.Error(ctx, "COS configuration is required") + c.Error(errors.NewBadRequestError("COS配置信息不能为空")) + return + } + + // 记录COS配置信息用于日志 + logger.Infof(ctx, "COS config: ID=%s, Region=%s, Bucket=%s, App=%s, Prefix=%s", + cosSecretID, cosRegion, cosBucketName, cosAppID, cosPathPrefix) + logger.Infof(ctx, "VLM config: Model=%s, URL=%s, HasKey=%v, Type=%s", + vlmModel, vlmBaseURL, vlmAPIKey != "", vlmInterfaceType) + + // 获取上传的图片文件 + file, header, err := c.Request.FormFile("image") + if err != nil { + logger.Error(ctx, "Failed to get uploaded image", err) + c.Error(errors.NewBadRequestError("获取上传图片失败")) + return + } + defer file.Close() + + // 验证文件类型 + if !strings.HasPrefix(header.Header.Get("Content-Type"), "image/") { + logger.Error(ctx, "Invalid file type, only images are allowed") + c.Error(errors.NewBadRequestError("只允许上传图片文件")) + return + } + + // 验证文件大小 (10MB) + if header.Size > 10*1024*1024 { + logger.Error(ctx, "File size too large") + c.Error(errors.NewBadRequestError("图片文件大小不能超过10MB")) + return + } + + logger.Infof(ctx, "Processing image: %s, size: %d bytes", header.Filename, header.Size) + + // 解析文档分割配置 + chunkSize, err := strconv.Atoi(chunkSizeStr) + if err != nil || chunkSize < 100 || chunkSize > 10000 { + chunkSize = 1000 // 默认值 + } + + chunkOverlap, err := strconv.Atoi(chunkOverlapStr) + if err != nil || chunkOverlap < 0 || chunkOverlap >= chunkSize { + chunkOverlap = 200 // 默认值 + } + + var separators []string + if separatorsStr != "" { + if err := json.Unmarshal([]byte(separatorsStr), &separators); err != nil { + separators = []string{"\n\n", "\n", "。", "!", "?", ";", ";"} // 默认值 + } + } else { + separators = []string{"\n\n", "\n", "。", "!", "?", ";", ";"} // 默认值 + } + + // 读取图片文件内容 + imageContent, err := io.ReadAll(file) + if err != nil { + logger.Error(ctx, "Failed to read image file", err) + c.Error(errors.NewBadRequestError("读取图片文件失败")) + return + } + + // 调用多模态测试 + startTime := time.Now() + result, err := h.testMultimodalWithDocReader(ctx, imageContent, header.Filename, + chunkSize, chunkOverlap, separators, + vlmModel, vlmBaseURL, vlmAPIKey, vlmInterfaceType, + cosSecretID, cosSecretKey, cosRegion, cosBucketName, cosAppID, cosPathPrefix) + processingTime := time.Since(startTime).Milliseconds() + + if err != nil { + logger.ErrorWithFields(ctx, err, map[string]interface{}{ + "vlm_model": vlmModel, + "vlm_base_url": vlmBaseURL, + "filename": header.Filename, + }) + c.JSON(http.StatusOK, gin.H{ + "success": true, + "data": gin.H{ + "success": false, + "message": err.Error(), + "processing_time": processingTime, + }, + }) + return + } + + logger.Info(ctx, fmt.Sprintf("Multimodal test completed successfully in %dms", processingTime)) + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "data": gin.H{ + "success": true, + "caption": result["caption"], + "ocr": result["ocr"], + "processing_time": processingTime, + }, + }) +} + +// testMultimodalWithDocReader 调用docreader服务进行多模态处理 +func (h *InitializationHandler) testMultimodalWithDocReader(ctx context.Context, + imageContent []byte, filename string, + chunkSize, chunkOverlap int, separators []string, + vlmModel, vlmBaseURL, vlmAPIKey, vlmInterfaceType, + cosSecretID, cosSecretKey, cosRegion, cosBucketName, cosAppID, cosPathPrefix string) ( + map[string]string, error) { + // 获取文件扩展名 + fileExt := "" + if idx := strings.LastIndex(filename, "."); idx != -1 { + fileExt = strings.ToLower(filename[idx+1:]) + } + + // 检查docreader服务配置 + if h.docReaderClient == nil { + return nil, fmt.Errorf("DocReader service not configured") + } + + // 构造请求 + request := &proto.ReadFromFileRequest{ + FileContent: imageContent, + FileName: filename, + FileType: fileExt, + ReadConfig: &proto.ReadConfig{ + ChunkSize: int32(chunkSize), + ChunkOverlap: int32(chunkOverlap), + Separators: separators, + EnableMultimodal: true, // 启用多模态处理 + VlmConfig: &proto.VLMConfig{ + ModelName: vlmModel, + BaseUrl: vlmBaseURL, + ApiKey: vlmAPIKey, + InterfaceType: vlmInterfaceType, + }, + CosConfig: &proto.COSConfig{ + SecretId: cosSecretID, + SecretKey: cosSecretKey, + Region: cosRegion, + BucketName: cosBucketName, + AppId: cosAppID, + PathPrefix: cosPathPrefix, + }, + }, + RequestId: ctx.Value(types.RequestIDContextKey).(string), + } + + // 调用docreader服务 + response, err := h.docReaderClient.ReadFromFile(ctx, request) + if err != nil { + return nil, fmt.Errorf("调用DocReader服务失败: %v", err) + } + + if response.Error != "" { + return nil, fmt.Errorf("DocReader服务返回错误: %s", response.Error) + } + + // 处理响应,提取Caption和OCR信息 + result := make(map[string]string) + var allCaptions, allOCRTexts []string + + for _, chunk := range response.Chunks { + if len(chunk.Images) > 0 { + for _, image := range chunk.Images { + if image.Caption != "" { + allCaptions = append(allCaptions, image.Caption) + } + if image.OcrText != "" { + allOCRTexts = append(allOCRTexts, image.OcrText) + } + } + } + } + + // 合并所有Caption和OCR结果 + result["caption"] = strings.Join(allCaptions, "; ") + result["ocr"] = strings.Join(allOCRTexts, "; ") + + return result, nil +} diff --git a/internal/handler/knowledge.go b/internal/handler/knowledge.go index d9ddcdf..dac6ffe 100644 --- a/internal/handler/knowledge.go +++ b/internal/handler/knowledge.go @@ -135,6 +135,10 @@ func (h *KnowledgeHandler) CreateKnowledgeFromFile(c *gin.Context) { if h.handleDuplicateKnowledgeError(c, err, knowledge, "file") { return } + if appErr, ok := errors.IsAppError(err); ok { + c.Error(appErr) + return + } logger.ErrorWithFields(ctx, err, nil) c.Error(errors.NewInternalServerError(err.Error())) return diff --git a/internal/handler/session.go b/internal/handler/session.go index 73a88de..7149e19 100644 --- a/internal/handler/session.go +++ b/internal/handler/session.go @@ -18,11 +18,12 @@ import ( // SessionHandler handles all HTTP requests related to conversation sessions type SessionHandler struct { - messageService interfaces.MessageService // Service for managing messages - sessionService interfaces.SessionService // Service for managing sessions - streamManager interfaces.StreamManager // Manager for handling streaming responses - config *config.Config // Application configuration - testDataService *service.TestDataService // Service for test data (models, etc.) + messageService interfaces.MessageService // Service for managing messages + sessionService interfaces.SessionService // Service for managing sessions + streamManager interfaces.StreamManager // Manager for handling streaming responses + config *config.Config // Application configuration + testDataService *service.TestDataService // Service for test data (models, etc.) + knowledgebaseService interfaces.KnowledgeBaseService } // NewSessionHandler creates a new instance of SessionHandler with all necessary dependencies @@ -32,13 +33,15 @@ func NewSessionHandler( streamManager interfaces.StreamManager, config *config.Config, testDataService *service.TestDataService, + knowledgebaseService interfaces.KnowledgeBaseService, ) *SessionHandler { return &SessionHandler{ - sessionService: sessionService, - messageService: messageService, - streamManager: streamManager, - config: config, - testDataService: testDataService, + sessionService: sessionService, + messageService: messageService, + streamManager: streamManager, + config: config, + testDataService: testDataService, + knowledgebaseService: knowledgebaseService, } } @@ -191,27 +194,24 @@ func (h *SessionHandler) CreateSession(c *gin.Context) { logger.Debug(ctx, "Using default session strategy") } - // Get model IDs from test data service if not provided + kb, err := h.knowledgebaseService.GetKnowledgeBaseByID(ctx, request.KnowledgeBaseID) + if err != nil { + logger.Error(ctx, "Failed to get knowledge base", err) + c.Error(errors.NewInternalServerError(err.Error())) + return + } + + // Get model IDs from knowledge base if not provided if createdSession.SummaryModelID == "" { - if h.testDataService != nil { - createdSession.SummaryModelID = h.testDataService.LLMModel.GetModelID() - logger.Debug(ctx, "Using summary model ID from test data service") - } else { - logger.Error(ctx, "Summary model ID is empty and cannot get default value") - c.Error(errors.NewBadRequestError("Summary model ID cannot be empty")) - return - } + createdSession.SummaryModelID = kb.SummaryModelID } if createdSession.RerankModelID == "" { - if h.testDataService != nil && h.testDataService.RerankModel != nil { - createdSession.RerankModelID = h.testDataService.RerankModel.GetModelID() - logger.Debug(ctx, "Using rerank model ID from test data service") - } + createdSession.RerankModelID = kb.RerankModelID } // Call service to create session logger.Infof(ctx, "Calling session service to create session") - createdSession, err := h.sessionService.CreateSession(ctx, createdSession) + createdSession, err = h.sessionService.CreateSession(ctx, createdSession) if err != nil { logger.ErrorWithFields(ctx, err, nil) c.Error(errors.NewInternalServerError(err.Error())) diff --git a/internal/handler/test_data.go b/internal/handler/test_data.go index bdfe12b..b0eeebb 100644 --- a/internal/handler/test_data.go +++ b/internal/handler/test_data.go @@ -4,7 +4,6 @@ import ( "errors" "net/http" "os" - "strconv" "github.com/gin-gonic/gin" @@ -65,26 +64,19 @@ func (h *TestDataHandler) GetTestData(c *gin.Context) { return } - tenantID := os.Getenv("INIT_TEST_TENANT_ID") - logger.Debugf(ctx, "Test tenant ID environment variable: %s", tenantID) - - tenantIDUint, err := strconv.ParseUint(tenantID, 10, 64) - if err != nil { - logger.Errorf(ctx, "Failed to parse tenant ID: %s", tenantID) - c.Error(err) - return - } + tenantID := uint(types.InitDefaultTenantID) + logger.Debugf(ctx, "Test tenant ID environment variable: %d", tenantID) // Retrieve the test tenant data - logger.Infof(ctx, "Retrieving test tenant, ID: %d", tenantIDUint) - tenant, err := h.tenantService.GetTenantByID(ctx, uint(tenantIDUint)) + logger.Infof(ctx, "Retrieving test tenant, ID: %d", tenantID) + tenant, err := h.tenantService.GetTenantByID(ctx, tenantID) if err != nil { logger.ErrorWithFields(ctx, err, nil) c.Error(err) return } - knowledgeBaseID := os.Getenv("INIT_TEST_KNOWLEDGE_BASE_ID") + knowledgeBaseID := types.InitDefaultKnowledgeBaseID logger.Debugf(ctx, "Test knowledge base ID environment variable: %s", knowledgeBaseID) // Retrieve the test knowledge base data diff --git a/internal/middleware/auth.go b/internal/middleware/auth.go index ddd7072..7f81ef5 100644 --- a/internal/middleware/auth.go +++ b/internal/middleware/auth.go @@ -6,6 +6,7 @@ import ( "log" "net/http" "slices" + "strings" "github.com/Tencent/WeKnora/internal/config" "github.com/Tencent/WeKnora/internal/types" @@ -15,14 +16,20 @@ import ( // 无需认证的API列表 var noAuthAPI = map[string][]string{ - "/api/v1/test-data": {"GET"}, - "/api/v1/tenants": {"POST"}, + "/api/v1/test-data": {"GET"}, + "/api/v1/tenants": {"POST"}, + "/api/v1/initialization/*": {"GET", "POST"}, } // 检查请求是否在无需认证的API列表中 func isNoAuthAPI(path string, method string) bool { for api, methods := range noAuthAPI { - if api == path && slices.Contains(methods, method) { + // 如果以*结尾,按照前缀匹配,否则按照全路径匹配 + if strings.HasSuffix(api, "*") { + if strings.HasPrefix(path, strings.TrimSuffix(api, "*")) && slices.Contains(methods, method) { + return true + } + } else if path == api && slices.Contains(methods, method) { return true } } diff --git a/internal/models/chat/ollama.go b/internal/models/chat/ollama.go index 5cdfe3b..4371f29 100644 --- a/internal/models/chat/ollama.go +++ b/internal/models/chat/ollama.go @@ -63,7 +63,9 @@ func (c *OllamaChat) buildChatRequest(messages []Message, opts *ChatOptions, isS chatReq.Options["num_predict"] = opts.MaxTokens } if opts.Thinking != nil { - chatReq.Options["think"] = *opts.Thinking + chatReq.Think = &ollamaapi.ThinkValue{ + Value: *opts.Thinking, + } } } diff --git a/internal/models/utils/ollama/ollama.go b/internal/models/utils/ollama/ollama.go index c4fbf9f..76b88ce 100644 --- a/internal/models/utils/ollama/ollama.go +++ b/internal/models/utils/ollama/ollama.go @@ -309,3 +309,8 @@ func (s *OllamaService) Generate(ctx context.Context, req *api.GenerateRequest, // Use official client Generate method return s.client.Generate(ctx, req, fn) } + +// GetClient returns the underlying ollama client for advanced operations +func (s *OllamaService) GetClient() *api.Client { + return s.client +} diff --git a/internal/router/router.go b/internal/router/router.go index 1826cd3..4e1997e 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -17,17 +17,18 @@ import ( type RouterParams struct { dig.In - Config *config.Config - KBHandler *handler.KnowledgeBaseHandler - KnowledgeHandler *handler.KnowledgeHandler - TenantHandler *handler.TenantHandler - TenantService interfaces.TenantService - ChunkHandler *handler.ChunkHandler - SessionHandler *handler.SessionHandler - MessageHandler *handler.MessageHandler - TestDataHandler *handler.TestDataHandler - ModelHandler *handler.ModelHandler - EvaluationHandler *handler.EvaluationHandler + Config *config.Config + KBHandler *handler.KnowledgeBaseHandler + KnowledgeHandler *handler.KnowledgeHandler + TenantHandler *handler.TenantHandler + TenantService interfaces.TenantService + ChunkHandler *handler.ChunkHandler + SessionHandler *handler.SessionHandler + MessageHandler *handler.MessageHandler + TestDataHandler *handler.TestDataHandler + ModelHandler *handler.ModelHandler + EvaluationHandler *handler.EvaluationHandler + InitializationHandler *handler.InitializationHandler } // NewRouter 创建新的路由 @@ -62,6 +63,23 @@ func NewRouter(params RouterParams) *gin.Engine { // 测试数据接口(不需要认证) r.GET("/api/v1/test-data", params.TestDataHandler.GetTestData) + // 初始化接口(不需要认证) + r.GET("/api/v1/initialization/status", params.InitializationHandler.CheckStatus) + r.GET("/api/v1/initialization/config", params.InitializationHandler.GetCurrentConfig) + r.POST("/api/v1/initialization/initialize", params.InitializationHandler.Initialize) + + // Ollama相关接口(不需要认证) + r.GET("/api/v1/initialization/ollama/status", params.InitializationHandler.CheckOllamaStatus) + r.POST("/api/v1/initialization/ollama/models/check", params.InitializationHandler.CheckOllamaModels) + r.POST("/api/v1/initialization/ollama/models/download", params.InitializationHandler.DownloadOllamaModel) + r.GET("/api/v1/initialization/ollama/download/progress/:taskId", params.InitializationHandler.GetDownloadProgress) + r.GET("/api/v1/initialization/ollama/download/tasks", params.InitializationHandler.ListDownloadTasks) + + // 远程API相关接口(不需要认证) + r.POST("/api/v1/initialization/remote/check", params.InitializationHandler.CheckRemoteModel) + r.POST("/api/v1/initialization/rerank/check", params.InitializationHandler.CheckRerankModel) + r.POST("/api/v1/initialization/multimodal/test", params.InitializationHandler.TestMultimodalFunction) + // 需要认证的API路由 v1 := r.Group("/api/v1") { diff --git a/internal/types/knowledgebase.go b/internal/types/knowledgebase.go index 8e14f27..e4be935 100644 --- a/internal/types/knowledgebase.go +++ b/internal/types/knowledgebase.go @@ -8,6 +8,10 @@ import ( "gorm.io/gorm" ) +const ( + InitDefaultKnowledgeBaseID = "kb-00000001" +) + // KnowledgeBase represents a knowledge base type KnowledgeBase struct { // Unique identifier of the knowledge base @@ -26,6 +30,14 @@ type KnowledgeBase struct { EmbeddingModelID string `yaml:"embedding_model_id" json:"embedding_model_id"` // Summary model ID SummaryModelID string `yaml:"summary_model_id" json:"summary_model_id"` + // Rerank model ID + RerankModelID string `yaml:"rerank_model_id" json:"rerank_model_id"` + // VLM model ID + VLMModelID string `yaml:"vlm_model_id" json:"vlm_model_id"` + // VLM config + VLMConfig VLMConfig `yaml:"vlm_config" json:"vlm_config" gorm:"type:json"` + // COS config + COSConfig COSConfig `yaml:"cos_config" json:"cos_config" gorm:"type:json"` // Creation time of the knowledge base CreatedAt time.Time `yaml:"created_at" json:"created_at"` // Last updated time of the knowledge base @@ -54,6 +66,37 @@ type ChunkingConfig struct { EnableMultimodal bool `yaml:"enable_multimodal" json:"enable_multimodal"` } +// COSConfig represents the COS configuration +type COSConfig struct { + // Secret ID + SecretID string `yaml:"secret_id" json:"secret_id"` + // Secret Key + SecretKey string `yaml:"secret_key" json:"secret_key"` + // Region + Region string `yaml:"region" json:"region"` + // Bucket Name + BucketName string `yaml:"bucket_name" json:"bucket_name"` + // App ID + AppID string `yaml:"app_id" json:"app_id"` + // Path Prefix + PathPrefix string `yaml:"path_prefix" json:"path_prefix"` +} + +func (c *COSConfig) Value() (driver.Value, error) { + return json.Marshal(c) +} + +func (c *COSConfig) Scan(value interface{}) error { + if value == nil { + return nil + } + b, ok := value.([]byte) + if !ok { + return nil + } + return json.Unmarshal(b, c) +} + // ImageProcessingConfig represents the image processing configuration type ImageProcessingConfig struct { // Model ID @@ -93,3 +136,32 @@ func (c *ImageProcessingConfig) Scan(value interface{}) error { } return json.Unmarshal(b, c) } + +// VLMConfig represents the VLM configuration +type VLMConfig struct { + // Model Name + ModelName string `yaml:"model_name" json:"model_name"` + // Base URL + BaseURL string `yaml:"base_url" json:"base_url"` + // API Key + APIKey string `yaml:"api_key" json:"api_key"` + // Interface Type: "ollama" or "openai" + InterfaceType string `yaml:"interface_type" json:"interface_type"` +} + +// Value implements the driver.Valuer interface, used to convert VLMConfig to database value +func (c VLMConfig) Value() (driver.Value, error) { + return json.Marshal(c) +} + +// Scan implements the sql.Scanner interface, used to convert database value to VLMConfig +func (c *VLMConfig) Scan(value interface{}) error { + if value == nil { + return nil + } + b, ok := value.([]byte) + if !ok { + return nil + } + return json.Unmarshal(b, c) +} diff --git a/internal/types/tenant.go b/internal/types/tenant.go index 64485f2..20e2e14 100644 --- a/internal/types/tenant.go +++ b/internal/types/tenant.go @@ -8,6 +8,10 @@ import ( "gorm.io/gorm" ) +const ( + InitDefaultTenantID uint = 1 +) + // Tenant represents the tenant type Tenant struct { // ID diff --git a/migrations/mysql/00-init-db.sql b/migrations/mysql/00-init-db.sql index 0c0f0f5..000e6fc 100644 --- a/migrations/mysql/00-init-db.sql +++ b/migrations/mysql/00-init-db.sql @@ -19,7 +19,7 @@ CREATE TABLE tenants ( created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, deleted_at TIMESTAMP NULL DEFAULT NULL -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 AUTO_INCREMENT=10000; CREATE TABLE models ( id VARCHAR(64) PRIMARY KEY, @@ -47,6 +47,10 @@ CREATE TABLE knowledge_bases ( image_processing_config JSON NOT NULL, embedding_model_id VARCHAR(64) NOT NULL, summary_model_id VARCHAR(64) NOT NULL, + rerank_model_id VARCHAR(64) NOT NULL, + vlm_model_id VARCHAR(64) NOT NULL, + cos_config JSON NOT NULL, + vlm_config JSON NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, deleted_at TIMESTAMP NULL DEFAULT NULL diff --git a/migrations/paradedb/00-init-db.sql b/migrations/paradedb/00-init-db.sql index 94d10e3..1354ba1 100644 --- a/migrations/paradedb/00-init-db.sql +++ b/migrations/paradedb/00-init-db.sql @@ -21,6 +21,9 @@ CREATE TABLE IF NOT EXISTS tenants ( deleted_at TIMESTAMP WITH TIME ZONE ); +-- Set the starting value for tenants id sequence +ALTER SEQUENCE tenants_id_seq RESTART WITH 10000; + -- Add indexes CREATE INDEX IF NOT EXISTS idx_tenants_api_key ON tenants(api_key); CREATE INDEX IF NOT EXISTS idx_tenants_status ON tenants(status); @@ -55,6 +58,10 @@ CREATE TABLE IF NOT EXISTS knowledge_bases ( image_processing_config JSONB NOT NULL DEFAULT '{"enable_multimodal": false, "model_id": ""}', embedding_model_id VARCHAR(64) NOT NULL, summary_model_id VARCHAR(64) NOT NULL, + rerank_model_id VARCHAR(64) NOT NULL, + vlm_model_id VARCHAR(64) NOT NULL, + cos_config JSONB NOT NULL DEFAULT '{}', + vlm_config JSONB NOT NULL DEFAULT '{}', created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, deleted_at TIMESTAMP WITH TIME ZONE diff --git a/scripts/start_all.sh b/scripts/start_all.sh index 7bdb0c3..6706d7d 100755 --- a/scripts/start_all.sh +++ b/scripts/start_all.sh @@ -227,12 +227,24 @@ start_docker() { source "$PROJECT_ROOT/.env" storage_type=${STORAGE_TYPE:-local} + # 检测当前系统平台 + log_info "检测系统平台信息..." + if [ "$(uname -m)" = "x86_64" ]; then + export PLATFORM="linux/amd64" + elif [ "$(uname -m)" = "aarch64" ] || [ "$(uname -m)" = "arm64" ]; then + export PLATFORM="linux/arm64" + else + log_warning "未识别的平台类型:$(uname -m),将使用默认平台 linux/amd64" + export PLATFORM="linux/amd64" + fi + log_info "当前平台:$PLATFORM" + # 进入项目根目录再执行docker-compose命令 cd "$PROJECT_ROOT" # 启动基本服务 log_info "启动核心服务容器..." - docker-compose up --build -d + PLATFORM=$PLATFORM docker-compose up --build -d if [ $? -ne 0 ]; then log_error "Docker容器启动失败" return 1 diff --git a/services/docreader/requirements.txt b/services/docreader/requirements.txt index a811c43..245d1c1 100644 --- a/services/docreader/requirements.txt +++ b/services/docreader/requirements.txt @@ -13,7 +13,6 @@ urllib3 markdownify mistletoe goose3[all] -# paddlepaddle==3.0.0 (普通CPU版本已被注释) paddleocr==3.0.0 markdown pypdf @@ -21,6 +20,10 @@ cos-python-sdk-v5 textract antiword openai +ollama ---extra-index-url https://www.paddlepaddle.org.cn/packages/stable/cu126/ -paddlepaddle-gpu==3.0.0 \ No newline at end of file +--extra-index-url https://www.paddlepaddle.org.cn/packages/stable/cpu/ +paddlepaddle==3.0.0 + +# --extra-index-url https://www.paddlepaddle.org.cn/packages/stable/cu126/ +# paddlepaddle-gpu==3.0.0 \ No newline at end of file diff --git a/services/docreader/scripts/generate_proto.sh b/services/docreader/scripts/generate_proto.sh index bca24c1..86249db 100644 --- a/services/docreader/scripts/generate_proto.sh +++ b/services/docreader/scripts/generate_proto.sh @@ -7,7 +7,7 @@ PYTHON_OUT="src/proto" GO_OUT="src/proto" # 生成Python代码 -python -m grpc_tools.protoc -I${PROTO_DIR} \ +python3 -m grpc_tools.protoc -I${PROTO_DIR} \ --python_out=${PYTHON_OUT} \ --grpc_python_out=${PYTHON_OUT} \ ${PROTO_DIR}/docreader.proto diff --git a/services/docreader/src/parser/base_parser.py b/services/docreader/src/parser/base_parser.py index 4a83fb6..1a1c2f7 100644 --- a/services/docreader/src/parser/base_parser.py +++ b/services/docreader/src/parser/base_parser.py @@ -3,7 +3,7 @@ import re import os import uuid import asyncio -from typing import List, Dict, Any, Optional, Tuple +from typing import List, Dict, Any, Optional, Tuple, Union from abc import ABC, abstractmethod from dataclasses import dataclass, field import logging @@ -13,6 +13,7 @@ import traceback import numpy as np import time from .ocr_engine import OCREngine +from .image_utils import image_to_base64 # Add parent directory to Python path for src imports current_dir = os.path.dirname(os.path.abspath(__file__)) @@ -100,6 +101,7 @@ class BaseParser(ABC): max_image_size: int = 1920, # Maximum image size max_concurrent_tasks: int = 5, # Max concurrent tasks max_chunks: int = 1000, # Max number of returned chunks + chunking_config: object = None, # Chunking configuration object ): """Initialize parser @@ -127,6 +129,7 @@ class BaseParser(ABC): self.max_image_size = max_image_size self.max_concurrent_tasks = max_concurrent_tasks self.max_chunks = max_chunks + self.chunking_config = chunking_config logger.info( f"Initializing {self.__class__.__name__} for file: {file_name}, type: {self.file_type}" @@ -141,7 +144,12 @@ class BaseParser(ABC): # Only initialize Caption service if multimodal is enabled if self.enable_multimodal: try: - self.caption_parser = Caption() + # Get VLM config from chunking config if available + vlm_config = None + if self.chunking_config and hasattr(self.chunking_config, 'vlm_config'): + vlm_config = self.chunking_config.vlm_config + + self.caption_parser = Caption(vlm_config) except Exception as e: logger.warning(f"Failed to initialize Caption service: {str(e)}") self.caption_parser = None @@ -200,55 +208,51 @@ class BaseParser(ABC): resized_image.close() def _resize_image_if_needed(self, image): - """Resize image to avoid processing large images + """Resize image if it exceeds maximum size limit Args: image: Image object (PIL.Image or numpy array) Returns: - Resized image + Resized image object """ try: - # Check if it's a PIL.Image object - if hasattr(image, 'width') and hasattr(image, 'height'): - width, height = image.width, image.height - # Check if resizing is needed + # If it's a PIL Image + if hasattr(image, 'size'): + width, height = image.size if width > self.max_image_size or height > self.max_image_size: - logger.info(f"Resizing image, original size: {width}x{height}") - # Calculate scaling factor + logger.info(f"Resizing PIL image, original size: {width}x{height}") scale = min(self.max_image_size / width, self.max_image_size / height) new_width = int(width * scale) new_height = int(height * scale) - # Resize resized_image = image.resize((new_width, new_height)) - logger.info(f"Image resized to: {new_width}x{new_height}") + logger.info(f"Resized to: {new_width}x{new_height}") return resized_image + else: + logger.info(f"PIL image size {width}x{height} is within limits, no resizing needed") + return image # If it's a numpy array - elif hasattr(image, 'shape') and len(image.shape) == 3: - height, width = image.shape[0], image.shape[1] + elif hasattr(image, 'shape'): + height, width = image.shape[:2] if width > self.max_image_size or height > self.max_image_size: logger.info(f"Resizing numpy image, original size: {width}x{height}") - # Use PIL for resizing - from PIL import Image + scale = min(self.max_image_size / width, self.max_image_size / height) + new_width = int(width * scale) + new_height = int(height * scale) + # Use PIL for resizing numpy arrays pil_image = Image.fromarray(image) - try: - scale = min(self.max_image_size / width, self.max_image_size / height) - new_width = int(width * scale) - new_height = int(height * scale) - resized_image = pil_image.resize((new_width, new_height)) - # Convert back to numpy array - import numpy as np - resized_array = np.array(resized_image) - logger.info(f"numpy image resized to: {new_width}x{new_height}") - return resized_array - finally: - # Ensure PIL image is closed - pil_image.close() - if 'resized_image' in locals() and hasattr(resized_image, 'close'): - resized_image.close() - return image + resized_pil = pil_image.resize((new_width, new_height)) + resized_image = np.array(resized_pil) + logger.info(f"Resized to: {new_width}x{new_height}") + return resized_image + else: + logger.info(f"Numpy image size {width}x{height} is within limits, no resizing needed") + return image + else: + logger.warning(f"Unknown image type: {type(image)}, cannot resize") + return image except Exception as e: - logger.warning(f"Failed to resize image: {str(e)}, using original image") + logger.error(f"Error resizing image: {str(e)}") return image def process_image(self, image, image_url=None): @@ -273,15 +277,21 @@ class BaseParser(ABC): ocr_text = self.perform_ocr(image) caption = "" - if self.caption_parser and image_url: + if self.caption_parser: logger.info(f"OCR successfully extracted {len(ocr_text)} characters, continuing to get caption") - caption = self.get_image_caption(image_url) - if caption: - logger.info(f"Successfully obtained image caption: {caption}") + # Convert image to base64 for caption generation + img_base64 = image_to_base64(image) + if img_base64: + caption = self.get_image_caption(img_base64) + if caption: + logger.info(f"Successfully obtained image caption: {caption}") + else: + logger.warning("Failed to get caption") else: - logger.warning("Failed to get caption") + logger.warning("Failed to convert image to base64") + caption = "" else: - logger.info("image_url not provided or Caption service not initialized, skipping caption retrieval") + logger.info("Caption service not initialized, skipping caption retrieval") # Release image resources del image @@ -323,21 +333,27 @@ class BaseParser(ABC): logger.info(f"OCR successfully extracted {len(ocr_text)} characters, continuing to get caption") caption = "" - if self.caption_parser and image_url: + if self.caption_parser: try: - # Add timeout to avoid blocking caption retrieval (30 seconds timeout) - caption_task = self.get_image_caption_async(image_url) - image_url, caption = await asyncio.wait_for(caption_task, timeout=30.0) - if caption: - logger.info(f"Successfully obtained image caption: {caption}") + # Convert image to base64 for caption generation + img_base64 = image_to_base64(resized_image) + if img_base64: + # Add timeout to avoid blocking caption retrieval (30 seconds timeout) + caption_task = self.get_image_caption_async(img_base64) + image_data, caption = await asyncio.wait_for(caption_task, timeout=30.0) + if caption: + logger.info(f"Successfully obtained image caption: {caption}") + else: + logger.warning("Failed to get caption") else: - logger.warning("Failed to get caption") + logger.warning("Failed to convert image to base64") + caption = "" except asyncio.TimeoutError: logger.warning("Caption retrieval timed out, skipping") except Exception as e: logger.error(f"Failed to get caption: {str(e)}") else: - logger.info("image_url not provided or Caption service not initialized, skipping caption retrieval") + logger.info("Caption service not initialized, skipping caption retrieval") return ocr_text, caption, image_url finally: @@ -473,56 +489,66 @@ class BaseParser(ABC): logger.info(f"Decoded text length: {len(text)} characters") return text - def get_image_caption(self, image_url: str) -> str: + def get_image_caption(self, image_data: str) -> str: """Get image description Args: - image_url: Image URL + image_data: Image data (base64 encoded string or URL) Returns: Image description """ start_time = time.time() logger.info( - f"Getting caption for image: {image_url[:250]}..." - if len(image_url) > 250 - else f"Getting caption for image: {image_url}" + f"Getting caption for image: {image_data[:250]}..." + if len(image_data) > 250 + else f"Getting caption for image: {image_data}" ) - caption = self.caption_parser.get_caption(image_url) + caption = self.caption_parser.get_caption(image_data) if caption: logger.info( - f"Received caption of length: {len(caption)}, caption: {caption}, image_url: {image_url}," + f"Received caption of length: {len(caption)}, caption: {caption}," f"cost: {time.time() - start_time} seconds" ) else: logger.warning("Failed to get caption for image") return caption - async def get_image_caption_async(self, image_url: str) -> Tuple[str, str]: + async def get_image_caption_async(self, image_data: str) -> Tuple[str, str]: """Asynchronously get image description Args: - image_url: Image URL + image_data: Image data (base64 encoded string or URL) Returns: - Tuple[str, str]: Image URL and corresponding description + Tuple[str, str]: Image data and corresponding description """ - caption = self.get_image_caption(image_url) - return image_url, caption + caption = self.get_image_caption(image_data) + return image_data, caption - def _init_cos_client(self): + def _init_cos_client(self, cos_config=None): """Initialize Tencent Cloud COS client""" try: - # Get COS configuration from environment variables - secret_id = os.getenv("COS_SECRET_ID") - secret_key = os.getenv("COS_SECRET_KEY") - region = os.getenv("COS_REGION") - bucket_name = os.getenv("COS_BUCKET_NAME") - appid = os.getenv("COS_APP_ID") - prefix = os.getenv("COS_PATH_PREFIX") - enable_old_domain = ( - os.getenv("COS_ENABLE_OLD_DOMAIN", "true").lower() == "true" - ) + # Use provided COS config if available, otherwise fall back to environment variables + if cos_config: + secret_id = cos_config.get("secret_id") + secret_key = cos_config.get("secret_key") + region = cos_config.get("region") + bucket_name = cos_config.get("bucket_name") + appid = cos_config.get("app_id") + prefix = cos_config.get("path_prefix", "") + enable_old_domain = cos_config.get("enable_old_domain", "true").lower() == "true" + else: + # Get COS configuration from environment variables + secret_id = os.getenv("COS_SECRET_ID") + secret_key = os.getenv("COS_SECRET_KEY") + region = os.getenv("COS_REGION") + bucket_name = os.getenv("COS_BUCKET_NAME") + appid = os.getenv("COS_APP_ID") + prefix = os.getenv("COS_PATH_PREFIX") + enable_old_domain = ( + os.getenv("COS_ENABLE_OLD_DOMAIN", "true").lower() == "true" + ) if not all([secret_id, secret_key, region, bucket_name, appid]): logger.error( @@ -600,12 +626,13 @@ class BaseParser(ABC): logger.error(f"Failed to upload file to COS: {str(e)}") return "" - def upload_bytes(self, content: bytes, file_ext: str = ".png") -> str: + def upload_bytes(self, content: bytes, file_ext: str = ".png", cos_config=None) -> str: """Directly upload file content to Tencent Cloud COS Args: content: File byte content file_ext: File extension, default is .png + cos_config: COS configuration dictionary Returns: File URL @@ -614,7 +641,7 @@ class BaseParser(ABC): f"Uploading bytes content to COS, content size: {len(content)} bytes" ) try: - client, bucket_name, region, prefix = self._init_cos_client() + client, bucket_name, region, prefix = self._init_cos_client(cos_config) if not client: return "" diff --git a/services/docreader/src/parser/caption.py b/services/docreader/src/parser/caption.py index 6d7f8b3..bd51290 100644 --- a/services/docreader/src/parser/caption.py +++ b/services/docreader/src/parser/caption.py @@ -1,10 +1,13 @@ import json import logging import os +import time from dataclasses import dataclass, field from typing import List, Optional, Union import requests +import ollama + logger = logging.getLogger(__name__) @@ -168,39 +171,108 @@ class Caption: Uses an external API to process images and return textual descriptions. """ - def __init__(self): - """Initialize the Caption service with configuration from environment variables.""" + def __init__(self, vlm_config=None): + """Initialize the Caption service with configuration from parameters or environment variables.""" logger.info("Initializing Caption service") self.prompt = """简单凝炼的描述图片的主要内容""" - if os.getenv("VLM_MODEL_BASE_URL") == "" or os.getenv("VLM_MODEL_NAME") == "": - logger.error("VLM_MODEL_BASE_URL or VLM_MODEL_NAME is not set") - return - self.completion_url = os.getenv("VLM_MODEL_BASE_URL") + "/v1/chat/completions" - self.model = os.getenv("VLM_MODEL_NAME") - self.api_key = os.getenv("VLM_MODEL_API_KEY") + + # Use provided VLM config if available, otherwise fall back to environment variables + if vlm_config: + self.completion_url = vlm_config.get("base_url", "") + "/chat/completions" + self.model = vlm_config.get("model_name", "") + self.api_key = vlm_config.get("api_key", "") + self.interface_type = vlm_config.get("interface_type", "openai").lower() + else: + if os.getenv("VLM_MODEL_BASE_URL") == "" or os.getenv("VLM_MODEL_NAME") == "": + logger.error("VLM_MODEL_BASE_URL or VLM_MODEL_NAME is not set") + return + self.completion_url = os.getenv("VLM_MODEL_BASE_URL") + "/chat/completions" + self.model = os.getenv("VLM_MODEL_NAME") + self.api_key = os.getenv("VLM_MODEL_API_KEY") + self.interface_type = os.getenv("VLM_INTERFACE_TYPE", "openai").lower() + + # 验证接口类型 + if self.interface_type not in ["ollama", "openai"]: + logger.warning(f"Unknown interface type: {self.interface_type}, defaulting to openai") + self.interface_type = "openai" + logger.info( - f"Service configured with model: {self.model}, endpoint: {self.completion_url}" + f"Service configured with model: {self.model}, endpoint: {self.completion_url}, interface: {self.interface_type}" ) - def _call_caption_api(self, image_url: str) -> Optional[CaptionChatResp]: + def _call_caption_api(self, image_data: str) -> Optional[CaptionChatResp]: """ Call the Caption API to generate a description for the given image. Args: - image_url: URL of the image to be captioned + image_data: URL of the image or base64 encoded image data Returns: CaptionChatResp object if successful, None otherwise """ logger.info(f"Calling Caption API for image captioning") - logger.info(f"Processing image from URL: {image_url}") + logger.info(f"Processing image data: {image_data[:50] if len(image_data) > 50 else image_data}") + # 根据接口类型选择调用方式 + if self.interface_type == "ollama": + return self._call_ollama_api(image_data) + else: + return self._call_openai_api(image_data) + + def _call_ollama_api(self, image_base64: str) -> Optional[CaptionChatResp]: + """Call Ollama API for image captioning using base64 encoded image data.""" + + host = self.completion_url.replace("/v1/chat/completions", "") + + client = ollama.Client( + host=host, + ) + + try: + logger.info(f"Calling Ollama API with model: {self.model}") + + # 调用Ollama API,使用images参数传递base64编码的图片 + response = client.generate( + model=self.model, + prompt="简单凝炼的描述图片的主要内容", + images=[image_base64], # image_base64是base64编码的图片数据 + options={"temperature": 0.1}, + stream=False, + ) + + # 构造响应对象 + caption_resp = CaptionChatResp( + id="ollama_response", + created=int(time.time()), + model=self.model, + object="chat.completion", + choices=[ + Choice( + message=Message( + role="assistant", + content=response.response + ) + ) + ] + ) + + logger.info("Successfully received response from Ollama API") + return caption_resp + + except Exception as e: + logger.error(f"Error calling Ollama API: {e}") + return None + + def _call_openai_api(self, image_base64: str) -> Optional[CaptionChatResp]: + """Call OpenAI-compatible API for image captioning.""" + logger.info(f"Calling OpenAI-compatible API with model: {self.model}") + user_msg = UserMessage( role="user", content=[ Content(type="text", text=self.prompt), Content( - type="image_url", image_url=ImageUrl(url=image_url, detail="auto") + type="image_url", image_url=ImageUrl(url="data:image/png;base64," + image_base64, detail="auto") ), ], ) @@ -223,7 +295,7 @@ class Caption: headers["Authorization"] = f"Bearer {self.api_key}" try: - logger.info(f"Sending request to Caption API with model: {self.model}") + logger.info(f"Sending request to OpenAI-compatible API with model: {self.model}") response = requests.post( self.completion_url, data=json.dumps(gpt_req, default=lambda o: o.__dict__, indent=4), @@ -232,12 +304,12 @@ class Caption: ) if response.status_code != 200: logger.error( - f"Caption API returned non-200 status code: {response.status_code}" + f"OpenAI-compatible API returned non-200 status code: {response.status_code}" ) response.raise_for_status() logger.info( - f"Successfully received response from Caption API with status: {response.status_code}" + f"Successfully received response from OpenAI-compatible API with status: {response.status_code}" ) logger.info(f"Converting response to CaptionChatResp object") caption_resp = CaptionChatResp.from_json(response.json()) @@ -250,33 +322,30 @@ class Caption: return caption_resp except requests.exceptions.Timeout: - logger.error(f"Timeout while calling Caption API after 30 seconds") + logger.error(f"Timeout while calling OpenAI-compatible API after 30 seconds") return None except requests.exceptions.RequestException as e: - logger.error(f"Request error calling Caption API: {e}") + logger.error(f"Request error calling OpenAI-compatible API: {e}") return None except Exception as e: - logger.error(f"Error calling Caption API: {str(e)}") - logger.info( - f"Request details: model={self.model}, URL={self.completion_url}" - ) + logger.error(f"Unexpected error calling OpenAI-compatible API: {e}") return None - def get_caption(self, image_url: str) -> str: + def get_caption(self, image_data: str) -> str: """ - Get a caption for the provided image URL. + Get a caption for the provided image data. Args: - image_url: URL of the image to be captioned + image_data: URL of the image or base64 encoded image data Returns: Caption text as string, or empty string if captioning failed """ logger.info("Getting caption for image") - if not image_url or self.completion_url is None: - logger.error("Image URL is not set") + if not image_data or self.completion_url is None: + logger.error("Image data is not set") return "" - caption_resp = self._call_caption_api(image_url) + caption_resp = self._call_caption_api(image_data) if caption_resp: caption = caption_resp.choice_data() caption_length = len(caption) diff --git a/services/docreader/src/parser/image_parser.py b/services/docreader/src/parser/image_parser.py index fb94449..f29f68a 100644 --- a/services/docreader/src/parser/image_parser.py +++ b/services/docreader/src/parser/image_parser.py @@ -37,7 +37,13 @@ class ImageParser(BaseParser): # Upload image to storage service logger.info("Uploading image to storage") _, ext = os.path.splitext(self.file_name) - image_url = self.upload_bytes(content, file_ext=ext) + + # Get COS config from chunking config if available + cos_config = None + if hasattr(self, 'chunking_config') and self.chunking_config and hasattr(self.chunking_config, 'cos_config'): + cos_config = self.chunking_config.cos_config + + image_url = self.upload_bytes(content, file_ext=ext, cos_config=cos_config) if not image_url: logger.error("Failed to upload image to storage") return "" diff --git a/services/docreader/src/parser/image_utils.py b/services/docreader/src/parser/image_utils.py new file mode 100644 index 0000000..55cb474 --- /dev/null +++ b/services/docreader/src/parser/image_utils.py @@ -0,0 +1,43 @@ +import base64 +import io +import logging +from typing import Union +from PIL import Image +import numpy as np + +logger = logging.getLogger(__name__) + +def image_to_base64(image: Union[str, bytes, Image.Image, np.ndarray]) -> str: + """Convert image to base64 encoded string + + Args: + image: Image file path, bytes, PIL Image object, or numpy array + + Returns: + Base64 encoded image string, or empty string if conversion fails + """ + try: + if isinstance(image, str): + # It's a file path + with open(image, "rb") as image_file: + return base64.b64encode(image_file.read()).decode("utf-8") + elif isinstance(image, bytes): + # It's bytes data + return base64.b64encode(image).decode("utf-8") + elif isinstance(image, Image.Image): + # It's a PIL Image + buffer = io.BytesIO() + image.save(buffer, format="PNG") + return base64.b64encode(buffer.getvalue()).decode("utf-8") + elif isinstance(image, np.ndarray): + # It's a numpy array + pil_image = Image.fromarray(image) + buffer = io.BytesIO() + pil_image.save(buffer, format="PNG") + return base64.b64encode(buffer.getvalue()).decode("utf-8") + else: + logger.error(f"Unsupported image type: {type(image)}") + return "" + except Exception as e: + logger.error(f"Error converting image to base64: {str(e)}") + return "" diff --git a/services/docreader/src/parser/ocr_engine.py b/services/docreader/src/parser/ocr_engine.py index b7ae673..6b3ec27 100644 --- a/services/docreader/src/parser/ocr_engine.py +++ b/services/docreader/src/parser/ocr_engine.py @@ -6,6 +6,7 @@ from abc import ABC, abstractmethod from PIL import Image import io import numpy as np +from .image_utils import image_to_base64 logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -161,35 +162,6 @@ class NanonetsOCRBackend(OCRBackend): logger.error(f"Failed to initialize Nanonets OCR: {str(e)}") self.client = None - def _encode_image(self, image: Union[str, bytes, Image.Image]) -> str: - """Encode image to base64 - - Args: - image: Image file path, bytes, or PIL Image object - - Returns: - Base64 encoded image - """ - try: - if isinstance(image, str): - # It's a file path - with open(image, "rb") as image_file: - return base64.b64encode(image_file.read()).decode("utf-8") - elif isinstance(image, bytes): - # It's bytes data - return base64.b64encode(image).decode("utf-8") - elif isinstance(image, Image.Image): - # It's a PIL Image - buffer = io.BytesIO() - image.save(buffer, format="PNG") - return base64.b64encode(buffer.getvalue()).decode("utf-8") - else: - logger.error(f"Unsupported image type: {type(image)}") - return "" - except Exception as e: - logger.error(f"Error encoding image: {str(e)}") - return "" - def predict(self, image: Union[str, bytes, Image.Image]) -> str: """Extract text from an image using Nanonets OCR @@ -205,7 +177,7 @@ class NanonetsOCRBackend(OCRBackend): try: # Encode image to base64 - img_base64 = self._encode_image(image) + img_base64 = image_to_base64(image) if not img_base64: return "" diff --git a/services/docreader/src/parser/parser.py b/services/docreader/src/parser/parser.py index 14c30a7..d81a191 100644 --- a/services/docreader/src/parser/parser.py +++ b/services/docreader/src/parser/parser.py @@ -30,6 +30,8 @@ class ChunkingConfig: enable_multimodal: bool = ( False # Whether to enable multimodal processing (text + images) ) + cos_config: dict = None # COS configuration for file storage + vlm_config: dict = None # VLM configuration for image captioning @dataclass @@ -138,6 +140,7 @@ class Parser: enable_multimodal=config.enable_multimodal, max_image_size=1920, # Limit image size to 1920px max_concurrent_tasks=5, # Limit concurrent tasks to 5 + chunking_config=config, # Pass the entire chunking config ) logger.info(f"Starting to parse file content, size: {len(content)} bytes") @@ -197,6 +200,7 @@ class Parser: enable_multimodal=config.enable_multimodal, max_image_size=1920, # Limit image size max_concurrent_tasks=5, # Limit concurrent tasks + chunking_config=config, ) logger.info(f"Starting to parse URL content") diff --git a/services/docreader/src/proto/docreader.pb.go b/services/docreader/src/proto/docreader.pb.go index 756f87e..70901b9 100644 --- a/services/docreader/src/proto/docreader.pb.go +++ b/services/docreader/src/proto/docreader.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.36.6 +// protoc-gen-go v1.36.7 // protoc v5.29.3 // source: docreader.proto @@ -21,19 +21,175 @@ const ( _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) +// COS 配置 +type COSConfig struct { + state protoimpl.MessageState `protogen:"open.v1"` + SecretId string `protobuf:"bytes,1,opt,name=secret_id,json=secretId,proto3" json:"secret_id,omitempty"` // COS Secret ID + SecretKey string `protobuf:"bytes,2,opt,name=secret_key,json=secretKey,proto3" json:"secret_key,omitempty"` // COS Secret Key + Region string `protobuf:"bytes,3,opt,name=region,proto3" json:"region,omitempty"` // COS Region + BucketName string `protobuf:"bytes,4,opt,name=bucket_name,json=bucketName,proto3" json:"bucket_name,omitempty"` // COS Bucket Name + AppId string `protobuf:"bytes,5,opt,name=app_id,json=appId,proto3" json:"app_id,omitempty"` // COS App ID + PathPrefix string `protobuf:"bytes,6,opt,name=path_prefix,json=pathPrefix,proto3" json:"path_prefix,omitempty"` // COS Path Prefix + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *COSConfig) Reset() { + *x = COSConfig{} + mi := &file_docreader_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *COSConfig) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*COSConfig) ProtoMessage() {} + +func (x *COSConfig) ProtoReflect() protoreflect.Message { + mi := &file_docreader_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use COSConfig.ProtoReflect.Descriptor instead. +func (*COSConfig) Descriptor() ([]byte, []int) { + return file_docreader_proto_rawDescGZIP(), []int{0} +} + +func (x *COSConfig) GetSecretId() string { + if x != nil { + return x.SecretId + } + return "" +} + +func (x *COSConfig) GetSecretKey() string { + if x != nil { + return x.SecretKey + } + return "" +} + +func (x *COSConfig) GetRegion() string { + if x != nil { + return x.Region + } + return "" +} + +func (x *COSConfig) GetBucketName() string { + if x != nil { + return x.BucketName + } + return "" +} + +func (x *COSConfig) GetAppId() string { + if x != nil { + return x.AppId + } + return "" +} + +func (x *COSConfig) GetPathPrefix() string { + if x != nil { + return x.PathPrefix + } + return "" +} + +// VLM 配置 +type VLMConfig struct { + state protoimpl.MessageState `protogen:"open.v1"` + ModelName string `protobuf:"bytes,1,opt,name=model_name,json=modelName,proto3" json:"model_name,omitempty"` // VLM Model Name + BaseUrl string `protobuf:"bytes,2,opt,name=base_url,json=baseUrl,proto3" json:"base_url,omitempty"` // VLM Base URL + ApiKey string `protobuf:"bytes,3,opt,name=api_key,json=apiKey,proto3" json:"api_key,omitempty"` // VLM API Key + InterfaceType string `protobuf:"bytes,4,opt,name=interface_type,json=interfaceType,proto3" json:"interface_type,omitempty"` // VLM Interface Type: "ollama" or "openai" + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *VLMConfig) Reset() { + *x = VLMConfig{} + mi := &file_docreader_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *VLMConfig) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*VLMConfig) ProtoMessage() {} + +func (x *VLMConfig) ProtoReflect() protoreflect.Message { + mi := &file_docreader_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use VLMConfig.ProtoReflect.Descriptor instead. +func (*VLMConfig) Descriptor() ([]byte, []int) { + return file_docreader_proto_rawDescGZIP(), []int{1} +} + +func (x *VLMConfig) GetModelName() string { + if x != nil { + return x.ModelName + } + return "" +} + +func (x *VLMConfig) GetBaseUrl() string { + if x != nil { + return x.BaseUrl + } + return "" +} + +func (x *VLMConfig) GetApiKey() string { + if x != nil { + return x.ApiKey + } + return "" +} + +func (x *VLMConfig) GetInterfaceType() string { + if x != nil { + return x.InterfaceType + } + return "" +} + type ReadConfig struct { state protoimpl.MessageState `protogen:"open.v1"` ChunkSize int32 `protobuf:"varint,1,opt,name=chunk_size,json=chunkSize,proto3" json:"chunk_size,omitempty"` // 分块大小 ChunkOverlap int32 `protobuf:"varint,2,opt,name=chunk_overlap,json=chunkOverlap,proto3" json:"chunk_overlap,omitempty"` // 分块重叠 Separators []string `protobuf:"bytes,3,rep,name=separators,proto3" json:"separators,omitempty"` // 分隔符 EnableMultimodal bool `protobuf:"varint,4,opt,name=enable_multimodal,json=enableMultimodal,proto3" json:"enable_multimodal,omitempty"` // 多模态处理 + CosConfig *COSConfig `protobuf:"bytes,5,opt,name=cos_config,json=cosConfig,proto3" json:"cos_config,omitempty"` // COS 配置 + VlmConfig *VLMConfig `protobuf:"bytes,6,opt,name=vlm_config,json=vlmConfig,proto3" json:"vlm_config,omitempty"` // VLM 配置 unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *ReadConfig) Reset() { *x = ReadConfig{} - mi := &file_docreader_proto_msgTypes[0] + mi := &file_docreader_proto_msgTypes[2] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -45,7 +201,7 @@ func (x *ReadConfig) String() string { func (*ReadConfig) ProtoMessage() {} func (x *ReadConfig) ProtoReflect() protoreflect.Message { - mi := &file_docreader_proto_msgTypes[0] + mi := &file_docreader_proto_msgTypes[2] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -58,7 +214,7 @@ func (x *ReadConfig) ProtoReflect() protoreflect.Message { // Deprecated: Use ReadConfig.ProtoReflect.Descriptor instead. func (*ReadConfig) Descriptor() ([]byte, []int) { - return file_docreader_proto_rawDescGZIP(), []int{0} + return file_docreader_proto_rawDescGZIP(), []int{2} } func (x *ReadConfig) GetChunkSize() int32 { @@ -89,6 +245,20 @@ func (x *ReadConfig) GetEnableMultimodal() bool { return false } +func (x *ReadConfig) GetCosConfig() *COSConfig { + if x != nil { + return x.CosConfig + } + return nil +} + +func (x *ReadConfig) GetVlmConfig() *VLMConfig { + if x != nil { + return x.VlmConfig + } + return nil +} + // 从文件读取文档请求 type ReadFromFileRequest struct { state protoimpl.MessageState `protogen:"open.v1"` @@ -103,7 +273,7 @@ type ReadFromFileRequest struct { func (x *ReadFromFileRequest) Reset() { *x = ReadFromFileRequest{} - mi := &file_docreader_proto_msgTypes[1] + mi := &file_docreader_proto_msgTypes[3] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -115,7 +285,7 @@ func (x *ReadFromFileRequest) String() string { func (*ReadFromFileRequest) ProtoMessage() {} func (x *ReadFromFileRequest) ProtoReflect() protoreflect.Message { - mi := &file_docreader_proto_msgTypes[1] + mi := &file_docreader_proto_msgTypes[3] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -128,7 +298,7 @@ func (x *ReadFromFileRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use ReadFromFileRequest.ProtoReflect.Descriptor instead. func (*ReadFromFileRequest) Descriptor() ([]byte, []int) { - return file_docreader_proto_rawDescGZIP(), []int{1} + return file_docreader_proto_rawDescGZIP(), []int{3} } func (x *ReadFromFileRequest) GetFileContent() []byte { @@ -179,7 +349,7 @@ type ReadFromURLRequest struct { func (x *ReadFromURLRequest) Reset() { *x = ReadFromURLRequest{} - mi := &file_docreader_proto_msgTypes[2] + mi := &file_docreader_proto_msgTypes[4] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -191,7 +361,7 @@ func (x *ReadFromURLRequest) String() string { func (*ReadFromURLRequest) ProtoMessage() {} func (x *ReadFromURLRequest) ProtoReflect() protoreflect.Message { - mi := &file_docreader_proto_msgTypes[2] + mi := &file_docreader_proto_msgTypes[4] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -204,7 +374,7 @@ func (x *ReadFromURLRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use ReadFromURLRequest.ProtoReflect.Descriptor instead. func (*ReadFromURLRequest) Descriptor() ([]byte, []int) { - return file_docreader_proto_rawDescGZIP(), []int{2} + return file_docreader_proto_rawDescGZIP(), []int{4} } func (x *ReadFromURLRequest) GetUrl() string { @@ -250,7 +420,7 @@ type Image struct { func (x *Image) Reset() { *x = Image{} - mi := &file_docreader_proto_msgTypes[3] + mi := &file_docreader_proto_msgTypes[5] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -262,7 +432,7 @@ func (x *Image) String() string { func (*Image) ProtoMessage() {} func (x *Image) ProtoReflect() protoreflect.Message { - mi := &file_docreader_proto_msgTypes[3] + mi := &file_docreader_proto_msgTypes[5] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -275,7 +445,7 @@ func (x *Image) ProtoReflect() protoreflect.Message { // Deprecated: Use Image.ProtoReflect.Descriptor instead. func (*Image) Descriptor() ([]byte, []int) { - return file_docreader_proto_rawDescGZIP(), []int{3} + return file_docreader_proto_rawDescGZIP(), []int{5} } func (x *Image) GetUrl() string { @@ -333,7 +503,7 @@ type Chunk struct { func (x *Chunk) Reset() { *x = Chunk{} - mi := &file_docreader_proto_msgTypes[4] + mi := &file_docreader_proto_msgTypes[6] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -345,7 +515,7 @@ func (x *Chunk) String() string { func (*Chunk) ProtoMessage() {} func (x *Chunk) ProtoReflect() protoreflect.Message { - mi := &file_docreader_proto_msgTypes[4] + mi := &file_docreader_proto_msgTypes[6] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -358,7 +528,7 @@ func (x *Chunk) ProtoReflect() protoreflect.Message { // Deprecated: Use Chunk.ProtoReflect.Descriptor instead. func (*Chunk) Descriptor() ([]byte, []int) { - return file_docreader_proto_rawDescGZIP(), []int{4} + return file_docreader_proto_rawDescGZIP(), []int{6} } func (x *Chunk) GetContent() string { @@ -407,7 +577,7 @@ type ReadResponse struct { func (x *ReadResponse) Reset() { *x = ReadResponse{} - mi := &file_docreader_proto_msgTypes[5] + mi := &file_docreader_proto_msgTypes[7] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -419,7 +589,7 @@ func (x *ReadResponse) String() string { func (*ReadResponse) ProtoMessage() {} func (x *ReadResponse) ProtoReflect() protoreflect.Message { - mi := &file_docreader_proto_msgTypes[5] + mi := &file_docreader_proto_msgTypes[7] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -432,7 +602,7 @@ func (x *ReadResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use ReadResponse.ProtoReflect.Descriptor instead. func (*ReadResponse) Descriptor() ([]byte, []int) { - return file_docreader_proto_rawDescGZIP(), []int{5} + return file_docreader_proto_rawDescGZIP(), []int{7} } func (x *ReadResponse) GetChunks() []*Chunk { @@ -453,7 +623,23 @@ var File_docreader_proto protoreflect.FileDescriptor const file_docreader_proto_rawDesc = "" + "\n" + - "\x0fdocreader.proto\x12\tdocreader\"\x9d\x01\n" + + "\x0fdocreader.proto\x12\tdocreader\"\xb8\x01\n" + + "\tCOSConfig\x12\x1b\n" + + "\tsecret_id\x18\x01 \x01(\tR\bsecretId\x12\x1d\n" + + "\n" + + "secret_key\x18\x02 \x01(\tR\tsecretKey\x12\x16\n" + + "\x06region\x18\x03 \x01(\tR\x06region\x12\x1f\n" + + "\vbucket_name\x18\x04 \x01(\tR\n" + + "bucketName\x12\x15\n" + + "\x06app_id\x18\x05 \x01(\tR\x05appId\x12\x1f\n" + + "\vpath_prefix\x18\x06 \x01(\tR\n" + + "pathPrefix\"\x85\x01\n" + + "\tVLMConfig\x12\x1d\n" + + "\n" + + "model_name\x18\x01 \x01(\tR\tmodelName\x12\x19\n" + + "\bbase_url\x18\x02 \x01(\tR\abaseUrl\x12\x17\n" + + "\aapi_key\x18\x03 \x01(\tR\x06apiKey\x12%\n" + + "\x0einterface_type\x18\x04 \x01(\tR\rinterfaceType\"\x87\x02\n" + "\n" + "ReadConfig\x12\x1d\n" + "\n" + @@ -462,7 +648,11 @@ const file_docreader_proto_rawDesc = "" + "\n" + "separators\x18\x03 \x03(\tR\n" + "separators\x12+\n" + - "\x11enable_multimodal\x18\x04 \x01(\bR\x10enableMultimodal\"\xc9\x01\n" + + "\x11enable_multimodal\x18\x04 \x01(\bR\x10enableMultimodal\x123\n" + + "\n" + + "cos_config\x18\x05 \x01(\v2\x14.docreader.COSConfigR\tcosConfig\x123\n" + + "\n" + + "vlm_config\x18\x06 \x01(\v2\x14.docreader.VLMConfigR\tvlmConfig\"\xc9\x01\n" + "\x13ReadFromFileRequest\x12!\n" + "\ffile_content\x18\x01 \x01(\fR\vfileContent\x12\x1b\n" + "\tfile_name\x18\x02 \x01(\tR\bfileName\x12\x1b\n" + @@ -510,29 +700,33 @@ func file_docreader_proto_rawDescGZIP() []byte { return file_docreader_proto_rawDescData } -var file_docreader_proto_msgTypes = make([]protoimpl.MessageInfo, 6) +var file_docreader_proto_msgTypes = make([]protoimpl.MessageInfo, 8) var file_docreader_proto_goTypes = []any{ - (*ReadConfig)(nil), // 0: docreader.ReadConfig - (*ReadFromFileRequest)(nil), // 1: docreader.ReadFromFileRequest - (*ReadFromURLRequest)(nil), // 2: docreader.ReadFromURLRequest - (*Image)(nil), // 3: docreader.Image - (*Chunk)(nil), // 4: docreader.Chunk - (*ReadResponse)(nil), // 5: docreader.ReadResponse + (*COSConfig)(nil), // 0: docreader.COSConfig + (*VLMConfig)(nil), // 1: docreader.VLMConfig + (*ReadConfig)(nil), // 2: docreader.ReadConfig + (*ReadFromFileRequest)(nil), // 3: docreader.ReadFromFileRequest + (*ReadFromURLRequest)(nil), // 4: docreader.ReadFromURLRequest + (*Image)(nil), // 5: docreader.Image + (*Chunk)(nil), // 6: docreader.Chunk + (*ReadResponse)(nil), // 7: docreader.ReadResponse } var file_docreader_proto_depIdxs = []int32{ - 0, // 0: docreader.ReadFromFileRequest.read_config:type_name -> docreader.ReadConfig - 0, // 1: docreader.ReadFromURLRequest.read_config:type_name -> docreader.ReadConfig - 3, // 2: docreader.Chunk.images:type_name -> docreader.Image - 4, // 3: docreader.ReadResponse.chunks:type_name -> docreader.Chunk - 1, // 4: docreader.DocReader.ReadFromFile:input_type -> docreader.ReadFromFileRequest - 2, // 5: docreader.DocReader.ReadFromURL:input_type -> docreader.ReadFromURLRequest - 5, // 6: docreader.DocReader.ReadFromFile:output_type -> docreader.ReadResponse - 5, // 7: docreader.DocReader.ReadFromURL:output_type -> docreader.ReadResponse - 6, // [6:8] is the sub-list for method output_type - 4, // [4:6] is the sub-list for method input_type - 4, // [4:4] is the sub-list for extension type_name - 4, // [4:4] is the sub-list for extension extendee - 0, // [0:4] is the sub-list for field type_name + 0, // 0: docreader.ReadConfig.cos_config:type_name -> docreader.COSConfig + 1, // 1: docreader.ReadConfig.vlm_config:type_name -> docreader.VLMConfig + 2, // 2: docreader.ReadFromFileRequest.read_config:type_name -> docreader.ReadConfig + 2, // 3: docreader.ReadFromURLRequest.read_config:type_name -> docreader.ReadConfig + 5, // 4: docreader.Chunk.images:type_name -> docreader.Image + 6, // 5: docreader.ReadResponse.chunks:type_name -> docreader.Chunk + 3, // 6: docreader.DocReader.ReadFromFile:input_type -> docreader.ReadFromFileRequest + 4, // 7: docreader.DocReader.ReadFromURL:input_type -> docreader.ReadFromURLRequest + 7, // 8: docreader.DocReader.ReadFromFile:output_type -> docreader.ReadResponse + 7, // 9: docreader.DocReader.ReadFromURL:output_type -> docreader.ReadResponse + 8, // [8:10] is the sub-list for method output_type + 6, // [6:8] is the sub-list for method input_type + 6, // [6:6] is the sub-list for extension type_name + 6, // [6:6] is the sub-list for extension extendee + 0, // [0:6] is the sub-list for field type_name } func init() { file_docreader_proto_init() } @@ -546,7 +740,7 @@ func file_docreader_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_docreader_proto_rawDesc), len(file_docreader_proto_rawDesc)), NumEnums: 0, - NumMessages: 6, + NumMessages: 8, NumExtensions: 0, NumServices: 1, }, diff --git a/services/docreader/src/proto/docreader.proto b/services/docreader/src/proto/docreader.proto index 6ca3d5c..3a9ce29 100644 --- a/services/docreader/src/proto/docreader.proto +++ b/services/docreader/src/proto/docreader.proto @@ -12,11 +12,31 @@ service DocReader { rpc ReadFromURL(ReadFromURLRequest) returns (ReadResponse) {} } +// COS 配置 +message COSConfig { + string secret_id = 1; // COS Secret ID + string secret_key = 2; // COS Secret Key + string region = 3; // COS Region + string bucket_name = 4; // COS Bucket Name + string app_id = 5; // COS App ID + string path_prefix = 6; // COS Path Prefix +} + +// VLM 配置 +message VLMConfig { + string model_name = 1; // VLM Model Name + string base_url = 2; // VLM Base URL + string api_key = 3; // VLM API Key + string interface_type = 4; // VLM Interface Type: "ollama" or "openai" +} + message ReadConfig { int32 chunk_size = 1; // 分块大小 int32 chunk_overlap = 2; // 分块重叠 repeated string separators = 3; // 分隔符 bool enable_multimodal = 4; // 多模态处理 + COSConfig cos_config = 5; // COS 配置 + VLMConfig vlm_config = 6; // VLM 配置 } // 从文件读取文档请求 diff --git a/services/docreader/src/proto/docreader_pb2_grpc.py b/services/docreader/src/proto/docreader_pb2_grpc.py index fd391bd..1b4eaf7 100644 --- a/services/docreader/src/proto/docreader_pb2_grpc.py +++ b/services/docreader/src/proto/docreader_pb2_grpc.py @@ -1,9 +1,29 @@ # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! """Client and server classes corresponding to protobuf-defined services.""" import grpc +import warnings from . import docreader_pb2 as docreader__pb2 +GRPC_GENERATED_VERSION = '1.74.0' +GRPC_VERSION = grpc.__version__ +_version_not_supported = False + +try: + from grpc._utilities import first_version_is_lower + _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) +except ImportError: + _version_not_supported = True + +if _version_not_supported: + raise RuntimeError( + f'The grpc package installed is at version {GRPC_VERSION},' + + f' but the generated code in docreader_pb2_grpc.py depends on' + + f' grpcio>={GRPC_GENERATED_VERSION}.' + + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' + + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' + ) + class DocReaderStub(object): """文档读取服务 @@ -19,12 +39,12 @@ class DocReaderStub(object): '/docreader.DocReader/ReadFromFile', request_serializer=docreader__pb2.ReadFromFileRequest.SerializeToString, response_deserializer=docreader__pb2.ReadResponse.FromString, - ) + _registered_method=True) self.ReadFromURL = channel.unary_unary( '/docreader.DocReader/ReadFromURL', request_serializer=docreader__pb2.ReadFromURLRequest.SerializeToString, response_deserializer=docreader__pb2.ReadResponse.FromString, - ) + _registered_method=True) class DocReaderServicer(object): @@ -62,6 +82,7 @@ def add_DocReaderServicer_to_server(servicer, server): generic_handler = grpc.method_handlers_generic_handler( 'docreader.DocReader', rpc_method_handlers) server.add_generic_rpc_handlers((generic_handler,)) + server.add_registered_method_handlers('docreader.DocReader', rpc_method_handlers) # This class is part of an EXPERIMENTAL API. @@ -80,11 +101,21 @@ class DocReader(object): wait_for_ready=None, timeout=None, metadata=None): - return grpc.experimental.unary_unary(request, target, '/docreader.DocReader/ReadFromFile', + return grpc.experimental.unary_unary( + request, + target, + '/docreader.DocReader/ReadFromFile', docreader__pb2.ReadFromFileRequest.SerializeToString, docreader__pb2.ReadResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) @staticmethod def ReadFromURL(request, @@ -97,8 +128,18 @@ class DocReader(object): wait_for_ready=None, timeout=None, metadata=None): - return grpc.experimental.unary_unary(request, target, '/docreader.DocReader/ReadFromURL', + return grpc.experimental.unary_unary( + request, + target, + '/docreader.DocReader/ReadFromURL', docreader__pb2.ReadFromURLRequest.SerializeToString, docreader__pb2.ReadResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) diff --git a/services/docreader/src/server/server.py b/services/docreader/src/server/server.py index 6e391ff..c6a3403 100644 --- a/services/docreader/src/server/server.py +++ b/services/docreader/src/server/server.py @@ -74,11 +74,39 @@ class DocReaderServicer(docreader_pb2_grpc.DocReaderServicer): f"multimodal={enable_multimodal}" ) + # Get COS and VLM config from request + cos_config = None + vlm_config = None + + if hasattr(request.read_config, 'cos_config') and request.read_config.cos_config: + cos_config = { + 'secret_id': request.read_config.cos_config.secret_id, + 'secret_key': request.read_config.cos_config.secret_key, + 'region': request.read_config.cos_config.region, + 'bucket_name': request.read_config.cos_config.bucket_name, + 'app_id': request.read_config.cos_config.app_id, + 'path_prefix': request.read_config.cos_config.path_prefix or '', + } + logger.info(f"Using COS config: region={cos_config['region']}, bucket={cos_config['bucket_name']}") + + if hasattr(request.read_config, 'vlm_config') and request.read_config.vlm_config: + vlm_config = { + 'model_name': request.read_config.vlm_config.model_name, + 'base_url': request.read_config.vlm_config.base_url, + 'api_key': request.read_config.vlm_config.api_key or '', + 'interface_type': request.read_config.vlm_config.interface_type or 'openai', + } + logger.info(f"Using VLM config: model={vlm_config['model_name']}, " + f"base_url={vlm_config['base_url']}, " + f"interface_type={vlm_config['interface_type']}") + chunking_config = ChunkingConfig( chunk_size=chunk_size, chunk_overlap=chunk_overlap, separators=separators, enable_multimodal=enable_multimodal, + cos_config=cos_config, + vlm_config=vlm_config, ) # Parse file @@ -138,11 +166,39 @@ class DocReaderServicer(docreader_pb2_grpc.DocReaderServicer): f"multimodal={enable_multimodal}" ) + # Get COS and VLM config from request + cos_config = None + vlm_config = None + + if hasattr(request.read_config, 'cos_config') and request.read_config.cos_config: + cos_config = { + 'secret_id': request.read_config.cos_config.secret_id, + 'secret_key': request.read_config.cos_config.secret_key, + 'region': request.read_config.cos_config.region, + 'bucket_name': request.read_config.cos_config.bucket_name, + 'app_id': request.read_config.cos_config.app_id, + 'path_prefix': request.read_config.cos_config.path_prefix or '', + } + logger.info(f"Using COS config: region={cos_config['region']}, bucket={cos_config['bucket_name']}") + + if hasattr(request.read_config, 'vlm_config') and request.read_config.vlm_config: + vlm_config = { + 'model_name': request.read_config.vlm_config.model_name, + 'base_url': request.read_config.vlm_config.base_url, + 'api_key': request.read_config.vlm_config.api_key or '', + 'interface_type': request.read_config.vlm_config.interface_type or 'openai', + } + logger.info(f"Using VLM config: model={vlm_config['model_name']}, " + f"base_url={vlm_config['base_url']}, " + f"interface_type={vlm_config['interface_type']}") + chunking_config = ChunkingConfig( chunk_size=chunk_size, chunk_overlap=chunk_overlap, separators=separators, enable_multimodal=enable_multimodal, + cos_config=cos_config, + vlm_config=vlm_config, ) # Parse URL