mirror of
https://github.com/Tencent/WeKnora.git
synced 2025-11-24 19:12:51 +08:00
feat: Added web page for configuring model information
This commit is contained in:
116
.env.example
116
.env.example
@@ -112,73 +112,79 @@ ENABLE_GRAPH_RAG=false
|
||||
# COS_ENABLE_OLD_DOMAIN=true 表示启用旧的域名格式,默认为 true
|
||||
COS_ENABLE_OLD_DOMAIN=true
|
||||
|
||||
# 初始化默认租户与知识库
|
||||
# 租户ID,通常是一个字符串
|
||||
INIT_TEST_TENANT_ID=1
|
||||
|
||||
# 知识库ID,通常是一个字符串
|
||||
INIT_TEST_KNOWLEDGE_BASE_ID=kb-00000001
|
||||
##############################################################
|
||||
|
||||
# LLM Model
|
||||
# 使用的LLM模型名称
|
||||
# 默认使用 Ollama 的 Qwen3 8B 模型,ollama 会自动处理模型下载和加载
|
||||
# 如果需要使用其他模型,请替换为实际的模型名称
|
||||
INIT_LLM_MODEL_NAME=qwen3:8b
|
||||
###### 注意: 以下配置不再生效,已在Web“配置初始化”阶段完成 #########
|
||||
|
||||
# LLM模型的访问地址
|
||||
# 支持第三方模型服务的URL
|
||||
# 如果使用 Ollama 的本地服务,可以留空,ollama 会自动处理
|
||||
# INIT_LLM_MODEL_BASE_URL=your_llm_model_base_url
|
||||
|
||||
# LLM模型的API密钥,如果需要身份验证,可以设置
|
||||
# 支持第三方模型服务的API密钥
|
||||
# 如果使用 Ollama 的本地服务,可以留空,ollama 会自动处理
|
||||
# INIT_LLM_MODEL_API_KEY=your_llm_model_api_key
|
||||
# # 初始化默认租户与知识库
|
||||
# # 租户ID,通常是一个字符串
|
||||
# INIT_TEST_TENANT_ID=1
|
||||
|
||||
# Embedding Model
|
||||
# 使用的Embedding模型名称
|
||||
# 默认使用 nomic-embed-text 模型,支持文本嵌入
|
||||
# 如果需要使用其他模型,请替换为实际的模型名称
|
||||
INIT_EMBEDDING_MODEL_NAME=nomic-embed-text
|
||||
# # 知识库ID,通常是一个字符串
|
||||
# INIT_TEST_KNOWLEDGE_BASE_ID=kb-00000001
|
||||
|
||||
# Embedding模型向量维度
|
||||
INIT_EMBEDDING_MODEL_DIMENSION=768
|
||||
# # LLM Model
|
||||
# # 使用的LLM模型名称
|
||||
# # 默认使用 Ollama 的 Qwen3 8B 模型,ollama 会自动处理模型下载和加载
|
||||
# # 如果需要使用其他模型,请替换为实际的模型名称
|
||||
# INIT_LLM_MODEL_NAME=qwen3:8b
|
||||
|
||||
# Embedding模型的ID,通常是一个字符串
|
||||
INIT_EMBEDDING_MODEL_ID=builtin:nomic-embed-text:768
|
||||
# # LLM模型的访问地址
|
||||
# # 支持第三方模型服务的URL
|
||||
# # 如果使用 Ollama 的本地服务,可以留空,ollama 会自动处理
|
||||
# # INIT_LLM_MODEL_BASE_URL=your_llm_model_base_url
|
||||
|
||||
# Embedding模型的访问地址
|
||||
# 支持第三方模型服务的URL
|
||||
# 如果使用 Ollama 的本地服务,可以留空,ollama 会自动处理
|
||||
# INIT_EMBEDDING_MODEL_BASE_URL=your_embedding_model_base_url
|
||||
# # LLM模型的API密钥,如果需要身份验证,可以设置
|
||||
# # 支持第三方模型服务的API密钥
|
||||
# # 如果使用 Ollama 的本地服务,可以留空,ollama 会自动处理
|
||||
# # INIT_LLM_MODEL_API_KEY=your_llm_model_api_key
|
||||
|
||||
# Embedding模型的API密钥,如果需要身份验证,可以设置
|
||||
# 支持第三方模型服务的API密钥
|
||||
# 如果使用 Ollama 的本地服务,可以留空,ollama 会自动处理
|
||||
# INIT_EMBEDDING_MODEL_API_KEY=your_embedding_model_api_key
|
||||
# # Embedding Model
|
||||
# # 使用的Embedding模型名称
|
||||
# # 默认使用 nomic-embed-text 模型,支持文本嵌入
|
||||
# # 如果需要使用其他模型,请替换为实际的模型名称
|
||||
# INIT_EMBEDDING_MODEL_NAME=nomic-embed-text
|
||||
|
||||
# Rerank Model(可选)
|
||||
# 对于rag来说,使用Rerank模型对提升文档搜索的准确度有着重要作用
|
||||
# 目前 ollama 暂不支持运行 Rerank 模型
|
||||
# 使用的Rerank模型名称
|
||||
# INIT_RERANK_MODEL_NAME=your_rerank_model_name
|
||||
# # Embedding模型向量维度
|
||||
# INIT_EMBEDDING_MODEL_DIMENSION=768
|
||||
|
||||
# Rerank模型的访问地址
|
||||
# 支持第三方模型服务的URL
|
||||
# INIT_RERANK_MODEL_BASE_URL=your_rerank_model_base_url
|
||||
# # Embedding模型的ID,通常是一个字符串
|
||||
# INIT_EMBEDDING_MODEL_ID=builtin:nomic-embed-text:768
|
||||
|
||||
# Rerank模型的API密钥,如果需要身份验证,可以设置
|
||||
# 支持第三方模型服务的API密钥
|
||||
# INIT_RERANK_MODEL_API_KEY=your_rerank_model_api_key
|
||||
# # Embedding模型的访问地址
|
||||
# # 支持第三方模型服务的URL
|
||||
# # 如果使用 Ollama 的本地服务,可以留空,ollama 会自动处理
|
||||
# # INIT_EMBEDDING_MODEL_BASE_URL=your_embedding_model_base_url
|
||||
|
||||
# VLM_MODEL_NAME 使用的多模态模型名称
|
||||
# 用于解析图片数据
|
||||
# VLM_MODEL_NAME=your_vlm_model_name
|
||||
# # Embedding模型的API密钥,如果需要身份验证,可以设置
|
||||
# # 支持第三方模型服务的API密钥
|
||||
# # 如果使用 Ollama 的本地服务,可以留空,ollama 会自动处理
|
||||
# # INIT_EMBEDDING_MODEL_API_KEY=your_embedding_model_api_key
|
||||
|
||||
# VLM_MODEL_BASE_URL 使用的多模态模型访问地址
|
||||
# 支持第三方模型服务的URL
|
||||
# VLM_MODEL_BASE_URL=your_vlm_model_base_url
|
||||
# # Rerank Model(可选)
|
||||
# # 对于rag来说,使用Rerank模型对提升文档搜索的准确度有着重要作用
|
||||
# # 目前 ollama 暂不支持运行 Rerank 模型
|
||||
# # 使用的Rerank模型名称
|
||||
# # INIT_RERANK_MODEL_NAME=your_rerank_model_name
|
||||
|
||||
# VLM_MODEL_API_KEY 使用的多模态模型API密钥
|
||||
# 支持第三方模型服务的API密钥
|
||||
# VLM_MODEL_API_KEY=your_vlm_model_api_key
|
||||
# # Rerank模型的访问地址
|
||||
# # 支持第三方模型服务的URL
|
||||
# # INIT_RERANK_MODEL_BASE_URL=your_rerank_model_base_url
|
||||
|
||||
# # Rerank模型的API密钥,如果需要身份验证,可以设置
|
||||
# # 支持第三方模型服务的API密钥
|
||||
# # INIT_RERANK_MODEL_API_KEY=your_rerank_model_api_key
|
||||
|
||||
# # VLM_MODEL_NAME 使用的多模态模型名称
|
||||
# # 用于解析图片数据
|
||||
# # VLM_MODEL_NAME=your_vlm_model_name
|
||||
|
||||
# # VLM_MODEL_BASE_URL 使用的多模态模型访问地址
|
||||
# # 支持第三方模型服务的URL
|
||||
# # VLM_MODEL_BASE_URL=your_vlm_model_base_url
|
||||
|
||||
# # VLM_MODEL_API_KEY 使用的多模态模型API密钥
|
||||
# # 支持第三方模型服务的API密钥
|
||||
# # VLM_MODEL_API_KEY=your_vlm_model_api_key
|
||||
32
README.md
32
README.md
@@ -140,6 +140,38 @@ WeKnora 作为[微信对话开放平台](https://chatbot.weixin.qq.com)的核心
|
||||
- **高效问题管理**:支持高频问题的独立分类管理,提供丰富的数据工具,确保回答精准可靠且易于维护
|
||||
- **微信生态覆盖**:通过微信对话开放平台,WeKnora 的智能问答能力可无缝集成到公众号、小程序等微信场景中,提升用户交互体验
|
||||
|
||||
## 🔧 初始化配置引导
|
||||
|
||||
为了方便用户快速配置各类模型,降低试错成本,我们改进了原来的配置文件初始化方式,增加了Web UI界面进行各种模型的配置。在使用之前,请确保代码更新到最新版本。具体使用步骤如下:
|
||||
如果是第一次使用本项目,可跳过①②步骤,直接进入③④步骤。
|
||||
|
||||
### ① 关闭服务
|
||||
|
||||
```bash
|
||||
./scripts/start_all.sh --stop
|
||||
```
|
||||
|
||||
### ② 清空原有数据表(建议在没有重要数据的情况下使用)
|
||||
|
||||
```bash
|
||||
make clean-db
|
||||
```
|
||||
|
||||
### ③ 编译并启动服务
|
||||
|
||||
```bash
|
||||
./scripts/start_all.sh
|
||||
```
|
||||
|
||||
### ④ 访问Web UI
|
||||
|
||||
http://localhost
|
||||
|
||||
首次访问会自动跳转到初始化配置页面,配置完成后会自动跳转到知识库页面。请按照页面提示信息完成模型的配置。
|
||||
|
||||

|
||||
|
||||
|
||||
## 📱 功能展示
|
||||
|
||||
### Web UI 界面
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/application/service"
|
||||
"github.com/Tencent/WeKnora/internal/config"
|
||||
"github.com/Tencent/WeKnora/internal/container"
|
||||
"github.com/Tencent/WeKnora/internal/runtime"
|
||||
@@ -42,7 +41,6 @@ func main() {
|
||||
cfg *config.Config,
|
||||
router *gin.Engine,
|
||||
tracer *tracing.Tracer,
|
||||
testDataService *service.TestDataService,
|
||||
resourceCleaner interfaces.ResourceCleaner,
|
||||
) error {
|
||||
// Create context for resource cleanup
|
||||
@@ -58,13 +56,6 @@ func main() {
|
||||
return tracer.Cleanup(cleanupCtx)
|
||||
})
|
||||
|
||||
// Initialize test data
|
||||
if testDataService != nil {
|
||||
if err := testDataService.InitializeTestData(context.Background()); err != nil {
|
||||
log.Printf("Failed to initialize test data: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create HTTP server
|
||||
server := &http.Server{
|
||||
Addr: fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port),
|
||||
|
||||
@@ -151,26 +151,27 @@ conversation:
|
||||
Output:
|
||||
|
||||
generate_summary_prompt: |
|
||||
你是一个总结助手,你的任务是总结文章或者片段内容。
|
||||
你是一个精准的文章总结专家。你的任务是提取并总结用户提供的文章或片段的核心内容。
|
||||
|
||||
## 准则要求
|
||||
- 总结结果长度不能超过100个字
|
||||
- 保持客观,不添加个人观点或评价
|
||||
- 使用第三人称陈述语气
|
||||
- 不要基于任何先验知识回答用户的问题,只基于文章内容生成摘要
|
||||
- 直接输出总结结果,不要有任何前缀或解释
|
||||
- 使用中文输出总结结果
|
||||
- 不能输出“无法生成”、“无法总结”等字眼
|
||||
## 核心要求
|
||||
- 总结结果长度为50-100个字,根据内容复杂度灵活调整
|
||||
- 完全基于提供的文章内容生成总结,不添加任何未在文章中出现的信息
|
||||
- 确保总结包含文章的关键信息点和主要结论
|
||||
- 即使文章内容较复杂或专业,也必须尝试提取核心要点进行总结
|
||||
- 直接输出总结结果,不包含任何引言、前缀或解释
|
||||
|
||||
## Few-shot示例
|
||||
## 格式与风格
|
||||
- 使用客观、中立的第三人称陈述语气
|
||||
- 使用清晰简洁的中文表达
|
||||
- 保持逻辑连贯性,确保句与句之间有合理过渡
|
||||
- 避免重复使用相同的表达方式或句式结构
|
||||
|
||||
用户给出的文章内容:
|
||||
随着5G技术的快速发展,各行各业正经历数字化转型。5G网络凭借高速率、低延迟和大连接的特性,正在推动智慧城市、工业互联网和远程医疗等领域的创新。专家预测,到2025年,5G将为全球经济贡献约2.2万亿美元。然而,5G建设也面临基础设施投入大、覆盖不均等挑战。
|
||||
|
||||
文章总结:
|
||||
5G技术凭借高速率、低延迟和大连接特性推动各行业数字化转型,促进智慧城市、工业互联网和远程医疗创新。预计2025年将贡献约2.2万亿美元经济价值,但仍面临基础设施投入大和覆盖不均等挑战。
|
||||
|
||||
## 用户给出的文章内容是:
|
||||
## 注意事项
|
||||
- 绝对不输出"无法生成"、"无法总结"、"内容不足"等拒绝回应的词语
|
||||
- 不要照抄或参考示例中的任何内容,确保总结完全基于用户提供的新文章
|
||||
- 对于任何文本都尽最大努力提取重点并总结,无论长度或复杂度
|
||||
|
||||
## 以下是用户给出的文章相关信息:
|
||||
|
||||
generate_session_title_prompt: |
|
||||
你是一个专业的会话标题生成助手,你的任务是为用户提问创建简洁、精准且具描述性的标题。
|
||||
|
||||
@@ -87,6 +87,8 @@ services:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: docker/Dockerfile.docreader
|
||||
args:
|
||||
- PLATFORM=${PLATFORM:-linux/amd64}
|
||||
container_name: WeKnora-docreader
|
||||
ports:
|
||||
- "50051:50051"
|
||||
@@ -104,6 +106,8 @@ services:
|
||||
networks:
|
||||
- WeKnora-network
|
||||
restart: unless-stopped
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
|
||||
jaeger:
|
||||
image: jaegertracing/all-in-one:latest
|
||||
|
||||
@@ -93,14 +93,29 @@ RUN apt-get update && apt-get install -y \
|
||||
# 下载并安装最新版本的 LibreOffice 25.2.4
|
||||
RUN mkdir -p /tmp/libreoffice && \
|
||||
cd /tmp/libreoffice && \
|
||||
wget -q https://mirrors.tuna.tsinghua.edu.cn/libreoffice/libreoffice/stable/25.2.4/deb/x86_64/LibreOffice_25.2.4_Linux_x86-64_deb.tar.gz && \
|
||||
tar -xzf LibreOffice_25.2.4_Linux_x86-64_deb.tar.gz && \
|
||||
cd LibreOffice_25.2.4*_Linux_x86-64_deb/DEBS && \
|
||||
dpkg -i *.deb && \
|
||||
if [ "$(uname -m)" = "x86_64" ]; then \
|
||||
wget -q https://mirrors.tuna.tsinghua.edu.cn/libreoffice/libreoffice/stable/25.2.4/deb/x86_64/LibreOffice_25.2.4_Linux_x86-64_deb.tar.gz && \
|
||||
tar -xzf LibreOffice_25.2.4_Linux_x86-64_deb.tar.gz && \
|
||||
cd LibreOffice_25.2.4*_Linux_x86-64_deb/DEBS && \
|
||||
dpkg -i *.deb; \
|
||||
elif [ "$(uname -m)" = "aarch64" ] || [ "$(uname -m)" = "arm64" ]; then \
|
||||
wget -q https://mirrors.tuna.tsinghua.edu.cn/libreoffice/libreoffice/testing/25.8.0/deb/aarch64/LibreOffice_25.8.0.2_Linux_aarch64_deb.tar.gz && \
|
||||
tar -xzf LibreOffice_25.8.0.2_Linux_aarch64_deb.tar.gz && \
|
||||
cd LibreOffice_25.8.0*_Linux_aarch64_deb/DEBS && \
|
||||
dpkg -i *.deb; \
|
||||
else \
|
||||
echo "Unsupported architecture: $(uname -m)" && exit 1; \
|
||||
fi && \
|
||||
cd / && \
|
||||
rm -rf /tmp/libreoffice
|
||||
|
||||
# 设置 LibreOffice 环境变量
|
||||
RUN if [ "$(uname -m)" = "x86_64" ]; then \
|
||||
echo 'export LIBREOFFICE_PATH=/opt/libreoffice25.2/program/soffice' >> /etc/environment; \
|
||||
elif [ "$(uname -m)" = "aarch64" ] || [ "$(uname -m)" = "arm64" ]; then \
|
||||
echo 'export LIBREOFFICE_PATH=/opt/libreoffice25.2/program/soffice' >> /etc/environment; \
|
||||
fi
|
||||
|
||||
ENV LIBREOFFICE_PATH=/opt/libreoffice25.2/program/soffice
|
||||
|
||||
# 从构建阶段复制已安装的依赖和生成的代码
|
||||
|
||||
BIN
docs/images/config.png
Normal file
BIN
docs/images/config.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 626 KiB |
278
frontend/src/api/initialization/index.ts
Normal file
278
frontend/src/api/initialization/index.ts
Normal file
@@ -0,0 +1,278 @@
|
||||
import { get, post } from '../../utils/request';
|
||||
|
||||
// 初始化配置数据类型
|
||||
export interface InitializationConfig {
|
||||
llm: {
|
||||
source: string;
|
||||
modelName: string;
|
||||
baseUrl?: string;
|
||||
apiKey?: string;
|
||||
};
|
||||
embedding: {
|
||||
source: string;
|
||||
modelName: string;
|
||||
baseUrl?: string;
|
||||
apiKey?: string;
|
||||
dimension?: number; // 添加embedding维度字段
|
||||
};
|
||||
rerank: {
|
||||
modelName: string;
|
||||
baseUrl: string;
|
||||
apiKey?: string;
|
||||
};
|
||||
multimodal: {
|
||||
enabled: boolean;
|
||||
vlm?: {
|
||||
modelName: string;
|
||||
baseUrl: string;
|
||||
apiKey?: string;
|
||||
interfaceType?: string; // "ollama" or "openai"
|
||||
};
|
||||
cos?: {
|
||||
secretId: string;
|
||||
secretKey: string;
|
||||
region: string;
|
||||
bucketName: string;
|
||||
appId: string;
|
||||
pathPrefix?: string;
|
||||
};
|
||||
};
|
||||
documentSplitting: {
|
||||
chunkSize: number;
|
||||
chunkOverlap: number;
|
||||
separators: string[];
|
||||
};
|
||||
}
|
||||
|
||||
// 下载任务状态类型
|
||||
export interface DownloadTask {
|
||||
id: string;
|
||||
modelName: string;
|
||||
status: 'pending' | 'downloading' | 'completed' | 'failed';
|
||||
progress: number;
|
||||
message: string;
|
||||
startTime: string;
|
||||
endTime?: string;
|
||||
}
|
||||
|
||||
// 系统初始化状态检查
|
||||
export function checkInitializationStatus(): Promise<{ initialized: boolean }> {
|
||||
return new Promise((resolve, reject) => {
|
||||
get('/api/v1/initialization/status')
|
||||
.then((response: any) => {
|
||||
resolve(response.data || { initialized: false });
|
||||
})
|
||||
.catch((error: any) => {
|
||||
console.warn('检查初始化状态失败,假设需要初始化:', error);
|
||||
resolve({ initialized: false });
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// 执行系统初始化
|
||||
export function initializeSystem(config: InitializationConfig): Promise<any> {
|
||||
return new Promise((resolve, reject) => {
|
||||
console.log('开始系统初始化...', config);
|
||||
post('/api/v1/initialization/initialize', config)
|
||||
.then((response: any) => {
|
||||
console.log('系统初始化完成', response);
|
||||
// 设置本地初始化状态标记
|
||||
localStorage.setItem('system_initialized', 'true');
|
||||
resolve(response);
|
||||
})
|
||||
.catch((error: any) => {
|
||||
console.error('系统初始化失败:', error);
|
||||
reject(error);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// 检查Ollama服务状态
|
||||
export function checkOllamaStatus(): Promise<{ available: boolean; version?: string; error?: string }> {
|
||||
return new Promise((resolve, reject) => {
|
||||
get('/api/v1/initialization/ollama/status')
|
||||
.then((response: any) => {
|
||||
resolve(response.data || { available: false });
|
||||
})
|
||||
.catch((error: any) => {
|
||||
console.error('检查Ollama状态失败:', error);
|
||||
resolve({ available: false, error: error.message || '检查失败' });
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// 检查Ollama模型状态
|
||||
export function checkOllamaModels(models: string[]): Promise<{ models: Record<string, boolean> }> {
|
||||
return new Promise((resolve, reject) => {
|
||||
post('/api/v1/initialization/ollama/models/check', { models })
|
||||
.then((response: any) => {
|
||||
resolve(response.data || { models: {} });
|
||||
})
|
||||
.catch((error: any) => {
|
||||
console.error('检查Ollama模型状态失败:', error);
|
||||
reject(error);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// 启动Ollama模型下载(异步)
|
||||
export function downloadOllamaModel(modelName: string): Promise<{ taskId: string; modelName: string; status: string; progress: number }> {
|
||||
return new Promise((resolve, reject) => {
|
||||
post('/api/v1/initialization/ollama/models/download', { modelName })
|
||||
.then((response: any) => {
|
||||
resolve(response.data || { taskId: '', modelName, status: 'failed', progress: 0 });
|
||||
})
|
||||
.catch((error: any) => {
|
||||
console.error('启动Ollama模型下载失败:', error);
|
||||
reject(error);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// 查询下载进度
|
||||
export function getDownloadProgress(taskId: string): Promise<DownloadTask> {
|
||||
return new Promise((resolve, reject) => {
|
||||
get(`/api/v1/initialization/ollama/download/progress/${taskId}`)
|
||||
.then((response: any) => {
|
||||
resolve(response.data);
|
||||
})
|
||||
.catch((error: any) => {
|
||||
console.error('查询下载进度失败:', error);
|
||||
reject(error);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// 获取所有下载任务
|
||||
export function listDownloadTasks(): Promise<DownloadTask[]> {
|
||||
return new Promise((resolve, reject) => {
|
||||
get('/api/v1/initialization/ollama/download/tasks')
|
||||
.then((response: any) => {
|
||||
resolve(response.data || []);
|
||||
})
|
||||
.catch((error: any) => {
|
||||
console.error('获取下载任务列表失败:', error);
|
||||
reject(error);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// 获取当前系统配置
|
||||
export function getCurrentConfig(): Promise<InitializationConfig & { hasFiles: boolean }> {
|
||||
return new Promise((resolve, reject) => {
|
||||
get('/api/v1/initialization/config')
|
||||
.then((response: any) => {
|
||||
resolve(response.data || {});
|
||||
})
|
||||
.catch((error: any) => {
|
||||
console.error('获取当前配置失败:', error);
|
||||
reject(error);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// 检查远程API模型
|
||||
export function checkRemoteModel(modelConfig: {
|
||||
modelName: string;
|
||||
baseUrl: string;
|
||||
apiKey?: string;
|
||||
}): Promise<{
|
||||
available: boolean;
|
||||
message?: string;
|
||||
}> {
|
||||
return new Promise((resolve, reject) => {
|
||||
post('/api/v1/initialization/remote/check', modelConfig)
|
||||
.then((response: any) => {
|
||||
resolve(response.data || {});
|
||||
})
|
||||
.catch((error: any) => {
|
||||
console.error('检查远程模型失败:', error);
|
||||
reject(error);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
export function checkRerankModel(modelConfig: {
|
||||
modelName: string;
|
||||
baseUrl: string;
|
||||
apiKey?: string;
|
||||
}): Promise<{
|
||||
available: boolean;
|
||||
message?: string;
|
||||
}> {
|
||||
return new Promise((resolve, reject) => {
|
||||
post('/api/v1/initialization/rerank/check', modelConfig)
|
||||
.then((response: any) => {
|
||||
resolve(response.data || {});
|
||||
})
|
||||
.catch((error: any) => {
|
||||
console.error('检查Rerank模型失败:', error);
|
||||
reject(error);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
export function testMultimodalFunction(testData: {
|
||||
image: File;
|
||||
vlm_model: string;
|
||||
vlm_base_url: string;
|
||||
vlm_api_key?: string;
|
||||
vlm_interface_type?: string;
|
||||
cos_secret_id: string;
|
||||
cos_secret_key: string;
|
||||
cos_region: string;
|
||||
cos_bucket_name: string;
|
||||
cos_app_id: string;
|
||||
cos_path_prefix?: string;
|
||||
chunk_size: number;
|
||||
chunk_overlap: number;
|
||||
separators: string[];
|
||||
}): Promise<{
|
||||
success: boolean;
|
||||
caption?: string;
|
||||
ocr?: string;
|
||||
processing_time?: number;
|
||||
message?: string;
|
||||
}> {
|
||||
return new Promise((resolve, reject) => {
|
||||
const formData = new FormData();
|
||||
formData.append('image', testData.image);
|
||||
formData.append('vlm_model', testData.vlm_model);
|
||||
formData.append('vlm_base_url', testData.vlm_base_url);
|
||||
if (testData.vlm_api_key) {
|
||||
formData.append('vlm_api_key', testData.vlm_api_key);
|
||||
}
|
||||
if (testData.vlm_interface_type) {
|
||||
formData.append('vlm_interface_type', testData.vlm_interface_type);
|
||||
}
|
||||
formData.append('cos_secret_id', testData.cos_secret_id);
|
||||
formData.append('cos_secret_key', testData.cos_secret_key);
|
||||
formData.append('cos_region', testData.cos_region);
|
||||
formData.append('cos_bucket_name', testData.cos_bucket_name);
|
||||
formData.append('cos_app_id', testData.cos_app_id);
|
||||
if (testData.cos_path_prefix) {
|
||||
formData.append('cos_path_prefix', testData.cos_path_prefix);
|
||||
}
|
||||
formData.append('chunk_size', testData.chunk_size.toString());
|
||||
formData.append('chunk_overlap', testData.chunk_overlap.toString());
|
||||
formData.append('separators', JSON.stringify(testData.separators));
|
||||
|
||||
// 使用原生fetch因为需要发送FormData
|
||||
fetch('/api/v1/initialization/multimodal/test', {
|
||||
method: 'POST',
|
||||
body: formData
|
||||
})
|
||||
.then(response => response.json())
|
||||
.then((data: any) => {
|
||||
if (data.success) {
|
||||
resolve(data.data || {});
|
||||
} else {
|
||||
resolve({ success: false, message: data.message || '测试失败' });
|
||||
}
|
||||
})
|
||||
.catch((error: any) => {
|
||||
console.error('多模态测试失败:', error);
|
||||
reject(error);
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -173,7 +173,12 @@ const getIcon = (path) => {
|
||||
getIcon(route.name)
|
||||
const gotopage = (path) => {
|
||||
pathPrefix.value = path;
|
||||
router.push(`/platform/${path}`);
|
||||
// 如果是系统设置,跳转到初始化配置页面
|
||||
if (path === 'settings') {
|
||||
router.push('/initialization');
|
||||
} else {
|
||||
router.push(`/platform/${path}`);
|
||||
}
|
||||
getIcon(path)
|
||||
}
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ export default function () {
|
||||
});
|
||||
const getKnowled = (query = { page: 1, page_size: 35 }) => {
|
||||
getKnowledgeBase(query)
|
||||
.then((result: object) => {
|
||||
.then((result: any) => {
|
||||
let { data, total: totalResult } = result;
|
||||
let cardList_ = data.map((item) => {
|
||||
item["file_name"] = item.file_name.substring(
|
||||
@@ -50,7 +50,7 @@ export default function () {
|
||||
cardList.value[index].isMore = false;
|
||||
moreIndex.value = -1;
|
||||
delKnowledgeDetails(item.id)
|
||||
.then((result) => {
|
||||
.then((result: any) => {
|
||||
if (result.success) {
|
||||
MessagePlugin.info("知识删除成功!");
|
||||
getKnowled();
|
||||
@@ -76,17 +76,29 @@ export default function () {
|
||||
return;
|
||||
}
|
||||
uploadKnowledgeBase({ file })
|
||||
.then((result) => {
|
||||
.then((result: any) => {
|
||||
if (result.success) {
|
||||
MessagePlugin.info("上传成功!");
|
||||
getKnowled();
|
||||
} else {
|
||||
MessagePlugin.error("上传失败!");
|
||||
// 检查错误码,如果是重复文件则显示特定提示
|
||||
if (result.code === 'duplicate_file') {
|
||||
MessagePlugin.error("文件已存在");
|
||||
} else {
|
||||
MessagePlugin.error(result.message || (result.error && result.error.message) || "上传失败!");
|
||||
}
|
||||
}
|
||||
uploadInput.value.value = "";
|
||||
})
|
||||
.catch((err) => {
|
||||
MessagePlugin.error("上传失败!");
|
||||
.catch((err: any) => {
|
||||
// 检查错误响应中的错误码
|
||||
if (err.code === 'duplicate_file') {
|
||||
MessagePlugin.error("文件已存在");
|
||||
} else if (err.message) {
|
||||
MessagePlugin.error(err.message);
|
||||
} else {
|
||||
MessagePlugin.error("上传失败!");
|
||||
}
|
||||
uploadInput.value.value = "";
|
||||
});
|
||||
} else {
|
||||
@@ -101,7 +113,7 @@ export default function () {
|
||||
id: "",
|
||||
});
|
||||
getKnowledgeDetails(item.id)
|
||||
.then((result) => {
|
||||
.then((result: any) => {
|
||||
if (result.success && result.data) {
|
||||
let { data } = result;
|
||||
Object.assign(details, {
|
||||
@@ -116,7 +128,7 @@ export default function () {
|
||||
};
|
||||
const getfDetails = (id, page) => {
|
||||
getKnowledgeDetailsCon(id, page)
|
||||
.then((result) => {
|
||||
.then((result: any) => {
|
||||
if (result.success && result.data) {
|
||||
let { data, total: totalResult } = result;
|
||||
if (page == 1) {
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { createRouter, createWebHistory } from 'vue-router'
|
||||
import { checkInitializationStatus } from '@/api/initialization'
|
||||
|
||||
const router = createRouter({
|
||||
history: createWebHistory(import.meta.env.BASE_URL),
|
||||
@@ -7,40 +8,82 @@ const router = createRouter({
|
||||
path: "/",
|
||||
redirect: "/platform",
|
||||
},
|
||||
{
|
||||
path: "/initialization",
|
||||
name: "initialization",
|
||||
component: () => import("../views/initialization/InitializationConfig.vue"),
|
||||
meta: { requiresInit: false } // 初始化页面不需要检查初始化状态
|
||||
},
|
||||
{
|
||||
path: "/knowledgeBase",
|
||||
name: "home",
|
||||
component: () => import("../views/knowledge/KnowledgeBase.vue"),
|
||||
meta: { requiresInit: true }
|
||||
},
|
||||
{
|
||||
path: "/platform",
|
||||
name: "Platform",
|
||||
redirect: "/platform/knowledgeBase",
|
||||
component: () => import("../views/platform/index.vue"),
|
||||
meta: { requiresInit: true },
|
||||
children: [
|
||||
{
|
||||
path: "knowledgeBase",
|
||||
name: "knowledgeBase",
|
||||
component: () => import("../views/knowledge/KnowledgeBase.vue"),
|
||||
meta: { requiresInit: true }
|
||||
},
|
||||
{
|
||||
path: "creatChat",
|
||||
name: "creatChat",
|
||||
component: () => import("../views/creatChat/creatChat.vue"),
|
||||
meta: { requiresInit: true }
|
||||
},
|
||||
{
|
||||
path: "chat/:chatid",
|
||||
name: "chat",
|
||||
component: () => import("../views/chat/index.vue"),
|
||||
meta: { requiresInit: true }
|
||||
},
|
||||
{
|
||||
path: "settings",
|
||||
name: "settings",
|
||||
component: () => import("../views/settings/Settings.vue"),
|
||||
meta: { requiresInit: true }
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
// 路由守卫:检查系统初始化状态
|
||||
router.beforeEach(async (to, from, next) => {
|
||||
// 如果访问的是初始化页面,直接放行
|
||||
if (to.meta.requiresInit === false) {
|
||||
next();
|
||||
return;
|
||||
}
|
||||
|
||||
1
|
||||
|
||||
try {
|
||||
// 检查系统是否已初始化
|
||||
const { initialized } = await checkInitializationStatus();
|
||||
|
||||
if (initialized) {
|
||||
// 系统已初始化,记录到本地存储并正常跳转
|
||||
localStorage.setItem('system_initialized', 'true');
|
||||
next();
|
||||
} else {
|
||||
// 系统未初始化,跳转到初始化页面
|
||||
console.log('系统未初始化,跳转到初始化页面');
|
||||
next('/initialization');
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('检查初始化状态失败:', error);
|
||||
// 如果检查失败,默认认为需要初始化
|
||||
next('/initialization');
|
||||
}
|
||||
});
|
||||
|
||||
export default router
|
||||
|
||||
2626
frontend/src/views/initialization/InitializationConfig.vue
Normal file
2626
frontend/src/views/initialization/InitializationConfig.vue
Normal file
File diff suppressed because it is too large
Load Diff
2
go.mod
2
go.mod
@@ -13,7 +13,7 @@ require (
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/hibiken/asynq v0.25.1
|
||||
github.com/minio/minio-go/v7 v7.0.90
|
||||
github.com/ollama/ollama v0.9.6
|
||||
github.com/ollama/ollama v0.11.4
|
||||
github.com/panjf2000/ants/v2 v2.11.2
|
||||
github.com/parquet-go/parquet-go v0.25.0
|
||||
github.com/pgvector/pgvector-go v0.3.0
|
||||
|
||||
4
go.sum
4
go.sum
@@ -141,8 +141,8 @@ github.com/mozillazg/go-httpheader v0.2.1 h1:geV7TrjbL8KXSyvghnFm+NyTux/hxwueTSr
|
||||
github.com/mozillazg/go-httpheader v0.2.1/go.mod h1:jJ8xECTlalr6ValeXYdOF8fFUISeBAdw6E61aqQma60=
|
||||
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
|
||||
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
|
||||
github.com/ollama/ollama v0.9.6 h1:HZNJmB52pMt6zLkGkkheBuXBXM5478eiSAj7GR75AMc=
|
||||
github.com/ollama/ollama v0.9.6/go.mod h1:zLwx3iZ3AI4Rc/egsrx3u1w4RU2MHQ/Ylxse48jvyt4=
|
||||
github.com/ollama/ollama v0.11.4 h1:6xLYLEPTKtw6N20qQecyEL/rrBktPO4o5U05cnvkSmI=
|
||||
github.com/ollama/ollama v0.11.4/go.mod h1:9+1//yWPsDE2u+l1a5mpaKrYw4VdnSsRU3ioq5BvMms=
|
||||
github.com/panjf2000/ants/v2 v2.11.2 h1:AVGpMSePxUNpcLaBO34xuIgM1ZdKOiGnpxLXixLi5Jo=
|
||||
github.com/panjf2000/ants/v2 v2.11.2/go.mod h1:8u92CYMUc6gyvTIw8Ru7Mt7+/ESnJahz5EVtqfrilek=
|
||||
github.com/parquet-go/parquet-go v0.25.0 h1:GwKy11MuF+al/lV6nUsFw8w8HCiPOSAx1/y8yFxjH5c=
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/application/service/retriever"
|
||||
"github.com/Tencent/WeKnora/internal/config"
|
||||
werrors "github.com/Tencent/WeKnora/internal/errors"
|
||||
"github.com/Tencent/WeKnora/internal/logger"
|
||||
"github.com/Tencent/WeKnora/internal/models/chat"
|
||||
"github.com/Tencent/WeKnora/internal/models/utils"
|
||||
@@ -105,6 +106,28 @@ func (s *knowledgeService) CreateKnowledgeFromFile(ctx context.Context,
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 检查多模态配置完整性 - 只在图片文件时校验
|
||||
// 检查是否为图片文件
|
||||
if !IsImageType(getFileType(file.Filename)) {
|
||||
logger.Info(ctx, "Non-image file with multimodal enabled, skipping COS/VLM validation")
|
||||
} else {
|
||||
// 检查COS配置
|
||||
if kb.COSConfig.SecretID == "" || kb.COSConfig.SecretKey == "" ||
|
||||
kb.COSConfig.Region == "" || kb.COSConfig.BucketName == "" ||
|
||||
kb.COSConfig.AppID == "" {
|
||||
logger.Error(ctx, "COS configuration incomplete for image multimodal processing")
|
||||
return nil, werrors.NewBadRequestError("上传图片文件需要完整的COS配置信息")
|
||||
}
|
||||
|
||||
// 检查VLM配置
|
||||
if kb.VLMConfig.ModelName == "" || kb.VLMConfig.BaseURL == "" {
|
||||
logger.Error(ctx, "VLM configuration incomplete for image multimodal processing")
|
||||
return nil, werrors.NewBadRequestError("上传图片文件需要完整的VLM配置信息")
|
||||
}
|
||||
|
||||
logger.Info(ctx, "Image multimodal configuration validation passed")
|
||||
}
|
||||
|
||||
// Validate file type
|
||||
logger.Infof(ctx, "Checking file type: %s", file.Filename)
|
||||
if !isValidFileType(file.Filename) {
|
||||
@@ -647,6 +670,20 @@ func (s *knowledgeService) processDocument(ctx context.Context,
|
||||
ChunkOverlap: int32(kb.ChunkingConfig.ChunkOverlap),
|
||||
Separators: kb.ChunkingConfig.Separators,
|
||||
EnableMultimodal: enableMultimodel,
|
||||
CosConfig: &proto.COSConfig{
|
||||
SecretId: kb.COSConfig.SecretID,
|
||||
SecretKey: kb.COSConfig.SecretKey,
|
||||
Region: kb.COSConfig.Region,
|
||||
BucketName: kb.COSConfig.BucketName,
|
||||
AppId: kb.COSConfig.AppID,
|
||||
PathPrefix: kb.COSConfig.PathPrefix,
|
||||
},
|
||||
VlmConfig: &proto.VLMConfig{
|
||||
ModelName: kb.VLMConfig.ModelName,
|
||||
BaseUrl: kb.VLMConfig.BaseURL,
|
||||
ApiKey: kb.VLMConfig.APIKey,
|
||||
InterfaceType: kb.VLMConfig.InterfaceType,
|
||||
},
|
||||
},
|
||||
RequestId: ctx.Value(types.RequestIDContextKey).(string),
|
||||
})
|
||||
@@ -687,6 +724,20 @@ func (s *knowledgeService) processDocumentFromURL(ctx context.Context,
|
||||
ChunkOverlap: int32(kb.ChunkingConfig.ChunkOverlap),
|
||||
Separators: kb.ChunkingConfig.Separators,
|
||||
EnableMultimodal: enableMultimodel,
|
||||
CosConfig: &proto.COSConfig{
|
||||
SecretId: kb.COSConfig.SecretID,
|
||||
SecretKey: kb.COSConfig.SecretKey,
|
||||
Region: kb.COSConfig.Region,
|
||||
BucketName: kb.COSConfig.BucketName,
|
||||
AppId: kb.COSConfig.AppID,
|
||||
PathPrefix: kb.COSConfig.PathPrefix,
|
||||
},
|
||||
VlmConfig: &proto.VLMConfig{
|
||||
ModelName: kb.VLMConfig.ModelName,
|
||||
BaseUrl: kb.VLMConfig.BaseURL,
|
||||
ApiKey: kb.VLMConfig.APIKey,
|
||||
InterfaceType: kb.VLMConfig.InterfaceType,
|
||||
},
|
||||
},
|
||||
RequestId: ctx.Value(types.RequestIDContextKey).(string),
|
||||
})
|
||||
@@ -1145,7 +1196,7 @@ func (s *knowledgeService) getSummary(ctx context.Context,
|
||||
chunkContents = chunkContents + imageAnnotations
|
||||
}
|
||||
|
||||
if len(chunkContents) < 30 {
|
||||
if len(chunkContents) < 300 {
|
||||
return chunkContents, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -47,7 +47,9 @@ func (s *knowledgeBaseService) CreateKnowledgeBase(ctx context.Context,
|
||||
kb *types.KnowledgeBase,
|
||||
) (*types.KnowledgeBase, error) {
|
||||
// Generate UUID and set creation timestamps
|
||||
kb.ID = uuid.New().String()
|
||||
if kb.ID == "" {
|
||||
kb.ID = uuid.New().String()
|
||||
}
|
||||
kb.CreatedAt = time.Now()
|
||||
kb.TenantID = ctx.Value(types.TenantIDContextKey).(uint)
|
||||
kb.UpdatedAt = time.Now()
|
||||
@@ -245,6 +247,9 @@ func (s *knowledgeBaseService) CopyKnowledgeBase(ctx context.Context,
|
||||
ImageProcessingConfig: sourceKB.ImageProcessingConfig,
|
||||
EmbeddingModelID: sourceKB.EmbeddingModelID,
|
||||
SummaryModelID: sourceKB.SummaryModelID,
|
||||
RerankModelID: sourceKB.RerankModelID,
|
||||
VLMModelID: sourceKB.VLMModelID,
|
||||
COSConfig: sourceKB.COSConfig,
|
||||
}
|
||||
if err := s.repo.CreateKnowledgeBase(ctx, targetKB); err != nil {
|
||||
return nil, nil, err
|
||||
|
||||
@@ -144,6 +144,7 @@ func (s *TestDataService) initKnowledgeBase(ctx context.Context) error {
|
||||
},
|
||||
EmbeddingModelID: s.EmbedModel.GetModelID(),
|
||||
SummaryModelID: s.LLMModel.GetModelID(),
|
||||
RerankModelID: s.RerankModel.GetModelID(),
|
||||
}
|
||||
|
||||
// 初始化测试知识库
|
||||
|
||||
@@ -115,6 +115,7 @@ func BuildContainer(container *dig.Container) *dig.Container {
|
||||
must(container.Provide(handler.NewTestDataHandler))
|
||||
must(container.Provide(handler.NewModelHandler))
|
||||
must(container.Provide(handler.NewEvaluationHandler))
|
||||
must(container.Provide(handler.NewInitializationHandler))
|
||||
|
||||
// Router configuration
|
||||
must(container.Provide(router.NewRouter))
|
||||
|
||||
1718
internal/handler/initialization.go
Normal file
1718
internal/handler/initialization.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -135,6 +135,10 @@ func (h *KnowledgeHandler) CreateKnowledgeFromFile(c *gin.Context) {
|
||||
if h.handleDuplicateKnowledgeError(c, err, knowledge, "file") {
|
||||
return
|
||||
}
|
||||
if appErr, ok := errors.IsAppError(err); ok {
|
||||
c.Error(appErr)
|
||||
return
|
||||
}
|
||||
logger.ErrorWithFields(ctx, err, nil)
|
||||
c.Error(errors.NewInternalServerError(err.Error()))
|
||||
return
|
||||
|
||||
@@ -18,11 +18,12 @@ import (
|
||||
|
||||
// SessionHandler handles all HTTP requests related to conversation sessions
|
||||
type SessionHandler struct {
|
||||
messageService interfaces.MessageService // Service for managing messages
|
||||
sessionService interfaces.SessionService // Service for managing sessions
|
||||
streamManager interfaces.StreamManager // Manager for handling streaming responses
|
||||
config *config.Config // Application configuration
|
||||
testDataService *service.TestDataService // Service for test data (models, etc.)
|
||||
messageService interfaces.MessageService // Service for managing messages
|
||||
sessionService interfaces.SessionService // Service for managing sessions
|
||||
streamManager interfaces.StreamManager // Manager for handling streaming responses
|
||||
config *config.Config // Application configuration
|
||||
testDataService *service.TestDataService // Service for test data (models, etc.)
|
||||
knowledgebaseService interfaces.KnowledgeBaseService
|
||||
}
|
||||
|
||||
// NewSessionHandler creates a new instance of SessionHandler with all necessary dependencies
|
||||
@@ -32,13 +33,15 @@ func NewSessionHandler(
|
||||
streamManager interfaces.StreamManager,
|
||||
config *config.Config,
|
||||
testDataService *service.TestDataService,
|
||||
knowledgebaseService interfaces.KnowledgeBaseService,
|
||||
) *SessionHandler {
|
||||
return &SessionHandler{
|
||||
sessionService: sessionService,
|
||||
messageService: messageService,
|
||||
streamManager: streamManager,
|
||||
config: config,
|
||||
testDataService: testDataService,
|
||||
sessionService: sessionService,
|
||||
messageService: messageService,
|
||||
streamManager: streamManager,
|
||||
config: config,
|
||||
testDataService: testDataService,
|
||||
knowledgebaseService: knowledgebaseService,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -191,27 +194,24 @@ func (h *SessionHandler) CreateSession(c *gin.Context) {
|
||||
logger.Debug(ctx, "Using default session strategy")
|
||||
}
|
||||
|
||||
// Get model IDs from test data service if not provided
|
||||
kb, err := h.knowledgebaseService.GetKnowledgeBaseByID(ctx, request.KnowledgeBaseID)
|
||||
if err != nil {
|
||||
logger.Error(ctx, "Failed to get knowledge base", err)
|
||||
c.Error(errors.NewInternalServerError(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// Get model IDs from knowledge base if not provided
|
||||
if createdSession.SummaryModelID == "" {
|
||||
if h.testDataService != nil {
|
||||
createdSession.SummaryModelID = h.testDataService.LLMModel.GetModelID()
|
||||
logger.Debug(ctx, "Using summary model ID from test data service")
|
||||
} else {
|
||||
logger.Error(ctx, "Summary model ID is empty and cannot get default value")
|
||||
c.Error(errors.NewBadRequestError("Summary model ID cannot be empty"))
|
||||
return
|
||||
}
|
||||
createdSession.SummaryModelID = kb.SummaryModelID
|
||||
}
|
||||
if createdSession.RerankModelID == "" {
|
||||
if h.testDataService != nil && h.testDataService.RerankModel != nil {
|
||||
createdSession.RerankModelID = h.testDataService.RerankModel.GetModelID()
|
||||
logger.Debug(ctx, "Using rerank model ID from test data service")
|
||||
}
|
||||
createdSession.RerankModelID = kb.RerankModelID
|
||||
}
|
||||
|
||||
// Call service to create session
|
||||
logger.Infof(ctx, "Calling session service to create session")
|
||||
createdSession, err := h.sessionService.CreateSession(ctx, createdSession)
|
||||
createdSession, err = h.sessionService.CreateSession(ctx, createdSession)
|
||||
if err != nil {
|
||||
logger.ErrorWithFields(ctx, err, nil)
|
||||
c.Error(errors.NewInternalServerError(err.Error()))
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
@@ -65,26 +64,19 @@ func (h *TestDataHandler) GetTestData(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
tenantID := os.Getenv("INIT_TEST_TENANT_ID")
|
||||
logger.Debugf(ctx, "Test tenant ID environment variable: %s", tenantID)
|
||||
|
||||
tenantIDUint, err := strconv.ParseUint(tenantID, 10, 64)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "Failed to parse tenant ID: %s", tenantID)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
tenantID := uint(types.InitDefaultTenantID)
|
||||
logger.Debugf(ctx, "Test tenant ID environment variable: %d", tenantID)
|
||||
|
||||
// Retrieve the test tenant data
|
||||
logger.Infof(ctx, "Retrieving test tenant, ID: %d", tenantIDUint)
|
||||
tenant, err := h.tenantService.GetTenantByID(ctx, uint(tenantIDUint))
|
||||
logger.Infof(ctx, "Retrieving test tenant, ID: %d", tenantID)
|
||||
tenant, err := h.tenantService.GetTenantByID(ctx, tenantID)
|
||||
if err != nil {
|
||||
logger.ErrorWithFields(ctx, err, nil)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
knowledgeBaseID := os.Getenv("INIT_TEST_KNOWLEDGE_BASE_ID")
|
||||
knowledgeBaseID := types.InitDefaultKnowledgeBaseID
|
||||
logger.Debugf(ctx, "Test knowledge base ID environment variable: %s", knowledgeBaseID)
|
||||
|
||||
// Retrieve the test knowledge base data
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"log"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/config"
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
@@ -15,14 +16,20 @@ import (
|
||||
|
||||
// 无需认证的API列表
|
||||
var noAuthAPI = map[string][]string{
|
||||
"/api/v1/test-data": {"GET"},
|
||||
"/api/v1/tenants": {"POST"},
|
||||
"/api/v1/test-data": {"GET"},
|
||||
"/api/v1/tenants": {"POST"},
|
||||
"/api/v1/initialization/*": {"GET", "POST"},
|
||||
}
|
||||
|
||||
// 检查请求是否在无需认证的API列表中
|
||||
func isNoAuthAPI(path string, method string) bool {
|
||||
for api, methods := range noAuthAPI {
|
||||
if api == path && slices.Contains(methods, method) {
|
||||
// 如果以*结尾,按照前缀匹配,否则按照全路径匹配
|
||||
if strings.HasSuffix(api, "*") {
|
||||
if strings.HasPrefix(path, strings.TrimSuffix(api, "*")) && slices.Contains(methods, method) {
|
||||
return true
|
||||
}
|
||||
} else if path == api && slices.Contains(methods, method) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -63,7 +63,9 @@ func (c *OllamaChat) buildChatRequest(messages []Message, opts *ChatOptions, isS
|
||||
chatReq.Options["num_predict"] = opts.MaxTokens
|
||||
}
|
||||
if opts.Thinking != nil {
|
||||
chatReq.Options["think"] = *opts.Thinking
|
||||
chatReq.Think = &ollamaapi.ThinkValue{
|
||||
Value: *opts.Thinking,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -309,3 +309,8 @@ func (s *OllamaService) Generate(ctx context.Context, req *api.GenerateRequest,
|
||||
// Use official client Generate method
|
||||
return s.client.Generate(ctx, req, fn)
|
||||
}
|
||||
|
||||
// GetClient returns the underlying ollama client for advanced operations
|
||||
func (s *OllamaService) GetClient() *api.Client {
|
||||
return s.client
|
||||
}
|
||||
|
||||
@@ -17,17 +17,18 @@ import (
|
||||
type RouterParams struct {
|
||||
dig.In
|
||||
|
||||
Config *config.Config
|
||||
KBHandler *handler.KnowledgeBaseHandler
|
||||
KnowledgeHandler *handler.KnowledgeHandler
|
||||
TenantHandler *handler.TenantHandler
|
||||
TenantService interfaces.TenantService
|
||||
ChunkHandler *handler.ChunkHandler
|
||||
SessionHandler *handler.SessionHandler
|
||||
MessageHandler *handler.MessageHandler
|
||||
TestDataHandler *handler.TestDataHandler
|
||||
ModelHandler *handler.ModelHandler
|
||||
EvaluationHandler *handler.EvaluationHandler
|
||||
Config *config.Config
|
||||
KBHandler *handler.KnowledgeBaseHandler
|
||||
KnowledgeHandler *handler.KnowledgeHandler
|
||||
TenantHandler *handler.TenantHandler
|
||||
TenantService interfaces.TenantService
|
||||
ChunkHandler *handler.ChunkHandler
|
||||
SessionHandler *handler.SessionHandler
|
||||
MessageHandler *handler.MessageHandler
|
||||
TestDataHandler *handler.TestDataHandler
|
||||
ModelHandler *handler.ModelHandler
|
||||
EvaluationHandler *handler.EvaluationHandler
|
||||
InitializationHandler *handler.InitializationHandler
|
||||
}
|
||||
|
||||
// NewRouter 创建新的路由
|
||||
@@ -62,6 +63,23 @@ func NewRouter(params RouterParams) *gin.Engine {
|
||||
// 测试数据接口(不需要认证)
|
||||
r.GET("/api/v1/test-data", params.TestDataHandler.GetTestData)
|
||||
|
||||
// 初始化接口(不需要认证)
|
||||
r.GET("/api/v1/initialization/status", params.InitializationHandler.CheckStatus)
|
||||
r.GET("/api/v1/initialization/config", params.InitializationHandler.GetCurrentConfig)
|
||||
r.POST("/api/v1/initialization/initialize", params.InitializationHandler.Initialize)
|
||||
|
||||
// Ollama相关接口(不需要认证)
|
||||
r.GET("/api/v1/initialization/ollama/status", params.InitializationHandler.CheckOllamaStatus)
|
||||
r.POST("/api/v1/initialization/ollama/models/check", params.InitializationHandler.CheckOllamaModels)
|
||||
r.POST("/api/v1/initialization/ollama/models/download", params.InitializationHandler.DownloadOllamaModel)
|
||||
r.GET("/api/v1/initialization/ollama/download/progress/:taskId", params.InitializationHandler.GetDownloadProgress)
|
||||
r.GET("/api/v1/initialization/ollama/download/tasks", params.InitializationHandler.ListDownloadTasks)
|
||||
|
||||
// 远程API相关接口(不需要认证)
|
||||
r.POST("/api/v1/initialization/remote/check", params.InitializationHandler.CheckRemoteModel)
|
||||
r.POST("/api/v1/initialization/rerank/check", params.InitializationHandler.CheckRerankModel)
|
||||
r.POST("/api/v1/initialization/multimodal/test", params.InitializationHandler.TestMultimodalFunction)
|
||||
|
||||
// 需要认证的API路由
|
||||
v1 := r.Group("/api/v1")
|
||||
{
|
||||
|
||||
@@ -8,6 +8,10 @@ import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
InitDefaultKnowledgeBaseID = "kb-00000001"
|
||||
)
|
||||
|
||||
// KnowledgeBase represents a knowledge base
|
||||
type KnowledgeBase struct {
|
||||
// Unique identifier of the knowledge base
|
||||
@@ -26,6 +30,14 @@ type KnowledgeBase struct {
|
||||
EmbeddingModelID string `yaml:"embedding_model_id" json:"embedding_model_id"`
|
||||
// Summary model ID
|
||||
SummaryModelID string `yaml:"summary_model_id" json:"summary_model_id"`
|
||||
// Rerank model ID
|
||||
RerankModelID string `yaml:"rerank_model_id" json:"rerank_model_id"`
|
||||
// VLM model ID
|
||||
VLMModelID string `yaml:"vlm_model_id" json:"vlm_model_id"`
|
||||
// VLM config
|
||||
VLMConfig VLMConfig `yaml:"vlm_config" json:"vlm_config" gorm:"type:json"`
|
||||
// COS config
|
||||
COSConfig COSConfig `yaml:"cos_config" json:"cos_config" gorm:"type:json"`
|
||||
// Creation time of the knowledge base
|
||||
CreatedAt time.Time `yaml:"created_at" json:"created_at"`
|
||||
// Last updated time of the knowledge base
|
||||
@@ -54,6 +66,37 @@ type ChunkingConfig struct {
|
||||
EnableMultimodal bool `yaml:"enable_multimodal" json:"enable_multimodal"`
|
||||
}
|
||||
|
||||
// COSConfig represents the COS configuration
|
||||
type COSConfig struct {
|
||||
// Secret ID
|
||||
SecretID string `yaml:"secret_id" json:"secret_id"`
|
||||
// Secret Key
|
||||
SecretKey string `yaml:"secret_key" json:"secret_key"`
|
||||
// Region
|
||||
Region string `yaml:"region" json:"region"`
|
||||
// Bucket Name
|
||||
BucketName string `yaml:"bucket_name" json:"bucket_name"`
|
||||
// App ID
|
||||
AppID string `yaml:"app_id" json:"app_id"`
|
||||
// Path Prefix
|
||||
PathPrefix string `yaml:"path_prefix" json:"path_prefix"`
|
||||
}
|
||||
|
||||
func (c *COSConfig) Value() (driver.Value, error) {
|
||||
return json.Marshal(c)
|
||||
}
|
||||
|
||||
func (c *COSConfig) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
b, ok := value.([]byte)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return json.Unmarshal(b, c)
|
||||
}
|
||||
|
||||
// ImageProcessingConfig represents the image processing configuration
|
||||
type ImageProcessingConfig struct {
|
||||
// Model ID
|
||||
@@ -93,3 +136,32 @@ func (c *ImageProcessingConfig) Scan(value interface{}) error {
|
||||
}
|
||||
return json.Unmarshal(b, c)
|
||||
}
|
||||
|
||||
// VLMConfig represents the VLM configuration
|
||||
type VLMConfig struct {
|
||||
// Model Name
|
||||
ModelName string `yaml:"model_name" json:"model_name"`
|
||||
// Base URL
|
||||
BaseURL string `yaml:"base_url" json:"base_url"`
|
||||
// API Key
|
||||
APIKey string `yaml:"api_key" json:"api_key"`
|
||||
// Interface Type: "ollama" or "openai"
|
||||
InterfaceType string `yaml:"interface_type" json:"interface_type"`
|
||||
}
|
||||
|
||||
// Value implements the driver.Valuer interface, used to convert VLMConfig to database value
|
||||
func (c VLMConfig) Value() (driver.Value, error) {
|
||||
return json.Marshal(c)
|
||||
}
|
||||
|
||||
// Scan implements the sql.Scanner interface, used to convert database value to VLMConfig
|
||||
func (c *VLMConfig) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
b, ok := value.([]byte)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return json.Unmarshal(b, c)
|
||||
}
|
||||
|
||||
@@ -8,6 +8,10 @@ import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
InitDefaultTenantID uint = 1
|
||||
)
|
||||
|
||||
// Tenant represents the tenant
|
||||
type Tenant struct {
|
||||
// ID
|
||||
|
||||
@@ -19,7 +19,7 @@ CREATE TABLE tenants (
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||
deleted_at TIMESTAMP NULL DEFAULT NULL
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 AUTO_INCREMENT=10000;
|
||||
|
||||
CREATE TABLE models (
|
||||
id VARCHAR(64) PRIMARY KEY,
|
||||
@@ -47,6 +47,10 @@ CREATE TABLE knowledge_bases (
|
||||
image_processing_config JSON NOT NULL,
|
||||
embedding_model_id VARCHAR(64) NOT NULL,
|
||||
summary_model_id VARCHAR(64) NOT NULL,
|
||||
rerank_model_id VARCHAR(64) NOT NULL,
|
||||
vlm_model_id VARCHAR(64) NOT NULL,
|
||||
cos_config JSON NOT NULL,
|
||||
vlm_config JSON NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||
deleted_at TIMESTAMP NULL DEFAULT NULL
|
||||
|
||||
@@ -21,6 +21,9 @@ CREATE TABLE IF NOT EXISTS tenants (
|
||||
deleted_at TIMESTAMP WITH TIME ZONE
|
||||
);
|
||||
|
||||
-- Set the starting value for tenants id sequence
|
||||
ALTER SEQUENCE tenants_id_seq RESTART WITH 10000;
|
||||
|
||||
-- Add indexes
|
||||
CREATE INDEX IF NOT EXISTS idx_tenants_api_key ON tenants(api_key);
|
||||
CREATE INDEX IF NOT EXISTS idx_tenants_status ON tenants(status);
|
||||
@@ -55,6 +58,10 @@ CREATE TABLE IF NOT EXISTS knowledge_bases (
|
||||
image_processing_config JSONB NOT NULL DEFAULT '{"enable_multimodal": false, "model_id": ""}',
|
||||
embedding_model_id VARCHAR(64) NOT NULL,
|
||||
summary_model_id VARCHAR(64) NOT NULL,
|
||||
rerank_model_id VARCHAR(64) NOT NULL,
|
||||
vlm_model_id VARCHAR(64) NOT NULL,
|
||||
cos_config JSONB NOT NULL DEFAULT '{}',
|
||||
vlm_config JSONB NOT NULL DEFAULT '{}',
|
||||
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
|
||||
deleted_at TIMESTAMP WITH TIME ZONE
|
||||
|
||||
@@ -227,12 +227,24 @@ start_docker() {
|
||||
source "$PROJECT_ROOT/.env"
|
||||
storage_type=${STORAGE_TYPE:-local}
|
||||
|
||||
# 检测当前系统平台
|
||||
log_info "检测系统平台信息..."
|
||||
if [ "$(uname -m)" = "x86_64" ]; then
|
||||
export PLATFORM="linux/amd64"
|
||||
elif [ "$(uname -m)" = "aarch64" ] || [ "$(uname -m)" = "arm64" ]; then
|
||||
export PLATFORM="linux/arm64"
|
||||
else
|
||||
log_warning "未识别的平台类型:$(uname -m),将使用默认平台 linux/amd64"
|
||||
export PLATFORM="linux/amd64"
|
||||
fi
|
||||
log_info "当前平台:$PLATFORM"
|
||||
|
||||
# 进入项目根目录再执行docker-compose命令
|
||||
cd "$PROJECT_ROOT"
|
||||
|
||||
# 启动基本服务
|
||||
log_info "启动核心服务容器..."
|
||||
docker-compose up --build -d
|
||||
PLATFORM=$PLATFORM docker-compose up --build -d
|
||||
if [ $? -ne 0 ]; then
|
||||
log_error "Docker容器启动失败"
|
||||
return 1
|
||||
|
||||
@@ -13,7 +13,6 @@ urllib3
|
||||
markdownify
|
||||
mistletoe
|
||||
goose3[all]
|
||||
# paddlepaddle==3.0.0 (普通CPU版本已被注释)
|
||||
paddleocr==3.0.0
|
||||
markdown
|
||||
pypdf
|
||||
@@ -21,6 +20,10 @@ cos-python-sdk-v5
|
||||
textract
|
||||
antiword
|
||||
openai
|
||||
ollama
|
||||
|
||||
--extra-index-url https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||
paddlepaddle-gpu==3.0.0
|
||||
--extra-index-url https://www.paddlepaddle.org.cn/packages/stable/cpu/
|
||||
paddlepaddle==3.0.0
|
||||
|
||||
# --extra-index-url https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||
# paddlepaddle-gpu==3.0.0
|
||||
@@ -7,7 +7,7 @@ PYTHON_OUT="src/proto"
|
||||
GO_OUT="src/proto"
|
||||
|
||||
# 生成Python代码
|
||||
python -m grpc_tools.protoc -I${PROTO_DIR} \
|
||||
python3 -m grpc_tools.protoc -I${PROTO_DIR} \
|
||||
--python_out=${PYTHON_OUT} \
|
||||
--grpc_python_out=${PYTHON_OUT} \
|
||||
${PROTO_DIR}/docreader.proto
|
||||
|
||||
@@ -3,7 +3,7 @@ import re
|
||||
import os
|
||||
import uuid
|
||||
import asyncio
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from typing import List, Dict, Any, Optional, Tuple, Union
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
import logging
|
||||
@@ -13,6 +13,7 @@ import traceback
|
||||
import numpy as np
|
||||
import time
|
||||
from .ocr_engine import OCREngine
|
||||
from .image_utils import image_to_base64
|
||||
|
||||
# Add parent directory to Python path for src imports
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
@@ -100,6 +101,7 @@ class BaseParser(ABC):
|
||||
max_image_size: int = 1920, # Maximum image size
|
||||
max_concurrent_tasks: int = 5, # Max concurrent tasks
|
||||
max_chunks: int = 1000, # Max number of returned chunks
|
||||
chunking_config: object = None, # Chunking configuration object
|
||||
):
|
||||
"""Initialize parser
|
||||
|
||||
@@ -127,6 +129,7 @@ class BaseParser(ABC):
|
||||
self.max_image_size = max_image_size
|
||||
self.max_concurrent_tasks = max_concurrent_tasks
|
||||
self.max_chunks = max_chunks
|
||||
self.chunking_config = chunking_config
|
||||
|
||||
logger.info(
|
||||
f"Initializing {self.__class__.__name__} for file: {file_name}, type: {self.file_type}"
|
||||
@@ -141,7 +144,12 @@ class BaseParser(ABC):
|
||||
# Only initialize Caption service if multimodal is enabled
|
||||
if self.enable_multimodal:
|
||||
try:
|
||||
self.caption_parser = Caption()
|
||||
# Get VLM config from chunking config if available
|
||||
vlm_config = None
|
||||
if self.chunking_config and hasattr(self.chunking_config, 'vlm_config'):
|
||||
vlm_config = self.chunking_config.vlm_config
|
||||
|
||||
self.caption_parser = Caption(vlm_config)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize Caption service: {str(e)}")
|
||||
self.caption_parser = None
|
||||
@@ -200,55 +208,51 @@ class BaseParser(ABC):
|
||||
resized_image.close()
|
||||
|
||||
def _resize_image_if_needed(self, image):
|
||||
"""Resize image to avoid processing large images
|
||||
"""Resize image if it exceeds maximum size limit
|
||||
|
||||
Args:
|
||||
image: Image object (PIL.Image or numpy array)
|
||||
|
||||
Returns:
|
||||
Resized image
|
||||
Resized image object
|
||||
"""
|
||||
try:
|
||||
# Check if it's a PIL.Image object
|
||||
if hasattr(image, 'width') and hasattr(image, 'height'):
|
||||
width, height = image.width, image.height
|
||||
# Check if resizing is needed
|
||||
# If it's a PIL Image
|
||||
if hasattr(image, 'size'):
|
||||
width, height = image.size
|
||||
if width > self.max_image_size or height > self.max_image_size:
|
||||
logger.info(f"Resizing image, original size: {width}x{height}")
|
||||
# Calculate scaling factor
|
||||
logger.info(f"Resizing PIL image, original size: {width}x{height}")
|
||||
scale = min(self.max_image_size / width, self.max_image_size / height)
|
||||
new_width = int(width * scale)
|
||||
new_height = int(height * scale)
|
||||
# Resize
|
||||
resized_image = image.resize((new_width, new_height))
|
||||
logger.info(f"Image resized to: {new_width}x{new_height}")
|
||||
logger.info(f"Resized to: {new_width}x{new_height}")
|
||||
return resized_image
|
||||
else:
|
||||
logger.info(f"PIL image size {width}x{height} is within limits, no resizing needed")
|
||||
return image
|
||||
# If it's a numpy array
|
||||
elif hasattr(image, 'shape') and len(image.shape) == 3:
|
||||
height, width = image.shape[0], image.shape[1]
|
||||
elif hasattr(image, 'shape'):
|
||||
height, width = image.shape[:2]
|
||||
if width > self.max_image_size or height > self.max_image_size:
|
||||
logger.info(f"Resizing numpy image, original size: {width}x{height}")
|
||||
# Use PIL for resizing
|
||||
from PIL import Image
|
||||
scale = min(self.max_image_size / width, self.max_image_size / height)
|
||||
new_width = int(width * scale)
|
||||
new_height = int(height * scale)
|
||||
# Use PIL for resizing numpy arrays
|
||||
pil_image = Image.fromarray(image)
|
||||
try:
|
||||
scale = min(self.max_image_size / width, self.max_image_size / height)
|
||||
new_width = int(width * scale)
|
||||
new_height = int(height * scale)
|
||||
resized_image = pil_image.resize((new_width, new_height))
|
||||
# Convert back to numpy array
|
||||
import numpy as np
|
||||
resized_array = np.array(resized_image)
|
||||
logger.info(f"numpy image resized to: {new_width}x{new_height}")
|
||||
return resized_array
|
||||
finally:
|
||||
# Ensure PIL image is closed
|
||||
pil_image.close()
|
||||
if 'resized_image' in locals() and hasattr(resized_image, 'close'):
|
||||
resized_image.close()
|
||||
return image
|
||||
resized_pil = pil_image.resize((new_width, new_height))
|
||||
resized_image = np.array(resized_pil)
|
||||
logger.info(f"Resized to: {new_width}x{new_height}")
|
||||
return resized_image
|
||||
else:
|
||||
logger.info(f"Numpy image size {width}x{height} is within limits, no resizing needed")
|
||||
return image
|
||||
else:
|
||||
logger.warning(f"Unknown image type: {type(image)}, cannot resize")
|
||||
return image
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to resize image: {str(e)}, using original image")
|
||||
logger.error(f"Error resizing image: {str(e)}")
|
||||
return image
|
||||
|
||||
def process_image(self, image, image_url=None):
|
||||
@@ -273,15 +277,21 @@ class BaseParser(ABC):
|
||||
ocr_text = self.perform_ocr(image)
|
||||
caption = ""
|
||||
|
||||
if self.caption_parser and image_url:
|
||||
if self.caption_parser:
|
||||
logger.info(f"OCR successfully extracted {len(ocr_text)} characters, continuing to get caption")
|
||||
caption = self.get_image_caption(image_url)
|
||||
if caption:
|
||||
logger.info(f"Successfully obtained image caption: {caption}")
|
||||
# Convert image to base64 for caption generation
|
||||
img_base64 = image_to_base64(image)
|
||||
if img_base64:
|
||||
caption = self.get_image_caption(img_base64)
|
||||
if caption:
|
||||
logger.info(f"Successfully obtained image caption: {caption}")
|
||||
else:
|
||||
logger.warning("Failed to get caption")
|
||||
else:
|
||||
logger.warning("Failed to get caption")
|
||||
logger.warning("Failed to convert image to base64")
|
||||
caption = ""
|
||||
else:
|
||||
logger.info("image_url not provided or Caption service not initialized, skipping caption retrieval")
|
||||
logger.info("Caption service not initialized, skipping caption retrieval")
|
||||
|
||||
# Release image resources
|
||||
del image
|
||||
@@ -323,21 +333,27 @@ class BaseParser(ABC):
|
||||
|
||||
logger.info(f"OCR successfully extracted {len(ocr_text)} characters, continuing to get caption")
|
||||
caption = ""
|
||||
if self.caption_parser and image_url:
|
||||
if self.caption_parser:
|
||||
try:
|
||||
# Add timeout to avoid blocking caption retrieval (30 seconds timeout)
|
||||
caption_task = self.get_image_caption_async(image_url)
|
||||
image_url, caption = await asyncio.wait_for(caption_task, timeout=30.0)
|
||||
if caption:
|
||||
logger.info(f"Successfully obtained image caption: {caption}")
|
||||
# Convert image to base64 for caption generation
|
||||
img_base64 = image_to_base64(resized_image)
|
||||
if img_base64:
|
||||
# Add timeout to avoid blocking caption retrieval (30 seconds timeout)
|
||||
caption_task = self.get_image_caption_async(img_base64)
|
||||
image_data, caption = await asyncio.wait_for(caption_task, timeout=30.0)
|
||||
if caption:
|
||||
logger.info(f"Successfully obtained image caption: {caption}")
|
||||
else:
|
||||
logger.warning("Failed to get caption")
|
||||
else:
|
||||
logger.warning("Failed to get caption")
|
||||
logger.warning("Failed to convert image to base64")
|
||||
caption = ""
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Caption retrieval timed out, skipping")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get caption: {str(e)}")
|
||||
else:
|
||||
logger.info("image_url not provided or Caption service not initialized, skipping caption retrieval")
|
||||
logger.info("Caption service not initialized, skipping caption retrieval")
|
||||
|
||||
return ocr_text, caption, image_url
|
||||
finally:
|
||||
@@ -473,56 +489,66 @@ class BaseParser(ABC):
|
||||
logger.info(f"Decoded text length: {len(text)} characters")
|
||||
return text
|
||||
|
||||
def get_image_caption(self, image_url: str) -> str:
|
||||
def get_image_caption(self, image_data: str) -> str:
|
||||
"""Get image description
|
||||
|
||||
Args:
|
||||
image_url: Image URL
|
||||
image_data: Image data (base64 encoded string or URL)
|
||||
|
||||
Returns:
|
||||
Image description
|
||||
"""
|
||||
start_time = time.time()
|
||||
logger.info(
|
||||
f"Getting caption for image: {image_url[:250]}..."
|
||||
if len(image_url) > 250
|
||||
else f"Getting caption for image: {image_url}"
|
||||
f"Getting caption for image: {image_data[:250]}..."
|
||||
if len(image_data) > 250
|
||||
else f"Getting caption for image: {image_data}"
|
||||
)
|
||||
caption = self.caption_parser.get_caption(image_url)
|
||||
caption = self.caption_parser.get_caption(image_data)
|
||||
if caption:
|
||||
logger.info(
|
||||
f"Received caption of length: {len(caption)}, caption: {caption}, image_url: {image_url},"
|
||||
f"Received caption of length: {len(caption)}, caption: {caption},"
|
||||
f"cost: {time.time() - start_time} seconds"
|
||||
)
|
||||
else:
|
||||
logger.warning("Failed to get caption for image")
|
||||
return caption
|
||||
|
||||
async def get_image_caption_async(self, image_url: str) -> Tuple[str, str]:
|
||||
async def get_image_caption_async(self, image_data: str) -> Tuple[str, str]:
|
||||
"""Asynchronously get image description
|
||||
|
||||
Args:
|
||||
image_url: Image URL
|
||||
image_data: Image data (base64 encoded string or URL)
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: Image URL and corresponding description
|
||||
Tuple[str, str]: Image data and corresponding description
|
||||
"""
|
||||
caption = self.get_image_caption(image_url)
|
||||
return image_url, caption
|
||||
caption = self.get_image_caption(image_data)
|
||||
return image_data, caption
|
||||
|
||||
def _init_cos_client(self):
|
||||
def _init_cos_client(self, cos_config=None):
|
||||
"""Initialize Tencent Cloud COS client"""
|
||||
try:
|
||||
# Get COS configuration from environment variables
|
||||
secret_id = os.getenv("COS_SECRET_ID")
|
||||
secret_key = os.getenv("COS_SECRET_KEY")
|
||||
region = os.getenv("COS_REGION")
|
||||
bucket_name = os.getenv("COS_BUCKET_NAME")
|
||||
appid = os.getenv("COS_APP_ID")
|
||||
prefix = os.getenv("COS_PATH_PREFIX")
|
||||
enable_old_domain = (
|
||||
os.getenv("COS_ENABLE_OLD_DOMAIN", "true").lower() == "true"
|
||||
)
|
||||
# Use provided COS config if available, otherwise fall back to environment variables
|
||||
if cos_config:
|
||||
secret_id = cos_config.get("secret_id")
|
||||
secret_key = cos_config.get("secret_key")
|
||||
region = cos_config.get("region")
|
||||
bucket_name = cos_config.get("bucket_name")
|
||||
appid = cos_config.get("app_id")
|
||||
prefix = cos_config.get("path_prefix", "")
|
||||
enable_old_domain = cos_config.get("enable_old_domain", "true").lower() == "true"
|
||||
else:
|
||||
# Get COS configuration from environment variables
|
||||
secret_id = os.getenv("COS_SECRET_ID")
|
||||
secret_key = os.getenv("COS_SECRET_KEY")
|
||||
region = os.getenv("COS_REGION")
|
||||
bucket_name = os.getenv("COS_BUCKET_NAME")
|
||||
appid = os.getenv("COS_APP_ID")
|
||||
prefix = os.getenv("COS_PATH_PREFIX")
|
||||
enable_old_domain = (
|
||||
os.getenv("COS_ENABLE_OLD_DOMAIN", "true").lower() == "true"
|
||||
)
|
||||
|
||||
if not all([secret_id, secret_key, region, bucket_name, appid]):
|
||||
logger.error(
|
||||
@@ -600,12 +626,13 @@ class BaseParser(ABC):
|
||||
logger.error(f"Failed to upload file to COS: {str(e)}")
|
||||
return ""
|
||||
|
||||
def upload_bytes(self, content: bytes, file_ext: str = ".png") -> str:
|
||||
def upload_bytes(self, content: bytes, file_ext: str = ".png", cos_config=None) -> str:
|
||||
"""Directly upload file content to Tencent Cloud COS
|
||||
|
||||
Args:
|
||||
content: File byte content
|
||||
file_ext: File extension, default is .png
|
||||
cos_config: COS configuration dictionary
|
||||
|
||||
Returns:
|
||||
File URL
|
||||
@@ -614,7 +641,7 @@ class BaseParser(ABC):
|
||||
f"Uploading bytes content to COS, content size: {len(content)} bytes"
|
||||
)
|
||||
try:
|
||||
client, bucket_name, region, prefix = self._init_cos_client()
|
||||
client, bucket_name, region, prefix = self._init_cos_client(cos_config)
|
||||
if not client:
|
||||
return ""
|
||||
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import requests
|
||||
import ollama
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -168,39 +171,108 @@ class Caption:
|
||||
Uses an external API to process images and return textual descriptions.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the Caption service with configuration from environment variables."""
|
||||
def __init__(self, vlm_config=None):
|
||||
"""Initialize the Caption service with configuration from parameters or environment variables."""
|
||||
logger.info("Initializing Caption service")
|
||||
self.prompt = """简单凝炼的描述图片的主要内容"""
|
||||
if os.getenv("VLM_MODEL_BASE_URL") == "" or os.getenv("VLM_MODEL_NAME") == "":
|
||||
logger.error("VLM_MODEL_BASE_URL or VLM_MODEL_NAME is not set")
|
||||
return
|
||||
self.completion_url = os.getenv("VLM_MODEL_BASE_URL") + "/v1/chat/completions"
|
||||
self.model = os.getenv("VLM_MODEL_NAME")
|
||||
self.api_key = os.getenv("VLM_MODEL_API_KEY")
|
||||
|
||||
# Use provided VLM config if available, otherwise fall back to environment variables
|
||||
if vlm_config:
|
||||
self.completion_url = vlm_config.get("base_url", "") + "/chat/completions"
|
||||
self.model = vlm_config.get("model_name", "")
|
||||
self.api_key = vlm_config.get("api_key", "")
|
||||
self.interface_type = vlm_config.get("interface_type", "openai").lower()
|
||||
else:
|
||||
if os.getenv("VLM_MODEL_BASE_URL") == "" or os.getenv("VLM_MODEL_NAME") == "":
|
||||
logger.error("VLM_MODEL_BASE_URL or VLM_MODEL_NAME is not set")
|
||||
return
|
||||
self.completion_url = os.getenv("VLM_MODEL_BASE_URL") + "/chat/completions"
|
||||
self.model = os.getenv("VLM_MODEL_NAME")
|
||||
self.api_key = os.getenv("VLM_MODEL_API_KEY")
|
||||
self.interface_type = os.getenv("VLM_INTERFACE_TYPE", "openai").lower()
|
||||
|
||||
# 验证接口类型
|
||||
if self.interface_type not in ["ollama", "openai"]:
|
||||
logger.warning(f"Unknown interface type: {self.interface_type}, defaulting to openai")
|
||||
self.interface_type = "openai"
|
||||
|
||||
logger.info(
|
||||
f"Service configured with model: {self.model}, endpoint: {self.completion_url}"
|
||||
f"Service configured with model: {self.model}, endpoint: {self.completion_url}, interface: {self.interface_type}"
|
||||
)
|
||||
|
||||
def _call_caption_api(self, image_url: str) -> Optional[CaptionChatResp]:
|
||||
def _call_caption_api(self, image_data: str) -> Optional[CaptionChatResp]:
|
||||
"""
|
||||
Call the Caption API to generate a description for the given image.
|
||||
|
||||
Args:
|
||||
image_url: URL of the image to be captioned
|
||||
image_data: URL of the image or base64 encoded image data
|
||||
|
||||
Returns:
|
||||
CaptionChatResp object if successful, None otherwise
|
||||
"""
|
||||
logger.info(f"Calling Caption API for image captioning")
|
||||
logger.info(f"Processing image from URL: {image_url}")
|
||||
logger.info(f"Processing image data: {image_data[:50] if len(image_data) > 50 else image_data}")
|
||||
|
||||
# 根据接口类型选择调用方式
|
||||
if self.interface_type == "ollama":
|
||||
return self._call_ollama_api(image_data)
|
||||
else:
|
||||
return self._call_openai_api(image_data)
|
||||
|
||||
def _call_ollama_api(self, image_base64: str) -> Optional[CaptionChatResp]:
|
||||
"""Call Ollama API for image captioning using base64 encoded image data."""
|
||||
|
||||
host = self.completion_url.replace("/v1/chat/completions", "")
|
||||
|
||||
client = ollama.Client(
|
||||
host=host,
|
||||
)
|
||||
|
||||
try:
|
||||
logger.info(f"Calling Ollama API with model: {self.model}")
|
||||
|
||||
# 调用Ollama API,使用images参数传递base64编码的图片
|
||||
response = client.generate(
|
||||
model=self.model,
|
||||
prompt="简单凝炼的描述图片的主要内容",
|
||||
images=[image_base64], # image_base64是base64编码的图片数据
|
||||
options={"temperature": 0.1},
|
||||
stream=False,
|
||||
)
|
||||
|
||||
# 构造响应对象
|
||||
caption_resp = CaptionChatResp(
|
||||
id="ollama_response",
|
||||
created=int(time.time()),
|
||||
model=self.model,
|
||||
object="chat.completion",
|
||||
choices=[
|
||||
Choice(
|
||||
message=Message(
|
||||
role="assistant",
|
||||
content=response.response
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
logger.info("Successfully received response from Ollama API")
|
||||
return caption_resp
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling Ollama API: {e}")
|
||||
return None
|
||||
|
||||
def _call_openai_api(self, image_base64: str) -> Optional[CaptionChatResp]:
|
||||
"""Call OpenAI-compatible API for image captioning."""
|
||||
logger.info(f"Calling OpenAI-compatible API with model: {self.model}")
|
||||
|
||||
user_msg = UserMessage(
|
||||
role="user",
|
||||
content=[
|
||||
Content(type="text", text=self.prompt),
|
||||
Content(
|
||||
type="image_url", image_url=ImageUrl(url=image_url, detail="auto")
|
||||
type="image_url", image_url=ImageUrl(url="data:image/png;base64," + image_base64, detail="auto")
|
||||
),
|
||||
],
|
||||
)
|
||||
@@ -223,7 +295,7 @@ class Caption:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
try:
|
||||
logger.info(f"Sending request to Caption API with model: {self.model}")
|
||||
logger.info(f"Sending request to OpenAI-compatible API with model: {self.model}")
|
||||
response = requests.post(
|
||||
self.completion_url,
|
||||
data=json.dumps(gpt_req, default=lambda o: o.__dict__, indent=4),
|
||||
@@ -232,12 +304,12 @@ class Caption:
|
||||
)
|
||||
if response.status_code != 200:
|
||||
logger.error(
|
||||
f"Caption API returned non-200 status code: {response.status_code}"
|
||||
f"OpenAI-compatible API returned non-200 status code: {response.status_code}"
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
logger.info(
|
||||
f"Successfully received response from Caption API with status: {response.status_code}"
|
||||
f"Successfully received response from OpenAI-compatible API with status: {response.status_code}"
|
||||
)
|
||||
logger.info(f"Converting response to CaptionChatResp object")
|
||||
caption_resp = CaptionChatResp.from_json(response.json())
|
||||
@@ -250,33 +322,30 @@ class Caption:
|
||||
|
||||
return caption_resp
|
||||
except requests.exceptions.Timeout:
|
||||
logger.error(f"Timeout while calling Caption API after 30 seconds")
|
||||
logger.error(f"Timeout while calling OpenAI-compatible API after 30 seconds")
|
||||
return None
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"Request error calling Caption API: {e}")
|
||||
logger.error(f"Request error calling OpenAI-compatible API: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling Caption API: {str(e)}")
|
||||
logger.info(
|
||||
f"Request details: model={self.model}, URL={self.completion_url}"
|
||||
)
|
||||
logger.error(f"Unexpected error calling OpenAI-compatible API: {e}")
|
||||
return None
|
||||
|
||||
def get_caption(self, image_url: str) -> str:
|
||||
def get_caption(self, image_data: str) -> str:
|
||||
"""
|
||||
Get a caption for the provided image URL.
|
||||
Get a caption for the provided image data.
|
||||
|
||||
Args:
|
||||
image_url: URL of the image to be captioned
|
||||
image_data: URL of the image or base64 encoded image data
|
||||
|
||||
Returns:
|
||||
Caption text as string, or empty string if captioning failed
|
||||
"""
|
||||
logger.info("Getting caption for image")
|
||||
if not image_url or self.completion_url is None:
|
||||
logger.error("Image URL is not set")
|
||||
if not image_data or self.completion_url is None:
|
||||
logger.error("Image data is not set")
|
||||
return ""
|
||||
caption_resp = self._call_caption_api(image_url)
|
||||
caption_resp = self._call_caption_api(image_data)
|
||||
if caption_resp:
|
||||
caption = caption_resp.choice_data()
|
||||
caption_length = len(caption)
|
||||
|
||||
@@ -37,7 +37,13 @@ class ImageParser(BaseParser):
|
||||
# Upload image to storage service
|
||||
logger.info("Uploading image to storage")
|
||||
_, ext = os.path.splitext(self.file_name)
|
||||
image_url = self.upload_bytes(content, file_ext=ext)
|
||||
|
||||
# Get COS config from chunking config if available
|
||||
cos_config = None
|
||||
if hasattr(self, 'chunking_config') and self.chunking_config and hasattr(self.chunking_config, 'cos_config'):
|
||||
cos_config = self.chunking_config.cos_config
|
||||
|
||||
image_url = self.upload_bytes(content, file_ext=ext, cos_config=cos_config)
|
||||
if not image_url:
|
||||
logger.error("Failed to upload image to storage")
|
||||
return ""
|
||||
|
||||
43
services/docreader/src/parser/image_utils.py
Normal file
43
services/docreader/src/parser/image_utils.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
from typing import Union
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def image_to_base64(image: Union[str, bytes, Image.Image, np.ndarray]) -> str:
|
||||
"""Convert image to base64 encoded string
|
||||
|
||||
Args:
|
||||
image: Image file path, bytes, PIL Image object, or numpy array
|
||||
|
||||
Returns:
|
||||
Base64 encoded image string, or empty string if conversion fails
|
||||
"""
|
||||
try:
|
||||
if isinstance(image, str):
|
||||
# It's a file path
|
||||
with open(image, "rb") as image_file:
|
||||
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||
elif isinstance(image, bytes):
|
||||
# It's bytes data
|
||||
return base64.b64encode(image).decode("utf-8")
|
||||
elif isinstance(image, Image.Image):
|
||||
# It's a PIL Image
|
||||
buffer = io.BytesIO()
|
||||
image.save(buffer, format="PNG")
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
elif isinstance(image, np.ndarray):
|
||||
# It's a numpy array
|
||||
pil_image = Image.fromarray(image)
|
||||
buffer = io.BytesIO()
|
||||
pil_image.save(buffer, format="PNG")
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
else:
|
||||
logger.error(f"Unsupported image type: {type(image)}")
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting image to base64: {str(e)}")
|
||||
return ""
|
||||
@@ -6,6 +6,7 @@ from abc import ABC, abstractmethod
|
||||
from PIL import Image
|
||||
import io
|
||||
import numpy as np
|
||||
from .image_utils import image_to_base64
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
@@ -161,35 +162,6 @@ class NanonetsOCRBackend(OCRBackend):
|
||||
logger.error(f"Failed to initialize Nanonets OCR: {str(e)}")
|
||||
self.client = None
|
||||
|
||||
def _encode_image(self, image: Union[str, bytes, Image.Image]) -> str:
|
||||
"""Encode image to base64
|
||||
|
||||
Args:
|
||||
image: Image file path, bytes, or PIL Image object
|
||||
|
||||
Returns:
|
||||
Base64 encoded image
|
||||
"""
|
||||
try:
|
||||
if isinstance(image, str):
|
||||
# It's a file path
|
||||
with open(image, "rb") as image_file:
|
||||
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||
elif isinstance(image, bytes):
|
||||
# It's bytes data
|
||||
return base64.b64encode(image).decode("utf-8")
|
||||
elif isinstance(image, Image.Image):
|
||||
# It's a PIL Image
|
||||
buffer = io.BytesIO()
|
||||
image.save(buffer, format="PNG")
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
else:
|
||||
logger.error(f"Unsupported image type: {type(image)}")
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.error(f"Error encoding image: {str(e)}")
|
||||
return ""
|
||||
|
||||
def predict(self, image: Union[str, bytes, Image.Image]) -> str:
|
||||
"""Extract text from an image using Nanonets OCR
|
||||
|
||||
@@ -205,7 +177,7 @@ class NanonetsOCRBackend(OCRBackend):
|
||||
|
||||
try:
|
||||
# Encode image to base64
|
||||
img_base64 = self._encode_image(image)
|
||||
img_base64 = image_to_base64(image)
|
||||
if not img_base64:
|
||||
return ""
|
||||
|
||||
|
||||
@@ -30,6 +30,8 @@ class ChunkingConfig:
|
||||
enable_multimodal: bool = (
|
||||
False # Whether to enable multimodal processing (text + images)
|
||||
)
|
||||
cos_config: dict = None # COS configuration for file storage
|
||||
vlm_config: dict = None # VLM configuration for image captioning
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -138,6 +140,7 @@ class Parser:
|
||||
enable_multimodal=config.enable_multimodal,
|
||||
max_image_size=1920, # Limit image size to 1920px
|
||||
max_concurrent_tasks=5, # Limit concurrent tasks to 5
|
||||
chunking_config=config, # Pass the entire chunking config
|
||||
)
|
||||
|
||||
logger.info(f"Starting to parse file content, size: {len(content)} bytes")
|
||||
@@ -197,6 +200,7 @@ class Parser:
|
||||
enable_multimodal=config.enable_multimodal,
|
||||
max_image_size=1920, # Limit image size
|
||||
max_concurrent_tasks=5, # Limit concurrent tasks
|
||||
chunking_config=config,
|
||||
)
|
||||
|
||||
logger.info(f"Starting to parse URL content")
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.36.6
|
||||
// protoc-gen-go v1.36.7
|
||||
// protoc v5.29.3
|
||||
// source: docreader.proto
|
||||
|
||||
@@ -21,19 +21,175 @@ const (
|
||||
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
|
||||
)
|
||||
|
||||
// COS 配置
|
||||
type COSConfig struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
SecretId string `protobuf:"bytes,1,opt,name=secret_id,json=secretId,proto3" json:"secret_id,omitempty"` // COS Secret ID
|
||||
SecretKey string `protobuf:"bytes,2,opt,name=secret_key,json=secretKey,proto3" json:"secret_key,omitempty"` // COS Secret Key
|
||||
Region string `protobuf:"bytes,3,opt,name=region,proto3" json:"region,omitempty"` // COS Region
|
||||
BucketName string `protobuf:"bytes,4,opt,name=bucket_name,json=bucketName,proto3" json:"bucket_name,omitempty"` // COS Bucket Name
|
||||
AppId string `protobuf:"bytes,5,opt,name=app_id,json=appId,proto3" json:"app_id,omitempty"` // COS App ID
|
||||
PathPrefix string `protobuf:"bytes,6,opt,name=path_prefix,json=pathPrefix,proto3" json:"path_prefix,omitempty"` // COS Path Prefix
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *COSConfig) Reset() {
|
||||
*x = COSConfig{}
|
||||
mi := &file_docreader_proto_msgTypes[0]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *COSConfig) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*COSConfig) ProtoMessage() {}
|
||||
|
||||
func (x *COSConfig) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_docreader_proto_msgTypes[0]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use COSConfig.ProtoReflect.Descriptor instead.
|
||||
func (*COSConfig) Descriptor() ([]byte, []int) {
|
||||
return file_docreader_proto_rawDescGZIP(), []int{0}
|
||||
}
|
||||
|
||||
func (x *COSConfig) GetSecretId() string {
|
||||
if x != nil {
|
||||
return x.SecretId
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *COSConfig) GetSecretKey() string {
|
||||
if x != nil {
|
||||
return x.SecretKey
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *COSConfig) GetRegion() string {
|
||||
if x != nil {
|
||||
return x.Region
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *COSConfig) GetBucketName() string {
|
||||
if x != nil {
|
||||
return x.BucketName
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *COSConfig) GetAppId() string {
|
||||
if x != nil {
|
||||
return x.AppId
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *COSConfig) GetPathPrefix() string {
|
||||
if x != nil {
|
||||
return x.PathPrefix
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// VLM 配置
|
||||
type VLMConfig struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
ModelName string `protobuf:"bytes,1,opt,name=model_name,json=modelName,proto3" json:"model_name,omitempty"` // VLM Model Name
|
||||
BaseUrl string `protobuf:"bytes,2,opt,name=base_url,json=baseUrl,proto3" json:"base_url,omitempty"` // VLM Base URL
|
||||
ApiKey string `protobuf:"bytes,3,opt,name=api_key,json=apiKey,proto3" json:"api_key,omitempty"` // VLM API Key
|
||||
InterfaceType string `protobuf:"bytes,4,opt,name=interface_type,json=interfaceType,proto3" json:"interface_type,omitempty"` // VLM Interface Type: "ollama" or "openai"
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *VLMConfig) Reset() {
|
||||
*x = VLMConfig{}
|
||||
mi := &file_docreader_proto_msgTypes[1]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *VLMConfig) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*VLMConfig) ProtoMessage() {}
|
||||
|
||||
func (x *VLMConfig) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_docreader_proto_msgTypes[1]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use VLMConfig.ProtoReflect.Descriptor instead.
|
||||
func (*VLMConfig) Descriptor() ([]byte, []int) {
|
||||
return file_docreader_proto_rawDescGZIP(), []int{1}
|
||||
}
|
||||
|
||||
func (x *VLMConfig) GetModelName() string {
|
||||
if x != nil {
|
||||
return x.ModelName
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *VLMConfig) GetBaseUrl() string {
|
||||
if x != nil {
|
||||
return x.BaseUrl
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *VLMConfig) GetApiKey() string {
|
||||
if x != nil {
|
||||
return x.ApiKey
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *VLMConfig) GetInterfaceType() string {
|
||||
if x != nil {
|
||||
return x.InterfaceType
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type ReadConfig struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
ChunkSize int32 `protobuf:"varint,1,opt,name=chunk_size,json=chunkSize,proto3" json:"chunk_size,omitempty"` // 分块大小
|
||||
ChunkOverlap int32 `protobuf:"varint,2,opt,name=chunk_overlap,json=chunkOverlap,proto3" json:"chunk_overlap,omitempty"` // 分块重叠
|
||||
Separators []string `protobuf:"bytes,3,rep,name=separators,proto3" json:"separators,omitempty"` // 分隔符
|
||||
EnableMultimodal bool `protobuf:"varint,4,opt,name=enable_multimodal,json=enableMultimodal,proto3" json:"enable_multimodal,omitempty"` // 多模态处理
|
||||
CosConfig *COSConfig `protobuf:"bytes,5,opt,name=cos_config,json=cosConfig,proto3" json:"cos_config,omitempty"` // COS 配置
|
||||
VlmConfig *VLMConfig `protobuf:"bytes,6,opt,name=vlm_config,json=vlmConfig,proto3" json:"vlm_config,omitempty"` // VLM 配置
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *ReadConfig) Reset() {
|
||||
*x = ReadConfig{}
|
||||
mi := &file_docreader_proto_msgTypes[0]
|
||||
mi := &file_docreader_proto_msgTypes[2]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -45,7 +201,7 @@ func (x *ReadConfig) String() string {
|
||||
func (*ReadConfig) ProtoMessage() {}
|
||||
|
||||
func (x *ReadConfig) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_docreader_proto_msgTypes[0]
|
||||
mi := &file_docreader_proto_msgTypes[2]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -58,7 +214,7 @@ func (x *ReadConfig) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use ReadConfig.ProtoReflect.Descriptor instead.
|
||||
func (*ReadConfig) Descriptor() ([]byte, []int) {
|
||||
return file_docreader_proto_rawDescGZIP(), []int{0}
|
||||
return file_docreader_proto_rawDescGZIP(), []int{2}
|
||||
}
|
||||
|
||||
func (x *ReadConfig) GetChunkSize() int32 {
|
||||
@@ -89,6 +245,20 @@ func (x *ReadConfig) GetEnableMultimodal() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (x *ReadConfig) GetCosConfig() *COSConfig {
|
||||
if x != nil {
|
||||
return x.CosConfig
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *ReadConfig) GetVlmConfig() *VLMConfig {
|
||||
if x != nil {
|
||||
return x.VlmConfig
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 从文件读取文档请求
|
||||
type ReadFromFileRequest struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
@@ -103,7 +273,7 @@ type ReadFromFileRequest struct {
|
||||
|
||||
func (x *ReadFromFileRequest) Reset() {
|
||||
*x = ReadFromFileRequest{}
|
||||
mi := &file_docreader_proto_msgTypes[1]
|
||||
mi := &file_docreader_proto_msgTypes[3]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -115,7 +285,7 @@ func (x *ReadFromFileRequest) String() string {
|
||||
func (*ReadFromFileRequest) ProtoMessage() {}
|
||||
|
||||
func (x *ReadFromFileRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_docreader_proto_msgTypes[1]
|
||||
mi := &file_docreader_proto_msgTypes[3]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -128,7 +298,7 @@ func (x *ReadFromFileRequest) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use ReadFromFileRequest.ProtoReflect.Descriptor instead.
|
||||
func (*ReadFromFileRequest) Descriptor() ([]byte, []int) {
|
||||
return file_docreader_proto_rawDescGZIP(), []int{1}
|
||||
return file_docreader_proto_rawDescGZIP(), []int{3}
|
||||
}
|
||||
|
||||
func (x *ReadFromFileRequest) GetFileContent() []byte {
|
||||
@@ -179,7 +349,7 @@ type ReadFromURLRequest struct {
|
||||
|
||||
func (x *ReadFromURLRequest) Reset() {
|
||||
*x = ReadFromURLRequest{}
|
||||
mi := &file_docreader_proto_msgTypes[2]
|
||||
mi := &file_docreader_proto_msgTypes[4]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -191,7 +361,7 @@ func (x *ReadFromURLRequest) String() string {
|
||||
func (*ReadFromURLRequest) ProtoMessage() {}
|
||||
|
||||
func (x *ReadFromURLRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_docreader_proto_msgTypes[2]
|
||||
mi := &file_docreader_proto_msgTypes[4]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -204,7 +374,7 @@ func (x *ReadFromURLRequest) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use ReadFromURLRequest.ProtoReflect.Descriptor instead.
|
||||
func (*ReadFromURLRequest) Descriptor() ([]byte, []int) {
|
||||
return file_docreader_proto_rawDescGZIP(), []int{2}
|
||||
return file_docreader_proto_rawDescGZIP(), []int{4}
|
||||
}
|
||||
|
||||
func (x *ReadFromURLRequest) GetUrl() string {
|
||||
@@ -250,7 +420,7 @@ type Image struct {
|
||||
|
||||
func (x *Image) Reset() {
|
||||
*x = Image{}
|
||||
mi := &file_docreader_proto_msgTypes[3]
|
||||
mi := &file_docreader_proto_msgTypes[5]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -262,7 +432,7 @@ func (x *Image) String() string {
|
||||
func (*Image) ProtoMessage() {}
|
||||
|
||||
func (x *Image) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_docreader_proto_msgTypes[3]
|
||||
mi := &file_docreader_proto_msgTypes[5]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -275,7 +445,7 @@ func (x *Image) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use Image.ProtoReflect.Descriptor instead.
|
||||
func (*Image) Descriptor() ([]byte, []int) {
|
||||
return file_docreader_proto_rawDescGZIP(), []int{3}
|
||||
return file_docreader_proto_rawDescGZIP(), []int{5}
|
||||
}
|
||||
|
||||
func (x *Image) GetUrl() string {
|
||||
@@ -333,7 +503,7 @@ type Chunk struct {
|
||||
|
||||
func (x *Chunk) Reset() {
|
||||
*x = Chunk{}
|
||||
mi := &file_docreader_proto_msgTypes[4]
|
||||
mi := &file_docreader_proto_msgTypes[6]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -345,7 +515,7 @@ func (x *Chunk) String() string {
|
||||
func (*Chunk) ProtoMessage() {}
|
||||
|
||||
func (x *Chunk) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_docreader_proto_msgTypes[4]
|
||||
mi := &file_docreader_proto_msgTypes[6]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -358,7 +528,7 @@ func (x *Chunk) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use Chunk.ProtoReflect.Descriptor instead.
|
||||
func (*Chunk) Descriptor() ([]byte, []int) {
|
||||
return file_docreader_proto_rawDescGZIP(), []int{4}
|
||||
return file_docreader_proto_rawDescGZIP(), []int{6}
|
||||
}
|
||||
|
||||
func (x *Chunk) GetContent() string {
|
||||
@@ -407,7 +577,7 @@ type ReadResponse struct {
|
||||
|
||||
func (x *ReadResponse) Reset() {
|
||||
*x = ReadResponse{}
|
||||
mi := &file_docreader_proto_msgTypes[5]
|
||||
mi := &file_docreader_proto_msgTypes[7]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -419,7 +589,7 @@ func (x *ReadResponse) String() string {
|
||||
func (*ReadResponse) ProtoMessage() {}
|
||||
|
||||
func (x *ReadResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_docreader_proto_msgTypes[5]
|
||||
mi := &file_docreader_proto_msgTypes[7]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -432,7 +602,7 @@ func (x *ReadResponse) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use ReadResponse.ProtoReflect.Descriptor instead.
|
||||
func (*ReadResponse) Descriptor() ([]byte, []int) {
|
||||
return file_docreader_proto_rawDescGZIP(), []int{5}
|
||||
return file_docreader_proto_rawDescGZIP(), []int{7}
|
||||
}
|
||||
|
||||
func (x *ReadResponse) GetChunks() []*Chunk {
|
||||
@@ -453,7 +623,23 @@ var File_docreader_proto protoreflect.FileDescriptor
|
||||
|
||||
const file_docreader_proto_rawDesc = "" +
|
||||
"\n" +
|
||||
"\x0fdocreader.proto\x12\tdocreader\"\x9d\x01\n" +
|
||||
"\x0fdocreader.proto\x12\tdocreader\"\xb8\x01\n" +
|
||||
"\tCOSConfig\x12\x1b\n" +
|
||||
"\tsecret_id\x18\x01 \x01(\tR\bsecretId\x12\x1d\n" +
|
||||
"\n" +
|
||||
"secret_key\x18\x02 \x01(\tR\tsecretKey\x12\x16\n" +
|
||||
"\x06region\x18\x03 \x01(\tR\x06region\x12\x1f\n" +
|
||||
"\vbucket_name\x18\x04 \x01(\tR\n" +
|
||||
"bucketName\x12\x15\n" +
|
||||
"\x06app_id\x18\x05 \x01(\tR\x05appId\x12\x1f\n" +
|
||||
"\vpath_prefix\x18\x06 \x01(\tR\n" +
|
||||
"pathPrefix\"\x85\x01\n" +
|
||||
"\tVLMConfig\x12\x1d\n" +
|
||||
"\n" +
|
||||
"model_name\x18\x01 \x01(\tR\tmodelName\x12\x19\n" +
|
||||
"\bbase_url\x18\x02 \x01(\tR\abaseUrl\x12\x17\n" +
|
||||
"\aapi_key\x18\x03 \x01(\tR\x06apiKey\x12%\n" +
|
||||
"\x0einterface_type\x18\x04 \x01(\tR\rinterfaceType\"\x87\x02\n" +
|
||||
"\n" +
|
||||
"ReadConfig\x12\x1d\n" +
|
||||
"\n" +
|
||||
@@ -462,7 +648,11 @@ const file_docreader_proto_rawDesc = "" +
|
||||
"\n" +
|
||||
"separators\x18\x03 \x03(\tR\n" +
|
||||
"separators\x12+\n" +
|
||||
"\x11enable_multimodal\x18\x04 \x01(\bR\x10enableMultimodal\"\xc9\x01\n" +
|
||||
"\x11enable_multimodal\x18\x04 \x01(\bR\x10enableMultimodal\x123\n" +
|
||||
"\n" +
|
||||
"cos_config\x18\x05 \x01(\v2\x14.docreader.COSConfigR\tcosConfig\x123\n" +
|
||||
"\n" +
|
||||
"vlm_config\x18\x06 \x01(\v2\x14.docreader.VLMConfigR\tvlmConfig\"\xc9\x01\n" +
|
||||
"\x13ReadFromFileRequest\x12!\n" +
|
||||
"\ffile_content\x18\x01 \x01(\fR\vfileContent\x12\x1b\n" +
|
||||
"\tfile_name\x18\x02 \x01(\tR\bfileName\x12\x1b\n" +
|
||||
@@ -510,29 +700,33 @@ func file_docreader_proto_rawDescGZIP() []byte {
|
||||
return file_docreader_proto_rawDescData
|
||||
}
|
||||
|
||||
var file_docreader_proto_msgTypes = make([]protoimpl.MessageInfo, 6)
|
||||
var file_docreader_proto_msgTypes = make([]protoimpl.MessageInfo, 8)
|
||||
var file_docreader_proto_goTypes = []any{
|
||||
(*ReadConfig)(nil), // 0: docreader.ReadConfig
|
||||
(*ReadFromFileRequest)(nil), // 1: docreader.ReadFromFileRequest
|
||||
(*ReadFromURLRequest)(nil), // 2: docreader.ReadFromURLRequest
|
||||
(*Image)(nil), // 3: docreader.Image
|
||||
(*Chunk)(nil), // 4: docreader.Chunk
|
||||
(*ReadResponse)(nil), // 5: docreader.ReadResponse
|
||||
(*COSConfig)(nil), // 0: docreader.COSConfig
|
||||
(*VLMConfig)(nil), // 1: docreader.VLMConfig
|
||||
(*ReadConfig)(nil), // 2: docreader.ReadConfig
|
||||
(*ReadFromFileRequest)(nil), // 3: docreader.ReadFromFileRequest
|
||||
(*ReadFromURLRequest)(nil), // 4: docreader.ReadFromURLRequest
|
||||
(*Image)(nil), // 5: docreader.Image
|
||||
(*Chunk)(nil), // 6: docreader.Chunk
|
||||
(*ReadResponse)(nil), // 7: docreader.ReadResponse
|
||||
}
|
||||
var file_docreader_proto_depIdxs = []int32{
|
||||
0, // 0: docreader.ReadFromFileRequest.read_config:type_name -> docreader.ReadConfig
|
||||
0, // 1: docreader.ReadFromURLRequest.read_config:type_name -> docreader.ReadConfig
|
||||
3, // 2: docreader.Chunk.images:type_name -> docreader.Image
|
||||
4, // 3: docreader.ReadResponse.chunks:type_name -> docreader.Chunk
|
||||
1, // 4: docreader.DocReader.ReadFromFile:input_type -> docreader.ReadFromFileRequest
|
||||
2, // 5: docreader.DocReader.ReadFromURL:input_type -> docreader.ReadFromURLRequest
|
||||
5, // 6: docreader.DocReader.ReadFromFile:output_type -> docreader.ReadResponse
|
||||
5, // 7: docreader.DocReader.ReadFromURL:output_type -> docreader.ReadResponse
|
||||
6, // [6:8] is the sub-list for method output_type
|
||||
4, // [4:6] is the sub-list for method input_type
|
||||
4, // [4:4] is the sub-list for extension type_name
|
||||
4, // [4:4] is the sub-list for extension extendee
|
||||
0, // [0:4] is the sub-list for field type_name
|
||||
0, // 0: docreader.ReadConfig.cos_config:type_name -> docreader.COSConfig
|
||||
1, // 1: docreader.ReadConfig.vlm_config:type_name -> docreader.VLMConfig
|
||||
2, // 2: docreader.ReadFromFileRequest.read_config:type_name -> docreader.ReadConfig
|
||||
2, // 3: docreader.ReadFromURLRequest.read_config:type_name -> docreader.ReadConfig
|
||||
5, // 4: docreader.Chunk.images:type_name -> docreader.Image
|
||||
6, // 5: docreader.ReadResponse.chunks:type_name -> docreader.Chunk
|
||||
3, // 6: docreader.DocReader.ReadFromFile:input_type -> docreader.ReadFromFileRequest
|
||||
4, // 7: docreader.DocReader.ReadFromURL:input_type -> docreader.ReadFromURLRequest
|
||||
7, // 8: docreader.DocReader.ReadFromFile:output_type -> docreader.ReadResponse
|
||||
7, // 9: docreader.DocReader.ReadFromURL:output_type -> docreader.ReadResponse
|
||||
8, // [8:10] is the sub-list for method output_type
|
||||
6, // [6:8] is the sub-list for method input_type
|
||||
6, // [6:6] is the sub-list for extension type_name
|
||||
6, // [6:6] is the sub-list for extension extendee
|
||||
0, // [0:6] is the sub-list for field type_name
|
||||
}
|
||||
|
||||
func init() { file_docreader_proto_init() }
|
||||
@@ -546,7 +740,7 @@ func file_docreader_proto_init() {
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: unsafe.Slice(unsafe.StringData(file_docreader_proto_rawDesc), len(file_docreader_proto_rawDesc)),
|
||||
NumEnums: 0,
|
||||
NumMessages: 6,
|
||||
NumMessages: 8,
|
||||
NumExtensions: 0,
|
||||
NumServices: 1,
|
||||
},
|
||||
|
||||
@@ -12,11 +12,31 @@ service DocReader {
|
||||
rpc ReadFromURL(ReadFromURLRequest) returns (ReadResponse) {}
|
||||
}
|
||||
|
||||
// COS 配置
|
||||
message COSConfig {
|
||||
string secret_id = 1; // COS Secret ID
|
||||
string secret_key = 2; // COS Secret Key
|
||||
string region = 3; // COS Region
|
||||
string bucket_name = 4; // COS Bucket Name
|
||||
string app_id = 5; // COS App ID
|
||||
string path_prefix = 6; // COS Path Prefix
|
||||
}
|
||||
|
||||
// VLM 配置
|
||||
message VLMConfig {
|
||||
string model_name = 1; // VLM Model Name
|
||||
string base_url = 2; // VLM Base URL
|
||||
string api_key = 3; // VLM API Key
|
||||
string interface_type = 4; // VLM Interface Type: "ollama" or "openai"
|
||||
}
|
||||
|
||||
message ReadConfig {
|
||||
int32 chunk_size = 1; // 分块大小
|
||||
int32 chunk_overlap = 2; // 分块重叠
|
||||
repeated string separators = 3; // 分隔符
|
||||
bool enable_multimodal = 4; // 多模态处理
|
||||
COSConfig cos_config = 5; // COS 配置
|
||||
VLMConfig vlm_config = 6; // VLM 配置
|
||||
}
|
||||
|
||||
// 从文件读取文档请求
|
||||
|
||||
@@ -1,9 +1,29 @@
|
||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||
"""Client and server classes corresponding to protobuf-defined services."""
|
||||
import grpc
|
||||
import warnings
|
||||
|
||||
from . import docreader_pb2 as docreader__pb2
|
||||
|
||||
GRPC_GENERATED_VERSION = '1.74.0'
|
||||
GRPC_VERSION = grpc.__version__
|
||||
_version_not_supported = False
|
||||
|
||||
try:
|
||||
from grpc._utilities import first_version_is_lower
|
||||
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
|
||||
except ImportError:
|
||||
_version_not_supported = True
|
||||
|
||||
if _version_not_supported:
|
||||
raise RuntimeError(
|
||||
f'The grpc package installed is at version {GRPC_VERSION},'
|
||||
+ f' but the generated code in docreader_pb2_grpc.py depends on'
|
||||
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
|
||||
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
|
||||
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
|
||||
)
|
||||
|
||||
|
||||
class DocReaderStub(object):
|
||||
"""文档读取服务
|
||||
@@ -19,12 +39,12 @@ class DocReaderStub(object):
|
||||
'/docreader.DocReader/ReadFromFile',
|
||||
request_serializer=docreader__pb2.ReadFromFileRequest.SerializeToString,
|
||||
response_deserializer=docreader__pb2.ReadResponse.FromString,
|
||||
)
|
||||
_registered_method=True)
|
||||
self.ReadFromURL = channel.unary_unary(
|
||||
'/docreader.DocReader/ReadFromURL',
|
||||
request_serializer=docreader__pb2.ReadFromURLRequest.SerializeToString,
|
||||
response_deserializer=docreader__pb2.ReadResponse.FromString,
|
||||
)
|
||||
_registered_method=True)
|
||||
|
||||
|
||||
class DocReaderServicer(object):
|
||||
@@ -62,6 +82,7 @@ def add_DocReaderServicer_to_server(servicer, server):
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'docreader.DocReader', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
server.add_registered_method_handlers('docreader.DocReader', rpc_method_handlers)
|
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API.
|
||||
@@ -80,11 +101,21 @@ class DocReader(object):
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/docreader.DocReader/ReadFromFile',
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/docreader.DocReader/ReadFromFile',
|
||||
docreader__pb2.ReadFromFileRequest.SerializeToString,
|
||||
docreader__pb2.ReadResponse.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def ReadFromURL(request,
|
||||
@@ -97,8 +128,18 @@ class DocReader(object):
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/docreader.DocReader/ReadFromURL',
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/docreader.DocReader/ReadFromURL',
|
||||
docreader__pb2.ReadFromURLRequest.SerializeToString,
|
||||
docreader__pb2.ReadResponse.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@@ -74,11 +74,39 @@ class DocReaderServicer(docreader_pb2_grpc.DocReaderServicer):
|
||||
f"multimodal={enable_multimodal}"
|
||||
)
|
||||
|
||||
# Get COS and VLM config from request
|
||||
cos_config = None
|
||||
vlm_config = None
|
||||
|
||||
if hasattr(request.read_config, 'cos_config') and request.read_config.cos_config:
|
||||
cos_config = {
|
||||
'secret_id': request.read_config.cos_config.secret_id,
|
||||
'secret_key': request.read_config.cos_config.secret_key,
|
||||
'region': request.read_config.cos_config.region,
|
||||
'bucket_name': request.read_config.cos_config.bucket_name,
|
||||
'app_id': request.read_config.cos_config.app_id,
|
||||
'path_prefix': request.read_config.cos_config.path_prefix or '',
|
||||
}
|
||||
logger.info(f"Using COS config: region={cos_config['region']}, bucket={cos_config['bucket_name']}")
|
||||
|
||||
if hasattr(request.read_config, 'vlm_config') and request.read_config.vlm_config:
|
||||
vlm_config = {
|
||||
'model_name': request.read_config.vlm_config.model_name,
|
||||
'base_url': request.read_config.vlm_config.base_url,
|
||||
'api_key': request.read_config.vlm_config.api_key or '',
|
||||
'interface_type': request.read_config.vlm_config.interface_type or 'openai',
|
||||
}
|
||||
logger.info(f"Using VLM config: model={vlm_config['model_name']}, "
|
||||
f"base_url={vlm_config['base_url']}, "
|
||||
f"interface_type={vlm_config['interface_type']}")
|
||||
|
||||
chunking_config = ChunkingConfig(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
separators=separators,
|
||||
enable_multimodal=enable_multimodal,
|
||||
cos_config=cos_config,
|
||||
vlm_config=vlm_config,
|
||||
)
|
||||
|
||||
# Parse file
|
||||
@@ -138,11 +166,39 @@ class DocReaderServicer(docreader_pb2_grpc.DocReaderServicer):
|
||||
f"multimodal={enable_multimodal}"
|
||||
)
|
||||
|
||||
# Get COS and VLM config from request
|
||||
cos_config = None
|
||||
vlm_config = None
|
||||
|
||||
if hasattr(request.read_config, 'cos_config') and request.read_config.cos_config:
|
||||
cos_config = {
|
||||
'secret_id': request.read_config.cos_config.secret_id,
|
||||
'secret_key': request.read_config.cos_config.secret_key,
|
||||
'region': request.read_config.cos_config.region,
|
||||
'bucket_name': request.read_config.cos_config.bucket_name,
|
||||
'app_id': request.read_config.cos_config.app_id,
|
||||
'path_prefix': request.read_config.cos_config.path_prefix or '',
|
||||
}
|
||||
logger.info(f"Using COS config: region={cos_config['region']}, bucket={cos_config['bucket_name']}")
|
||||
|
||||
if hasattr(request.read_config, 'vlm_config') and request.read_config.vlm_config:
|
||||
vlm_config = {
|
||||
'model_name': request.read_config.vlm_config.model_name,
|
||||
'base_url': request.read_config.vlm_config.base_url,
|
||||
'api_key': request.read_config.vlm_config.api_key or '',
|
||||
'interface_type': request.read_config.vlm_config.interface_type or 'openai',
|
||||
}
|
||||
logger.info(f"Using VLM config: model={vlm_config['model_name']}, "
|
||||
f"base_url={vlm_config['base_url']}, "
|
||||
f"interface_type={vlm_config['interface_type']}")
|
||||
|
||||
chunking_config = ChunkingConfig(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
separators=separators,
|
||||
enable_multimodal=enable_multimodal,
|
||||
cos_config=cos_config,
|
||||
vlm_config=vlm_config,
|
||||
)
|
||||
|
||||
# Parse URL
|
||||
|
||||
Reference in New Issue
Block a user