feat: 新增异步任务提取服务

This commit is contained in:
begoniezhao
2025-09-24 12:15:17 +08:00
parent de96a52d54
commit 02b78a5908
30 changed files with 2514 additions and 117 deletions

View File

@@ -121,6 +121,18 @@ COS_ENABLE_OLD_DOMAIN=true
# 如果解析网络连接使用Web代理需要配置以下参数 # 如果解析网络连接使用Web代理需要配置以下参数
# WEB_PROXY=your_web_proxy # WEB_PROXY=your_web_proxy
# Neo4j 开关
# NEO4J_ENABLE=false
# Neo4j的访问地址
# NEO4J_URI=neo4j://neo4j:7687
# Neo4j的用户名和密码
# NEO4J_USERNAME=neo4j
# Neo4j的密码
# NEO4J_PASSWORD=password
############################################################## ##############################################################
###### 注意: 以下配置不再生效已在Web“配置初始化”阶段完成 ######### ###### 注意: 以下配置不再生效已在Web“配置初始化”阶段完成 #########

View File

@@ -534,3 +534,69 @@ knowledge_base:
split_markers: ["\n\n", "\n", "。"] split_markers: ["\n\n", "\n", "。"]
image_processing: image_processing:
enable_multimodal: true enable_multimodal: true
extract:
extract_graph:
description: |
请基于给定文本,按以下步骤完成信息提取任务,确保逻辑清晰、信息完整准确:
## 一、实体提取与属性补充
1. **提取核心实体**:通读文本,按逻辑顺序(如文本叙述顺序、实体关联紧密程度)提取所有与任务相关的核心实体。
2. **补充实体详细属性**:针对每个提取的实体,全面补充其在文本中明确提及的详细属性,确保无关键属性遗漏。
## 二、关系提取与验证
1. **明确关系类型**:仅从指定关系列表中选择对应类型,限定关系类型为: %s。
2. **提取有效关系**:基于已提取的实体及属性,识别文本中真实存在的关系,确保关系符合文本事实、无虚假关联。
3. **明确关系主体**:对每一组提取的关系,清晰标注两个关联主体,避免主体混淆。
4. **补充关联属性**:若文本中存在与该关系直接相关的补充信息,需将该信息作为关系的关联属性补充,进一步完善关系信息。
tags:
- "作者"
- "别名"
examples:
- text: |
《红楼梦》又名《石头记》是清代作家曹雪芹创作的中国古典四大名著之一被誉为中国封建社会的百科全书。该书前80回由曹雪芹所著后40回一般认为是高鹗所续。
小说以贾、史、王、薛四大家族的兴衰为背景,以贾宝玉、林黛玉和薛宝钗的爱情悲剧为主线,刻画了以贾宝玉和金陵十二钗为中心的正邪两赋、贤愚并出的高度复杂的人物群像。
成书于乾隆年间1743年前后是中国文学史上现实主义的高峰对后世影响深远。
node:
- name: "红楼梦"
attributes:
- "中国古典四大名著之一"
- "又名《石头记》"
- "被誉为中国封建社会的百科全书"
- name: "石头记"
attributes:
- "《红楼梦》的别名"
- name: "曹雪芹"
attributes:
- "清代作家"
- "《红楼梦》前 80 回的作者"
- name: "高鹗"
attributes:
- "一般认为是《红楼梦》后 40 回的续写者"
relation:
- node1: "红楼梦"
node2: "曹雪芹"
type: "作者"
- node1: "红楼梦"
node2: "高鹗"
type: "作者"
- node1: "红楼梦"
node2: "石头记"
type: "别名"
extract_entity:
description: |
请基于用户给的问题,按以下步骤处理关键信息提取任务:
1. 梳理逻辑关联:首先完整分析文本内容,明确其核心逻辑关系,并简要标注该核心逻辑类型;
2. 提取关键实体:围绕梳理出的逻辑关系,精准提取文本中的关键信息并归类为明确实体,确保不遗漏核心信息、不添加冗余内容;
3. 排序实体优先级:按实体与文本核心主题的关联紧密程度排序,优先呈现对理解文本主旨最重要的实体;
examples:
- text: "《红楼梦》,又名《石头记》,是清代作家曹雪芹创作的中国古典四大名著之一,被誉为中国封建社会的百科全书。"
node:
- name: "红楼梦"
- name: "曹雪芹"
- name: "中国古典四大名著"
fabri_text:
with_tag: |
请随机生成一段文本,要求内容与 %s 等相关,字数在 [50-200] 之间,并且尽量包含一些与这些标签相关的专业术语或典型元素,使文本更具针对性和相关性。
with_no_tag: |
请随机生成一段文本,内容请自由发挥,字数在 [50-200] 之间。

View File

@@ -54,6 +54,10 @@ services:
- REDIS_DB=${REDIS_DB:-} - REDIS_DB=${REDIS_DB:-}
- REDIS_PREFIX=${REDIS_PREFIX:-} - REDIS_PREFIX=${REDIS_PREFIX:-}
- ENABLE_GRAPH_RAG=${ENABLE_GRAPH_RAG:-} - ENABLE_GRAPH_RAG=${ENABLE_GRAPH_RAG:-}
- NEO4J_ENABLE=${NEO4J_ENABLE:-}
- NEO4J_URI=bolt://neo4j:7687
- NEO4J_USERNAME=${NEO4J_USERNAME:-neo4j}
- NEO4J_PASSWORD=${NEO4J_PASSWORD:-password}
- TENANT_AES_KEY=${TENANT_AES_KEY:-} - TENANT_AES_KEY=${TENANT_AES_KEY:-}
- CONCURRENCY_POOL_SIZE=${CONCURRENCY_POOL_SIZE:-5} - CONCURRENCY_POOL_SIZE=${CONCURRENCY_POOL_SIZE:-5}
- INIT_LLM_MODEL_NAME=${INIT_LLM_MODEL_NAME:-} - INIT_LLM_MODEL_NAME=${INIT_LLM_MODEL_NAME:-}
@@ -76,6 +80,8 @@ services:
condition: service_started condition: service_started
docreader: docreader:
condition: service_healthy condition: service_healthy
neo4j:
condition: service_started
networks: networks:
- WeKnora-network - WeKnora-network
restart: unless-stopped restart: unless-stopped
@@ -209,6 +215,24 @@ services:
networks: networks:
- WeKnora-network - WeKnora-network
neo4j:
image: neo4j:latest
container_name: WeKnora-neo4j
volumes:
- neo4j-data:/data
environment:
- NEO4J_AUTH=${NEO4J_USERNAME:-neo4j}/${NEO4J_PASSWORD:-password}
- NEO4J_apoc_export_file_enabled=true
- NEO4J_apoc_import_file_enabled=true
- NEO4J_apoc_import_file_use__neo4j__config=true
- NEO4JLABS_PLUGINS=["apoc"]
ports:
- "7474:7474"
- "7687:7687"
restart: always
networks:
- WeKnora-network
networks: networks:
WeKnora-network: WeKnora-network:
driver: bridge driver: bridge
@@ -219,3 +243,4 @@ volumes:
jaeger_data: jaeger_data:
redis_data: redis_data:
minio_data: minio_data:
neo4j-data:

View File

@@ -50,6 +50,13 @@ export interface InitializationConfig {
}; };
// Frontend-only hint for storage selection UI // Frontend-only hint for storage selection UI
storageType?: 'cos' | 'minio'; storageType?: 'cos' | 'minio';
nodeExtract: {
enabled: boolean,
text: string,
tags: string[],
nodes: Node[],
relations: Relation[]
}
} }
// 下载任务状态类型 // 下载任务状态类型
@@ -63,8 +70,6 @@ export interface DownloadTask {
endTime?: string; endTime?: string;
} }
// 根据知识库ID执行配置更新 // 根据知识库ID执行配置更新
export function initializeSystemByKB(kbId: string, config: InitializationConfig): Promise<any> { export function initializeSystemByKB(kbId: string, config: InitializationConfig): Promise<any> {
return new Promise((resolve, reject) => { return new Promise((resolve, reject) => {
@@ -76,7 +81,7 @@ export function initializeSystemByKB(kbId: string, config: InitializationConfig)
}) })
.catch((error: any) => { .catch((error: any) => {
console.error('知识库配置更新失败:', error); console.error('知识库配置更新失败:', error);
reject(error); reject(error.error || error);
}); });
}); });
} }
@@ -324,4 +329,93 @@ export function testMultimodalFunction(testData: {
reject(error); reject(error);
}); });
}); });
} }
// 文本内容关系提取接口
export interface TextRelationExtractionRequest {
text: string;
tags: string[];
llmConfig: LLMConfig;
}
export interface Node {
name: string;
attributes: string[];
}
export interface Relation {
node1: string;
node2: string;
type: string;
}
export interface LLMConfig {
source: 'local' | 'remote';
modelName: string;
baseUrl: string;
apiKey: string;
}
export interface TextRelationExtractionResponse {
nodes: Node[];
relations: Relation[];
}
// 文本内容关系提取
export function extractTextRelations(request: TextRelationExtractionRequest): Promise<TextRelationExtractionResponse> {
return new Promise((resolve, reject) => {
post('/api/v1/initialization/extract/text-relation', request)
.then((response: any) => {
resolve(response.data || { nodes: [], relations: [] });
})
.catch((error: any) => {
console.error('文本内容关系提取失败:', error);
reject(error);
});
});
}
export interface FabriTextRequest {
tags: string[];
llmConfig: LLMConfig;
}
export interface FabriTextResponse {
text: string;
}
// 文本内容生成
export function fabriText(request: FabriTextRequest): Promise<FabriTextResponse> {
return new Promise((resolve, reject) => {
post('/api/v1/initialization/extract/fabri-text', request)
.then((response: any) => {
resolve(response.data || { text: '' });
})
.catch((error: any) => {
console.error('文本内容生成失败:', error);
reject(error);
});
});
}
export interface FabriTagRequest {
llmConfig: LLMConfig;
}
export interface FabriTagResponse {
tags: string[];
}
// 文本内容生成
export function fabriTag(request: FabriTagRequest): Promise<FabriTagResponse> {
return new Promise((resolve, reject) => {
post('/api/v1/initialization/extract/fabri-tag', request)
.then((response: any) => {
resolve(response.data || { tags: [] as string[] });
})
.catch((error: any) => {
console.error('标签生成失败:', error);
reject(error);
});
});
}

View File

@@ -144,7 +144,10 @@ let activeSubmenu = ref<number>(-1);
// 是否处于知识库详情页 // 是否处于知识库详情页
const isInKnowledgeBase = computed<boolean>(() => { const isInKnowledgeBase = computed<boolean>(() => {
return route.name === 'knowledgeBaseDetail' || route.name === 'kbCreatChat' || route.name === 'chat' || route.name === 'knowledgeBaseSettings'; return route.name === 'knowledgeBaseDetail' ||
route.name === 'kbCreatChat' ||
route.name === 'chat' ||
route.name === 'knowledgeBaseSettings';
}); });
// 统一的菜单项激活状态判断 // 统一的菜单项激活状态判断

View File

