feat: Added web page for configuring model information

This commit is contained in:
wizardchen
2025-08-10 17:04:39 +08:00
committed by lyingbug
parent 4498442fcc
commit bdabed6bfa
43 changed files with 5689 additions and 338 deletions

View File

@@ -112,73 +112,79 @@ ENABLE_GRAPH_RAG=false
# COS_ENABLE_OLD_DOMAIN=true 表示启用旧的域名格式,默认为 true
COS_ENABLE_OLD_DOMAIN=true
# 初始化默认租户与知识库
# 租户ID通常是一个字符串
INIT_TEST_TENANT_ID=1
# 知识库ID通常是一个字符串
INIT_TEST_KNOWLEDGE_BASE_ID=kb-00000001
##############################################################
# LLM Model
# 使用的LLM模型名称
# 默认使用 Ollama 的 Qwen3 8B 模型ollama 会自动处理模型下载和加载
# 如果需要使用其他模型,请替换为实际的模型名称
INIT_LLM_MODEL_NAME=qwen3:8b
###### 注意: 以下配置不再生效已在Web“配置初始化”阶段完成 #########
# LLM模型的访问地址
# 支持第三方模型服务的URL
# 如果使用 Ollama 的本地服务可以留空ollama 会自动处理
# INIT_LLM_MODEL_BASE_URL=your_llm_model_base_url
# LLM模型的API密钥如果需要身份验证可以设置
# 支持第三方模型服务的API密钥
# 如果使用 Ollama 的本地服务可以留空ollama 会自动处理
# INIT_LLM_MODEL_API_KEY=your_llm_model_api_key
# # 初始化默认租户与知识库
# # 租户ID通常是一个字符串
# INIT_TEST_TENANT_ID=1
# Embedding Model
# 使用的Embedding模型名称
# 默认使用 nomic-embed-text 模型,支持文本嵌入
# 如果需要使用其他模型,请替换为实际的模型名称
INIT_EMBEDDING_MODEL_NAME=nomic-embed-text
# # 知识库ID通常是一个字符串
# INIT_TEST_KNOWLEDGE_BASE_ID=kb-00000001
# Embedding模型向量维度
INIT_EMBEDDING_MODEL_DIMENSION=768
# # LLM Model
# # 使用的LLM模型名称
# # 默认使用 Ollama 的 Qwen3 8B 模型ollama 会自动处理模型下载和加载
# # 如果需要使用其他模型,请替换为实际的模型名称
# INIT_LLM_MODEL_NAME=qwen3:8b
# Embedding模型的ID通常是一个字符串
INIT_EMBEDDING_MODEL_ID=builtin:nomic-embed-text:768
# # LLM模型的访问地址
# # 支持第三方模型服务的URL
# # 如果使用 Ollama 的本地服务可以留空ollama 会自动处理
# # INIT_LLM_MODEL_BASE_URL=your_llm_model_base_url
# Embedding模型的访问地址
# 支持第三方模型服务的URL
# 如果使用 Ollama 的本地服务可以留空ollama 会自动处理
# INIT_EMBEDDING_MODEL_BASE_URL=your_embedding_model_base_url
# # LLM模型的API密钥如果需要身份验证可以设置
# # 支持第三方模型服务的API密钥
# # 如果使用 Ollama 的本地服务可以留空ollama 会自动处理
# # INIT_LLM_MODEL_API_KEY=your_llm_model_api_key
# Embedding模型的API密钥如果需要身份验证可以设置
# 支持第三方模型服务的API密钥
# 如果使用 Ollama 的本地服务可以留空ollama 会自动处理
# INIT_EMBEDDING_MODEL_API_KEY=your_embedding_model_api_key
# # Embedding Model
# # 使用的Embedding模型名称
# # 默认使用 nomic-embed-text 模型,支持文本嵌入
# # 如果需要使用其他模型,请替换为实际的模型名称
# INIT_EMBEDDING_MODEL_NAME=nomic-embed-text
# Rerank Model(可选)
# 对于rag来说使用Rerank模型对提升文档搜索的准确度有着重要作用
# 目前 ollama 暂不支持运行 Rerank 模型
# 使用的Rerank模型名称
# INIT_RERANK_MODEL_NAME=your_rerank_model_name
# # Embedding模型向量维度
# INIT_EMBEDDING_MODEL_DIMENSION=768
# Rerank模型的访问地址
# 支持第三方模型服务的URL
# INIT_RERANK_MODEL_BASE_URL=your_rerank_model_base_url
# # Embedding模型的ID通常是一个字符串
# INIT_EMBEDDING_MODEL_ID=builtin:nomic-embed-text:768
# Rerank模型的API密钥如果需要身份验证可以设置
# 支持第三方模型服务的API密钥
# INIT_RERANK_MODEL_API_KEY=your_rerank_model_api_key
# # Embedding模型的访问地址
# # 支持第三方模型服务的URL
# # 如果使用 Ollama 的本地服务可以留空ollama 会自动处理
# # INIT_EMBEDDING_MODEL_BASE_URL=your_embedding_model_base_url
# VLM_MODEL_NAME 使用的多模态模型名称
# 用于解析图片数据
# VLM_MODEL_NAME=your_vlm_model_name
# # Embedding模型的API密钥如果需要身份验证可以设置
# # 支持第三方模型服务的API密钥
# # 如果使用 Ollama 的本地服务可以留空ollama 会自动处理
# # INIT_EMBEDDING_MODEL_API_KEY=your_embedding_model_api_key
# VLM_MODEL_BASE_URL 使用的多模态模型访问地址
# 支持第三方模型服务的URL
# VLM_MODEL_BASE_URL=your_vlm_model_base_url
# # Rerank Model(可选)
# # 对于rag来说使用Rerank模型对提升文档搜索的准确度有着重要作用
# # 目前 ollama 暂不支持运行 Rerank 模型
# # 使用的Rerank模型名称
# # INIT_RERANK_MODEL_NAME=your_rerank_model_name
# VLM_MODEL_API_KEY 使用的多模态模型API密钥
# 支持第三方模型服务的API密钥
# VLM_MODEL_API_KEY=your_vlm_model_api_key
# # Rerank模型的访问地址
# # 支持第三方模型服务的URL
# # INIT_RERANK_MODEL_BASE_URL=your_rerank_model_base_url
# # Rerank模型的API密钥如果需要身份验证可以设置
# # 支持第三方模型服务的API密钥
# # INIT_RERANK_MODEL_API_KEY=your_rerank_model_api_key
# # VLM_MODEL_NAME 使用的多模态模型名称
# # 用于解析图片数据
# # VLM_MODEL_NAME=your_vlm_model_name
# # VLM_MODEL_BASE_URL 使用的多模态模型访问地址
# # 支持第三方模型服务的URL
# # VLM_MODEL_BASE_URL=your_vlm_model_base_url
# # VLM_MODEL_API_KEY 使用的多模态模型API密钥
# # 支持第三方模型服务的API密钥
# # VLM_MODEL_API_KEY=your_vlm_model_api_key

View File