@@ -692,6 +692,285 @@
</div> </div>
</div> </div>
<!-- 实体关系提取 -->
<div class="config-section">
<h3><t-icon name="transform" class="section-icon" />实体关系提取</h3>
<div class="form-row">
<t-form-item name="nodeExtract.enabled">
<div class="switch-container">
<t-switch v-model="formData.nodeExtract.enabled" @change="clearExtractExample" />
<span class="switch-label">启用实体关系提取</span>
</div>
</t-form-item>
</div>
<div v-if="formData.nodeExtract.enabled" class="node-config">
<h4>关系标签配置</h4>
<!-- 关系标签配置区域 -->
<div class="form-row">
<t-form-item label="关系类型" name="tags">
<div class="tags-grid">
<div class="btn-tips-form">
<div class="tags-gen-btn">
<t-button
theme="default"
size="medium"
:disabled="!modelStatus.llm.available"
:loading="tagFabring"
@click="handleFabriTag"
class="gen-tags-btn"
>
随机生成标签
</t-button>
</div>
<div v-if="!modelStatus.llm.available" class="btn-tips">
<t-icon name="info-circle" class="tip-icon" />
<span>请完善模型配置信息</span>
</div>
</div>
<div class="tags-config">
<t-select
v-model="formData.nodeExtract.tags"
v-model:input-value="tagInput"
multiple
placeholder="系统将根据选定的关系类型从文本中提取相应的实体关系"
:options="tagOptions"
clearable
@clear="clearTags"
creatable
@create="addTag"
filterable
/>
</div>
</div>
</t-form-item>
</div>
<h4>提取示例</h4>
<!-- 文本内容输入区域 -->
<div class="form-row">
<t-form-item label="示例文本" name="text" :required="true">
<div class="sample-text-form">
<div class="btn-tips-form">
<div class="tags-gen-btn">
<t-button
theme="default"
size="medium"
:disabled="!modelStatus.llm.available"
:title="!modelStatus.llm.available ? 'LLM 模型不可用' : ''"
:loading="textFabring"
@click="handleFabriText"
class="tags-gen-btn"
>
随机生成文本
</t-button>
</div>
<div v-if="!modelStatus.llm.available" class="btn-tips">
<t-icon name="info-circle" class="tip-icon" />
<span>请完善模型配置信息</span>
</div>
</div>
<div class="sample-text">
<t-textarea
v-model="formData.nodeExtract.text"
placeholder="请输入需要分析的文本内容,例如:《红楼梦》,又名《石头记》,是清代作家曹雪芹创作的中国古典四大名著之一..."
:autosize="{ minRows: 8, maxRows: 15 }"
show-word-limit
maxlength="5000"
/>
</div>
</div>
</t-form-item>
</div>
<!-- 提取实体 -->
<div class="form-row">
<!-- 实体列表 -->
<t-form-item v-if="formData.nodeExtract.nodes.length > 0" label="实体列表" name="node-form">
<div class="node-list">
<div v-for="(node, nodeIndex) in formData.nodeExtract.nodes" :key="nodeIndex" class="node-item">
<div class="node-header">
<span class="node-icon"><t-icon name="user" class="node-icon-svg" /></span>
<!-- 节点名称输入 -->
<t-input
type="text"
v-model="node.name"
class="node-name-input"
placeholder="节点名称"
/>
<!-- 删除节点按钮 -->
<t-button
class="delete-node-btn"
theme="default"
@click="removeNode(nodeIndex)"
:disabled="formData.nodeExtract.nodes.length === 0"
size="small"
>
<t-icon name="delete" />
</t-button>
</div>
<div class="node-attributes">
<!-- 属性列表 -->
<div v-for="(attribute, attrIndex) in node.attributes" :key="attrIndex" class="attribute-item">
<t-input
type="text"
v-model="node.attributes[attrIndex]"
class="attribute-input"
placeholder="属性值"
/>
<t-button
class="delete-attr-btn"
theme="default"
@click="removeAttribute(nodeIndex, attrIndex)"
:disabled="node.attributes.length === 0"
size="small"
>
<t-icon name="close" />
</t-button>
</div>
<!-- 添加属性按钮 -->
<t-button class="add-attr-btn" @click="addAttribute(nodeIndex)" size="small">
添加属性
</t-button>
</div>
</div>
</div>
</t-form-item>
<!-- 添加实体按钮 -->
<div class="btn-tips-form">
<div class="tags-gen-btn">
<t-button class="add-node-btn" @click="addNode">
添加实体
</t-button>
</div>
<div v-if="!readyNode" class="btn-tips">
<t-icon name="info-circle" class="tip-icon" />
<span>请完善实体信息</span>
</div>
</div>
</div>
<!-- 提取关系 -->
<div class="form-row">
<t-form-item v-if="formData.nodeExtract.relations.length > 0" label="关系连接" name="node-relation">
<div class="relation-list">
<div v-for="(relation, index) in formData.nodeExtract.relations" :key="index" class="relation-item">
<div class="relation-line">
<t-select-input
:value="formData.nodeExtract.relations[index].node1"
:popup-visible="popupVisibleNode1[index]"
placeholder="请选择实体"
clearable
@popup-visible-change="onPopupVisibleNode1Change(index, $event)"
@clear="relationOnClearNode1(index)"
@focus="onFocus"
>
<template #panel>
<ul class="select-input-node">
<li v-for="item in formData.nodeExtract.nodes" :key="item.name" @click="onRelationNode1OptionClick(index, item)">
{{ item.name }}
</li>
</ul>
</template>
<template #suffixIcon>
<ChevronDownIcon />
</template>
</t-select-input>
<t-icon name="arrow-right" class="relation-arrow"/>
<t-select
v-model="formData.nodeExtract.relations[index].type"
placeholder="请选择关系类型"
:options="tagOptions"
clearable
creatable
filterable
/>
<t-icon name="arrow-right" class="relation-arrow"/>
<t-select-input
:value="formData.nodeExtract.relations[index].node2"
:popup-visible="popupVisibleNode2[index]"
placeholder="请选择实体"
clearable
@popup-visible-change="onPopupVisibleNode2Change(index, $event)"
@clear="relationOnClearNode2(index)"
@focus="onFocus"
>
<template #panel>
<ul class="select-input-node">
<li v-for="item in formData.nodeExtract.nodes" :key="item.name" @click="onRelationNode2OptionClick(index, item)">
{{ item.name }}
</li>
</ul>
</template>
<template #suffixIcon>
<ChevronDownIcon />
</template>
</t-select-input>
<t-button
class="delete-node-btn"
theme="default"
@click="removeRelation(index)"
:disabled="formData.nodeExtract.relations.length === 0"
size="small"
>
<t-icon name="delete" />
</t-button>
</div>
</div>
</div>
</t-form-item>
<!-- 添加关系按钮 -->
<div class="btn-tips-form">
<div class="tags-gen-btn">
<t-button class="add-node-btn" @click="addRelation">
添加关系
</t-button>
</div>
<div v-if="!readyRelation" class="btn-tips">
<t-icon name="info-circle" class="tip-icon" />
<span>请完善关系信息</span>
</div>
</div>
</div>
<!-- 重置按钮区域 -->
<div class="extract-button">
<t-button
theme="primary"
size="medium"
:disabled="!modelStatus.llm.available"
:title="!modelStatus.llm.available ? 'LLM 模型不可用' : ''"
:loading="extracting"
@click="handleExtract"
>
{{ extracting ? '正在提取...' : '开始提取' }}
</t-button>
<t-button
theme="default"
size="medium"
@click="defaultExtractExample"
class="default-extract-btn"
>
默认示例
</t-button>
<t-button
theme="default"
size="medium"
@click="clearExtractExample"
class="clear-extract-btn"
>
清空示例
</t-button>
</div>
</div>
</div>
<!-- 提交按钮区域 --> <!-- 提交按钮区域 -->
<div class="submit-section"> <div class="submit-section">
<t-button theme="primary" type="button" size="large" <t-button theme="primary" type="button" size="large"
@@ -724,6 +1003,7 @@
import { ref, reactive, computed, watch, onMounted, onUnmounted, nextTick } from 'vue'; import { ref, reactive, computed, watch, onMounted, onUnmounted, nextTick } from 'vue';
import { useRouter, useRoute } from 'vue-router'; import { useRouter, useRoute } from 'vue-router';
import { MessagePlugin } from 'tdesign-vue-next'; import { MessagePlugin } from 'tdesign-vue-next';
import { ChevronDownIcon } from 'tdesign-icons-vue-next';
import { import {
initializeSystemByKB, initializeSystemByKB,
checkOllamaStatus, checkOllamaStatus,
@@ -736,7 +1016,15 @@ import {
checkRerankModel, checkRerankModel,
testMultimodalFunction, testMultimodalFunction,
listOllamaModels, listOllamaModels,
testEmbeddingModel testEmbeddingModel,
extractTextRelations,
fabriText,
fabriTag,
type TextRelationExtractionRequest,
type Node,
type Relation,
type FabriTagRequest,
type FabriTextRequest
} from '@/api/initialization'; } from '@/api/initialization';
import { getKnowledgeBaseById } from '@/api/knowledge-base'; import { getKnowledgeBaseById } from '@/api/knowledge-base';
import { useAuthStore } from '@/stores/auth'; import { useAuthStore } from '@/stores/auth';
@@ -762,6 +1050,25 @@ const form = ref<TFormRef>(null);
const submitting = ref(false); const submitting = ref(false);
const hasFiles = ref(false); const hasFiles = ref(false);
const isUpdateMode = ref(false); // 是否为更新模式 const isUpdateMode = ref(false); // 是否为更新模式
const tagOptionsDefault = [
{ label: '内容', value: '内容' },
{ label: '文化', value: '文化' },
{ label: '人物', value: '人物' },
{ label: '事件', value: '事件' },
{ label: '时间', value: '时间' },
{ label: '地点', value: '地点' },
{ label: '作品', value: '作品' },
{ label: '作者', value: '作者' },
{ label: '关系', value: '关系' },
{ label: '属性', value: '属性' }
];
const tagOptions = ref([] as {label: string, value: string}[]);
const tagInput = ref('');
const popupVisibleNode1 = ref<boolean[]>([]);
const popupVisibleNode2 = ref<boolean[]>([]);
const tagFabring = ref(false);
const textFabring = ref(false);
const extracting = ref(false);
// 防抖机制:防止按钮快速重复点击 // 防抖机制:防止按钮快速重复点击
const submitDebounceTimer = ref<ReturnType<typeof setTimeout> | null>(null); const submitDebounceTimer = ref<ReturnType<typeof setTimeout> | null>(null);
@@ -874,6 +1181,13 @@ const formData = reactive({
chunkSize: 512, chunkSize: 512,
chunkOverlap: 100, chunkOverlap: 100,
separators: ['\n\n', '\n', '。', '', '', ';', ''] separators: ['\n\n', '\n', '。', '', '', ';', '']
},
nodeExtract: {
enabled: false,
text: '',
tags: [] as string[],
nodes: [] as Node[],
relations: [] as Relation[]
} }
}); });
@@ -995,8 +1309,29 @@ const canSubmit = computed(() => {
vlmOk = true; vlmOk = true;
} }
} }
let extractOk = true;
if (formData.nodeExtract.enabled) {
if (formData.nodeExtract.text === '') {
extractOk = false;
}
for (let i = 0; i < formData.nodeExtract.tags.length; i++) {
const tag = formData.nodeExtract.tags[i];
if (tag == '') {
extractOk = false;
break;
}
}
if (!readyNode.value){
extractOk = false;
}
if (!readyRelation.value){
extractOk = false;
}
}
return llmOk && embeddingOk && rerankOk && vlmOk; return llmOk && embeddingOk && rerankOk && vlmOk && extractOk;
}); });
const imageUpload = ref(null); const imageUpload = ref(null);
@@ -1034,6 +1369,10 @@ const rules = {
'embedding.dimension': [ 'embedding.dimension': [
{ required: true, message: '请输入Embedding维度', type: 'error' }, { required: true, message: '请输入Embedding维度', type: 'error' },
{ validator: validateEmbeddingDimension, message: '维度必须为有效整数值常见取值为768, 1024, 1536, 3584等', type: 'error' } { validator: validateEmbeddingDimension, message: '维度必须为有效整数值常见取值为768, 1024, 1536, 3584等', type: 'error' }
],
'nodeExtract.text': [
{ required: true, message: '请输入文本内容', type: 'error' },
{ min: 10, message: '文本内容至少需要10个字符', type: 'error' }
] ]
}; };
@@ -1283,6 +1622,20 @@ const loadCurrentConfig = async () => {
// 如果没有文档分割配置确保使用默认的precision模式 // 如果没有文档分割配置确保使用默认的precision模式
selectedPreset.value = 'precision'; selectedPreset.value = 'precision';
} }
if (config.nodeExtract.enabled) {
formData.nodeExtract.enabled = true;
formData.nodeExtract.text = config.nodeExtract.text;
formData.nodeExtract.tags = config.nodeExtract.tags;
formData.nodeExtract.nodes = config.nodeExtract.nodes;
formData.nodeExtract.relations = config.nodeExtract.relations;
tagOptions.value = [];
for (const tag of config.nodeExtract.tags) {
if (tagOptions.value.find((item) => item.value === tag)) {
continue;
}
tagOptions.value.push({ label: tag, value: tag });
}
}
// 在配置加载完成后,检查模型状态 // 在配置加载完成后,检查模型状态
await checkModelsAfterLoading(config); await checkModelsAfterLoading(config);
@@ -2027,9 +2380,9 @@ const handleSubmit = async () => {
} else { } else {
MessagePlugin.error(result.message || '操作失败'); MessagePlugin.error(result.message || '操作失败');
} }
} catch (error) { } catch (error: any) {
console.error('提交失败:', error); console.error('提交失败:', error);
MessagePlugin.error('操作失败,请检查网络连接'); MessagePlugin.error(error.message || '操作失败,请检查网络连接');
} finally { } finally {
submitting.value = false; submitting.value = false;
@@ -2050,6 +2403,288 @@ const formatFileSize = (bytes: number): string => {
return parseFloat((bytes / Math.pow(k, i)).toFixed(2)) + ' ' + sizes[i]; return parseFloat((bytes / Math.pow(k, i)).toFixed(2)) + ' ' + sizes[i];
}; };
const addTag = async (val: string) => {
val = val.trim();
if (val === '') {
MessagePlugin.error('请输入有效的标签');
return;
}
if (!tagOptions.value.find(item => item.value === val)){
tagOptions.value.push({ label: val, value: val });
}
if (!formData.nodeExtract.tags.includes(val)) {
formData.nodeExtract.tags.push(val);
}else {
MessagePlugin.error('该标签已存在');
}
tagInput.value = '';
}
const clearTags = async () => {
formData.nodeExtract.tags = [];
}
const defaultExtractExample = async () => {
formData.nodeExtract.tags = ['作者', '别名'];
formData.nodeExtract.text = `《红楼梦》又名《石头记》是清代作家曹雪芹创作的中国古典四大名著之一被誉为中国封建社会的百科全书。该书前80回由曹雪芹所著后40回一般认为是高鹗所续。小说以贾、史、王、薛四大家族的兴衰为背景以贾宝玉、林黛玉和薛宝钗的爱情悲剧为主线刻画了以贾宝玉和金陵十二钗为中心的正邪两赋、贤愚并出的高度复杂的人物群像。成书于乾隆年间1743年前后是中国文学史上现实主义的高峰对后世影响深远。`;
formData.nodeExtract.nodes = [
{name: '红楼梦', attributes: ['中国古典四大名著之一', '又名《石头记》', '被誉为中国封建社会的百科全书']},
{name: '石头记', attributes: ['《红楼梦》的别名']},
{name: '曹雪芹', attributes: ['清代作家', '《红楼梦》前 80 回的作者']},
{name: '高鹗', attributes: ['一般认为是《红楼梦》后 40 回的续写者']}
];
formData.nodeExtract.relations = [
{node1: '红楼梦', node2: '石头记', type: '别名'},
{node1: '红楼梦', node2: '曹雪芹', type: '作者'},
{node1: '红楼梦', node2: '高鹗', type: '作者'}
];
tagOptions.value = [];
tagOptions.value.push({ label: '作者', value: '作者' });
tagOptions.value.push({ label: '别名', value: '别名' });
popupVisibleNode1.value = Array(formData.nodeExtract.nodes.length).fill(false);
popupVisibleNode2.value = Array(formData.nodeExtract.nodes.length).fill(false);
}
const clearExtractExample = async () => {
formData.nodeExtract.tags = [];
formData.nodeExtract.text = '';
formData.nodeExtract.nodes = [];
formData.nodeExtract.relations = [];
tagOptions.value = [...tagOptionsDefault];
popupVisibleNode1.value = [];
popupVisibleNode2.value = [];
}
const addNode = async () =>{
formData.nodeExtract.nodes.push({
name: '',
attributes: []
});
}
const removeNode = async (index: number) => {
formData.nodeExtract.nodes.splice(index, 1);
}
const addAttribute = async (nodeIndex: number) => {
formData.nodeExtract.nodes[nodeIndex].attributes.push('');
}
const removeAttribute = async(nodeIndex: number, attrIndex: number) => {
formData.nodeExtract.nodes[nodeIndex].attributes.splice(attrIndex, 1);
}
const onRelationNode1OptionClick = async (index: number, item: Node) => {
formData.nodeExtract.relations[index].node1 = item.name;
popupVisibleNode1.value[index] = false;
}
const onRelationNode2OptionClick = async (index: number, item: Node) => {
formData.nodeExtract.relations[index].node2 = item.name;
popupVisibleNode2.value[index] = false;
}
const relationOnClearNode1 = async (index: number) => {
formData.nodeExtract.relations[index].node1 = '';
}
const relationOnClearNode2 = async (index: number) => {
formData.nodeExtract.relations[index].node2 = '';
}
const onPopupVisibleNode1Change = async (index: number, val: boolean) => {
popupVisibleNode1.value[index] = val;
};
const onPopupVisibleNode2Change = async (index: number, val: boolean) => {
popupVisibleNode2.value[index] = val;
};
const addRelation = async () => {
formData.nodeExtract.relations.push({
node1: '',
node2: '',
type: ''
});
popupVisibleNode1.value.push(false);
popupVisibleNode2.value.push(false);
}
const removeRelation = async (index: number) => {
formData.nodeExtract.relations.splice(index, 1);
}
const onFocus = async () => {};
const canExtract = async (): Promise<boolean> =>{
if (formData.nodeExtract.text === '') {
MessagePlugin.error('请输入示例文本');
return false;
}
if (formData.nodeExtract.tags.length === 0) {
MessagePlugin.error('请输入关系类型');
return false;
}
for (let i = 0; i < formData.nodeExtract.tags.length; i++) {
if (formData.nodeExtract.tags[i] === '') {
MessagePlugin.error('请输入关系类型');
return false;
}
}
if (!modelStatus.llm.available) {
MessagePlugin.error('请输入 LLM 大语言模型配置');
return false;
}
return true;
}
const readyNode = computed(() => {
for (let i = 0; i < formData.nodeExtract.nodes.length; i++) {
let node = formData.nodeExtract.nodes[i];
if (node.name === '') {
return false;
}
if (node.attributes){
for (let j = 0; j < node.attributes.length; j++) {
if (node.attributes[j] === '') {
return false;
}
}
}
}
return formData.nodeExtract.nodes.length > 0;
})
const readyRelation = computed(() => {
for (let i = 0; i < formData.nodeExtract.relations.length; i++) {
let relation = formData.nodeExtract.relations[i];
if (relation.node1 == '' || relation.node2 == '' || relation.type == '' ) {
return false
}
}
return formData.nodeExtract.relations.length > 0;
})
// 处理提取
const handleExtract = async () => {
if (extracting.value) return;
try {
// 表单验证
const isValid = await form.value?.validate();
if (!isValid) {
MessagePlugin.error('请检查表单填写是否正确');
return;
}
if (!canExtract()){
return;
}
extracting.value = true;
const request: TextRelationExtractionRequest = {
text: formData.nodeExtract.text.trim(),
tags: formData.nodeExtract.tags,
llmConfig: {
source: formData.llm.source as 'local' | 'remote',
modelName: formData.llm.modelName,
baseUrl: formData.llm.baseUrl,
apiKey: formData.llm.apiKey,
},
};
const result = await extractTextRelations(request);
if (result.nodes.length === 0 ) {
MessagePlugin.info('未提取有效节点');
} else {
formData.nodeExtract.nodes = result.nodes;
}
if ( result.relations.length === 0) {
MessagePlugin.info('未提取有效关系');
} else {
formData.nodeExtract.relations = result.relations;
}
} catch (error) {
console.error('文本内容关系提取失败:', error);
MessagePlugin.error('提取失败,请检查网络连接或文本内容格式');
} finally {
extracting.value = false;
}
};
// 处理标签
const handleFabriTag = async () => {
if (tagFabring.value) return;
try {
// 表单验证
const isValid = await form.value?.validate();
if (!isValid) {
MessagePlugin.error('请检查表单填写是否正确');
return;
}
tagFabring.value = true;
const request: FabriTagRequest = {
llmConfig: {
source: formData.llm.source as 'local' | 'remote',
modelName: formData.llm.modelName,
baseUrl: formData.llm.baseUrl,
apiKey: formData.llm.apiKey,
},
};
const result = await fabriTag(request);
formData.nodeExtract.tags = result.tags;
tagOptions.value = [];
for (let i = 0; i < result.tags.length; i++) {
tagOptions.value.push({ label: result.tags[i], value: result.tags[i] });
}
} catch (error) {
console.error('随机生成标签:', error);
MessagePlugin.error('生成失败,请重试');
} finally {
tagFabring.value = false;
}
};
// 处理示例文本
const handleFabriText = async () => {
if (textFabring.value) return;
try {
// 表单验证
const isValid = await form.value?.validate();
if (!isValid) {
MessagePlugin.error('请检查表单填写是否正确');
return;
}
textFabring.value = true;
const request: FabriTextRequest = {
tags: formData.nodeExtract.tags,
llmConfig: {
source: formData.llm.source as 'local' | 'remote',
modelName: formData.llm.modelName,
baseUrl: formData.llm.baseUrl,
apiKey: formData.llm.apiKey,
},
};
const result = await fabriText(request);
formData.nodeExtract.text = result.text;
} catch (error) {
console.error('生成示例文本失败:', error);
MessagePlugin.error('生成失败,请重试');
} finally {
textFabring.value = false;
}
};
// 组件挂载时检查Ollama状态 // 组件挂载时检查Ollama状态
onMounted(async () => { onMounted(async () => {
// 加载当前配置 // 加载当前配置
@@ -2166,6 +2801,76 @@ onMounted(async () => {
font-size: 20px; font-size: 20px;
} }
} }
.add-tag-container {
display: flex;
align-items: center; /* 垂直居中 */
justify-content: flex-start; /* 水平起始对齐 */
gap: 8px;
}
.extract-button {
display: flex;
justify-content: center;
align-items: center;
gap: 12px;
text-align: center;
}
.node-list {
display: flex;
flex-wrap: wrap;
gap: 12px;
}
.node-header {
display: flex;
align-items: center;
justify-content: flex-start;
gap: 4px;
margin-bottom: 8px;
}
.attribute-item {
display: flex;
align-items: center;
justify-content: flex-start;
margin-bottom: 4px;
}
.relation-line {
display: flex;
align-items: center;
justify-content: flex-start;
gap: 4px;
}
.relation-arrow {
font-size: 50px;
}
.sample-text-form {
display: flex;
flex-direction: column;
width: 100%;
}
}
.btn-tips-form {
display: flex;
align-items: center;
gap: 8px;
margin-bottom: 12px;
.btn-tips {
display: flex;
align-items: center;
justify-content: center;
color: #fa8c16;
.tip-icon {
margin-right: 6px;
}
}
} }
.form-row { .form-row {
@@ -2385,7 +3090,7 @@ onMounted(async () => {
} }
} }
.rerank-config, .multimodal-config { .rerank-config, .multimodal-config, .node-config {
// margin-top: 20px; // margin-top: 20px;
// padding: 20px; // padding: 20px;
// background: #f9fcff; // background: #f9fcff;
@@ -2834,4 +3539,29 @@ onMounted(async () => {
} }
} }
} }
.select-input-node {
display: flex;
flex-direction: column;
padding: 0;
gap: 2px;
}
.select-input-node > li {
display: block;
border-radius: 3px;
line-height: 22px;
cursor: pointer;
padding: 3px 8px;
color: var(--td-text-color-primary);
transition: background-color 0.2s linear;
white-space: nowrap;
word-wrap: normal;
overflow: hidden;
text-overflow: ellipsis;
}
.select-input-node > li:hover {
background-color: var(--td-bg-color-container-hover);
}
</style> </style>

9
go.mod
View File

@@ -14,11 +14,12 @@ require (
github.com/google/uuid v1.6.0 github.com/google/uuid v1.6.0
github.com/hibiken/asynq v0.25.1 github.com/hibiken/asynq v0.25.1
github.com/minio/minio-go/v7 v7.0.90 github.com/minio/minio-go/v7 v7.0.90
github.com/neo4j/neo4j-go-driver/v6 v6.0.0-alpha.1
github.com/ollama/ollama v0.11.4 github.com/ollama/ollama v0.11.4
github.com/panjf2000/ants/v2 v2.11.2 github.com/panjf2000/ants/v2 v2.11.2
github.com/parquet-go/parquet-go v0.25.0 github.com/parquet-go/parquet-go v0.25.0
github.com/pgvector/pgvector-go v0.3.0 github.com/pgvector/pgvector-go v0.3.0
github.com/redis/go-redis/v9 v9.7.3 github.com/redis/go-redis/v9 v9.14.0
github.com/sashabaranov/go-openai v1.40.5 github.com/sashabaranov/go-openai v1.40.5
github.com/sirupsen/logrus v1.9.3 github.com/sirupsen/logrus v1.9.3
github.com/spf13/viper v1.20.1 github.com/spf13/viper v1.20.1
@@ -35,7 +36,7 @@ require (
golang.org/x/crypto v0.42.0 golang.org/x/crypto v0.42.0
golang.org/x/sync v0.17.0 golang.org/x/sync v0.17.0
google.golang.org/grpc v1.73.0 google.golang.org/grpc v1.73.0
google.golang.org/protobuf v1.36.6 google.golang.org/protobuf v1.36.9
gorm.io/driver/postgres v1.5.11 gorm.io/driver/postgres v1.5.11
gorm.io/gorm v1.25.12 gorm.io/gorm v1.25.12
) )
@@ -92,7 +93,7 @@ require (
github.com/sagikazarmark/locafero v0.7.0 // indirect github.com/sagikazarmark/locafero v0.7.0 // indirect
github.com/sourcegraph/conc v0.3.0 // indirect github.com/sourcegraph/conc v0.3.0 // indirect
github.com/spf13/afero v1.12.0 // indirect github.com/spf13/afero v1.12.0 // indirect
github.com/spf13/cast v1.7.1 // indirect github.com/spf13/cast v1.10.0 // indirect
github.com/spf13/pflag v1.0.6 // indirect github.com/spf13/pflag v1.0.6 // indirect
github.com/subosito/gotenv v1.6.0 // indirect github.com/subosito/gotenv v1.6.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
@@ -105,7 +106,7 @@ require (
golang.org/x/net v0.43.0 // indirect golang.org/x/net v0.43.0 // indirect
golang.org/x/sys v0.36.0 // indirect golang.org/x/sys v0.36.0 // indirect
golang.org/x/text v0.29.0 // indirect golang.org/x/text v0.29.0 // indirect
golang.org/x/time v0.11.0 // indirect golang.org/x/time v0.13.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect

18
go.sum
View File

@@ -141,6 +141,8 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/mozillazg/go-httpheader v0.2.1 h1:geV7TrjbL8KXSyvghnFm+NyTux/hxwueTSrwhe88TQQ= github.com/mozillazg/go-httpheader v0.2.1 h1:geV7TrjbL8KXSyvghnFm+NyTux/hxwueTSrwhe88TQQ=
github.com/mozillazg/go-httpheader v0.2.1/go.mod h1:jJ8xECTlalr6ValeXYdOF8fFUISeBAdw6E61aqQma60= github.com/mozillazg/go-httpheader v0.2.1/go.mod h1:jJ8xECTlalr6ValeXYdOF8fFUISeBAdw6E61aqQma60=
github.com/neo4j/neo4j-go-driver/v6 v6.0.0-alpha.1 h1:nV3ZdYJTi73jel0mm3dpWumNY3i3nwyo25y69SPGwyg=
github.com/neo4j/neo4j-go-driver/v6 v6.0.0-alpha.1/go.mod h1:hzSTfNfM31p1uRSzL1F/BAYOgaiTarE6OAQBajfsm+I=
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
github.com/ollama/ollama v0.11.4 h1:6xLYLEPTKtw6N20qQecyEL/rrBktPO4o5U05cnvkSmI= github.com/ollama/ollama v0.11.4 h1:6xLYLEPTKtw6N20qQecyEL/rrBktPO4o5U05cnvkSmI=
@@ -157,8 +159,8 @@ github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ
github.com/pierrec/lz4/v4 v4.1.21/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= github.com/pierrec/lz4/v4 v4.1.21/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM= github.com/redis/go-redis/v9 v9.14.0 h1:u4tNCjXOyzfgeLN+vAZaW1xUooqWDqVEsZN0U01jfAE=
github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA= github.com/redis/go-redis/v9 v9.14.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
@@ -178,8 +180,8 @@ github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9yS
github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0=
github.com/spf13/afero v1.12.0 h1:UcOPyRBYczmFn6yvphxkn9ZEOY65cpwGKb5mL36mrqs= github.com/spf13/afero v1.12.0 h1:UcOPyRBYczmFn6yvphxkn9ZEOY65cpwGKb5mL36mrqs=
github.com/spf13/afero v1.12.0/go.mod h1:ZTlWwG4/ahT8W7T0WQ5uYmjI9duaLQGy3Q2OAl4sk/4= github.com/spf13/afero v1.12.0/go.mod h1:ZTlWwG4/ahT8W7T0WQ5uYmjI9duaLQGy3Q2OAl4sk/4=
github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y= github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY=
github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo=
github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o= github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o=
github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.20.1 h1:ZMi+z/lvLyPSCoNtFCpqjy0S4kPbirhpTMwl8BkW9X4= github.com/spf13/viper v1.20.1 h1:ZMi+z/lvLyPSCoNtFCpqjy0S4kPbirhpTMwl8BkW9X4=
@@ -269,8 +271,8 @@ golang.org/x/term v0.35.0 h1:bZBVKBudEyhRcajGcNc3jIfWPqV4y/Kt2XcoigOWtDQ=
golang.org/x/term v0.35.0/go.mod h1:TPGtkTLesOwf2DE8CgVYiZinHAOuy5AYUYT1lENIZnA= golang.org/x/term v0.35.0/go.mod h1:TPGtkTLesOwf2DE8CgVYiZinHAOuy5AYUYT1lENIZnA=
golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk= golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk=
golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4= golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4=
golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0= golang.org/x/time v0.13.0 h1:eUlYslOIt32DgYD6utsuUeHs4d7AsEYLuIAdg7FlYgI=
golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= golang.org/x/time v0.13.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822 h1:oWVWY3NzT7KJppx2UKhKmzPq4SRe0LdCijVRwvGeikY= google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822 h1:oWVWY3NzT7KJppx2UKhKmzPq4SRe0LdCijVRwvGeikY=
google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822/go.mod h1:h3c4v36UTKzUiuaOKQ6gr3S+0hovBtUrXzTG/i3+XEc= google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822/go.mod h1:h3c4v36UTKzUiuaOKQ6gr3S+0hovBtUrXzTG/i3+XEc=
@@ -278,8 +280,8 @@ google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 h1:
google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A= google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A=
google.golang.org/grpc v1.73.0 h1:VIWSmpI2MegBtTuFt5/JWy2oXxtjJ/e89Z70ImfD2ok= google.golang.org/grpc v1.73.0 h1:VIWSmpI2MegBtTuFt5/JWy2oXxtjJ/e89Z70ImfD2ok=
google.golang.org/grpc v1.73.0/go.mod h1:50sbHOUqWoCQGI8V2HQLJM0B+LMlIUjNSZmow7EVBQc= google.golang.org/grpc v1.73.0/go.mod h1:50sbHOUqWoCQGI8V2HQLJM0B+LMlIUjNSZmow7EVBQc=
google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= google.golang.org/protobuf v1.36.9 h1:w2gp2mA27hUeUzj9Ex9FBjsBm40zfaDtEWow293U7Iw=
google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= google.golang.org/protobuf v1.36.9/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=

View File

@@ -0,0 +1,231 @@
package neo4j
import (
"context"
"fmt"
"strings"
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
"github.com/neo4j/neo4j-go-driver/v6/neo4j"
)
type Neo4jRepository struct {
driver neo4j.Driver
nodePrefix string
}
func NewNeo4jRepository(driver neo4j.Driver) interfaces.RetrieveGraphRepository {
return &Neo4jRepository{driver: driver, nodePrefix: "ENTITY"}
}
func _remove_hyphen(s string) string {
return strings.ReplaceAll(s, "-", "_")
}
func (n *Neo4jRepository) Labels(namespace types.NameSpace) []string {
res := make([]string, 0)
for _, label := range namespace.Labels() {
res = append(res, n.nodePrefix+_remove_hyphen(label))
}
return res
}
func (n *Neo4jRepository) Label(namespace types.NameSpace) string {
labels := n.Labels(namespace)
return strings.Join(labels, ":")
}
// AddGraph implements interfaces.RetrieveGraphRepository.
func (n *Neo4jRepository) AddGraph(ctx context.Context, namespace types.NameSpace, graphs []*types.GraphData) error {
if n.driver == nil {
logger.Warnf(ctx, "NOT SUPPORT RETRIEVE GRAPH")
return nil
}
for _, graph := range graphs {
if err := n.addGraph(ctx, namespace, graph); err != nil {
return err
}
}
return nil
}
func (n *Neo4jRepository) addGraph(ctx context.Context, namespace types.NameSpace, graph *types.GraphData) error {
session := n.driver.NewSession(ctx, neo4j.SessionConfig{AccessMode: neo4j.AccessModeWrite})
defer session.Close(ctx)
_, err := session.ExecuteWrite(ctx, func(tx neo4j.ManagedTransaction) (interface{}, error) {
node_import_query := `
UNWIND $data AS row
CALL apoc.merge.node(row.labels, {name: row.name, kg: row.knowledge_id}, row.props, {}) YIELD node
SET node.chunks = apoc.coll.union(node.chunks, row.chunks)
RETURN distinct 'done' AS result
`
nodeData := []map[string]interface{}{}
for _, node := range graph.Node {
nodeData = append(nodeData, map[string]interface{}{
"name": node.Name,
"knowledge_id": namespace.Knowledge,
"props": map[string][]string{"attributes": node.Attributes},
"chunks": node.Chunks,
"labels": n.Labels(namespace),
})
}
if _, err := tx.Run(ctx, node_import_query, map[string]interface{}{"data": nodeData}); err != nil {
return nil, fmt.Errorf("failed to create nodes: %v", err)
}
rel_import_query := `
UNWIND $data AS row
CALL apoc.merge.node(row.source_labels, {name: row.source, kg: row.knowledge_id}, {}, {}) YIELD node as source
CALL apoc.merge.node(row.target_labels, {name: row.target, kg: row.knowledge_id}, {}, {}) YIELD node as target
CALL apoc.merge.relationship(source, row.type, {}, row.attributes, target) YIELD rel
RETURN distinct 'done'
`
relData := []map[string]interface{}{}
for _, rel := range graph.Relation {
relData = append(relData, map[string]interface{}{
"source": rel.Node1,
"target": rel.Node2,
"knowledge_id": namespace.Knowledge,
"type": rel.Type,
"source_labels": n.Labels(namespace),
"target_labels": n.Labels(namespace),
})
}
if _, err := tx.Run(ctx, rel_import_query, map[string]interface{}{"data": relData}); err != nil {
return nil, fmt.Errorf("failed to create relationships: %v", err)
}
return nil, nil
})
if err != nil {
logger.Errorf(ctx, "failed to add graph: %v", err)
return err
}
return nil
}
// DelGraph implements interfaces.RetrieveGraphRepository.
func (n *Neo4jRepository) DelGraph(ctx context.Context, namespaces []types.NameSpace) error {
if n.driver == nil {
logger.Warnf(ctx, "NOT SUPPORT RETRIEVE GRAPH")
return nil
}
session := n.driver.NewSession(ctx, neo4j.SessionConfig{AccessMode: neo4j.AccessModeWrite})
defer session.Close(ctx)
result, err := session.ExecuteWrite(ctx, func(tx neo4j.ManagedTransaction) (interface{}, error) {
for _, namespace := range namespaces {
labelExpr := n.Label(namespace)
deleteRelsQuery := `
CALL apoc.periodic.iterate(
"MATCH (n:` + labelExpr + ` {kg: $knowledge_id})-[r]-(m:` + labelExpr + ` {kg: $knowledge_id}) RETURN r",
"DELETE r",
{batchSize: 1000, parallel: true, params: {knowledge_id: $knowledge_id}}
) YIELD batches, total
RETURN total
`
if _, err := tx.Run(ctx, deleteRelsQuery, map[string]interface{}{"knowledge_id": namespace.Knowledge}); err != nil {
return nil, fmt.Errorf("failed to delete relationships: %v", err)
}
deleteNodesQuery := `
CALL apoc.periodic.iterate(
"MATCH (n:` + labelExpr + ` {kg: $knowledge_id}) RETURN n",
"DELETE n",
{batchSize: 1000, parallel: true, params: {knowledge_id: $knowledge_id}}
) YIELD batches, total
RETURN total
`
if _, err := tx.Run(ctx, deleteNodesQuery, map[string]interface{}{"knowledge_id": namespace.Knowledge}); err != nil {
return nil, fmt.Errorf("failed to delete nodes: %v", err)
}
}
return nil, nil
})
if err != nil {
return err
}
logger.Infof(ctx, "delete graph result: %v", result)
return nil
}
func (n *Neo4jRepository) SearchNode(ctx context.Context, namespace types.NameSpace, nodes []string) (*types.GraphData, error) {
if n.driver == nil {
logger.Warnf(ctx, "NOT SUPPORT RETRIEVE GRAPH")
return nil, nil
}
session := n.driver.NewSession(ctx, neo4j.SessionConfig{AccessMode: neo4j.AccessModeRead})
defer session.Close(ctx)
result, err := session.ExecuteRead(ctx, func(tx neo4j.ManagedTransaction) (interface{}, error) {
labelExpr := n.Label(namespace)
query := `
MATCH (n:` + labelExpr + `)-[r]-(m:` + labelExpr + `)
WHERE ANY(nodeText IN $nodes WHERE n.name CONTAINS nodeText)
RETURN n, r, m
`
params := map[string]interface{}{"nodes": nodes}
result, err := tx.Run(ctx, query, params)
if err != nil {
return nil, fmt.Errorf("failed to run query: %v", err)
}
graphData := &types.GraphData{}
nodeSeen := make(map[string]bool)
for result.Next(ctx) {
record := result.Record()
node, _ := record.Get("n")
rel, _ := record.Get("r")
targetNode, _ := record.Get("m")
nodeData := node.(neo4j.Node)
targetNodeData := targetNode.(neo4j.Node)
// Convert node to types.Node
for _, n := range []neo4j.Node{nodeData, targetNodeData} {
nameStr := n.Props["name"].(string)
if _, ok := nodeSeen[nameStr]; !ok {
nodeSeen[nameStr] = true
graphData.Node = append(graphData.Node, &types.GraphNode{
Name: nameStr,
Chunks: listI2listS(n.Props["chunks"].([]interface{})),
Attributes: listI2listS(n.Props["attributes"].([]interface{})),
})
}
}
// Convert relationship to types.Relation
relData := rel.(neo4j.Relationship)
graphData.Relation = append(graphData.Relation, &types.GraphRelation{
Node1: nodeData.Props["name"].(string),
Node2: targetNodeData.Props["name"].(string),
Type: relData.Type,
})
}
return graphData, nil
})
if err != nil {
logger.Errorf(ctx, "search node failed: %v", err)
return nil, err
}
return result.(*types.GraphData), nil
}
func listI2listS(list []any) []string {
result := make([]string, len(list))
for i, v := range list {
result[i] = fmt.Sprintf("%v", v)
}
return result
}
func mapI2mapS(prop map[string]any) map[string]string {
attributes := make(map[string]string)
for k, v := range prop {
attributes[k] = fmt.Sprintf("%v", v)
}
return attributes
}

View File

@@ -0,0 +1,499 @@
package chatpipline
import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"regexp"
"strings"
"github.com/Tencent/WeKnora/internal/config"
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/models/chat"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
)
// PluginExtractEntity is a plugin for extracting entities from user queries
// It uses historical dialog context and large language models to identify key entities in the user's original query
type PluginExtractEntity struct {
modelService interfaces.ModelService // Model service for calling large language models
template *types.PromptTemplateStructured // Template for generating prompts
knowledgeBaseRepo interfaces.KnowledgeBaseRepository
}
// NewPluginRewrite creates a new query rewriting plugin instance
// Also registers the plugin with the event manager
func NewPluginExtractEntity(
eventManager *EventManager,
modelService interfaces.ModelService,
knowledgeBaseRepo interfaces.KnowledgeBaseRepository,
config *config.Config,
) *PluginExtractEntity {
res := &PluginExtractEntity{
modelService: modelService,
template: config.ExtractManager.ExtractEntity,
knowledgeBaseRepo: knowledgeBaseRepo,
}
eventManager.Register(res)
return res
}
// ActivationEvents returns the list of event types this plugin responds to
// This plugin only responds to REWRITE_QUERY events
func (p *PluginExtractEntity) ActivationEvents() []types.EventType {
return []types.EventType{types.REWRITE_QUERY}
}
// OnEvent processes triggered events
// When receiving a REWRITE_QUERY event, it rewrites the user query using conversation history and the language model
func (p *PluginExtractEntity) OnEvent(ctx context.Context,
eventType types.EventType, chatManage *types.ChatManage, next func() *PluginError,
) *PluginError {
if strings.ToLower(os.Getenv("NEO4J_ENABLE")) != "true" {
logger.Debugf(ctx, "skipping extract entity, neo4j is disabled")
return next()
}
query := chatManage.Query
model, err := p.modelService.GetChatModel(ctx, chatManage.ChatModelID)
if err != nil {
logger.Errorf(ctx, "Failed to get model, session_id: %s, error: %v", chatManage.SessionID, err)
return next()
}
kb, err := p.knowledgeBaseRepo.GetKnowledgeBaseByID(ctx, chatManage.KnowledgeBaseID)
if err != nil {
logger.Errorf(ctx, "failed to get knowledge base: %v", err)
return next()
}
if kb.ExtractConfig == nil {
logger.Warnf(ctx, "failed to get extract config")
return next()
}
template := &types.PromptTemplateStructured{
Description: p.template.Description,
Examples: p.template.Examples,
}
extractor := NewExtractor(model, template)
graph, err := extractor.Extract(ctx, query)
if err != nil {
logger.Errorf(ctx, "Failed to extract entities, session_id: %s, error: %v", chatManage.SessionID, err)
return next()
}
nodes := []string{}
for _, node := range graph.Node {
nodes = append(nodes, node.Name)
}
logger.Debugf(ctx, "extracted node: %v", nodes)
chatManage.Entity = nodes
return next()
}
type Extractor struct {
chat chat.Chat
formater *Formater
template *types.PromptTemplateStructured
chatOpt *chat.ChatOptions
}
func NewExtractor(
chatModel chat.Chat,
template *types.PromptTemplateStructured,
) Extractor {
think := false
return Extractor{
chat: chatModel,
formater: NewFormater(),
template: template,
chatOpt: &chat.ChatOptions{
Temperature: 0.3,
MaxTokens: 4096,
Thinking: &think,
},
}
}
func (e *Extractor) Extract(ctx context.Context, content string) (*types.GraphData, error) {
generator := NewQAPromptGenerator(e.formater, e.template)
// logger.Debugf(ctx, "chat system: %s", generator.System(ctx))
// logger.Debugf(ctx, "chat user: %s", generator.User(ctx, content))
chatResponse, err := e.chat.Chat(ctx, generator.Render(ctx, content), e.chatOpt)
if err != nil {
logger.Errorf(ctx, "failed to chat: %v", err)
return nil, err
}
graph, err := e.formater.ParseGraph(ctx, chatResponse.Content)
if err != nil {
logger.Errorf(ctx, "failed to parse graph: %v", err)
return nil, err
}
// e.RemoveUnknownRelation(ctx, graph)
return graph, nil
}
func (e *Extractor) RemoveUnknownRelation(ctx context.Context, graph *types.GraphData) {
relationType := make(map[string]bool)
for _, tag := range e.template.Tags {
relationType[tag] = true
}
relationNew := make([]*types.GraphRelation, 0)
for _, relation := range graph.Relation {
if _, ok := relationType[relation.Type]; ok {
relationNew = append(relationNew, relation)
} else {
logger.Infof(ctx, "Unknown relation type %s with %v, ignore it", relation.Type, e.template.Tags)
}
}
graph.Relation = relationNew
}
type QAPromptGenerator struct {
Formater *Formater
Template *types.PromptTemplateStructured
ExamplesHeading string
QuestionHeading string
QuestionPrefix string
AnswerPrefix string
}
func NewQAPromptGenerator(formater *Formater, template *types.PromptTemplateStructured) *QAPromptGenerator {
return &QAPromptGenerator{
Formater: formater,
Template: template,
ExamplesHeading: "# Examples",
QuestionHeading: "# Question",
QuestionPrefix: "Q: ",
AnswerPrefix: "A: ",
}
}
func (qa *QAPromptGenerator) System(ctx context.Context) string {
promptLines := []string{}
if len(qa.Template.Tags) == 0 {
promptLines = append(promptLines, qa.Template.Description)
} else {
tags, _ := json.Marshal(qa.Template.Tags)
promptLines = append(promptLines, fmt.Sprintf(qa.Template.Description, string(tags)))
}
if len(qa.Template.Examples) > 0 {
promptLines = append(promptLines, qa.ExamplesHeading)
for _, example := range qa.Template.Examples {
// Question
promptLines = append(promptLines, fmt.Sprintf("%s%s", qa.QuestionPrefix, strings.TrimSpace(example.Text)))
// Answer
answer, err := qa.Formater.formatExtraction(example.Node, example.Relation)
if err != nil {
return ""
}
promptLines = append(promptLines, fmt.Sprintf("%s%s", qa.AnswerPrefix, answer))
// new line
promptLines = append(promptLines, "")
}
}
return strings.Join(promptLines, "\n")
}
func (qa *QAPromptGenerator) User(ctx context.Context, question string) string {
promptLines := []string{}
promptLines = append(promptLines, qa.QuestionHeading)
promptLines = append(promptLines, fmt.Sprintf("%s%s", qa.QuestionPrefix, question))
promptLines = append(promptLines, qa.AnswerPrefix)
return strings.Join(promptLines, "\n")
}
func (qa *QAPromptGenerator) Render(ctx context.Context, question string) []chat.Message {
return []chat.Message{
{
Role: "system",
Content: qa.System(ctx),
},
{
Role: "user",
Content: qa.User(ctx, question),
},
}
}
type FormatType string
const (
FormatTypeJSON FormatType = "json"
FormatTypeYAML FormatType = "yaml"
)
const (
_FENCE_START = "```"
_LANGUAGE_TAG = `(?P<lang>[A-Za-z0-9_+-]+)?`
_FENCE_NEWLINE = `(?:\s*\n)?`
_FENCE_BODY = `(?P<body>[\s\S]*?)`
_FENCE_END = "```"
)
var _FENCE_RE = regexp.MustCompile(
_FENCE_START + _LANGUAGE_TAG + _FENCE_NEWLINE + _FENCE_BODY + _FENCE_END,
)
type Formater struct {
attributeSuffix string
formatType FormatType
useFences bool
nodePrefix string
relationSource string
relationTarget string
relationPrefix string
}
func NewFormater() *Formater {
return &Formater{
attributeSuffix: "_attributes",
formatType: FormatTypeJSON,
useFences: true,
nodePrefix: "entity",
relationSource: "entity1",
relationTarget: "entity2",
relationPrefix: "relation",
}
}
func (f *Formater) formatExtraction(nodes []*types.GraphNode, relations []*types.GraphRelation) (string, error) {
items := make([]map[string]interface{}, 0)
for _, node := range nodes {
item := map[string]interface{}{
f.nodePrefix: node.Name,
}
if len(node.Attributes) > 0 {
item[fmt.Sprintf("%s%s", f.nodePrefix, f.attributeSuffix)] = node.Attributes
}
items = append(items, item)
}
for _, relation := range relations {
item := map[string]interface{}{
f.relationSource: relation.Node1,
f.relationTarget: relation.Node2,
f.relationPrefix: relation.Type,
}
items = append(items, item)
}
formatted := ""
switch f.formatType {
default:
formattedBytes, err := json.MarshalIndent(items, "", " ")
if err != nil {
return "", err
}
formatted = string(formattedBytes)
}
if f.useFences {
formatted = f.addFences(formatted)
}
return formatted, nil
}
func (f *Formater) parseOutput(ctx context.Context, text string) ([]map[string]interface{}, error) {
if text == "" {
return nil, errors.New("Empty or invalid input string.")
}
content := f.extractContent(ctx, text)
// logger.Debugf(ctx, "Extracted content: %s", content)
if content == "" {
return nil, errors.New("Empty or invalid input string.")
}
var parsed interface{}
var err error
if f.formatType == FormatTypeJSON {
err = json.Unmarshal([]byte(content), &parsed)
}
if err != nil {
return nil, fmt.Errorf("Failed to parse %s content: %s", strings.ToUpper(string(f.formatType)), err.Error())
}
if parsed == nil {
return nil, fmt.Errorf("Content must be a list of extractions or a dict.")
}
var items []interface{}
if parsedMap, ok := parsed.(map[string]interface{}); ok {
items = []interface{}{parsedMap}
} else if parsedList, ok := parsed.([]interface{}); ok {
items = parsedList
} else {
return nil, fmt.Errorf("Expected list or dict, got %T", parsed)
}
itemsList := make([]map[string]interface{}, 0)
for _, item := range items {
if itemMap, ok := item.(map[string]interface{}); ok {
itemsList = append(itemsList, itemMap)
} else {
return nil, fmt.Errorf("Each item in the sequence must be a mapping.")
}
}
return itemsList, nil
}
func (f *Formater) ParseGraph(ctx context.Context, text string) (*types.GraphData, error) {
matchData, err := f.parseOutput(ctx, text)
if err != nil {
return nil, err
}
if len(matchData) == 0 {
logger.Debugf(ctx, "Received empty extraction data.")
return &types.GraphData{}, nil
}
// mm, _ := json.Marshal(matchData)
// logger.Debugf(ctx, "Parsed graph data: %s", string(mm))
var nodes []*types.GraphNode
var relations []*types.GraphRelation
for _, group := range matchData {
switch {
case group[f.nodePrefix] != nil:
attributes := make([]string, 0)
attributesKey := f.nodePrefix + f.attributeSuffix
if attr, ok := group[attributesKey].([]interface{}); ok {
for _, v := range attr {
attributes = append(attributes, fmt.Sprintf("%v", v))
}
}
nodes = append(nodes, &types.GraphNode{
Name: fmt.Sprintf("%v", group[f.nodePrefix]),
Attributes: attributes,
})
case group[f.relationSource] != nil && group[f.relationTarget] != nil:
relations = append(relations, &types.GraphRelation{
Node1: fmt.Sprintf("%v", group[f.relationSource]),
Node2: fmt.Sprintf("%v", group[f.relationTarget]),
Type: fmt.Sprintf("%v", group[f.relationPrefix]),
})
default:
logger.Warnf(ctx, "Unsupported graph group: %v", group)
continue
}
}
graph := &types.GraphData{
Node: nodes,
Relation: relations,
}
f.rebuildGraph(ctx, graph)
return graph, nil
}
func (f *Formater) rebuildGraph(ctx context.Context, graph *types.GraphData) {
nodeMap := make(map[string]*types.GraphNode)
nodes := make([]*types.GraphNode, 0, len(graph.Node))
for _, node := range graph.Node {
if prenode, ok := nodeMap[node.Name]; ok {
logger.Infof(ctx, "Duplicate node ID: %s, merge attribute", node.Name)
// 修复panic检查Attributes是否为nil
if node.Attributes == nil {
node.Attributes = make([]string, 0)
}
if prenode.Attributes != nil {
for _, attr := range prenode.Attributes {
node.Attributes = append(node.Attributes, attr)
}
}
continue
}
nodeMap[node.Name] = node
nodes = append(nodes, node)
}
relations := make([]*types.GraphRelation, 0, len(graph.Relation))
for _, relation := range graph.Relation {
if relation.Node1 == relation.Node2 {
logger.Infof(ctx, "Duplicate relation, ignore it")
continue
}
if _, ok := nodeMap[relation.Node1]; !ok {
node := &types.GraphNode{Name: relation.Node1}
nodes = append(nodes, node)
nodeMap[relation.Node1] = node
logger.Infof(ctx, "Add unknown source node ID: %s", relation.Node1)
}
if _, ok := nodeMap[relation.Node2]; !ok {
node := &types.GraphNode{Name: relation.Node2}
nodes = append(nodes, node)
nodeMap[relation.Node2] = node
logger.Infof(ctx, "Add unknown target node ID: %s", relation.Node2)
}
relations = append(relations, relation)
}
*graph = types.GraphData{
Node: nodes,
Relation: relations,
}
}
func (f *Formater) extractContent(ctx context.Context, text string) string {
if !f.useFences {
return strings.TrimSpace(text)
}
validTags := map[FormatType]map[string]struct{}{
FormatTypeYAML: {"yaml": {}, "yml": {}},
FormatTypeJSON: {"json": {}},
}
matches := _FENCE_RE.FindAllStringSubmatch(text, -1)
var candidates []string
for _, match := range matches {
lang := match[1]
body := match[2]
if f.isValidLanguageTag(lang, validTags) {
candidates = append(candidates, body)
}
}
switch {
case len(candidates) == 1:
return strings.TrimSpace(candidates[0])
case len(candidates) > 1:
logger.Warnf(ctx, "multiple candidates found: %d", len(candidates))
return strings.TrimSpace(candidates[0])
case len(matches) == 1:
logger.Debugf(ctx, "no candidate found, use first match without language tag: %s", matches[0][1])
return strings.TrimSpace(matches[0][2])
case len(matches) > 1:
logger.Warnf(ctx, "multiple matches found: %d", len(matches))
return strings.TrimSpace(matches[0][2])
default:
logger.Warnf(ctx, "no match found")
return strings.TrimSpace(text)
}
}
func (f *Formater) addFences(content string) string {
content = strings.TrimSpace(content)
return fmt.Sprintf("```%s\n%s\n```", f.formatType, content)
}
func (f *Formater) isValidLanguageTag(lang string, validTags map[FormatType]map[string]struct{}) bool {
if lang == "" {
return true
}
tag := strings.TrimSpace(strings.ToLower(lang))
validSet, ok := validTags[f.formatType]
if !ok {
return false
}
_, exists := validSet[tag]
return exists
}

View File

@@ -77,6 +77,7 @@ func (p *PluginSearch) OnEvent(ctx context.Context,
} }
chatManage.SearchResult = append(chatManage.SearchResult, searchResults...) chatManage.SearchResult = append(chatManage.SearchResult, searchResults...)
} }
// remove duplicate results // remove duplicate results
chatManage.SearchResult = removeDuplicateResults(chatManage.SearchResult) chatManage.SearchResult = removeDuplicateResults(chatManage.SearchResult)