@@ -140,6 +140,38 @@ WeKnora 作为[微信对话开放平台](https://chatbot.weixin.qq.com)的核心
- **高效问题管理**:支持高频问题的独立分类管理,提供丰富的数据工具,确保回答精准可靠且易于维护
- **微信生态覆盖**通过微信对话开放平台WeKnora 的智能问答能力可无缝集成到公众号、小程序等微信场景中,提升用户交互体验
## 🔧 初始化配置引导
为了方便用户快速配置各类模型降低试错成本我们改进了原来的配置文件初始化方式增加了Web UI界面进行各种模型的配置。在使用之前请确保代码更新到最新版本。具体使用步骤如下
如果是第一次使用本项目,可跳过①②步骤,直接进入③④步骤。
### ① 关闭服务
```bash
./scripts/start_all.sh --stop
```
### ② 清空原有数据表(建议在没有重要数据的情况下使用)
```bash
make clean-db
```
### ③ 编译并启动服务
```bash
./scripts/start_all.sh
```
### ④ 访问Web UI
http://localhost
首次访问会自动跳转到初始化配置页面,配置完成后会自动跳转到知识库页面。请按照页面提示信息完成模型的配置。
![配置页面](./docs/images/config.png)
## 📱 功能展示
### Web UI 界面

View File

@@ -14,7 +14,6 @@ import (
"github.com/gin-gonic/gin"
"github.com/Tencent/WeKnora/internal/application/service"
"github.com/Tencent/WeKnora/internal/config"
"github.com/Tencent/WeKnora/internal/container"
"github.com/Tencent/WeKnora/internal/runtime"
@@ -42,7 +41,6 @@ func main() {
cfg *config.Config,
router *gin.Engine,
tracer *tracing.Tracer,
testDataService *service.TestDataService,
resourceCleaner interfaces.ResourceCleaner,
) error {
// Create context for resource cleanup
@@ -58,13 +56,6 @@ func main() {
return tracer.Cleanup(cleanupCtx)
})
// Initialize test data
if testDataService != nil {
if err := testDataService.InitializeTestData(context.Background()); err != nil {
log.Printf("Failed to initialize test data: %v", err)
}
}
// Create HTTP server
server := &http.Server{
Addr: fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port),

View File

@@ -151,26 +151,27 @@ conversation:
Output:
generate_summary_prompt: |
你是一个总结助手,你的任务是总结文章或片段内容。
你是一个精准的文章总结专家。你的任务是提取并总结用户提供的文章或片段的核心内容。
## 准则要求
- 总结结果长度不能超过100个字
- 保持客观,不添加个人观点或评价
- 使用第三人称陈述语气
- 不要基于任何先验知识回答用户的问题,只基于文章内容生成摘要
- 直接输出总结结果,不要有任何前缀或解释
- 使用中文输出总结结果
- 不能输出“无法生成”、“无法总结”等字眼
## 核心要求
- 总结结果长度为50-100个字,根据内容复杂度灵活调整
- 完全基于提供的文章内容生成总结,不添加任何未在文章中出现的信息
- 确保总结包含文章的关键信息点和主要结论
- 即使文章内容较复杂或专业,也必须尝试提取核心要点进行总结
- 直接输出总结结果,不包含任何引言、前缀或解释
## Few-shot示例
## 格式与风格
- 使用客观、中立的第三人称陈述语气
- 使用清晰简洁的中文表达
- 保持逻辑连贯性,确保句与句之间有合理过渡
- 避免重复使用相同的表达方式或句式结构
用户给出的文章内容:
随着5G技术的快速发展各行各业正经历数字化转型。5G网络凭借高速率、低延迟和大连接的特性正在推动智慧城市、工业互联网和远程医疗等领域的创新。专家预测到2025年5G将为全球经济贡献约2.2万亿美元。然而5G建设也面临基础设施投入大、覆盖不均等挑战。
## 注意事项
- 绝对不输出"无法生成"、"无法总结"、"内容不足"等拒绝回应的词语
- 不要照抄或参考示例中的任何内容,确保总结完全基于用户提供的新文章
- 对于任何文本都尽最大努力提取重点并总结,无论长度或复杂度
文章总结:
5G技术凭借高速率、低延迟和大连接特性推动各行业数字化转型促进智慧城市、工业互联网和远程医疗创新。预计2025年将贡献约2.2万亿美元经济价值,但仍面临基础设施投入大和覆盖不均等挑战。
## 用户给出的文章内容是:
## 以下是用户给出的文章相关信息:
generate_session_title_prompt: |
你是一个专业的会话标题生成助手,你的任务是为用户提问创建简洁、精准且具描述性的标题。

View File

@@ -87,6 +87,8 @@ services:
build:
context: .
dockerfile: docker/Dockerfile.docreader
args:
- PLATFORM=${PLATFORM:-linux/amd64}
container_name: WeKnora-docreader
ports:
- "50051:50051"
@@ -104,6 +106,8 @@ services:
networks:
- WeKnora-network
restart: unless-stopped
extra_hosts:
- "host.docker.internal:host-gateway"
jaeger:
image: jaegertracing/all-in-one:latest

View File

@@ -93,14 +93,29 @@ RUN apt-get update && apt-get install -y \
# 下载并安装最新版本的 LibreOffice 25.2.4
RUN mkdir -p /tmp/libreoffice && \
cd /tmp/libreoffice && \
if [ "$(uname -m)" = "x86_64" ]; then \
wget -q https://mirrors.tuna.tsinghua.edu.cn/libreoffice/libreoffice/stable/25.2.4/deb/x86_64/LibreOffice_25.2.4_Linux_x86-64_deb.tar.gz && \
tar -xzf LibreOffice_25.2.4_Linux_x86-64_deb.tar.gz && \
cd LibreOffice_25.2.4*_Linux_x86-64_deb/DEBS && \
dpkg -i *.deb && \
dpkg -i *.deb; \
elif [ "$(uname -m)" = "aarch64" ] || [ "$(uname -m)" = "arm64" ]; then \
wget -q https://mirrors.tuna.tsinghua.edu.cn/libreoffice/libreoffice/testing/25.8.0/deb/aarch64/LibreOffice_25.8.0.2_Linux_aarch64_deb.tar.gz && \
tar -xzf LibreOffice_25.8.0.2_Linux_aarch64_deb.tar.gz && \
cd LibreOffice_25.8.0*_Linux_aarch64_deb/DEBS && \
dpkg -i *.deb; \
else \
echo "Unsupported architecture: $(uname -m)" && exit 1; \
fi && \
cd / && \
rm -rf /tmp/libreoffice
# 设置 LibreOffice 环境变量
RUN if [ "$(uname -m)" = "x86_64" ]; then \
echo 'export LIBREOFFICE_PATH=/opt/libreoffice25.2/program/soffice' >> /etc/environment; \
elif [ "$(uname -m)" = "aarch64" ] || [ "$(uname -m)" = "arm64" ]; then \
echo 'export LIBREOFFICE_PATH=/opt/libreoffice25.2/program/soffice' >> /etc/environment; \
fi
ENV LIBREOFFICE_PATH=/opt/libreoffice25.2/program/soffice
# 从构建阶段复制已安装的依赖和生成的代码

BIN
docs/images/config.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 626 KiB

View File

@@ -0,0 +1,278 @@
import { get, post } from '../../utils/request';
// 初始化配置数据类型
export interface InitializationConfig {
llm: {
source: string;
modelName: string;
baseUrl?: string;
apiKey?: string;
};
embedding: {
source: string;
modelName: string;
baseUrl?: string;
apiKey?: string;
dimension?: number; // 添加embedding维度字段
};
rerank: {
modelName: string;
baseUrl: string;
apiKey?: string;
};
multimodal: {
enabled: boolean;
vlm?: {
modelName: string;
baseUrl: string;
apiKey?: string;
interfaceType?: string; // "ollama" or "openai"
};
cos?: {
secretId: string;
secretKey: string;
region: string;
bucketName: string;
appId: string;
pathPrefix?: string;
};
};
documentSplitting: {
chunkSize: number;
chunkOverlap: number;
separators: string[];
};
}
// 下载任务状态类型
export interface DownloadTask {
id: string;
modelName: string;
status: 'pending' | 'downloading' | 'completed' | 'failed';
progress: number;
message: string;
startTime: string;
endTime?: string;
}
// 系统初始化状态检查
export function checkInitializationStatus(): Promise<{ initialized: boolean }> {
return new Promise((resolve, reject) => {
get('/api/v1/initialization/status')
.then((response: any) => {
resolve(response.data || { initialized: false });
})
.catch((error: any) => {
console.warn('检查初始化状态失败,假设需要初始化:', error);
resolve({ initialized: false });
});
});
}
// 执行系统初始化
export function initializeSystem(config: InitializationConfig): Promise<any> {
return new Promise((resolve, reject) => {
console.log('开始系统初始化...', config);
post('/api/v1/initialization/initialize', config)
.then((response: any) => {
console.log('系统初始化完成', response);
// 设置本地初始化状态标记
localStorage.setItem('system_initialized', 'true');
resolve(response);
})
.catch((error: any) => {
console.error('系统初始化失败:', error);
reject(error);
});
});
}
// 检查Ollama服务状态
export function checkOllamaStatus(): Promise<{ available: boolean; version?: string; error?: string }> {
return new Promise((resolve, reject) => {
get('/api/v1/initialization/ollama/status')
.then((response: any) => {
resolve(response.data || { available: false });
})
.catch((error: any) => {
console.error('检查Ollama状态失败:', error);
resolve({ available: false, error: error.message || '检查失败' });
});
});
}
// 检查Ollama模型状态
export function checkOllamaModels(models: string[]): Promise<{ models: Record<string, boolean> }> {
return new Promise((resolve, reject) => {
post('/api/v1/initialization/ollama/models/check', { models })
.then((response: any) => {
resolve(response.data || { models: {} });
})
.catch((error: any) => {
console.error('检查Ollama模型状态失败:', error);
reject(error);
});
});
}
// 启动Ollama模型下载异步
export function downloadOllamaModel(modelName: string): Promise<{ taskId: string; modelName: string; status: string; progress: number }> {
return new Promise((resolve, reject) => {
post('/api/v1/initialization/ollama/models/download', { modelName })
.then((response: any) => {
resolve(response.data || { taskId: '', modelName, status: 'failed', progress: 0 });
})
.catch((error: any) => {
console.error('启动Ollama模型下载失败:', error);
reject(error);
});
});
}
// 查询下载进度
export function getDownloadProgress(taskId: string): Promise<DownloadTask> {
return new Promise((resolve, reject) => {
get(`/api/v1/initialization/ollama/download/progress/${taskId}`)
.then((response: any) => {
resolve(response.data);
})
.catch((error: any) => {
console.error('查询下载进度失败:', error);
reject(error);
});
});
}
// 获取所有下载任务
export function listDownloadTasks(): Promise<DownloadTask[]> {
return new Promise((resolve, reject) => {
get('/api/v1/initialization/ollama/download/tasks')
.then((response: any) => {
resolve(response.data || []);
})
.catch((error: any) => {
console.error('获取下载任务列表失败:', error);
reject(error);
});
});
}
// 获取当前系统配置
export function getCurrentConfig(): Promise<InitializationConfig & { hasFiles: boolean }> {
return new Promise((resolve, reject) => {
get('/api/v1/initialization/config')
.then((response: any) => {
resolve(response.data || {});
})
.catch((error: any) => {
console.error('获取当前配置失败:', error);
reject(error);
});
});
}
// 检查远程API模型
export function checkRemoteModel(modelConfig: {
modelName: string;
baseUrl: string;
apiKey?: string;
}): Promise<{
available: boolean;
message?: string;
}> {
return new Promise((resolve, reject) => {
post('/api/v1/initialization/remote/check', modelConfig)
.then((response: any) => {
resolve(response.data || {});
})
.catch((error: any) => {
console.error('检查远程模型失败:', error);
reject(error);
});
});
}
export function checkRerankModel(modelConfig: {
modelName: string;
baseUrl: string;
apiKey?: string;
}): Promise<{
available: boolean;
message?: string;
}> {
return new Promise((resolve, reject) => {
post('/api/v1/initialization/rerank/check', modelConfig)
.then((response: any) => {
resolve(response.data || {});
})
.catch((error: any) => {
console.error('检查Rerank模型失败:', error);
reject(error);
});
});
}
export function testMultimodalFunction(testData: {
image: File;
vlm_model: string;
vlm_base_url: string;
vlm_api_key?: string;
vlm_interface_type?: string;
cos_secret_id: string;
cos_secret_key: string;
cos_region: string;
cos_bucket_name: string;
cos_app_id: string;
cos_path_prefix?: string;
chunk_size: number;
chunk_overlap: number;
separators: string[];
}): Promise<{
success: boolean;
caption?: string;
ocr?: string;
processing_time?: number;
message?: string;
}> {
return new Promise((resolve, reject) => {
const formData = new FormData();
formData.append('image', testData.image);
formData.append('vlm_model', testData.vlm_model);
formData.append('vlm_base_url', testData.vlm_base_url);
if (testData.vlm_api_key) {
formData.append('vlm_api_key', testData.vlm_api_key);
}
if (testData.vlm_interface_type) {
formData.append('vlm_interface_type', testData.vlm_interface_type);
}
formData.append('cos_secret_id', testData.cos_secret_id);
formData.append('cos_secret_key', testData.cos_secret_key);
formData.append('cos_region', testData.cos_region);
formData.append('cos_bucket_name', testData.cos_bucket_name);
formData.append('cos_app_id', testData.cos_app_id);
if (testData.cos_path_prefix) {
formData.append('cos_path_prefix', testData.cos_path_prefix);
}
formData.append('chunk_size', testData.chunk_size.toString());
formData.append('chunk_overlap', testData.chunk_overlap.toString());
formData.append('separators', JSON.stringify(testData.separators));
// 使用原生fetch因为需要发送FormData
fetch('/api/v1/initialization/multimodal/test', {
method: 'POST',
body: formData
})
.then(response => response.json())
.then((data: any) => {
if (data.success) {
resolve(data.data || {});
} else {
resolve({ success: false, message: data.message || '测试失败' });
}
})
.catch((error: any) => {
console.error('多模态测试失败:', error);
reject(error);
});
});
}

View File

@@ -173,7 +173,12 @@ const getIcon = (path) => {
getIcon(route.name)
const gotopage = (path) => {
pathPrefix.value = path;
// 如果是系统设置,跳转到初始化配置页面
if (path === 'settings') {
router.push('/initialization');
} else {
router.push(`/platform/${path}`);
}
getIcon(path)
}

View File

@@ -23,7 +23,7 @@ export default function () {
});
const getKnowled = (query = { page: 1, page_size: 35 }) => {
getKnowledgeBase(query)
.then((result: object) => {
.then((result: any) => {
let { data, total: totalResult } = result;
let cardList_ = data.map((item) => {
item["file_name"] = item.file_name.substring(
@@ -50,7 +50,7 @@ export default function () {
cardList.value[index].isMore = false;
moreIndex.value = -1;
delKnowledgeDetails(item.id)
.then((result) => {
.then((result: any) => {
if (result.success) {
MessagePlugin.info("知识删除成功!");
getKnowled();
@@ -76,17 +76,29 @@ export default function () {
return;
}
uploadKnowledgeBase({ file })
.then((result) => {
.then((result: any) => {
if (result.success) {
MessagePlugin.info("上传成功!");
getKnowled();
} else {
MessagePlugin.error("上传失败!");
// 检查错误码,如果是重复文件则显示特定提示
if (result.code === 'duplicate_file') {
MessagePlugin.error("文件已存在");
} else {
MessagePlugin.error(result.message || (result.error && result.error.message) || "上传失败!");
}
}
uploadInput.value.value = "";
})
.catch((err) => {
.catch((err: any) => {
// 检查错误响应中的错误码
if (err.code === 'duplicate_file') {
MessagePlugin.error("文件已存在");
} else if (err.message) {
MessagePlugin.error(err.message);
} else {
MessagePlugin.error("上传失败!");
}
uploadInput.value.value = "";
});
} else {
@@ -101,7 +113,7 @@ export default function () {
id: "",
});
getKnowledgeDetails(item.id)
.then((result) => {
.then((result: any) => {
if (result.success && result.data) {
let { data } = result;
Object.assign(details, {
@@ -116,7 +128,7 @@ export default function () {
};
const getfDetails = (id, page) => {
getKnowledgeDetailsCon(id, page)
.then((result) => {
.then((result: any) => {
if (result.success && result.data) {
let { data, total: totalResult } = result;
if (page == 1) {

View File

@@ -1,4 +1,5 @@
import { createRouter, createWebHistory } from 'vue-router'
import { checkInitializationStatus } from '@/api/initialization'
const router = createRouter({
history: createWebHistory(import.meta.env.BASE_URL),
@@ -7,40 +8,82 @@ const router = createRouter({
path: "/",
redirect: "/platform",
},
{
path: "/initialization",
name: "initialization",
component: () => import("../views/initialization/InitializationConfig.vue"),
meta: { requiresInit: false } // 初始化页面不需要检查初始化状态
},
{
path: "/knowledgeBase",
name: "home",
component: () => import("../views/knowledge/KnowledgeBase.vue"),
meta: { requiresInit: true }
},
{
path: "/platform",
name: "Platform",
redirect: "/platform/knowledgeBase",
component: () => import("../views/platform/index.vue"),
meta: { requiresInit: true },
children: [
{
path: "knowledgeBase",
name: "knowledgeBase",
component: () => import("../views/knowledge/KnowledgeBase.vue"),
meta: { requiresInit: true }
},
{
path: "creatChat",
name: "creatChat",
component: () => import("../views/creatChat/creatChat.vue"),
meta: { requiresInit: true }
},
{
path: "chat/:chatid",
name: "chat",
component: () => import("../views/chat/index.vue"),
meta: { requiresInit: true }
},
{
path: "settings",
name: "settings",
component: () => import("../views/settings/Settings.vue"),
meta: { requiresInit: true }
},
],
},
],
});
// 路由守卫:检查系统初始化状态
router.beforeEach(async (to, from, next) => {
// 如果访问的是初始化页面,直接放行
if (to.meta.requiresInit === false) {
next();
return;
}
1
try {
// 检查系统是否已初始化
const { initialized } = await checkInitializationStatus();
if (initialized) {
// 系统已初始化,记录到本地存储并正常跳转
localStorage.setItem('system_initialized', 'true');
next();
} else {
// 系统未初始化,跳转到初始化页面
console.log('系统未初始化,跳转到初始化页面');
next('/initialization');
}
} catch (error) {
console.error('检查初始化状态失败:', error);
// 如果检查失败,默认认为需要初始化
next('/initialization');
}
});
export default router

File diff suppressed because it is too large Load Diff

2
go.mod
View File

@@ -13,7 +13,7 @@ require (
github.com/google/uuid v1.6.0
github.com/hibiken/asynq v0.25.1
github.com/minio/minio-go/v7 v7.0.90
github.com/ollama/ollama v0.9.6
github.com/ollama/ollama v0.11.4
github.com/panjf2000/ants/v2 v2.11.2
github.com/parquet-go/parquet-go v0.25.0
github.com/pgvector/pgvector-go v0.3.0

4
go.sum
View File

@@ -141,8 +141,8 @@ github.com/mozillazg/go-httpheader v0.2.1 h1:geV7TrjbL8KXSyvghnFm+NyTux/hxwueTSr
github.com/mozillazg/go-httpheader v0.2.1/go.mod h1:jJ8xECTlalr6ValeXYdOF8fFUISeBAdw6E61aqQma60=
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
github.com/ollama/ollama v0.9.6 h1:HZNJmB52pMt6zLkGkkheBuXBXM5478eiSAj7GR75AMc=
github.com/ollama/ollama v0.9.6/go.mod h1:zLwx3iZ3AI4Rc/egsrx3u1w4RU2MHQ/Ylxse48jvyt4=
github.com/ollama/ollama v0.11.4 h1:6xLYLEPTKtw6N20qQecyEL/rrBktPO4o5U05cnvkSmI=
github.com/ollama/ollama v0.11.4/go.mod h1:9+1//yWPsDE2u+l1a5mpaKrYw4VdnSsRU3ioq5BvMms=
github.com/panjf2000/ants/v2 v2.11.2 h1:AVGpMSePxUNpcLaBO34xuIgM1ZdKOiGnpxLXixLi5Jo=
github.com/panjf2000/ants/v2 v2.11.2/go.mod h1:8u92CYMUc6gyvTIw8Ru7Mt7+/ESnJahz5EVtqfrilek=
github.com/parquet-go/parquet-go v0.25.0 h1:GwKy11MuF+al/lV6nUsFw8w8HCiPOSAx1/y8yFxjH5c=

View File

@@ -19,6 +19,7 @@ import (
"github.com/Tencent/WeKnora/internal/application/service/retriever"
"github.com/Tencent/WeKnora/internal/config"
werrors "github.com/Tencent/WeKnora/internal/errors"
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/models/chat"
"github.com/Tencent/WeKnora/internal/models/utils"
@@ -105,6 +106,28 @@ func (s *knowledgeService) CreateKnowledgeFromFile(ctx context.Context,
return nil, err
}
// 检查多模态配置完整性 - 只在图片文件时校验
// 检查是否为图片文件
if !IsImageType(getFileType(file.Filename)) {
logger.Info(ctx, "Non-image file with multimodal enabled, skipping COS/VLM validation")
} else {
// 检查COS配置
if kb.COSConfig.SecretID == "" || kb.COSConfig.SecretKey == "" ||
kb.COSConfig.Region == "" || kb.COSConfig.BucketName == "" ||
kb.COSConfig.AppID == "" {
logger.Error(ctx, "COS configuration incomplete for image multimodal processing")
return nil, werrors.NewBadRequestError("上传图片文件需要完整的COS配置信息")
}
// 检查VLM配置
if kb.VLMConfig.ModelName == "" || kb.VLMConfig.BaseURL == "" {
logger.Error(ctx, "VLM configuration incomplete for image multimodal processing")
return nil, werrors.NewBadRequestError("上传图片文件需要完整的VLM配置信息")
}
logger.Info(ctx, "Image multimodal configuration validation passed")
}
// Validate file type
logger.Infof(ctx, "Checking file type: %s", file.Filename)
if !isValidFileType(file.Filename) {
@@ -647,6 +670,20 @@ func (s *knowledgeService) processDocument(ctx context.Context,
ChunkOverlap: int32(kb.ChunkingConfig.ChunkOverlap),
Separators: kb.ChunkingConfig.Separators,
EnableMultimodal: enableMultimodel,
CosConfig: &proto.COSConfig{
SecretId: kb.COSConfig.SecretID,
SecretKey: kb.COSConfig.SecretKey,
Region: kb.COSConfig.Region,
BucketName: kb.COSConfig.BucketName,
AppId: kb.COSConfig.AppID,
PathPrefix: kb.COSConfig.PathPrefix,
},
VlmConfig: &proto.VLMConfig{
ModelName: kb.VLMConfig.ModelName,
BaseUrl: kb.VLMConfig.BaseURL,
ApiKey: kb.VLMConfig.APIKey,
InterfaceType: kb.VLMConfig.InterfaceType,
},
},
RequestId: ctx.Value(types.RequestIDContextKey).(string),
})
@@ -687,6 +724,20 @@ func (s *knowledgeService) processDocumentFromURL(ctx context.Context,
ChunkOverlap: int32(kb.ChunkingConfig.ChunkOverlap),
Separators: kb.ChunkingConfig.Separators,
EnableMultimodal: enableMultimodel,
CosConfig: &proto.COSConfig{
SecretId: kb.COSConfig.SecretID,
SecretKey: kb.COSConfig.SecretKey,
Region: kb.COSConfig.Region,
BucketName: kb.COSConfig.BucketName,
AppId: kb.COSConfig.AppID,
PathPrefix: kb.COSConfig.PathPrefix,
},
VlmConfig: &proto.VLMConfig{
ModelName: kb.VLMConfig.ModelName,
BaseUrl: kb.VLMConfig.BaseURL,
ApiKey: kb.VLMConfig.APIKey,
InterfaceType: kb.VLMConfig.InterfaceType,
},
},
RequestId: ctx.Value(types.RequestIDContextKey).(string),
})
@@ -1145,7 +1196,7 @@ func (s *knowledgeService) getSummary(ctx context.Context,
chunkContents = chunkContents + imageAnnotations
}
if len(chunkContents) < 30 {
if len(chunkContents) < 300 {
return chunkContents, nil
}

View File

@@ -47,7 +47,9 @@ func (s *knowledgeBaseService) CreateKnowledgeBase(ctx context.Context,
kb *types.KnowledgeBase,
) (*types.KnowledgeBase, error) {
// Generate UUID and set creation timestamps
if kb.ID == "" {
kb.ID = uuid.New().String()
}
kb.CreatedAt = time.Now()
kb.TenantID = ctx.Value(types.TenantIDContextKey).(uint)
kb.UpdatedAt = time.Now()
@@ -245,6 +247,9 @@ func (s *knowledgeBaseService) CopyKnowledgeBase(ctx context.Context,
ImageProcessingConfig: sourceKB.ImageProcessingConfig,
EmbeddingModelID: sourceKB.EmbeddingModelID,
SummaryModelID: sourceKB.SummaryModelID,
RerankModelID: sourceKB.RerankModelID,
VLMModelID: sourceKB.VLMModelID,
COSConfig: sourceKB.COSConfig,
}
if err := s.repo.CreateKnowledgeBase(ctx, targetKB); err != nil {
return nil, nil, err

View File

@@ -144,6 +144,7 @@ func (s *TestDataService) initKnowledgeBase(ctx context.Context) error {
},
EmbeddingModelID: s.EmbedModel.GetModelID(),
SummaryModelID: s.LLMModel.GetModelID(),
RerankModelID: s.RerankModel.GetModelID(),
}
// 初始化测试知识库

View File

@@ -115,6 +115,7 @@ func BuildContainer(container *dig.Container) *dig.Container {
must(container.Provide(handler.NewTestDataHandler))
must(container.Provide(handler.NewModelHandler))
must(container.Provide(handler.NewEvaluationHandler))
must(container.Provide(handler.NewInitializationHandler))
// Router configuration
must(container.Provide(router.NewRouter))

File diff suppressed because it is too large Load Diff

View File

@@ -135,6 +135,10 @@ func (h *KnowledgeHandler) CreateKnowledgeFromFile(c *gin.Context) {
if h.handleDuplicateKnowledgeError(c, err, knowledge, "file") {
return
}
if appErr, ok := errors.IsAppError(err); ok {
c.Error(appErr)
return
}
logger.ErrorWithFields(ctx, err, nil)
c.Error(errors.NewInternalServerError(err.Error()))
return

View File

@@ -23,6 +23,7 @@ type SessionHandler struct {
streamManager interfaces.StreamManager // Manager for handling streaming responses
config *config.Config // Application configuration
testDataService *service.TestDataService // Service for test data (models, etc.)
knowledgebaseService interfaces.KnowledgeBaseService
}
// NewSessionHandler creates a new instance of SessionHandler with all necessary dependencies
@@ -32,6 +33,7 @@ func NewSessionHandler(
streamManager interfaces.StreamManager,
config *config.Config,
testDataService *service.TestDataService,
knowledgebaseService interfaces.KnowledgeBaseService,
) *SessionHandler {
return &SessionHandler{
sessionService: sessionService,
@@ -39,6 +41,7 @@ func NewSessionHandler(
streamManager: streamManager,
config: config,
testDataService: testDataService,
knowledgebaseService: knowledgebaseService,
}
}
@@ -191,27 +194,24 @@ func (h *SessionHandler) CreateSession(c *gin.Context) {
logger.Debug(ctx, "Using default session strategy")
}
// Get model IDs from test data service if not provided
if createdSession.SummaryModelID == "" {
if h.testDataService != nil {
createdSession.SummaryModelID = h.testDataService.LLMModel.GetModelID()
logger.Debug(ctx, "Using summary model ID from test data service")
} else {
logger.Error(ctx, "Summary model ID is empty and cannot get default value")
c.Error(errors.NewBadRequestError("Summary model ID cannot be empty"))
kb, err := h.knowledgebaseService.GetKnowledgeBaseByID(ctx, request.KnowledgeBaseID)
if err != nil {
logger.Error(ctx, "Failed to get knowledge base", err)
c.Error(errors.NewInternalServerError(err.Error()))
return
}
// Get model IDs from knowledge base if not provided
if createdSession.SummaryModelID == "" {
createdSession.SummaryModelID = kb.SummaryModelID
}
if createdSession.RerankModelID == "" {
if h.testDataService != nil && h.testDataService.RerankModel != nil {
createdSession.RerankModelID = h.testDataService.RerankModel.GetModelID()
logger.Debug(ctx, "Using rerank model ID from test data service")
}
createdSession.RerankModelID = kb.RerankModelID
}
// Call service to create session
logger.Infof(ctx, "Calling session service to create session")
createdSession, err := h.sessionService.CreateSession(ctx, createdSession)
createdSession, err = h.sessionService.CreateSession(ctx, createdSession)
if err != nil {
logger.ErrorWithFields(ctx, err, nil)
c.Error(errors.NewInternalServerError(err.Error()))

View File

@@ -4,7 +4,6 @@ import (
"errors"
"net/http"
"os"
"strconv"
"github.com/gin-gonic/gin"
@@ -65,26 +64,19 @@ func (h *TestDataHandler) GetTestData(c *gin.Context) {
return
}
tenantID := os.Getenv("INIT_TEST_TENANT_ID")
logger.Debugf(ctx, "Test tenant ID environment variable: %s", tenantID)
tenantIDUint, err := strconv.ParseUint(tenantID, 10, 64)
if err != nil {
logger.Errorf(ctx, "Failed to parse tenant ID: %s", tenantID)
c.Error(err)
return
}
tenantID := uint(types.InitDefaultTenantID)
logger.Debugf(ctx, "Test tenant ID environment variable: %d", tenantID)
// Retrieve the test tenant data
logger.Infof(ctx, "Retrieving test tenant, ID: %d", tenantIDUint)
tenant, err := h.tenantService.GetTenantByID(ctx, uint(tenantIDUint))
logger.Infof(ctx, "Retrieving test tenant, ID: %d", tenantID)
tenant, err := h.tenantService.GetTenantByID(ctx, tenantID)
if err != nil {
logger.ErrorWithFields(ctx, err, nil)
c.Error(err)
return
}
knowledgeBaseID := os.Getenv("INIT_TEST_KNOWLEDGE_BASE_ID")
knowledgeBaseID := types.InitDefaultKnowledgeBaseID
logger.Debugf(ctx, "Test knowledge base ID environment variable: %s", knowledgeBaseID)
// Retrieve the test knowledge base data

View File

@@ -6,6 +6,7 @@ import (
"log"
"net/http"
"slices"
"strings"
"github.com/Tencent/WeKnora/internal/config"
"github.com/Tencent/WeKnora/internal/types"
@@ -17,12 +18,18 @@ import (
var noAuthAPI = map[string][]string{
"/api/v1/test-data": {"GET"},
"/api/v1/tenants": {"POST"},
"/api/v1/initialization/*": {"GET", "POST"},
}
// 检查请求是否在无需认证的API列表中
func isNoAuthAPI(path string, method string) bool {
for api, methods := range noAuthAPI {
if api == path && slices.Contains(methods, method) {
// 如果以*结尾,按照前缀匹配,否则按照全路径匹配
if strings.HasSuffix(api, "*") {
if strings.HasPrefix(path, strings.TrimSuffix(api, "*")) && slices.Contains(methods, method) {
return true
}
} else if path == api && slices.Contains(methods, method) {
return true
}
}

View File

@@ -63,7 +63,9 @@ func (c *OllamaChat) buildChatRequest(messages []Message, opts *ChatOptions, isS
chatReq.Options["num_predict"] = opts.MaxTokens
}
if opts.Thinking != nil {
chatReq.Options["think"] = *opts.Thinking
chatReq.Think = &ollamaapi.ThinkValue{
Value: *opts.Thinking,
}
}
}

View File

@@ -309,3 +309,8 @@ func (s *OllamaService) Generate(ctx context.Context, req *api.GenerateRequest,
// Use official client Generate method
return s.client.Generate(ctx, req, fn)
}
// GetClient returns the underlying ollama client for advanced operations
func (s *OllamaService) GetClient() *api.Client {
return s.client
}

View File

@@ -28,6 +28,7 @@ type RouterParams struct {
TestDataHandler *handler.TestDataHandler
ModelHandler *handler.ModelHandler
EvaluationHandler *handler.EvaluationHandler
InitializationHandler *handler.InitializationHandler
}
// NewRouter 创建新的路由
@@ -62,6 +63,23 @@ func NewRouter(params RouterParams) *gin.Engine {
// 测试数据接口(不需要认证)
r.GET("/api/v1/test-data", params.TestDataHandler.GetTestData)
// 初始化接口(不需要认证)
r.GET("/api/v1/initialization/status", params.InitializationHandler.CheckStatus)
r.GET("/api/v1/initialization/config", params.InitializationHandler.GetCurrentConfig)
r.POST("/api/v1/initialization/initialize", params.InitializationHandler.Initialize)
// Ollama相关接口不需要认证
r.GET("/api/v1/initialization/ollama/status", params.InitializationHandler.CheckOllamaStatus)
r.POST("/api/v1/initialization/ollama/models/check", params.InitializationHandler.CheckOllamaModels)
r.POST("/api/v1/initialization/ollama/models/download", params.InitializationHandler.DownloadOllamaModel)
r.GET("/api/v1/initialization/ollama/download/progress/:taskId", params.InitializationHandler.GetDownloadProgress)
r.GET("/api/v1/initialization/ollama/download/tasks", params.InitializationHandler.ListDownloadTasks)
// 远程API相关接口不需要认证
r.POST("/api/v1/initialization/remote/check", params.InitializationHandler.CheckRemoteModel)
r.POST("/api/v1/initialization/rerank/check", params.InitializationHandler.CheckRerankModel)
r.POST("/api/v1/initialization/multimodal/test", params.InitializationHandler.TestMultimodalFunction)
// 需要认证的API路由
v1 := r.Group("/api/v1")
{

View File

@@ -8,6 +8,10 @@ import (
"gorm.io/gorm"
)
const (
InitDefaultKnowledgeBaseID = "kb-00000001"
)
// KnowledgeBase represents a knowledge base
type KnowledgeBase struct {
// Unique identifier of the knowledge base
@@ -26,6 +30,14 @@ type KnowledgeBase struct {
EmbeddingModelID string `yaml:"embedding_model_id" json:"embedding_model_id"`
// Summary model ID
SummaryModelID string `yaml:"summary_model_id" json:"summary_model_id"`
// Rerank model ID
RerankModelID string `yaml:"rerank_model_id" json:"rerank_model_id"`
// VLM model ID
VLMModelID string `yaml:"vlm_model_id" json:"vlm_model_id"`
// VLM config
VLMConfig VLMConfig `yaml:"vlm_config" json:"vlm_config" gorm:"type:json"`
// COS config
COSConfig COSConfig `yaml:"cos_config" json:"cos_config" gorm:"type:json"`
// Creation time of the knowledge base
CreatedAt time.Time `yaml:"created_at" json:"created_at"`
// Last updated time of the knowledge base
@@ -54,6 +66,37 @@ type ChunkingConfig struct {
EnableMultimodal bool `yaml:"enable_multimodal" json:"enable_multimodal"`
}
// COSConfig represents the COS configuration
type COSConfig struct {
// Secret ID
SecretID string `yaml:"secret_id" json:"secret_id"`
// Secret Key
SecretKey string `yaml:"secret_key" json:"secret_key"`
// Region
Region string `yaml:"region" json:"region"`
// Bucket Name
BucketName string `yaml:"bucket_name" json:"bucket_name"`
// App ID
AppID string `yaml:"app_id" json:"app_id"`
// Path Prefix
PathPrefix string `yaml:"path_prefix" json:"path_prefix"`
}
func (c *COSConfig) Value() (driver.Value, error) {
return json.Marshal(c)
}
func (c *COSConfig) Scan(value interface{}) error {
if value == nil {
return nil
}
b, ok := value.([]byte)
if !ok {
return nil
}
return json.Unmarshal(b, c)
}
// ImageProcessingConfig represents the image processing configuration
type ImageProcessingConfig struct {
// Model ID
@@ -93,3 +136,32 @@ func (c *ImageProcessingConfig) Scan(value interface{}) error {
}
return json.Unmarshal(b, c)
}
// VLMConfig represents the VLM configuration
type VLMConfig struct {
// Model Name
ModelName string `yaml:"model_name" json:"model_name"`
// Base URL
BaseURL string `yaml:"base_url" json:"base_url"`
// API Key
APIKey string `yaml:"api_key" json:"api_key"`
// Interface Type: "ollama" or "openai"
InterfaceType string `yaml:"interface_type" json:"interface_type"`
}
// Value implements the driver.Valuer interface, used to convert VLMConfig to database value
func (c VLMConfig) Value() (driver.Value, error) {
return json.Marshal(c)
}
// Scan implements the sql.Scanner interface, used to convert database value to VLMConfig
func (c *VLMConfig) Scan(value interface{}) error {
if value == nil {
return nil
}
b, ok := value.([]byte)
if !ok {
return nil
}
return json.Unmarshal(b, c)
}

View File

@@ -8,6 +8,10 @@ import (
"gorm.io/gorm"
)
const (
InitDefaultTenantID uint = 1
)
// Tenant represents the tenant
type Tenant struct {
// ID

View File

@@ -19,7 +19,7 @@ CREATE TABLE tenants (
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
deleted_at TIMESTAMP NULL DEFAULT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 AUTO_INCREMENT=10000;
CREATE TABLE models (
id VARCHAR(64) PRIMARY KEY,
@@ -47,6 +47,10 @@ CREATE TABLE knowledge_bases (
image_processing_config JSON NOT NULL,
embedding_model_id VARCHAR(64) NOT NULL,
summary_model_id VARCHAR(64) NOT NULL,
rerank_model_id VARCHAR(64) NOT NULL,
vlm_model_id VARCHAR(64) NOT NULL,
cos_config JSON NOT NULL,
vlm_config JSON NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
deleted_at TIMESTAMP NULL DEFAULT NULL

View File

@@ -21,6 +21,9 @@ CREATE TABLE IF NOT EXISTS tenants (
deleted_at TIMESTAMP WITH TIME ZONE
);
-- Set the starting value for tenants id sequence
ALTER SEQUENCE tenants_id_seq RESTART WITH 10000;
-- Add indexes
CREATE INDEX IF NOT EXISTS idx_tenants_api_key ON tenants(api_key);
CREATE INDEX IF NOT EXISTS idx_tenants_status ON tenants(status);
@@ -55,6 +58,10 @@ CREATE TABLE IF NOT EXISTS knowledge_bases (
image_processing_config JSONB NOT NULL DEFAULT '{"enable_multimodal": false, "model_id": ""}',
embedding_model_id VARCHAR(64) NOT NULL,
summary_model_id VARCHAR(64) NOT NULL,
rerank_model_id VARCHAR(64) NOT NULL,
vlm_model_id VARCHAR(64) NOT NULL,
cos_config JSONB NOT NULL DEFAULT '{}',
vlm_config JSONB NOT NULL DEFAULT '{}',
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
deleted_at TIMESTAMP WITH TIME ZONE

View File

@@ -227,12 +227,24 @@ start_docker() {
source "$PROJECT_ROOT/.env"
storage_type=${STORAGE_TYPE:-local}
# 检测当前系统平台
log_info "检测系统平台信息..."
if [ "$(uname -m)" = "x86_64" ]; then
export PLATFORM="linux/amd64"
elif [ "$(uname -m)" = "aarch64" ] || [ "$(uname -m)" = "arm64" ]; then
export PLATFORM="linux/arm64"
else
log_warning "未识别的平台类型:$(uname -m),将使用默认平台 linux/amd64"
export PLATFORM="linux/amd64"
fi
log_info "当前平台:$PLATFORM"
# 进入项目根目录再执行docker-compose命令
cd "$PROJECT_ROOT"
# 启动基本服务
log_info "启动核心服务容器..."
docker-compose up --build -d
PLATFORM=$PLATFORM docker-compose up --build -d
if [ $? -ne 0 ]; then
log_error "Docker容器启动失败"
return 1

View File

@@ -13,7 +13,6 @@ urllib3
markdownify
mistletoe
goose3[all]
# paddlepaddle==3.0.0 (普通CPU版本已被注释)
paddleocr==3.0.0
markdown
pypdf
@@ -21,6 +20,10 @@ cos-python-sdk-v5
textract
antiword
openai
ollama
--extra-index-url https://www.paddlepaddle.org.cn/packages/stable/cu126/
paddlepaddle-gpu==3.0.0
--extra-index-url https://www.paddlepaddle.org.cn/packages/stable/cpu/
paddlepaddle==3.0.0
# --extra-index-url https://www.paddlepaddle.org.cn/packages/stable/cu126/
# paddlepaddle-gpu==3.0.0

View File

@@ -7,7 +7,7 @@ PYTHON_OUT="src/proto"
GO_OUT="src/proto"
# 生成Python代码
python -m grpc_tools.protoc -I${PROTO_DIR} \
python3 -m grpc_tools.protoc -I${PROTO_DIR} \
--python_out=${PYTHON_OUT} \
--grpc_python_out=${PYTHON_OUT} \
${PROTO_DIR}/docreader.proto

View File

@@ -3,7 +3,7 @@ import re
import os
import uuid
import asyncio
from typing import List, Dict, Any, Optional, Tuple
from typing import List, Dict, Any, Optional, Tuple, Union
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
import logging
@@ -13,6 +13,7 @@ import traceback
import numpy as np
import time
from .ocr_engine import OCREngine
from .image_utils import image_to_base64
# Add parent directory to Python path for src imports
current_dir = os.path.dirname(os.path.abspath(__file__))
@@ -100,6 +101,7 @@ class BaseParser(ABC):
max_image_size: int = 1920, # Maximum image size
max_concurrent_tasks: int = 5, # Max concurrent tasks
max_chunks: int = 1000, # Max number of returned chunks
chunking_config: object = None, # Chunking configuration object
):
"""Initialize parser
@@ -127,6 +129,7 @@ class BaseParser(ABC):
self.max_image_size = max_image_size
self.max_concurrent_tasks = max_concurrent_tasks
self.max_chunks = max_chunks
self.chunking_config = chunking_config
logger.info(
f"Initializing {self.__class__.__name__} for file: {file_name}, type: {self.file_type}"
@@ -141,7 +144,12 @@ class BaseParser(ABC):
# Only initialize Caption service if multimodal is enabled
if self.enable_multimodal:
try:
self.caption_parser = Caption()
# Get VLM config from chunking config if available
vlm_config = None
if self.chunking_config and hasattr(self.chunking_config, 'vlm_config'):
vlm_config = self.chunking_config.vlm_config
self.caption_parser = Caption(vlm_config)
except Exception as e:
logger.warning(f"Failed to initialize Caption service: {str(e)}")
self.caption_parser = None
@@ -200,55 +208,51 @@ class BaseParser(ABC):
resized_image.close()
def _resize_image_if_needed(self, image):
"""Resize image to avoid processing large images
"""Resize image if it exceeds maximum size limit
Args:
image: Image object (PIL.Image or numpy array)
Returns:
Resized image
Resized image object
"""
try:
# Check if it's a PIL.Image object
if hasattr(image, 'width') and hasattr(image, 'height'):
width, height = image.width, image.height
# Check if resizing is needed
# If it's a PIL Image
if hasattr(image, 'size'):
width, height = image.size
if width > self.max_image_size or height > self.max_image_size:
logger.info(f"Resizing image, original size: {width}x{height}")
# Calculate scaling factor
logger.info(f"Resizing PIL image, original size: {width}x{height}")
scale = min(self.max_image_size / width, self.max_image_size / height)
new_width = int(width * scale)
new_height = int(height * scale)
# Resize
resized_image = image.resize((new_width, new_height))
logger.info(f"Image resized to: {new_width}x{new_height}")
logger.info(f"Resized to: {new_width}x{new_height}")
return resized_image
else:
logger.info(f"PIL image size {width}x{height} is within limits, no resizing needed")
return image
# If it's a numpy array
elif hasattr(image, 'shape') and len(image.shape) == 3:
height, width = image.shape[0], image.shape[1]
elif hasattr(image, 'shape'):
height, width = image.shape[:2]
if width > self.max_image_size or height > self.max_image_size:
logger.info(f"Resizing numpy image, original size: {width}x{height}")
# Use PIL for resizing
from PIL import Image
pil_image = Image.fromarray(image)
try:
scale = min(self.max_image_size / width, self.max_image_size / height)
new_width = int(width * scale)
new_height = int(height * scale)
resized_image = pil_image.resize((new_width, new_height))
# Convert back to numpy array
import numpy as np
resized_array = np.array(resized_image)
logger.info(f"numpy image resized to: {new_width}x{new_height}")
return resized_array
finally:
# Ensure PIL image is closed
pil_image.close()
if 'resized_image' in locals() and hasattr(resized_image, 'close'):
resized_image.close()
# Use PIL for resizing numpy arrays
pil_image = Image.fromarray(image)
resized_pil = pil_image.resize((new_width, new_height))
resized_image = np.array(resized_pil)
logger.info(f"Resized to: {new_width}x{new_height}")
return resized_image
else:
logger.info(f"Numpy image size {width}x{height} is within limits, no resizing needed")
return image
else:
logger.warning(f"Unknown image type: {type(image)}, cannot resize")
return image
except Exception as e:
logger.warning(f"Failed to resize image: {str(e)}, using original image")
logger.error(f"Error resizing image: {str(e)}")
return image
def process_image(self, image, image_url=None):
@@ -273,15 +277,21 @@ class BaseParser(ABC):
ocr_text = self.perform_ocr(image)
caption = ""
if self.caption_parser and image_url:
if self.caption_parser:
logger.info(f"OCR successfully extracted {len(ocr_text)} characters, continuing to get caption")
caption = self.get_image_caption(image_url)
# Convert image to base64 for caption generation
img_base64 = image_to_base64(image)
if img_base64:
caption = self.get_image_caption(img_base64)
if caption:
logger.info(f"Successfully obtained image caption: {caption}")
else:
logger.warning("Failed to get caption")
else:
logger.info("image_url not provided or Caption service not initialized, skipping caption retrieval")
logger.warning("Failed to convert image to base64")
caption = ""
else:
logger.info("Caption service not initialized, skipping caption retrieval")
# Release image resources
del image
@@ -323,21 +333,27 @@ class BaseParser(ABC):
logger.info(f"OCR successfully extracted {len(ocr_text)} characters, continuing to get caption")
caption = ""
if self.caption_parser and image_url:
if self.caption_parser:
try:
# Convert image to base64 for caption generation
img_base64 = image_to_base64(resized_image)
if img_base64:
# Add timeout to avoid blocking caption retrieval (30 seconds timeout)
caption_task = self.get_image_caption_async(image_url)
image_url, caption = await asyncio.wait_for(caption_task, timeout=30.0)
caption_task = self.get_image_caption_async(img_base64)
image_data, caption = await asyncio.wait_for(caption_task, timeout=30.0)
if caption:
logger.info(f"Successfully obtained image caption: {caption}")
else:
logger.warning("Failed to get caption")
else:
logger.warning("Failed to convert image to base64")
caption = ""
except asyncio.TimeoutError:
logger.warning("Caption retrieval timed out, skipping")
except Exception as e:
logger.error(f"Failed to get caption: {str(e)}")
else:
logger.info("image_url not provided or Caption service not initialized, skipping caption retrieval")
logger.info("Caption service not initialized, skipping caption retrieval")
return ocr_text, caption, image_url
finally:
@@ -473,46 +489,56 @@ class BaseParser(ABC):
logger.info(f"Decoded text length: {len(text)} characters")
return text
def get_image_caption(self, image_url: str) -> str:
def get_image_caption(self, image_data: str) -> str:
"""Get image description
Args:
image_url: Image URL
image_data: Image data (base64 encoded string or URL)
Returns:
Image description
"""
start_time = time.time()
logger.info(
f"Getting caption for image: {image_url[:250]}..."
if len(image_url) > 250
else f"Getting caption for image: {image_url}"
f"Getting caption for image: {image_data[:250]}..."
if len(image_data) > 250
else f"Getting caption for image: {image_data}"
)
caption = self.caption_parser.get_caption(image_url)
caption = self.caption_parser.get_caption(image_data)
if caption:
logger.info(
f"Received caption of length: {len(caption)}, caption: {caption}, image_url: {image_url},"
f"Received caption of length: {len(caption)}, caption: {caption},"
f"cost: {time.time() - start_time} seconds"
)
else:
logger.warning("Failed to get caption for image")
return caption
async def get_image_caption_async(self, image_url: str) -> Tuple[str, str]:
async def get_image_caption_async(self, image_data: str) -> Tuple[str, str]:
"""Asynchronously get image description
Args:
image_url: Image URL
image_data: Image data (base64 encoded string or URL)
Returns:
Tuple[str, str]: Image URL and corresponding description
Tuple[str, str]: Image data and corresponding description
"""
caption = self.get_image_caption(image_url)
return image_url, caption
caption = self.get_image_caption(image_data)
return image_data, caption
def _init_cos_client(self):
def _init_cos_client(self, cos_config=None):
"""Initialize Tencent Cloud COS client"""
try:
# Use provided COS config if available, otherwise fall back to environment variables
if cos_config:
secret_id = cos_config.get("secret_id")
secret_key = cos_config.get("secret_key")
region = cos_config.get("region")
bucket_name = cos_config.get("bucket_name")
appid = cos_config.get("app_id")
prefix = cos_config.get("path_prefix", "")
enable_old_domain = cos_config.get("enable_old_domain", "true").lower() == "true"
else:
# Get COS configuration from environment variables
secret_id = os.getenv("COS_SECRET_ID")
secret_key = os.getenv("COS_SECRET_KEY")
@@ -600,12 +626,13 @@ class BaseParser(ABC):
logger.error(f"Failed to upload file to COS: {str(e)}")
return ""
def upload_bytes(self, content: bytes, file_ext: str = ".png") -> str:
def upload_bytes(self, content: bytes, file_ext: str = ".png", cos_config=None) -> str:
"""Directly upload file content to Tencent Cloud COS
Args:
content: File byte content
file_ext: File extension, default is .png
cos_config: COS configuration dictionary
Returns:
File URL
@@ -614,7 +641,7 @@ class BaseParser(ABC):
f"Uploading bytes content to COS, content size: {len(content)} bytes"
)
try:
client, bucket_name, region, prefix = self._init_cos_client()
client, bucket_name, region, prefix = self._init_cos_client(cos_config)
if not client:
return ""

View File

@@ -1,10 +1,13 @@
import json
import logging
import os
import time
from dataclasses import dataclass, field
from typing import List, Optional, Union
import requests
import ollama
logger = logging.getLogger(__name__)
@@ -168,39 +171,108 @@ class Caption:
Uses an external API to process images and return textual descriptions.
"""
def __init__(self):
"""Initialize the Caption service with configuration from environment variables."""
def __init__(self, vlm_config=None):
"""Initialize the Caption service with configuration from parameters or environment variables."""
logger.info("Initializing Caption service")
self.prompt = """简单凝炼的描述图片的主要内容"""
# Use provided VLM config if available, otherwise fall back to environment variables
if vlm_config:
self.completion_url = vlm_config.get("base_url", "") + "/chat/completions"
self.model = vlm_config.get("model_name", "")
self.api_key = vlm_config.get("api_key", "")
self.interface_type = vlm_config.get("interface_type", "openai").lower()
else:
if os.getenv("VLM_MODEL_BASE_URL") == "" or os.getenv("VLM_MODEL_NAME") == "":
logger.error("VLM_MODEL_BASE_URL or VLM_MODEL_NAME is not set")
return
self.completion_url = os.getenv("VLM_MODEL_BASE_URL") + "/v1/chat/completions"
self.completion_url = os.getenv("VLM_MODEL_BASE_URL") + "/chat/completions"
self.model = os.getenv("VLM_MODEL_NAME")
self.api_key = os.getenv("VLM_MODEL_API_KEY")
self.interface_type = os.getenv("VLM_INTERFACE_TYPE", "openai").lower()
# 验证接口类型
if self.interface_type not in ["ollama", "openai"]:
logger.warning(f"Unknown interface type: {self.interface_type}, defaulting to openai")
self.interface_type = "openai"
logger.info(
f"Service configured with model: {self.model}, endpoint: {self.completion_url}"
f"Service configured with model: {self.model}, endpoint: {self.completion_url}, interface: {self.interface_type}"
)
def _call_caption_api(self, image_url: str) -> Optional[CaptionChatResp]:
def _call_caption_api(self, image_data: str) -> Optional[CaptionChatResp]:
"""
Call the Caption API to generate a description for the given image.
Args:
image_url: URL of the image to be captioned
image_data: URL of the image or base64 encoded image data
Returns:
CaptionChatResp object if successful, None otherwise
"""
logger.info(f"Calling Caption API for image captioning")
logger.info(f"Processing image from URL: {image_url}")
logger.info(f"Processing image data: {image_data[:50] if len(image_data) > 50 else image_data}")
# 根据接口类型选择调用方式
if self.interface_type == "ollama":
return self._call_ollama_api(image_data)
else:
return self._call_openai_api(image_data)
def _call_ollama_api(self, image_base64: str) -> Optional[CaptionChatResp]:
"""Call Ollama API for image captioning using base64 encoded image data."""
host = self.completion_url.replace("/v1/chat/completions", "")
client = ollama.Client(
host=host,
)
try:
logger.info(f"Calling Ollama API with model: {self.model}")
# 调用Ollama API使用images参数传递base64编码的图片
response = client.generate(
model=self.model,
prompt="简单凝炼的描述图片的主要内容",
images=[image_base64], # image_base64是base64编码的图片数据
options={"temperature": 0.1},
stream=False,
)
# 构造响应对象
caption_resp = CaptionChatResp(
id="ollama_response",
created=int(time.time()),
model=self.model,
object="chat.completion",
choices=[
Choice(
message=Message(
role="assistant",
content=response.response
)
)
]
)
logger.info("Successfully received response from Ollama API")
return caption_resp
except Exception as e:
logger.error(f"Error calling Ollama API: {e}")
return None
def _call_openai_api(self, image_base64: str) -> Optional[CaptionChatResp]:
"""Call OpenAI-compatible API for image captioning."""
logger.info(f"Calling OpenAI-compatible API with model: {self.model}")
user_msg = UserMessage(
role="user",
content=[
Content(type="text", text=self.prompt),
Content(
type="image_url", image_url=ImageUrl(url=image_url, detail="auto")
type="image_url", image_url=ImageUrl(url="data:image/png;base64," + image_base64, detail="auto")
),
],
)
@@ -223,7 +295,7 @@ class Caption:
headers["Authorization"] = f"Bearer {self.api_key}"
try:
logger.info(f"Sending request to Caption API with model: {self.model}")
logger.info(f"Sending request to OpenAI-compatible API with model: {self.model}")
response = requests.post(
self.completion_url,
data=json.dumps(gpt_req, default=lambda o: o.__dict__, indent=4),
@@ -232,12 +304,12 @@ class Caption:
)
if response.status_code != 200:
logger.error(
f"Caption API returned non-200 status code: {response.status_code}"
f"OpenAI-compatible API returned non-200 status code: {response.status_code}"
)
response.raise_for_status()
logger.info(
f"Successfully received response from Caption API with status: {response.status_code}"
f"Successfully received response from OpenAI-compatible API with status: {response.status_code}"
)
logger.info(f"Converting response to CaptionChatResp object")
caption_resp = CaptionChatResp.from_json(response.json())
@@ -250,33 +322,30 @@ class Caption:
return caption_resp
except requests.exceptions.Timeout:
logger.error(f"Timeout while calling Caption API after 30 seconds")
logger.error(f"Timeout while calling OpenAI-compatible API after 30 seconds")
return None
except requests.exceptions.RequestException as e:
logger.error(f"Request error calling Caption API: {e}")
logger.error(f"Request error calling OpenAI-compatible API: {e}")
return None
except Exception as e:
logger.error(f"Error calling Caption API: {str(e)}")
logger.info(
f"Request details: model={self.model}, URL={self.completion_url}"
)
logger.error(f"Unexpected error calling OpenAI-compatible API: {e}")
return None
def get_caption(self, image_url: str) -> str:
def get_caption(self, image_data: str) -> str:
"""
Get a caption for the provided image URL.
Get a caption for the provided image data.
Args:
image_url: URL of the image to be captioned
image_data: URL of the image or base64 encoded image data
Returns:
Caption text as string, or empty string if captioning failed
"""
logger.info("Getting caption for image")
if not image_url or self.completion_url is None:
logger.error("Image URL is not set")
if not image_data or self.completion_url is None:
logger.error("Image data is not set")
return ""
caption_resp = self._call_caption_api(image_url)
caption_resp = self._call_caption_api(image_data)
if caption_resp:
caption = caption_resp.choice_data()
caption_length = len(caption)

View File

@@ -37,7 +37,13 @@ class ImageParser(BaseParser):
# Upload image to storage service
logger.info("Uploading image to storage")
_, ext = os.path.splitext(self.file_name)
image_url = self.upload_bytes(content, file_ext=ext)
# Get COS config from chunking config if available
cos_config = None
if hasattr(self, 'chunking_config') and self.chunking_config and hasattr(self.chunking_config, 'cos_config'):
cos_config = self.chunking_config.cos_config
image_url = self.upload_bytes(content, file_ext=ext, cos_config=cos_config)
if not image_url:
logger.error("Failed to upload image to storage")
return ""

View File

@@ -0,0 +1,43 @@
import base64
import io
import logging
from typing import Union
from PIL import Image
import numpy as np
logger = logging.getLogger(__name__)
def image_to_base64(image: Union[str, bytes, Image.Image, np.ndarray]) -> str:
"""Convert image to base64 encoded string
Args:
image: Image file path, bytes, PIL Image object, or numpy array
Returns:
Base64 encoded image string, or empty string if conversion fails
"""
try:
if isinstance(image, str):
# It's a file path
with open(image, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
elif isinstance(image, bytes):
# It's bytes data
return base64.b64encode(image).decode("utf-8")
elif isinstance(image, Image.Image):
# It's a PIL Image
buffer = io.BytesIO()
image.save(buffer, format="PNG")
return base64.b64encode(buffer.getvalue()).decode("utf-8")
elif isinstance(image, np.ndarray):
# It's a numpy array
pil_image = Image.fromarray(image)
buffer = io.BytesIO()
pil_image.save(buffer, format="PNG")
return base64.b64encode(buffer.getvalue()).decode("utf-8")
else:
logger.error(f"Unsupported image type: {type(image)}")
return ""
except Exception as e:
logger.error(f"Error converting image to base64: {str(e)}")
return ""

View File

@@ -6,6 +6,7 @@ from abc import ABC, abstractmethod
from PIL import Image
import io
import numpy as np
from .image_utils import image_to_base64
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
@@ -161,35 +162,6 @@ class NanonetsOCRBackend(OCRBackend):
logger.error(f"Failed to initialize Nanonets OCR: {str(e)}")
self.client = None
def _encode_image(self, image: Union[str, bytes, Image.Image]) -> str:
"""Encode image to base64
Args:
image: Image file path, bytes, or PIL Image object
Returns:
Base64 encoded image
"""
try:
if isinstance(image, str):
# It's a file path
with open(image, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
elif isinstance(image, bytes):
# It's bytes data
return base64.b64encode(image).decode("utf-8")
elif isinstance(image, Image.Image):
# It's a PIL Image
buffer = io.BytesIO()
image.save(buffer, format="PNG")
return base64.b64encode(buffer.getvalue()).decode("utf-8")
else:
logger.error(f"Unsupported image type: {type(image)}")
return ""
except Exception as e:
logger.error(f"Error encoding image: {str(e)}")
return ""
def predict(self, image: Union[str, bytes, Image.Image]) -> str:
"""Extract text from an image using Nanonets OCR
@@ -205,7 +177,7 @@ class NanonetsOCRBackend(OCRBackend):
try:
# Encode image to base64
img_base64 = self._encode_image(image)
img_base64 = image_to_base64(image)
if not img_base64:
return ""

View File

@@ -30,6 +30,8 @@ class ChunkingConfig:
enable_multimodal: bool = (
False # Whether to enable multimodal processing (text + images)
)
cos_config: dict = None # COS configuration for file storage
vlm_config: dict = None # VLM configuration for image captioning
@dataclass
@@ -138,6 +140,7 @@ class Parser:
enable_multimodal=config.enable_multimodal,
max_image_size=1920, # Limit image size to 1920px
max_concurrent_tasks=5, # Limit concurrent tasks to 5
chunking_config=config, # Pass the entire chunking config
)
logger.info(f"Starting to parse file content, size: {len(content)} bytes")
@@ -197,6 +200,7 @@ class Parser:
enable_multimodal=config.enable_multimodal,
max_image_size=1920, # Limit image size
max_concurrent_tasks=5, # Limit concurrent tasks
chunking_config=config,
)
logger.info(f"Starting to parse URL content")

View File

@@ -1,6 +1,6 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.36.6
// protoc-gen-go v1.36.7
// protoc v5.29.3
// source: docreader.proto
@@ -21,19 +21,175 @@ const (
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
// COS 配置
type COSConfig struct {
state protoimpl.MessageState `protogen:"open.v1"`
SecretId string `protobuf:"bytes,1,opt,name=secret_id,json=secretId,proto3" json:"secret_id,omitempty"` // COS Secret ID
SecretKey string `protobuf:"bytes,2,opt,name=secret_key,json=secretKey,proto3" json:"secret_key,omitempty"` // COS Secret Key
Region string `protobuf:"bytes,3,opt,name=region,proto3" json:"region,omitempty"` // COS Region
BucketName string `protobuf:"bytes,4,opt,name=bucket_name,json=bucketName,proto3" json:"bucket_name,omitempty"` // COS Bucket Name
AppId string `protobuf:"bytes,5,opt,name=app_id,json=appId,proto3" json:"app_id,omitempty"` // COS App ID
PathPrefix string `protobuf:"bytes,6,opt,name=path_prefix,json=pathPrefix,proto3" json:"path_prefix,omitempty"` // COS Path Prefix
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *COSConfig) Reset() {
*x = COSConfig{}
mi := &file_docreader_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *COSConfig) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*COSConfig) ProtoMessage() {}
func (x *COSConfig) ProtoReflect() protoreflect.Message {
mi := &file_docreader_proto_msgTypes[0]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use COSConfig.ProtoReflect.Descriptor instead.
func (*COSConfig) Descriptor() ([]byte, []int) {
return file_docreader_proto_rawDescGZIP(), []int{0}
}
func (x *COSConfig) GetSecretId() string {
if x != nil {
return x.SecretId
}
return ""
}
func (x *COSConfig) GetSecretKey() string {
if x != nil {
return x.SecretKey
}
return ""
}
func (x *COSConfig) GetRegion() string {
if x != nil {
return x.Region
}
return ""
}
func (x *COSConfig) GetBucketName() string {
if x != nil {
return x.BucketName
}
return ""
}
func (x *COSConfig) GetAppId() string {
if x != nil {
return x.AppId
}
return ""
}
func (x *COSConfig) GetPathPrefix() string {
if x != nil {
return x.PathPrefix
}
return ""
}
// VLM 配置
type VLMConfig struct {
state protoimpl.MessageState `protogen:"open.v1"`
ModelName string `protobuf:"bytes,1,opt,name=model_name,json=modelName,proto3" json:"model_name,omitempty"` // VLM Model Name
BaseUrl string `protobuf:"bytes,2,opt,name=base_url,json=baseUrl,proto3" json:"base_url,omitempty"` // VLM Base URL
ApiKey string `protobuf:"bytes,3,opt,name=api_key,json=apiKey,proto3" json:"api_key,omitempty"` // VLM API Key
InterfaceType string `protobuf:"bytes,4,opt,name=interface_type,json=interfaceType,proto3" json:"interface_type,omitempty"` // VLM Interface Type: "ollama" or "openai"
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *VLMConfig) Reset() {
*x = VLMConfig{}
mi := &file_docreader_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *VLMConfig) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*VLMConfig) ProtoMessage() {}
func (x *VLMConfig) ProtoReflect() protoreflect.Message {
mi := &file_docreader_proto_msgTypes[1]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use VLMConfig.ProtoReflect.Descriptor instead.
func (*VLMConfig) Descriptor() ([]byte, []int) {
return file_docreader_proto_rawDescGZIP(), []int{1}
}
func (x *VLMConfig) GetModelName() string {
if x != nil {
return x.ModelName
}
return ""
}
func (x *VLMConfig) GetBaseUrl() string {
if x != nil {
return x.BaseUrl
}
return ""
}
func (x *VLMConfig) GetApiKey() string {
if x != nil {
return x.ApiKey
}
return ""
}
func (x *VLMConfig) GetInterfaceType() string {
if x != nil {
return x.InterfaceType
}
return ""
}
type ReadConfig struct {
state protoimpl.MessageState `protogen:"open.v1"`
ChunkSize int32 `protobuf:"varint,1,opt,name=chunk_size,json=chunkSize,proto3" json:"chunk_size,omitempty"` // 分块大小
ChunkOverlap int32 `protobuf:"varint,2,opt,name=chunk_overlap,json=chunkOverlap,proto3" json:"chunk_overlap,omitempty"` // 分块重叠
Separators []string `protobuf:"bytes,3,rep,name=separators,proto3" json:"separators,omitempty"` // 分隔符
EnableMultimodal bool `protobuf:"varint,4,opt,name=enable_multimodal,json=enableMultimodal,proto3" json:"enable_multimodal,omitempty"` // 多模态处理
CosConfig *COSConfig `protobuf:"bytes,5,opt,name=cos_config,json=cosConfig,proto3" json:"cos_config,omitempty"` // COS 配置
VlmConfig *VLMConfig `protobuf:"bytes,6,opt,name=vlm_config,json=vlmConfig,proto3" json:"vlm_config,omitempty"` // VLM 配置
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *ReadConfig) Reset() {
*x = ReadConfig{}
mi := &file_docreader_proto_msgTypes[0]
mi := &file_docreader_proto_msgTypes[2]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -45,7 +201,7 @@ func (x *ReadConfig) String() string {
func (*ReadConfig) ProtoMessage() {}
func (x *ReadConfig) ProtoReflect() protoreflect.Message {
mi := &file_docreader_proto_msgTypes[0]
mi := &file_docreader_proto_msgTypes[2]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -58,7 +214,7 @@ func (x *ReadConfig) ProtoReflect() protoreflect.Message {
// Deprecated: Use ReadConfig.ProtoReflect.Descriptor instead.
func (*ReadConfig) Descriptor() ([]byte, []int) {
return file_docreader_proto_rawDescGZIP(), []int{0}
return file_docreader_proto_rawDescGZIP(), []int{2}
}
func (x *ReadConfig) GetChunkSize() int32 {
@@ -89,6 +245,20 @@ func (x *ReadConfig) GetEnableMultimodal() bool {
return false
}
func (x *ReadConfig) GetCosConfig() *COSConfig {
if x != nil {
return x.CosConfig
}
return nil
}
func (x *ReadConfig) GetVlmConfig() *VLMConfig {
if x != nil {
return x.VlmConfig
}
return nil
}
// 从文件读取文档请求
type ReadFromFileRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
@@ -103,7 +273,7 @@ type ReadFromFileRequest struct {
func (x *ReadFromFileRequest) Reset() {
*x = ReadFromFileRequest{}
mi := &file_docreader_proto_msgTypes[1]
mi := &file_docreader_proto_msgTypes[3]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -115,7 +285,7 @@ func (x *ReadFromFileRequest) String() string {
func (*ReadFromFileRequest) ProtoMessage() {}
func (x *ReadFromFileRequest) ProtoReflect() protoreflect.Message {
mi := &file_docreader_proto_msgTypes[1]
mi := &file_docreader_proto_msgTypes[3]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -128,7 +298,7 @@ func (x *ReadFromFileRequest) ProtoReflect() protoreflect.Message {
// Deprecated: Use ReadFromFileRequest.ProtoReflect.Descriptor instead.
func (*ReadFromFileRequest) Descriptor() ([]byte, []int) {
return file_docreader_proto_rawDescGZIP(), []int{1}
return file_docreader_proto_rawDescGZIP(), []int{3}
}
func (x *ReadFromFileRequest) GetFileContent() []byte {
@@ -179,7 +349,7 @@ type ReadFromURLRequest struct {
func (x *ReadFromURLRequest) Reset() {
*x = ReadFromURLRequest{}
mi := &file_docreader_proto_msgTypes[2]
mi := &file_docreader_proto_msgTypes[4]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -191,7 +361,7 @@ func (x *ReadFromURLRequest) String() string {
func (*ReadFromURLRequest) ProtoMessage() {}
func (x *ReadFromURLRequest) ProtoReflect() protoreflect.Message {
mi := &file_docreader_proto_msgTypes[2]
mi := &file_docreader_proto_msgTypes[4]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -204,7 +374,7 @@ func (x *ReadFromURLRequest) ProtoReflect() protoreflect.Message {
// Deprecated: Use ReadFromURLRequest.ProtoReflect.Descriptor instead.
func (*ReadFromURLRequest) Descriptor() ([]byte, []int) {
return file_docreader_proto_rawDescGZIP(), []int{2}
return file_docreader_proto_rawDescGZIP(), []int{4}
}
func (x *ReadFromURLRequest) GetUrl() string {
@@ -250,7 +420,7 @@ type Image struct {
func (x *Image) Reset() {
*x = Image{}
mi := &file_docreader_proto_msgTypes[3]
mi := &file_docreader_proto_msgTypes[5]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -262,7 +432,7 @@ func (x *Image) String() string {
func (*Image) ProtoMessage() {}
func (x *Image) ProtoReflect() protoreflect.Message {
mi := &file_docreader_proto_msgTypes[3]
mi := &file_docreader_proto_msgTypes[5]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -275,7 +445,7 @@ func (x *Image) ProtoReflect() protoreflect.Message {
// Deprecated: Use Image.ProtoReflect.Descriptor instead.
func (*Image) Descriptor() ([]byte, []int) {
return file_docreader_proto_rawDescGZIP(), []int{3}
return file_docreader_proto_rawDescGZIP(), []int{5}
}
func (x *Image) GetUrl() string {
@@ -333,7 +503,7 @@ type Chunk struct {
func (x *Chunk) Reset() {
*x = Chunk{}
mi := &file_docreader_proto_msgTypes[4]
mi := &file_docreader_proto_msgTypes[6]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -345,7 +515,7 @@ func (x *Chunk) String() string {
func (*Chunk) ProtoMessage() {}
func (x *Chunk) ProtoReflect() protoreflect.Message {
mi := &file_docreader_proto_msgTypes[4]
mi := &file_docreader_proto_msgTypes[6]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -358,7 +528,7 @@ func (x *Chunk) ProtoReflect() protoreflect.Message {
// Deprecated: Use Chunk.ProtoReflect.Descriptor instead.
func (*Chunk) Descriptor() ([]byte, []int) {
return file_docreader_proto_rawDescGZIP(), []int{4}
return file_docreader_proto_rawDescGZIP(), []int{6}
}
func (x *Chunk) GetContent() string {
@@ -407,7 +577,7 @@ type ReadResponse struct {
func (x *ReadResponse) Reset() {
*x = ReadResponse{}
mi := &file_docreader_proto_msgTypes[5]
mi := &file_docreader_proto_msgTypes[7]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -419,7 +589,7 @@ func (x *ReadResponse) String() string {
func (*ReadResponse) ProtoMessage() {}
func (x *ReadResponse) ProtoReflect() protoreflect.Message {
mi := &file_docreader_proto_msgTypes[5]
mi := &file_docreader_proto_msgTypes[7]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -432,7 +602,7 @@ func (x *ReadResponse) ProtoReflect() protoreflect.Message {
// Deprecated: Use ReadResponse.ProtoReflect.Descriptor instead.
func (*ReadResponse) Descriptor() ([]byte, []int) {
return file_docreader_proto_rawDescGZIP(), []int{5}
return file_docreader_proto_rawDescGZIP(), []int{7}
}
func (x *ReadResponse) GetChunks() []*Chunk {
@@ -453,7 +623,23 @@ var File_docreader_proto protoreflect.FileDescriptor
const file_docreader_proto_rawDesc = "" +
"\n" +
"\x0fdocreader.proto\x12\tdocreader\"\x9d\x01\n" +
"\x0fdocreader.proto\x12\tdocreader\"\xb8\x01\n" +
"\tCOSConfig\x12\x1b\n" +
"\tsecret_id\x18\x01 \x01(\tR\bsecretId\x12\x1d\n" +
"\n" +
"secret_key\x18\x02 \x01(\tR\tsecretKey\x12\x16\n" +
"\x06region\x18\x03 \x01(\tR\x06region\x12\x1f\n" +
"\vbucket_name\x18\x04 \x01(\tR\n" +
"bucketName\x12\x15\n" +
"\x06app_id\x18\x05 \x01(\tR\x05appId\x12\x1f\n" +
"\vpath_prefix\x18\x06 \x01(\tR\n" +
"pathPrefix\"\x85\x01\n" +
"\tVLMConfig\x12\x1d\n" +
"\n" +
"model_name\x18\x01 \x01(\tR\tmodelName\x12\x19\n" +
"\bbase_url\x18\x02 \x01(\tR\abaseUrl\x12\x17\n" +
"\aapi_key\x18\x03 \x01(\tR\x06apiKey\x12%\n" +
"\x0einterface_type\x18\x04 \x01(\tR\rinterfaceType\"\x87\x02\n" +
"\n" +
"ReadConfig\x12\x1d\n" +
"\n" +
@@ -462,7 +648,11 @@ const file_docreader_proto_rawDesc = "" +
"\n" +
"separators\x18\x03 \x03(\tR\n" +
"separators\x12+\n" +
"\x11enable_multimodal\x18\x04 \x01(\bR\x10enableMultimodal\"\xc9\x01\n" +
"\x11enable_multimodal\x18\x04 \x01(\bR\x10enableMultimodal\x123\n" +
"\n" +
"cos_config\x18\x05 \x01(\v2\x14.docreader.COSConfigR\tcosConfig\x123\n" +
"\n" +
"vlm_config\x18\x06 \x01(\v2\x14.docreader.VLMConfigR\tvlmConfig\"\xc9\x01\n" +
"\x13ReadFromFileRequest\x12!\n" +
"\ffile_content\x18\x01 \x01(\fR\vfileContent\x12\x1b\n" +
"\tfile_name\x18\x02 \x01(\tR\bfileName\x12\x1b\n" +
@@ -510,29 +700,33 @@ func file_docreader_proto_rawDescGZIP() []byte {
return file_docreader_proto_rawDescData
}
var file_docreader_proto_msgTypes = make([]protoimpl.MessageInfo, 6)
var file_docreader_proto_msgTypes = make([]protoimpl.MessageInfo, 8)
var file_docreader_proto_goTypes = []any{
(*ReadConfig)(nil), // 0: docreader.ReadConfig
(*ReadFromFileRequest)(nil), // 1: docreader.ReadFromFileRequest
(*ReadFromURLRequest)(nil), // 2: docreader.ReadFromURLRequest
(*Image)(nil), // 3: docreader.Image
(*Chunk)(nil), // 4: docreader.Chunk
(*ReadResponse)(nil), // 5: docreader.ReadResponse
(*COSConfig)(nil), // 0: docreader.COSConfig
(*VLMConfig)(nil), // 1: docreader.VLMConfig
(*ReadConfig)(nil), // 2: docreader.ReadConfig
(*ReadFromFileRequest)(nil), // 3: docreader.ReadFromFileRequest
(*ReadFromURLRequest)(nil), // 4: docreader.ReadFromURLRequest
(*Image)(nil), // 5: docreader.Image
(*Chunk)(nil), // 6: docreader.Chunk
(*ReadResponse)(nil), // 7: docreader.ReadResponse
}
var file_docreader_proto_depIdxs = []int32{
0, // 0: docreader.ReadFromFileRequest.read_config:type_name -> docreader.ReadConfig
0, // 1: docreader.ReadFromURLRequest.read_config:type_name -> docreader.ReadConfig
3, // 2: docreader.Chunk.images:type_name -> docreader.Image
4, // 3: docreader.ReadResponse.chunks:type_name -> docreader.Chunk
1, // 4: docreader.DocReader.ReadFromFile:input_type -> docreader.ReadFromFileRequest
2, // 5: docreader.DocReader.ReadFromURL:input_type -> docreader.ReadFromURLRequest
5, // 6: docreader.DocReader.ReadFromFile:output_type -> docreader.ReadResponse
5, // 7: docreader.DocReader.ReadFromURL:output_type -> docreader.ReadResponse
6, // [6:8] is the sub-list for method output_type
4, // [4:6] is the sub-list for method input_type
4, // [4:4] is the sub-list for extension type_name
4, // [4:4] is the sub-list for extension extendee
0, // [0:4] is the sub-list for field type_name
0, // 0: docreader.ReadConfig.cos_config:type_name -> docreader.COSConfig
1, // 1: docreader.ReadConfig.vlm_config:type_name -> docreader.VLMConfig
2, // 2: docreader.ReadFromFileRequest.read_config:type_name -> docreader.ReadConfig
2, // 3: docreader.ReadFromURLRequest.read_config:type_name -> docreader.ReadConfig
5, // 4: docreader.Chunk.images:type_name -> docreader.Image
6, // 5: docreader.ReadResponse.chunks:type_name -> docreader.Chunk
3, // 6: docreader.DocReader.ReadFromFile:input_type -> docreader.ReadFromFileRequest
4, // 7: docreader.DocReader.ReadFromURL:input_type -> docreader.ReadFromURLRequest
7, // 8: docreader.DocReader.ReadFromFile:output_type -> docreader.ReadResponse
7, // 9: docreader.DocReader.ReadFromURL:output_type -> docreader.ReadResponse
8, // [8:10] is the sub-list for method output_type
6, // [6:8] is the sub-list for method input_type
6, // [6:6] is the sub-list for extension type_name
6, // [6:6] is the sub-list for extension extendee
0, // [0:6] is the sub-list for field type_name
}
func init() { file_docreader_proto_init() }
@@ -546,7 +740,7 @@ func file_docreader_proto_init() {
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: unsafe.Slice(unsafe.StringData(file_docreader_proto_rawDesc), len(file_docreader_proto_rawDesc)),
NumEnums: 0,
NumMessages: 6,
NumMessages: 8,
NumExtensions: 0,
NumServices: 1,
},

View File

@@ -12,11 +12,31 @@ service DocReader {
rpc ReadFromURL(ReadFromURLRequest) returns (ReadResponse) {}
}
// COS 配置
message COSConfig {
string secret_id = 1; // COS Secret ID
string secret_key = 2; // COS Secret Key
string region = 3; // COS Region
string bucket_name = 4; // COS Bucket Name
string app_id = 5; // COS App ID
string path_prefix = 6; // COS Path Prefix
}
// VLM 配置
message VLMConfig {
string model_name = 1; // VLM Model Name
string base_url = 2; // VLM Base URL
string api_key = 3; // VLM API Key
string interface_type = 4; // VLM Interface Type: "ollama" or "openai"
}
message ReadConfig {
int32 chunk_size = 1; // 分块大小
int32 chunk_overlap = 2; // 分块重叠
repeated string separators = 3; // 分隔符
bool enable_multimodal = 4; // 多模态处理
COSConfig cos_config = 5; // COS 配置
VLMConfig vlm_config = 6; // VLM 配置
}
// 从文件读取文档请求

View File

@@ -1,9 +1,29 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
import warnings
from . import docreader_pb2 as docreader__pb2
GRPC_GENERATED_VERSION = '1.74.0'
GRPC_VERSION = grpc.__version__
_version_not_supported = False
try:
from grpc._utilities import first_version_is_lower
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
except ImportError:
_version_not_supported = True
if _version_not_supported:
raise RuntimeError(
f'The grpc package installed is at version {GRPC_VERSION},'
+ f' but the generated code in docreader_pb2_grpc.py depends on'
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
)
class DocReaderStub(object):
"""文档读取服务
@@ -19,12 +39,12 @@ class DocReaderStub(object):
'/docreader.DocReader/ReadFromFile',
request_serializer=docreader__pb2.ReadFromFileRequest.SerializeToString,
response_deserializer=docreader__pb2.ReadResponse.FromString,
)
_registered_method=True)
self.ReadFromURL = channel.unary_unary(
'/docreader.DocReader/ReadFromURL',
request_serializer=docreader__pb2.ReadFromURLRequest.SerializeToString,
response_deserializer=docreader__pb2.ReadResponse.FromString,
)
_registered_method=True)
class DocReaderServicer(object):
@@ -62,6 +82,7 @@ def add_DocReaderServicer_to_server(servicer, server):
generic_handler = grpc.method_handlers_generic_handler(
'docreader.DocReader', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
server.add_registered_method_handlers('docreader.DocReader', rpc_method_handlers)
# This class is part of an EXPERIMENTAL API.
@@ -80,11 +101,21 @@ class DocReader(object):
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/docreader.DocReader/ReadFromFile',
return grpc.experimental.unary_unary(
request,
target,
'/docreader.DocReader/ReadFromFile',
docreader__pb2.ReadFromFileRequest.SerializeToString,
docreader__pb2.ReadResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def ReadFromURL(request,
@@ -97,8 +128,18 @@ class DocReader(object):
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/docreader.DocReader/ReadFromURL',
return grpc.experimental.unary_unary(
request,
target,
'/docreader.DocReader/ReadFromURL',
docreader__pb2.ReadFromURLRequest.SerializeToString,
docreader__pb2.ReadResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)

View File

@@ -74,11 +74,39 @@ class DocReaderServicer(docreader_pb2_grpc.DocReaderServicer):
f"multimodal={enable_multimodal}"
)
# Get COS and VLM config from request
cos_config = None
vlm_config = None
if hasattr(request.read_config, 'cos_config') and request.read_config.cos_config:
cos_config = {
'secret_id': request.read_config.cos_config.secret_id,
'secret_key': request.read_config.cos_config.secret_key,
'region': request.read_config.cos_config.region,
'bucket_name': request.read_config.cos_config.bucket_name,
'app_id': request.read_config.cos_config.app_id,
'path_prefix': request.read_config.cos_config.path_prefix or '',
}
logger.info(f"Using COS config: region={cos_config['region']}, bucket={cos_config['bucket_name']}")
if hasattr(request.read_config, 'vlm_config') and request.read_config.vlm_config:
vlm_config = {
'model_name': request.read_config.vlm_config.model_name,
'base_url': request.read_config.vlm_config.base_url,
'api_key': request.read_config.vlm_config.api_key or '',
'interface_type': request.read_config.vlm_config.interface_type or 'openai',
}
logger.info(f"Using VLM config: model={vlm_config['model_name']}, "
f"base_url={vlm_config['base_url']}, "
f"interface_type={vlm_config['interface_type']}")
chunking_config = ChunkingConfig(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
separators=separators,
enable_multimodal=enable_multimodal,
cos_config=cos_config,
vlm_config=vlm_config,
)
# Parse file
@@ -138,11 +166,39 @@ class DocReaderServicer(docreader_pb2_grpc.DocReaderServicer):
f"multimodal={enable_multimodal}"
)
# Get COS and VLM config from request
cos_config = None
vlm_config = None
if hasattr(request.read_config, 'cos_config') and request.read_config.cos_config:
cos_config = {
'secret_id': request.read_config.cos_config.secret_id,
'secret_key': request.read_config.cos_config.secret_key,
'region': request.read_config.cos_config.region,
'bucket_name': request.read_config.cos_config.bucket_name,
'app_id': request.read_config.cos_config.app_id,
'path_prefix': request.read_config.cos_config.path_prefix or '',
}
logger.info(f"Using COS config: region={cos_config['region']}, bucket={cos_config['bucket_name']}")
if hasattr(request.read_config, 'vlm_config') and request.read_config.vlm_config:
vlm_config = {
'model_name': request.read_config.vlm_config.model_name,
'base_url': request.read_config.vlm_config.base_url,
'api_key': request.read_config.vlm_config.api_key or '',
'interface_type': request.read_config.vlm_config.interface_type or 'openai',
}
logger.info(f"Using VLM config: model={vlm_config['model_name']}, "
f"base_url={vlm_config['base_url']}, "
f"interface_type={vlm_config['interface_type']}")
chunking_config = ChunkingConfig(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
separators=separators,
enable_multimodal=enable_multimodal,
cos_config=cos_config,
vlm_config=vlm_config,
)
# Parse URL