View File

@@ -0,0 +1,136 @@
package chatpipline
import (
"context"
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
)
// PluginSearch implements search functionality for chat pipeline
type PluginSearchEntity struct {
graphRepo interfaces.RetrieveGraphRepository
chunkRepo interfaces.ChunkRepository
knowledgeRepo interfaces.KnowledgeRepository
}
func NewPluginSearchEntity(
eventManager *EventManager,
graphRepository interfaces.RetrieveGraphRepository,
chunkRepository interfaces.ChunkRepository,
knowledgeRepository interfaces.KnowledgeRepository,
) *PluginSearchEntity {
res := &PluginSearchEntity{
graphRepo: graphRepository,
chunkRepo: chunkRepository,
knowledgeRepo: knowledgeRepository,
}
eventManager.Register(res)
return res
}
// ActivationEvents returns the event types this plugin handles
func (p *PluginSearchEntity) ActivationEvents() []types.EventType {
return []types.EventType{types.ENTITY_SEARCH}
}
// OnEvent handles search events in the chat pipeline
func (p *PluginSearchEntity) OnEvent(ctx context.Context,
eventType types.EventType, chatManage *types.ChatManage, next func() *PluginError,
) *PluginError {
entity := chatManage.Entity
if len(entity) == 0 {
logger.Infof(ctx, "No entity found")
return next()
}
graph, err := p.graphRepo.SearchNode(ctx, types.NameSpace{KnowledgeBase: chatManage.KnowledgeBaseID}, entity)
if err != nil {
logger.Errorf(ctx, "Failed to search node, session_id: %s, error: %v", chatManage.SessionID, err)
return next()
}
chatManage.GraphResult = graph
logger.Infof(ctx, "search entity result count: %d", len(graph.Node))
// graphStr, _ := json.Marshal(graph)
// logger.Debugf(ctx, "search entity result: %s", string(graphStr))
chunkIDs := filterSeenChunk(ctx, graph, chatManage.SearchResult)
if len(chunkIDs) == 0 {
logger.Infof(ctx, "No new chunk found")
return next()
}
chunks, err := p.chunkRepo.ListChunksByID(ctx, ctx.Value(types.TenantIDContextKey).(uint), chunkIDs)
if err != nil {
logger.Errorf(ctx, "Failed to list chunks, session_id: %s, error: %v", chatManage.SessionID, err)
return next()
}
knowledgeIDs := []string{}
for _, chunk := range chunks {
knowledgeIDs = append(knowledgeIDs, chunk.KnowledgeID)
}
knowledges, err := p.knowledgeRepo.GetKnowledgeBatch(ctx, ctx.Value(types.TenantIDContextKey).(uint), knowledgeIDs)
if err != nil {
logger.Errorf(ctx, "Failed to list knowledge, session_id: %s, error: %v", chatManage.SessionID, err)
return next()
}
knowledgeMap := map[string]*types.Knowledge{}
for _, knowledge := range knowledges {
knowledgeMap[knowledge.ID] = knowledge
}
for _, chunk := range chunks {
searchResult := chunk2SearchResult(chunk, knowledgeMap[chunk.KnowledgeID])
chatManage.SearchResult = append(chatManage.SearchResult, searchResult)
}
// remove duplicate results
chatManage.SearchResult = removeDuplicateResults(chatManage.SearchResult)
if len(chatManage.SearchResult) == 0 {
logger.Infof(ctx, "No new search result, session_id: %s", chatManage.SessionID)
return ErrSearchNothing
}
logger.Infof(ctx, "search entity result count: %d, session_id: %s", len(chatManage.SearchResult), chatManage.SessionID)
return next()
}
func filterSeenChunk(ctx context.Context, graph *types.GraphData, searchResult []*types.SearchResult) []string {
seen := map[string]bool{}
for _, chunk := range searchResult {
seen[chunk.ID] = true
}
logger.Infof(ctx, "filterSeenChunk: seen count: %d", len(seen))
chunkIDs := []string{}
for _, node := range graph.Node {
for _, chunkID := range node.Chunks {
if seen[chunkID] {
continue
}
seen[chunkID] = true
chunkIDs = append(chunkIDs, chunkID)
}
}
logger.Infof(ctx, "filterSeenChunk: new chunkIDs count: %d", len(chunkIDs))
return chunkIDs
}
func chunk2SearchResult(chunk *types.Chunk, knowledge *types.Knowledge) *types.SearchResult {
return &types.SearchResult{
ID: chunk.ID,
Content: chunk.Content,
KnowledgeID: chunk.KnowledgeID,
ChunkIndex: chunk.ChunkIndex,
KnowledgeTitle: knowledge.Title,
StartAt: chunk.StartAt,
EndAt: chunk.EndAt,
Seq: chunk.ChunkIndex,
Score: 1.0,
MatchType: types.MatchTypeGraph,
Metadata: knowledge.GetMetadata(),
ChunkType: string(chunk.ChunkType),
ParentChunkID: chunk.ParentChunkID,
ImageInfo: chunk.ImageInfo,
KnowledgeFilename: knowledge.FileName,
KnowledgeSource: knowledge.Source,
}
}

View File

@@ -0,0 +1,137 @@
package service
import (
"context"
"encoding/json"
"fmt"
"os"
"strings"
chatpipline "github.com/Tencent/WeKnora/internal/application/service/chat_pipline"
"github.com/Tencent/WeKnora/internal/config"
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
"github.com/google/uuid"
"github.com/hibiken/asynq"
)
func NewChunkExtractTask(ctx context.Context, client *asynq.Client, tenantID uint, chunkID string, modelID string) error {
if strings.ToLower(os.Getenv("NEO4J_ENABLE")) != "true" {
logger.Debugf(ctx, "NOT SUPPORT RETRIEVE GRAPH")
return nil
}
payload, err := json.Marshal(types.ExtractChunkPayload{
TenantID: tenantID,
ChunkID: chunkID,
ModelID: modelID,
})
if err != nil {
return err
}
task := asynq.NewTask(types.TypeChunkExtract, payload, asynq.MaxRetry(3))
info, err := client.Enqueue(task)
if err != nil {
logger.Errorf(ctx, "failed to enqueue task: %v", err)
return fmt.Errorf("failed to enqueue task: %v", err)
}
logger.Infof(ctx, "enqueued task: id=%s queue=%s chunk=%s", info.ID, info.Queue, chunkID)
return nil
}
type ChunkExtractService struct {
template *types.PromptTemplateStructured
modelService interfaces.ModelService
knowledgeBaseRepo interfaces.KnowledgeBaseRepository
chunkRepo interfaces.ChunkRepository
graphEngine interfaces.RetrieveGraphRepository
}
func NewChunkExtractService(
config *config.Config,
modelService interfaces.ModelService,
knowledgeBaseRepo interfaces.KnowledgeBaseRepository,
chunkRepo interfaces.ChunkRepository,
graphEngine interfaces.RetrieveGraphRepository,
) interfaces.Extracter {
generator := chatpipline.NewQAPromptGenerator(chatpipline.NewFormater(), config.ExtractManager.ExtractGraph)
ctx := context.Background()
logger.Debugf(ctx, "chunk extract system prompt: %s", generator.System(ctx))
logger.Debugf(ctx, "chunk extract user prompt: %s", generator.User(ctx, "demo"))
return &ChunkExtractService{
template: config.ExtractManager.ExtractGraph,
modelService: modelService,
knowledgeBaseRepo: knowledgeBaseRepo,
chunkRepo: chunkRepo,
graphEngine: graphEngine,
}
}
func (s *ChunkExtractService) Extract(ctx context.Context, t *asynq.Task) error {
var p types.ExtractChunkPayload
if err := json.Unmarshal(t.Payload(), &p); err != nil {
logger.Errorf(ctx, "failed to unmarshal task payload: %v", err)
return err
}
ctx = logger.WithRequestID(ctx, uuid.New().String())
ctx = logger.WithField(ctx, "extract", p.ChunkID)
ctx = context.WithValue(ctx, types.TenantIDContextKey, p.TenantID)
chunk, err := s.chunkRepo.GetChunkByID(ctx, p.TenantID, p.ChunkID)
if err != nil {
logger.Errorf(ctx, "failed to get chunk: %v", err)
return err
}
kb, err := s.knowledgeBaseRepo.GetKnowledgeBaseByID(ctx, chunk.KnowledgeBaseID)
if err != nil {
logger.Errorf(ctx, "failed to get knowledge base: %v", err)
return err
}
if kb.ExtractConfig == nil {
logger.Warnf(ctx, "failed to get extract config")
return err
}
chatModel, err := s.modelService.GetChatModel(ctx, p.ModelID)
if err != nil {
logger.Errorf(ctx, "failed to get chat model: %v", err)
return err
}
template := &types.PromptTemplateStructured{
Description: s.template.Description,
Tags: kb.ExtractConfig.Tags,
Examples: []types.GraphData{
{
Text: kb.ExtractConfig.Text,
Node: kb.ExtractConfig.Nodes,
Relation: kb.ExtractConfig.Relations,
},
},
}
extractor := chatpipline.NewExtractor(chatModel, template)
graph, err := extractor.Extract(ctx, chunk.Content)
if err != nil {
return err
}
chunk, err = s.chunkRepo.GetChunkByID(ctx, p.TenantID, p.ChunkID)
if err != nil {
logger.Warnf(ctx, "graph ignore chunk %s: %v", p.ChunkID, err)
return nil
}
for _, node := range graph.Node {
node.Chunks = []string{chunk.ID}
}
if err = s.graphEngine.AddGraph(ctx,
types.NameSpace{KnowledgeBase: chunk.KnowledgeBaseID, Knowledge: chunk.KnowledgeID},
[]*types.GraphData{graph},
); err != nil {
logger.Errorf(ctx, "failed to add graph: %v", err)
return err
}
// gg, _ := json.Marshal(graph)
// logger.Infof(ctx, "extracted graph: %s", string(gg))
return nil
}

View File

@@ -29,6 +29,7 @@ import (
"github.com/Tencent/WeKnora/services/docreader/src/client" "github.com/Tencent/WeKnora/services/docreader/src/client"
"github.com/Tencent/WeKnora/services/docreader/src/proto" "github.com/Tencent/WeKnora/services/docreader/src/proto"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/hibiken/asynq"
"go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/attribute"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
) )
@@ -61,6 +62,8 @@ type knowledgeService struct {
chunkRepo interfaces.ChunkRepository chunkRepo interfaces.ChunkRepository
fileSvc interfaces.FileService fileSvc interfaces.FileService
modelService interfaces.ModelService modelService interfaces.ModelService
task *asynq.Client
graphEngine interfaces.RetrieveGraphRepository
} }
// NewKnowledgeService creates a new knowledge service instance // NewKnowledgeService creates a new knowledge service instance
@@ -74,6 +77,8 @@ func NewKnowledgeService(
chunkRepo interfaces.ChunkRepository, chunkRepo interfaces.ChunkRepository,
fileSvc interfaces.FileService, fileSvc interfaces.FileService,
modelService interfaces.ModelService, modelService interfaces.ModelService,
task *asynq.Client,
graphEngine interfaces.RetrieveGraphRepository,
) (interfaces.KnowledgeService, error) { ) (interfaces.KnowledgeService, error) {
return &knowledgeService{ return &knowledgeService{
config: config, config: config,
@@ -85,6 +90,8 @@ func NewKnowledgeService(
chunkRepo: chunkRepo, chunkRepo: chunkRepo,
fileSvc: fileSvc, fileSvc: fileSvc,
modelService: modelService, modelService: modelService,
task: task,
graphEngine: graphEngine,
}, nil }, nil
} }
@@ -488,6 +495,16 @@ func (s *knowledgeService) DeleteKnowledge(ctx context.Context, id string) error
return nil return nil
}) })
// Delete the knowledge graph
wg.Go(func() error {
namespace := types.NameSpace{KnowledgeBase: knowledge.KnowledgeBaseID, Knowledge: knowledge.ID}
if err := s.graphEngine.DelGraph(ctx, []types.NameSpace{namespace}); err != nil {
logger.GetLogger(ctx).WithField("error", err).Errorf("DeleteKnowledge delete knowledge graph failed")
return err
}
return nil
})
if err = wg.Wait(); err != nil { if err = wg.Wait(); err != nil {
return err return err
} }
@@ -561,6 +578,19 @@ func (s *knowledgeService) DeleteKnowledgeList(ctx context.Context, ids []string
return nil return nil
}) })
// Delete the knowledge graph
wg.Go(func() error {
namespaces := []types.NameSpace{}
for _, knowledge := range knowledgeList {
namespaces = append(namespaces, types.NameSpace{KnowledgeBase: knowledge.KnowledgeBaseID, Knowledge: knowledge.ID})
}
if err := s.graphEngine.DelGraph(ctx, namespaces); err != nil {
logger.GetLogger(ctx).WithField("error", err).Errorf("DeleteKnowledge delete knowledge graph failed")
return err
}
return nil
})
if err = wg.Wait(); err != nil { if err = wg.Wait(); err != nil {
return err return err
} }
@@ -1158,6 +1188,15 @@ func (s *knowledgeService) processChunks(ctx context.Context,
} }
logger.GetLogger(ctx).Infof("processChunks batch index successfully, with %d index", len(indexInfoList)) logger.GetLogger(ctx).Infof("processChunks batch index successfully, with %d index", len(indexInfoList))
logger.Infof(ctx, "processChunks create relationship rag task")
for _, chunk := range textChunks {
err := NewChunkExtractTask(ctx, s.task, chunk.TenantID, chunk.ID, kb.SummaryModelID)
if err != nil {
logger.GetLogger(ctx).WithField("error", err).Errorf("processChunks create chunk extract task failed")
span.RecordError(err)
}
}
// Update knowledge status to completed // Update knowledge status to completed
knowledge.ParseStatus = "completed" knowledge.ParseStatus = "completed"
knowledge.EnableStatus = "enabled" knowledge.EnableStatus = "enabled"

View File

@@ -2,12 +2,11 @@ package service
import ( import (
"context" "context"
"encoding/json"
"errors" "errors"
"slices" "slices"
"time" "time"
"encoding/json"
"github.com/Tencent/WeKnora/internal/application/service/retriever" "github.com/Tencent/WeKnora/internal/application/service/retriever"
"github.com/Tencent/WeKnora/internal/common" "github.com/Tencent/WeKnora/internal/common"
"github.com/Tencent/WeKnora/internal/logger" "github.com/Tencent/WeKnora/internal/logger"
@@ -376,8 +375,8 @@ func (s *knowledgeBaseService) HybridSearch(ctx context.Context,
// processSearchResults handles the processing of search results, optimizing database queries // processSearchResults handles the processing of search results, optimizing database queries
func (s *knowledgeBaseService) processSearchResults(ctx context.Context, func (s *knowledgeBaseService) processSearchResults(ctx context.Context,
chunks []*types.IndexWithScore) ([]*types.SearchResult, error) { chunks []*types.IndexWithScore,
) ([]*types.SearchResult, error) {
if len(chunks) == 0 { if len(chunks) == 0 {
return nil, nil return nil, nil
} }
@@ -527,8 +526,8 @@ func (s *knowledgeBaseService) collectRelatedChunkIDs(chunk *types.Chunk, proces
func (s *knowledgeBaseService) buildSearchResult(chunk *types.Chunk, func (s *knowledgeBaseService) buildSearchResult(chunk *types.Chunk,
knowledge *types.Knowledge, knowledge *types.Knowledge,
score float64, score float64,
matchType types.MatchType) *types.SearchResult { matchType types.MatchType,
) *types.SearchResult {
return &types.SearchResult{ return &types.SearchResult{
ID: chunk.ID, ID: chunk.ID,
Content: chunk.Content, Content: chunk.Content,

View File

@@ -1,72 +0,0 @@
package common
import (
"log"
"github.com/Tencent/WeKnora/internal/config"
"github.com/hibiken/asynq"
)
// client is the global asyncq client instance
var client *asynq.Client
// InitAsyncq initializes the asyncq client with configuration
// It creates a new client and starts the server in a goroutine
func InitAsyncq(config *config.Config) error {
cfg := config.Asynq
client = asynq.NewClient(asynq.RedisClientOpt{
Addr: cfg.Addr,
Username: cfg.Username,
Password: cfg.Password,
ReadTimeout: cfg.ReadTimeout,
WriteTimeout: cfg.WriteTimeout,
})
go run(cfg)
return nil
}
// GetAsyncqClient returns the global asyncq client instance
func GetAsyncqClient() *asynq.Client {
return client
}
// handleFunc stores registered task handlers
var handleFunc = map[string]asynq.HandlerFunc{}
// RegisterHandlerFunc registers a handler function for a specific task type
func RegisterHandlerFunc(taskType string, handlerFunc asynq.HandlerFunc) {
handleFunc[taskType] = handlerFunc
}
// run starts the asyncq server with the given configuration
// It creates a new server, sets up handlers, and runs the server
func run(config *config.AsynqConfig) {
srv := asynq.NewServer(
asynq.RedisClientOpt{
Addr: config.Addr,
Username: config.Username,
Password: config.Password,
ReadTimeout: config.ReadTimeout,
WriteTimeout: config.WriteTimeout,
},
asynq.Config{
Concurrency: config.Concurrency,
Queues: map[string]int{
"critical": 6, // Highest priority queue
"default": 3, // Default priority queue
"low": 1, // Lowest priority queue
},
},
)
// Create a new mux and register all handlers
mux := asynq.NewServeMux()
for typ, handler := range handleFunc {
mux.HandleFunc(typ, handler)
}
// Start the server
if err := srv.Run(mux); err != nil {
log.Fatalf("could not run server: %v", err)
}
}

View File

@@ -7,6 +7,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/Tencent/WeKnora/internal/types"
"github.com/go-viper/mapstructure/v2" "github.com/go-viper/mapstructure/v2"
"github.com/spf13/viper" "github.com/spf13/viper"
) )
@@ -18,10 +19,10 @@ type Config struct {
KnowledgeBase *KnowledgeBaseConfig `yaml:"knowledge_base" json:"knowledge_base"` KnowledgeBase *KnowledgeBaseConfig `yaml:"knowledge_base" json:"knowledge_base"`
Tenant *TenantConfig `yaml:"tenant" json:"tenant"` Tenant *TenantConfig `yaml:"tenant" json:"tenant"`
Models []ModelConfig `yaml:"models" json:"models"` Models []ModelConfig `yaml:"models" json:"models"`
Asynq *AsynqConfig `yaml:"asynq" json:"asynq"`
VectorDatabase *VectorDatabaseConfig `yaml:"vector_database" json:"vector_database"` VectorDatabase *VectorDatabaseConfig `yaml:"vector_database" json:"vector_database"`
DocReader *DocReaderConfig `yaml:"docreader" json:"docreader"` DocReader *DocReaderConfig `yaml:"docreader" json:"docreader"`
StreamManager *StreamManagerConfig `yaml:"stream_manager" json:"stream_manager"` StreamManager *StreamManagerConfig `yaml:"stream_manager" json:"stream_manager"`
ExtractManager *ExtractManagerConfig `yaml:"extract" json:"extract"`
} }
type DocReaderConfig struct { type DocReaderConfig struct {
@@ -109,15 +110,6 @@ type ModelConfig struct {
Parameters map[string]interface{} `yaml:"parameters" json:"parameters"` Parameters map[string]interface{} `yaml:"parameters" json:"parameters"`
} }
type AsynqConfig struct {
Addr string `yaml:"addr" json:"addr"`
Username string `yaml:"username" json:"username"`
Password string `yaml:"password" json:"password"`
ReadTimeout time.Duration `yaml:"read_timeout" json:"read_timeout"`
WriteTimeout time.Duration `yaml:"write_timeout" json:"write_timeout"`
Concurrency int `yaml:"concurrency" json:"concurrency"`
}
// StreamManagerConfig 流管理器配置 // StreamManagerConfig 流管理器配置
type StreamManagerConfig struct { type StreamManagerConfig struct {
Type string `yaml:"type" json:"type"` // 类型: "memory" 或 "redis" Type string `yaml:"type" json:"type"` // 类型: "memory" 或 "redis"
@@ -134,6 +126,18 @@ type RedisConfig struct {
TTL time.Duration `yaml:"ttl" json:"ttl"` // 过期时间(小时) TTL time.Duration `yaml:"ttl" json:"ttl"` // 过期时间(小时)
} }
// ExtractManagerConfig 抽取管理器配置
type ExtractManagerConfig struct {
ExtractGraph *types.PromptTemplateStructured `yaml:"extract_graph" json:"extract_graph"`
ExtractEntity *types.PromptTemplateStructured `yaml:"extract_entity" json:"extract_entity"`
FabriText *FebriText `yaml:"fabri_text" json:"fabri_text"`
}
type FebriText struct {
WithTag string `yaml:"with_tag" json:"with_tag"`
WithNoTag string `yaml:"with_no_tag" json:"with_no_tag"`
}
// LoadConfig 从配置文件加载配置 // LoadConfig 从配置文件加载配置
func LoadConfig() (*Config, error) { func LoadConfig() (*Config, error) {
// 设置配置文件名和路径 // 设置配置文件名和路径

View File

@@ -14,6 +14,7 @@ import (
esv7 "github.com/elastic/go-elasticsearch/v7" esv7 "github.com/elastic/go-elasticsearch/v7"
"github.com/elastic/go-elasticsearch/v8" "github.com/elastic/go-elasticsearch/v8"
"github.com/neo4j/neo4j-go-driver/v6/neo4j"
"github.com/panjf2000/ants/v2" "github.com/panjf2000/ants/v2"
"go.uber.org/dig" "go.uber.org/dig"
"gorm.io/driver/postgres" "gorm.io/driver/postgres"
@@ -22,6 +23,7 @@ import (
"github.com/Tencent/WeKnora/internal/application/repository" "github.com/Tencent/WeKnora/internal/application/repository"
elasticsearchRepoV7 "github.com/Tencent/WeKnora/internal/application/repository/retriever/elasticsearch/v7" elasticsearchRepoV7 "github.com/Tencent/WeKnora/internal/application/repository/retriever/elasticsearch/v7"
elasticsearchRepoV8 "github.com/Tencent/WeKnora/internal/application/repository/retriever/elasticsearch/v8" elasticsearchRepoV8 "github.com/Tencent/WeKnora/internal/application/repository/retriever/elasticsearch/v8"
neo4jRepo "github.com/Tencent/WeKnora/internal/application/repository/retriever/neo4j"
postgresRepo "github.com/Tencent/WeKnora/internal/application/repository/retriever/postgres" postgresRepo "github.com/Tencent/WeKnora/internal/application/repository/retriever/postgres"
"github.com/Tencent/WeKnora/internal/application/service" "github.com/Tencent/WeKnora/internal/application/service"
chatpipline "github.com/Tencent/WeKnora/internal/application/service/chat_pipline" chatpipline "github.com/Tencent/WeKnora/internal/application/service/chat_pipline"
@@ -68,6 +70,7 @@ func BuildContainer(container *dig.Container) *dig.Container {
// External service clients // External service clients
must(container.Provide(initDocReaderClient)) must(container.Provide(initDocReaderClient))
must(container.Provide(initOllamaService)) must(container.Provide(initOllamaService))
must(container.Provide(initNeo4jClient))
must(container.Provide(stream.NewStreamManager)) must(container.Provide(stream.NewStreamManager))
// Data repositories layer // Data repositories layer
@@ -80,6 +83,7 @@ func BuildContainer(container *dig.Container) *dig.Container {
must(container.Provide(repository.NewModelRepository)) must(container.Provide(repository.NewModelRepository))
must(container.Provide(repository.NewUserRepository)) must(container.Provide(repository.NewUserRepository))
must(container.Provide(repository.NewAuthTokenRepository)) must(container.Provide(repository.NewAuthTokenRepository))
must(container.Provide(neo4jRepo.NewNeo4jRepository))
// Business service layer // Business service layer
must(container.Provide(service.NewTenantService)) must(container.Provide(service.NewTenantService))
@@ -93,6 +97,7 @@ func BuildContainer(container *dig.Container) *dig.Container {
must(container.Provide(service.NewDatasetService)) must(container.Provide(service.NewDatasetService))
must(container.Provide(service.NewEvaluationService)) must(container.Provide(service.NewEvaluationService))
must(container.Provide(service.NewUserService)) must(container.Provide(service.NewUserService))
must(container.Provide(service.NewChunkExtractService))
// Chat pipeline components for processing chat requests // Chat pipeline components for processing chat requests
must(container.Provide(chatpipline.NewEventManager)) must(container.Provide(chatpipline.NewEventManager))
@@ -107,6 +112,8 @@ func BuildContainer(container *dig.Container) *dig.Container {
must(container.Invoke(chatpipline.NewPluginFilterTopK)) must(container.Invoke(chatpipline.NewPluginFilterTopK))
must(container.Invoke(chatpipline.NewPluginPreprocess)) must(container.Invoke(chatpipline.NewPluginPreprocess))
must(container.Invoke(chatpipline.NewPluginRewrite)) must(container.Invoke(chatpipline.NewPluginRewrite))
must(container.Invoke(chatpipline.NewPluginExtractEntity))
must(container.Invoke(chatpipline.NewPluginSearchEntity))
// HTTP handlers layer // HTTP handlers layer
must(container.Provide(handler.NewTenantHandler)) must(container.Provide(handler.NewTenantHandler))
@@ -123,6 +130,9 @@ func BuildContainer(container *dig.Container) *dig.Container {
// Router configuration // Router configuration
must(container.Provide(router.NewRouter)) must(container.Provide(router.NewRouter))
must(container.Provide(router.NewAsyncqClient))
must(container.Provide(router.NewAsynqServer))
must(container.Invoke(router.RunAsynqServer))
return container return container
} }
@@ -184,6 +194,7 @@ func initDatabase(cfg *config.Config) (*gorm.DB, error) {
err = db.AutoMigrate( err = db.AutoMigrate(
&types.User{}, &types.User{},
&types.AuthToken{}, &types.AuthToken{},
&types.KnowledgeBase{},
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to auto-migrate database tables: %v", err) return nil, fmt.Errorf("failed to auto-migrate database tables: %v", err)
@@ -385,3 +396,23 @@ func initOllamaService() (*ollama.OllamaService, error) {
// Get Ollama service from existing factory function // Get Ollama service from existing factory function
return ollama.GetOllamaService() return ollama.GetOllamaService()
} }
func initNeo4jClient() (neo4j.Driver, error) {
if strings.ToLower(os.Getenv("NEO4J_ENABLE")) != "true" {
logger.Debugf(context.Background(), "NOT SUPPORT RETRIEVE GRAPH")
return nil, nil
}
uri := os.Getenv("NEO4J_URI")
username := os.Getenv("NEO4J_USERNAME")
password := os.Getenv("NEO4J_PASSWORD")
driver, err := neo4j.NewDriver(uri, neo4j.BasicAuth(username, password, ""))
if err != nil {
return nil, err
}
err = driver.VerifyAuthentication(context.Background(), nil)
if err != nil {
return nil, err
}
return driver, nil
}

View File

@@ -5,14 +5,15 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"math/rand"
"net/http" "net/http"
"os" "os"
"strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
"strconv" chatpipline "github.com/Tencent/WeKnora/internal/application/service/chat_pipline"
"github.com/Tencent/WeKnora/internal/config" "github.com/Tencent/WeKnora/internal/config"
"github.com/Tencent/WeKnora/internal/errors" "github.com/Tencent/WeKnora/internal/errors"
"github.com/Tencent/WeKnora/internal/logger" "github.com/Tencent/WeKnora/internal/logger"
@@ -133,6 +134,21 @@ type InitializationRequest struct {
ChunkOverlap int `json:"chunkOverlap" binding:"min=0"` ChunkOverlap int `json:"chunkOverlap" binding:"min=0"`
Separators []string `json:"separators" binding:"required,min=1"` Separators []string `json:"separators" binding:"required,min=1"`
} `json:"documentSplitting" binding:"required"` } `json:"documentSplitting" binding:"required"`
NodeExtract struct {
Enabled bool `json:"enabled"`
Text string `json:"text"`
Tags []string `json:"tags"`
Nodes []struct {
Name string `json:"name"`
Attributes []string `json:"attributes"`
} `json:"nodes"`
Relations []struct {
Node1 string `json:"node1"`
Node2 string `json:"node2"`
Type string `json:"type"`
} `json:"relations"`
} `json:"nodeExtract"`
} }
// InitializeByKB 根据知识库ID执行配置更新 // InitializeByKB 根据知识库ID执行配置更新
@@ -207,6 +223,25 @@ func (h *InitializationHandler) InitializeByKB(c *gin.Context) {
} }
} }
// 验证Node Extractor配置如果启用
if strings.ToLower(os.Getenv("NEO4J_ENABLE")) != "true" && req.NodeExtract.Enabled {
logger.Error(ctx, "Node Extractor configuration incomplete")
c.Error(errors.NewBadRequestError("请正确配置环境变量NEO4J_ENABLE"))
return
}
if req.NodeExtract.Enabled {
if req.NodeExtract.Text == "" || len(req.NodeExtract.Tags) == 0 {
logger.Error(ctx, "Node Extractor configuration incomplete")
c.Error(errors.NewBadRequestError("Node Extractor配置不完整"))
return
}
if len(req.NodeExtract.Nodes) == 0 || len(req.NodeExtract.Relations) == 0 {
logger.Error(ctx, "Node Extractor configuration incomplete")
c.Error(errors.NewBadRequestError("请先提取实体和关系"))
return
}
}
// 处理模型创建/更新 // 处理模型创建/更新
modelsToProcess := []struct { modelsToProcess := []struct {
modelType types.ModelType modelType types.ModelType
@@ -406,6 +441,29 @@ func (h *InitializationHandler) InitializeByKB(c *gin.Context) {
kb.StorageConfig = types.StorageConfig{} kb.StorageConfig = types.StorageConfig{}
} }
if req.NodeExtract.Enabled {
kb.ExtractConfig = &types.ExtractConfig{
Text: req.NodeExtract.Text,
Tags: req.NodeExtract.Tags,
Nodes: make([]*types.GraphNode, 0),
Relations: make([]*types.GraphRelation, 0),
}
for _, rnode := range req.NodeExtract.Nodes {
node := &types.GraphNode{
Name: rnode.Name,
Attributes: rnode.Attributes,
}
kb.ExtractConfig.Nodes = append(kb.ExtractConfig.Nodes, node)
}
for _, relation := range req.NodeExtract.Relations {
kb.ExtractConfig.Relations = append(kb.ExtractConfig.Relations, &types.GraphRelation{
Node1: relation.Node1,
Node2: relation.Node2,
Type: relation.Type,
})
}
}
err = h.kbRepository.UpdateKnowledgeBase(ctx, kb) err = h.kbRepository.UpdateKnowledgeBase(ctx, kb)
if err != nil { if err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{"kbId": kbIdStr}) logger.ErrorWithFields(ctx, err, map[string]interface{}{"kbId": kbIdStr})
@@ -710,7 +768,6 @@ func (h *InitializationHandler) downloadModelAsync(ctx context.Context,
err := h.pullModelWithProgress(ctx, modelName, func(progress float64, message string) { err := h.pullModelWithProgress(ctx, modelName, func(progress float64, message string) {
h.updateTaskStatus(taskID, "downloading", progress, message) h.updateTaskStatus(taskID, "downloading", progress, message)
}) })
if err != nil { if err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{ logger.ErrorWithFields(ctx, err, map[string]interface{}{
"model_name": modelName, "model_name": modelName,
@@ -777,7 +834,6 @@ func (h *InitializationHandler) pullModelWithProgress(ctx context.Context,
) )
return nil return nil
}) })
if err != nil { if err != nil {
return fmt.Errorf("failed to pull model: %w", err) return fmt.Errorf("failed to pull model: %w", err)
} }
@@ -971,6 +1027,20 @@ func buildConfigResponse(models []*types.Model,
} }
} }
if kb.ExtractConfig != nil {
config["nodeExtract"] = map[string]interface{}{
"enabled": true,
"text": kb.ExtractConfig.Text,
"tags": kb.ExtractConfig.Tags,
"nodes": kb.ExtractConfig.Nodes,
"relations": kb.ExtractConfig.Relations,
}
} else {
config["nodeExtract"] = map[string]interface{}{
"enabled": false,
}
}
return config return config
} }
@@ -1148,8 +1218,8 @@ func (h *InitializationHandler) checkRemoteModelConnection(ctx context.Context,
// checkRerankModelConnection 检查Rerank模型连接和功能的内部方法 // checkRerankModelConnection 检查Rerank模型连接和功能的内部方法
func (h *InitializationHandler) checkRerankModelConnection(ctx context.Context, func (h *InitializationHandler) checkRerankModelConnection(ctx context.Context,
modelName, baseURL, apiKey string) (bool, string) { modelName, baseURL, apiKey string,
) (bool, string) {
// 创建Reranker配置 // 创建Reranker配置
config := &rerank.RerankerConfig{ config := &rerank.RerankerConfig{
APIKey: apiKey, APIKey: apiKey,
@@ -1492,3 +1562,214 @@ func (h *InitializationHandler) testMultimodalWithDocReader(
return result, nil return result, nil
} }
// TextRelationExtractionRequest 文本关系提取请求结构
type TextRelationExtractionRequest struct {
Text string `json:"text" binding:"required"`
Tags []string `json:"tags" binding:"required"`
LLMConfig LLMConfig `json:"llmConfig"`
}
type LLMConfig struct {
Source string `json:"source"`
ModelName string `json:"modelName"`
BaseUrl string `json:"baseUrl"`
ApiKey string `json:"apiKey"`
}
// TextRelationExtractionResponse 文本关系提取响应结构
type TextRelationExtractionResponse struct {
Nodes []*types.GraphNode `json:"nodes"`
Relations []*types.GraphRelation `json:"relations"`
}
func (h *InitializationHandler) ExtractTextRelations(c *gin.Context) {
ctx := c.Request.Context()
var req TextRelationExtractionRequest
if err := c.ShouldBindJSON(&req); err != nil {
logger.Error(ctx, "文本关系提取请求参数错误")
c.Error(errors.NewBadRequestError("文本关系提取请求参数错误"))
return
}
// 验证文本内容
if len(req.Text) == 0 {
c.Error(errors.NewBadRequestError("文本内容不能为空"))
return
}
if len(req.Text) > 5000 {
c.Error(errors.NewBadRequestError("文本内容长度不能超过5000字符"))
return
}
// 验证标签
if len(req.Tags) == 0 {
c.Error(errors.NewBadRequestError("至少需要选择一个关系标签"))
return
}
// 调用模型服务进行文本关系提取
result, err := h.extractRelationsFromText(ctx, req.Text, req.Tags, req.LLMConfig)
if err != nil {
logger.Error(ctx, "文本关系提取失败", err)
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": gin.H{
"success": false,
"message": err.Error(),
},
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": result,
})
}
// extractRelationsFromText 从文本中提取关系
func (h *InitializationHandler) extractRelationsFromText(ctx context.Context, text string, tags []string, llm LLMConfig) (*TextRelationExtractionResponse, error) {
chatModel, err := chat.NewChat(&chat.ChatConfig{
ModelID: "initialization",
APIKey: llm.ApiKey,
BaseURL: llm.BaseUrl,
ModelName: llm.ModelName,
Source: types.ModelSource(llm.Source),
})
if err != nil {
logger.Error(ctx, "初始化模型服务失败", err)
return nil, err
}
template := &types.PromptTemplateStructured{
Description: h.config.ExtractManager.ExtractGraph.Description,
Tags: tags,
Examples: h.config.ExtractManager.ExtractGraph.Examples,
}
extractor := chatpipline.NewExtractor(chatModel, template)
graph, err := extractor.Extract(ctx, text)
if err != nil {
logger.Error(ctx, "文本关系提取失败", err)
return nil, err
}
extractor.RemoveUnknownRelation(ctx, graph)
result := &TextRelationExtractionResponse{
Nodes: graph.Node,
Relations: graph.Relation,
}
return result, nil
}
type FabriTextRequest struct {
Tags []string `json:"tags"`
LLMConfig LLMConfig `json:"llmConfig"`
}
type FabriTextResponse struct {
Text string `json:"text"`
}
func (h *InitializationHandler) FabriText(c *gin.Context) {
ctx := c.Request.Context()
var req FabriTextRequest
if err := c.ShouldBindJSON(&req); err != nil {
logger.Error(ctx, "生成示例文本请求参数错误")
c.Error(errors.NewBadRequestError("生成示例文本请求参数错误"))
return
}
result, err := h.fabriText(ctx, req.Tags, req.LLMConfig)
if err != nil {
logger.Error(ctx, "生成示例文本失败", err)
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": gin.H{
"success": false,
"message": err.Error(),
},
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": FabriTextResponse{Text: result},
})
}
func (h *InitializationHandler) fabriText(ctx context.Context, tags []string, llm LLMConfig) (string, error) {
chatModel, err := chat.NewChat(&chat.ChatConfig{
ModelID: "initialization",
APIKey: llm.ApiKey,
BaseURL: llm.BaseUrl,
ModelName: llm.ModelName,
Source: types.ModelSource(llm.Source),
})
if err != nil {
logger.Error(ctx, "初始化模型服务失败", err)
return "", err
}
content := h.config.ExtractManager.FabriText.WithNoTag
if len(tags) > 0 {
tagStr, _ := json.Marshal(tags)
content = fmt.Sprintf(h.config.ExtractManager.FabriText.WithTag, string(tagStr))
}
think := false
result, err := chatModel.Chat(ctx, []chat.Message{
{Role: "user", Content: content},
}, &chat.ChatOptions{
Temperature: 0.3,
MaxTokens: 4096,
Thinking: &think,
})
if err != nil {
logger.Error(ctx, "生成示例文本失败", err)
return "", err
}
return result.Content, nil
}
type FabriTagRequest struct {
LLMConfig LLMConfig `json:"llmConfig"`
}
type FabriTagResponse struct {
Tags []string `json:"tags"`
}
var tagOptions = []string{
"内容", "文化", "人物", "事件", "时间", "地点", "作品", "作者", "关系", "属性",
}
func (h *InitializationHandler) FabriTag(c *gin.Context) {
tagRandom := RandomSelect(tagOptions, rand.Intn(len(tagOptions)-1)+1)
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": FabriTagResponse{Tags: tagRandom},
})
}
func RandomSelect(strs []string, n int) []string {
if n <= 0 {
return []string{}
}
result := make([]string, len(strs))
copy(result, strs)
rand.Shuffle(len(result), func(i, j int) {
result[i], result[j] = result[j], result[i]
})
if n > len(strs) {
n = len(strs)
}
return result[:n]
}

View File

@@ -83,7 +83,7 @@ type CreateSessionRequest struct {
func (h *SessionHandler) CreateSession(c *gin.Context) { func (h *SessionHandler) CreateSession(c *gin.Context) {
ctx := c.Request.Context() ctx := c.Request.Context()
logger.Infof(ctx, "Start creating session, config: %+v", h.config.Conversation) // logger.Infof(ctx, "Start creating session, config: %+v", h.config.Conversation)
// Parse and validate the request body // Parse and validate the request body
var request CreateSessionRequest var request CreateSessionRequest

View File

@@ -267,6 +267,10 @@ func RegisterInitializationRoutes(r *gin.RouterGroup, handler *handler.Initializ
r.POST("/initialization/embedding/test", handler.TestEmbeddingModel) r.POST("/initialization/embedding/test", handler.TestEmbeddingModel)
r.POST("/initialization/rerank/check", handler.CheckRerankModel) r.POST("/initialization/rerank/check", handler.CheckRerankModel)
r.POST("/initialization/multimodal/test", handler.TestMultimodalFunction) r.POST("/initialization/multimodal/test", handler.TestMultimodalFunction)
r.POST("/initialization/extract/text-relation", handler.ExtractTextRelations)
r.POST("/initialization/extract/fabri-tag", handler.FabriTag)
r.POST("/initialization/extract/fabri-text", handler.FabriText)
} }
// RegisterSystemRoutes registers system information routes // RegisterSystemRoutes registers system information routes

66
internal/router/task.go Normal file
View File

@@ -0,0 +1,66 @@
package router
import (
"log"
"os"
"time"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
"github.com/hibiken/asynq"
"go.uber.org/dig"
)
type AsynqTaskParams struct {
dig.In
Server *asynq.Server
Extracter interfaces.Extracter
}
func getAsynqRedisClientOpt() *asynq.RedisClientOpt {
opt := &asynq.RedisClientOpt{
Addr: os.Getenv("REDIS_ADDR"),
Password: os.Getenv("REDIS_PASSWORD"),
ReadTimeout: 100 * time.Millisecond,
WriteTimeout: 200 * time.Millisecond,
DB: 0,
}
return opt
}
func NewAsyncqClient() *asynq.Client {
opt := getAsynqRedisClientOpt()
client := asynq.NewClient(opt)
return client
}
func NewAsynqServer() *asynq.Server {
opt := getAsynqRedisClientOpt()
srv := asynq.NewServer(
opt,
asynq.Config{
Queues: map[string]int{
"critical": 6, // Highest priority queue
"default": 3, // Default priority queue
"low": 1, // Lowest priority queue
},
},
)
return srv
}
func RunAsynqServer(params AsynqTaskParams) *asynq.ServeMux {
// Create a new mux and register all handlers
mux := asynq.NewServeMux()
mux.HandleFunc(types.TypeChunkExtract, params.Extracter.Extract)
go func() {
// Start the server
if err := params.Server.Run(mux); err != nil {
log.Fatalf("could not run server: %v", err)
}
}()
return mux
}

View File

@@ -28,6 +28,8 @@ type ChatManage struct {
SearchResult []*SearchResult `json:"-"` // Results from search phase SearchResult []*SearchResult `json:"-"` // Results from search phase
RerankResult []*SearchResult `json:"-"` // Results after reranking RerankResult []*SearchResult `json:"-"` // Results after reranking
MergeResult []*SearchResult `json:"-"` // Final merged results after all processing MergeResult []*SearchResult `json:"-"` // Final merged results after all processing
Entity []string `json:"-"` // List of identified entities
GraphResult *GraphData `json:"-"` // Graph data from search phase
UserContent string `json:"-"` // Processed user content UserContent string `json:"-"` // Processed user content
ChatResponse *ChatResponse `json:"-"` // Final response from chat model ChatResponse *ChatResponse `json:"-"` // Final response from chat model
ResponseChan <-chan StreamResponse `json:"-"` // Channel for streaming responses ResponseChan <-chan StreamResponse `json:"-"` // Channel for streaming responses
@@ -75,6 +77,7 @@ const (
PREPROCESS_QUERY EventType = "preprocess_query" // Query preprocessing stage PREPROCESS_QUERY EventType = "preprocess_query" // Query preprocessing stage
REWRITE_QUERY EventType = "rewrite_query" // Query rewriting for better retrieval REWRITE_QUERY EventType = "rewrite_query" // Query rewriting for better retrieval
CHUNK_SEARCH EventType = "chunk_search" // Search for relevant chunks CHUNK_SEARCH EventType = "chunk_search" // Search for relevant chunks
ENTITY_SEARCH EventType = "entity_search" // Search for relevant entities
CHUNK_RERANK EventType = "chunk_rerank" // Rerank search results CHUNK_RERANK EventType = "chunk_rerank" // Rerank search results
CHUNK_MERGE EventType = "chunk_merge" // Merge similar chunks CHUNK_MERGE EventType = "chunk_merge" // Merge similar chunks
INTO_CHAT_MESSAGE EventType = "into_chat_message" // Convert chunks into chat messages INTO_CHAT_MESSAGE EventType = "into_chat_message" // Convert chunks into chat messages
@@ -104,6 +107,7 @@ var Pipline = map[string][]EventType{
REWRITE_QUERY, REWRITE_QUERY,
PREPROCESS_QUERY, PREPROCESS_QUERY,
CHUNK_SEARCH, CHUNK_SEARCH,
ENTITY_SEARCH,
CHUNK_RERANK, CHUNK_RERANK,
CHUNK_MERGE, CHUNK_MERGE,
FILTER_TOP_K, FILTER_TOP_K,

View File

@@ -19,6 +19,7 @@ const (
MatchTypeHistory MatchTypeHistory
MatchTypeParentChunk // 父Chunk匹配类型 MatchTypeParentChunk // 父Chunk匹配类型
MatchTypeRelationChunk // 关系Chunk匹配类型 MatchTypeRelationChunk // 关系Chunk匹配类型
MatchTypeGraph
) )
// IndexInfo contains information about indexed content // IndexInfo contains information about indexed content

View File

@@ -0,0 +1,51 @@
package types
const (
TypeChunkExtract = "chunk:extract"
)
type ExtractChunkPayload struct {
TenantID uint `json:"tenant_id"`
ChunkID string `json:"chunk_id"`
ModelID string `json:"model_id"`
}
type PromptTemplateStructured struct {
Description string `json:"description"`
Tags []string `json:"tags"`
Examples []GraphData `json:"examples"`
}
type GraphNode struct {
Name string `json:"name,omitempty"`
Chunks []string `json:"chunks,omitempty"`
Attributes []string `json:"attributes,omitempty"`
}
type GraphRelation struct {
Node1 string `json:"node1,omitempty"`
Node2 string `json:"node2,omitempty"`
Type string `json:"type,omitempty"`
}
type GraphData struct {
Text string `json:"text,omitempty"`
Node []*GraphNode `json:"node,omitempty"`
Relation []*GraphRelation `json:"relation,omitempty"`
}
type NameSpace struct {
KnowledgeBase string `json:"knowledge_base"`
Knowledge string `json:"knowledge"`
}
func (n NameSpace) Labels() []string {
res := make([]string, 0)
if n.KnowledgeBase != "" {
res = append(res, n.KnowledgeBase)
}
if n.Knowledge != "" {
res = append(res, n.Knowledge)
}
return res
}

View File

@@ -0,0 +1,11 @@
package interfaces
import (
"context"
"github.com/hibiken/asynq"
)
type Extracter interface {
Extract(ctx context.Context, t *asynq.Task) error
}

View File

@@ -0,0 +1,13 @@
package interfaces
import (
"context"
"github.com/Tencent/WeKnora/internal/types"
)
type RetrieveGraphRepository interface {
AddGraph(ctx context.Context, namespace types.NameSpace, graphs []*types.GraphData) error
DelGraph(ctx context.Context, namespace []types.NameSpace) error
SearchNode(ctx context.Context, namespace types.NameSpace, nodes []string) (*types.GraphData, error)
}

View File

@@ -38,6 +38,8 @@ type KnowledgeBase struct {
VLMConfig VLMConfig `yaml:"vlm_config" json:"vlm_config" gorm:"type:json"` VLMConfig VLMConfig `yaml:"vlm_config" json:"vlm_config" gorm:"type:json"`
// Storage config // Storage config
StorageConfig StorageConfig `yaml:"cos_config" json:"cos_config" gorm:"column:cos_config;type:json"` StorageConfig StorageConfig `yaml:"cos_config" json:"cos_config" gorm:"column:cos_config;type:json"`
// Extract config
ExtractConfig *ExtractConfig `yaml:"extract_config" json:"extract_config" gorm:"column:extract_config;type:json"`
// Creation time of the knowledge base // Creation time of the knowledge base
CreatedAt time.Time `yaml:"created_at" json:"created_at"` CreatedAt time.Time `yaml:"created_at" json:"created_at"`
// Last updated time of the knowledge base // Last updated time of the knowledge base
@@ -167,3 +169,27 @@ func (c *VLMConfig) Scan(value interface{}) error {
} }
return json.Unmarshal(b, c) return json.Unmarshal(b, c)
} }
type ExtractConfig struct {
Text string `yaml:"text" json:"text"`
Tags []string `yaml:"tags" json:"tags"`
Nodes []*GraphNode `yaml:"nodes" json:"nodes"`
Relations []*GraphRelation `yaml:"relations" json:"relations"`
}
// Value implements the driver.Valuer interface, used to convert ExtractConfig to database value
func (e ExtractConfig) Value() (driver.Value, error) {
return json.Marshal(e)
}
// Scan implements the sql.Scanner interface, used to convert database value to ExtractConfig
func (e *ExtractConfig) Scan(value interface{}) error {
if value == nil {
return nil
}
b, ok := value.([]byte)
if !ok {
return nil
}
return json.Unmarshal(b, e)
}

View File

@@ -51,6 +51,7 @@ CREATE TABLE knowledge_bases (
vlm_model_id VARCHAR(64) NOT NULL, vlm_model_id VARCHAR(64) NOT NULL,
cos_config JSON NOT NULL, cos_config JSON NOT NULL,
vlm_config JSON NOT NULL, vlm_config JSON NOT NULL,
extract_config JSON NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
deleted_at TIMESTAMP NULL DEFAULT NULL deleted_at TIMESTAMP NULL DEFAULT NULL

View File

@@ -62,6 +62,7 @@ CREATE TABLE IF NOT EXISTS knowledge_bases (
vlm_model_id VARCHAR(64) NOT NULL, vlm_model_id VARCHAR(64) NOT NULL,
cos_config JSONB NOT NULL DEFAULT '{}', cos_config JSONB NOT NULL DEFAULT '{}',
vlm_config JSONB NOT NULL DEFAULT '{}', vlm_config JSONB NOT NULL DEFAULT '{}',
extract_config JSONB NULL DEFAULT NULL,
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
deleted_at TIMESTAMP WITH TIME ZONE deleted_at TIMESTAMP WITH TIME ZONE