From 02b78a5908bb3f623144f31e79e52e44a7c63d5e Mon Sep 17 00:00:00 2001 From: begoniezhao Date: Wed, 24 Sep 2025 12:15:17 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=96=B0=E5=A2=9E=E5=BC=82=E6=AD=A5?= =?UTF-8?q?=E4=BB=BB=E5=8A=A1=E6=8F=90=E5=8F=96=E6=9C=8D=E5=8A=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .env.example | 12 + config/config.yaml | 66 ++ docker-compose.yml | 25 + frontend/src/api/initialization/index.ts | 102 ++- frontend/src/components/menu.vue | 5 +- .../initialization/InitializationContent.vue | 740 +++++++++++++++++- go.mod | 9 +- go.sum | 18 +- .../repository/retriever/neo4j/repository.go | 231 ++++++ .../service/chat_pipline/extract_entity.go | 499 ++++++++++++ .../service/chat_pipline/search.go | 1 + .../service/chat_pipline/search_entity.go | 136 ++++ internal/application/service/extract.go | 137 ++++ internal/application/service/knowledge.go | 39 + internal/application/service/knowledgebase.go | 11 +- internal/common/asyncq.go | 72 -- internal/config/config.go | 24 +- internal/container/container.go | 31 + internal/handler/initialization.go | 293 ++++++- internal/handler/session.go | 2 +- internal/router/router.go | 4 + internal/router/task.go | 66 ++ internal/types/chat_manage.go | 4 + internal/types/embedding.go | 1 + internal/types/extract_graph.go | 51 ++ internal/types/interfaces/extracter.go | 11 + internal/types/interfaces/retriever_graph.go | 13 + internal/types/knowledgebase.go | 26 + migrations/mysql/00-init-db.sql | 1 + migrations/paradedb/00-init-db.sql | 1 + 30 files changed, 2514 insertions(+), 117 deletions(-) create mode 100644 internal/application/repository/retriever/neo4j/repository.go create mode 100644 internal/application/service/chat_pipline/extract_entity.go create mode 100644 internal/application/service/chat_pipline/search_entity.go create mode 100644 internal/application/service/extract.go delete mode 100644 internal/common/asyncq.go create mode 100644 internal/router/task.go create mode 100644 internal/types/extract_graph.go create mode 100644 internal/types/interfaces/extracter.go create mode 100644 internal/types/interfaces/retriever_graph.go diff --git a/.env.example b/.env.example index e3a44d0..6b035ee 100644 --- a/.env.example +++ b/.env.example @@ -121,6 +121,18 @@ COS_ENABLE_OLD_DOMAIN=true # 如果解析网络连接使用Web代理,需要配置以下参数 # WEB_PROXY=your_web_proxy +# Neo4j 开关 +# NEO4J_ENABLE=false + +# Neo4j的访问地址 +# NEO4J_URI=neo4j://neo4j:7687 + +# Neo4j的用户名和密码 +# NEO4J_USERNAME=neo4j + +# Neo4j的密码 +# NEO4J_PASSWORD=password + ############################################################## ###### 注意: 以下配置不再生效,已在Web“配置初始化”阶段完成 ######### diff --git a/config/config.yaml b/config/config.yaml index 4caa32c..cc8e6e6 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -534,3 +534,69 @@ knowledge_base: split_markers: ["\n\n", "\n", "。"] image_processing: enable_multimodal: true + +extract: + extract_graph: + description: | + 请基于给定文本,按以下步骤完成信息提取任务,确保逻辑清晰、信息完整准确: + + ## 一、实体提取与属性补充 + 1. **提取核心实体**:通读文本,按逻辑顺序(如文本叙述顺序、实体关联紧密程度)提取所有与任务相关的核心实体。 + 2. **补充实体详细属性**:针对每个提取的实体,全面补充其在文本中明确提及的详细属性,确保无关键属性遗漏。 + + ## 二、关系提取与验证 + 1. **明确关系类型**:仅从指定关系列表中选择对应类型,限定关系类型为: %s。 + 2. **提取有效关系**:基于已提取的实体及属性,识别文本中真实存在的关系,确保关系符合文本事实、无虚假关联。 + 3. **明确关系主体**:对每一组提取的关系,清晰标注两个关联主体,避免主体混淆。 + 4. **补充关联属性**:若文本中存在与该关系直接相关的补充信息,需将该信息作为关系的关联属性补充,进一步完善关系信息。 + tags: + - "作者" + - "别名" + examples: + - text: | + 《红楼梦》,又名《石头记》,是清代作家曹雪芹创作的中国古典四大名著之一,被誉为中国封建社会的百科全书。该书前80回由曹雪芹所著,后40回一般认为是高鹗所续。 + 小说以贾、史、王、薛四大家族的兴衰为背景,以贾宝玉、林黛玉和薛宝钗的爱情悲剧为主线,刻画了以贾宝玉和金陵十二钗为中心的正邪两赋、贤愚并出的高度复杂的人物群像。 + 成书于乾隆年间(1743年前后),是中国文学史上现实主义的高峰,对后世影响深远。 + node: + - name: "红楼梦" + attributes: + - "中国古典四大名著之一" + - "又名《石头记》" + - "被誉为中国封建社会的百科全书" + - name: "石头记" + attributes: + - "《红楼梦》的别名" + - name: "曹雪芹" + attributes: + - "清代作家" + - "《红楼梦》前 80 回的作者" + - name: "高鹗" + attributes: + - "一般认为是《红楼梦》后 40 回的续写者" + relation: + - node1: "红楼梦" + node2: "曹雪芹" + type: "作者" + - node1: "红楼梦" + node2: "高鹗" + type: "作者" + - node1: "红楼梦" + node2: "石头记" + type: "别名" + extract_entity: + description: | + 请基于用户给的问题,按以下步骤处理关键信息提取任务: + 1. 梳理逻辑关联:首先完整分析文本内容,明确其核心逻辑关系,并简要标注该核心逻辑类型; + 2. 提取关键实体:围绕梳理出的逻辑关系,精准提取文本中的关键信息并归类为明确实体,确保不遗漏核心信息、不添加冗余内容; + 3. 排序实体优先级:按实体与文本核心主题的关联紧密程度排序,优先呈现对理解文本主旨最重要的实体; + examples: + - text: "《红楼梦》,又名《石头记》,是清代作家曹雪芹创作的中国古典四大名著之一,被誉为中国封建社会的百科全书。" + node: + - name: "红楼梦" + - name: "曹雪芹" + - name: "中国古典四大名著" + fabri_text: + with_tag: | + 请随机生成一段文本,要求内容与 %s 等相关,字数在 [50-200] 之间,并且尽量包含一些与这些标签相关的专业术语或典型元素,使文本更具针对性和相关性。 + with_no_tag: | + 请随机生成一段文本,内容请自由发挥,字数在 [50-200] 之间。 \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml index 8889c65..9f29764 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -54,6 +54,10 @@ services: - REDIS_DB=${REDIS_DB:-} - REDIS_PREFIX=${REDIS_PREFIX:-} - ENABLE_GRAPH_RAG=${ENABLE_GRAPH_RAG:-} + - NEO4J_ENABLE=${NEO4J_ENABLE:-} + - NEO4J_URI=bolt://neo4j:7687 + - NEO4J_USERNAME=${NEO4J_USERNAME:-neo4j} + - NEO4J_PASSWORD=${NEO4J_PASSWORD:-password} - TENANT_AES_KEY=${TENANT_AES_KEY:-} - CONCURRENCY_POOL_SIZE=${CONCURRENCY_POOL_SIZE:-5} - INIT_LLM_MODEL_NAME=${INIT_LLM_MODEL_NAME:-} @@ -76,6 +80,8 @@ services: condition: service_started docreader: condition: service_healthy + neo4j: + condition: service_started networks: - WeKnora-network restart: unless-stopped @@ -209,6 +215,24 @@ services: networks: - WeKnora-network + neo4j: + image: neo4j:latest + container_name: WeKnora-neo4j + volumes: + - neo4j-data:/data + environment: + - NEO4J_AUTH=${NEO4J_USERNAME:-neo4j}/${NEO4J_PASSWORD:-password} + - NEO4J_apoc_export_file_enabled=true + - NEO4J_apoc_import_file_enabled=true + - NEO4J_apoc_import_file_use__neo4j__config=true + - NEO4JLABS_PLUGINS=["apoc"] + ports: + - "7474:7474" + - "7687:7687" + restart: always + networks: + - WeKnora-network + networks: WeKnora-network: driver: bridge @@ -219,3 +243,4 @@ volumes: jaeger_data: redis_data: minio_data: + neo4j-data: diff --git a/frontend/src/api/initialization/index.ts b/frontend/src/api/initialization/index.ts index 4231cdb..1fea566 100644 --- a/frontend/src/api/initialization/index.ts +++ b/frontend/src/api/initialization/index.ts @@ -50,6 +50,13 @@ export interface InitializationConfig { }; // Frontend-only hint for storage selection UI storageType?: 'cos' | 'minio'; + nodeExtract: { + enabled: boolean, + text: string, + tags: string[], + nodes: Node[], + relations: Relation[] + } } // 下载任务状态类型 @@ -63,8 +70,6 @@ export interface DownloadTask { endTime?: string; } - - // 根据知识库ID执行配置更新 export function initializeSystemByKB(kbId: string, config: InitializationConfig): Promise { return new Promise((resolve, reject) => { @@ -76,7 +81,7 @@ export function initializeSystemByKB(kbId: string, config: InitializationConfig) }) .catch((error: any) => { console.error('知识库配置更新失败:', error); - reject(error); + reject(error.error || error); }); }); } @@ -324,4 +329,93 @@ export function testMultimodalFunction(testData: { reject(error); }); }); -} \ No newline at end of file +} + +// 文本内容关系提取接口 +export interface TextRelationExtractionRequest { + text: string; + tags: string[]; + llmConfig: LLMConfig; +} + +export interface Node { + name: string; + attributes: string[]; +} + +export interface Relation { + node1: string; + node2: string; + type: string; +} + +export interface LLMConfig { + source: 'local' | 'remote'; + modelName: string; + baseUrl: string; + apiKey: string; +} + +export interface TextRelationExtractionResponse { + nodes: Node[]; + relations: Relation[]; +} + +// 文本内容关系提取 +export function extractTextRelations(request: TextRelationExtractionRequest): Promise { + return new Promise((resolve, reject) => { + post('/api/v1/initialization/extract/text-relation', request) + .then((response: any) => { + resolve(response.data || { nodes: [], relations: [] }); + }) + .catch((error: any) => { + console.error('文本内容关系提取失败:', error); + reject(error); + }); + }); +} + +export interface FabriTextRequest { + tags: string[]; + llmConfig: LLMConfig; +} + +export interface FabriTextResponse { + text: string; +} + +// 文本内容生成 +export function fabriText(request: FabriTextRequest): Promise { + return new Promise((resolve, reject) => { + post('/api/v1/initialization/extract/fabri-text', request) + .then((response: any) => { + resolve(response.data || { text: '' }); + }) + .catch((error: any) => { + console.error('文本内容生成失败:', error); + reject(error); + }); + }); +} + +export interface FabriTagRequest { + llmConfig: LLMConfig; +} + +export interface FabriTagResponse { + tags: string[]; +} + +// 文本内容生成 +export function fabriTag(request: FabriTagRequest): Promise { + return new Promise((resolve, reject) => { + post('/api/v1/initialization/extract/fabri-tag', request) + .then((response: any) => { + resolve(response.data || { tags: [] as string[] }); + }) + .catch((error: any) => { + console.error('标签生成失败:', error); + reject(error); + }); + }); +} \ No newline at end of file diff --git a/frontend/src/components/menu.vue b/frontend/src/components/menu.vue index efd7f73..b8f6f00 100644 --- a/frontend/src/components/menu.vue +++ b/frontend/src/components/menu.vue @@ -144,7 +144,10 @@ let activeSubmenu = ref(-1); // 是否处于知识库详情页 const isInKnowledgeBase = computed(() => { - return route.name === 'knowledgeBaseDetail' || route.name === 'kbCreatChat' || route.name === 'chat' || route.name === 'knowledgeBaseSettings'; + return route.name === 'knowledgeBaseDetail' || + route.name === 'kbCreatChat' || + route.name === 'chat' || + route.name === 'knowledgeBaseSettings'; }); // 统一的菜单项激活状态判断 diff --git a/frontend/src/views/initialization/InitializationContent.vue b/frontend/src/views/initialization/InitializationContent.vue index e467543..e26716f 100644 --- a/frontend/src/views/initialization/InitializationContent.vue +++ b/frontend/src/views/initialization/InitializationContent.vue @@ -692,6 +692,285 @@ + +
+

实体关系提取

+ +
+ +
+ + 启用实体关系提取 +
+
+
+ +
+

关系标签配置

+ +
+ +
+
+
+ + 随机生成标签 + +
+
+ + 请完善模型配置信息 +
+
+
+ +
+
+
+
+ +

提取示例

+ +
+ +
+
+
+ + 随机生成文本 + +
+
+ + 请完善模型配置信息 +
+
+
+ +
+
+
+
+ + +
+ + +
+
+
+ + + + + + + +
+ +
+ +
+ + + + +
+ + + + 添加属性 + +
+
+
+
+ +
+
+ + 添加实体 + +
+
+ + 请完善实体信息 +
+
+
+ + +
+ +
+
+
+ + + + + + + + + + + + + + +
+
+
+
+ + +
+
+ + 添加关系 + +
+
+ + 请完善关系信息 +
+
+
+ + +
+ + {{ extracting ? '正在提取...' : '开始提取' }} + + + + 默认示例 + + + + 清空示例 + +
+
+
+
(null); const submitting = ref(false); const hasFiles = ref(false); const isUpdateMode = ref(false); // 是否为更新模式 +const tagOptionsDefault = [ + { label: '内容', value: '内容' }, + { label: '文化', value: '文化' }, + { label: '人物', value: '人物' }, + { label: '事件', value: '事件' }, + { label: '时间', value: '时间' }, + { label: '地点', value: '地点' }, + { label: '作品', value: '作品' }, + { label: '作者', value: '作者' }, + { label: '关系', value: '关系' }, + { label: '属性', value: '属性' } +]; +const tagOptions = ref([] as {label: string, value: string}[]); +const tagInput = ref(''); +const popupVisibleNode1 = ref([]); +const popupVisibleNode2 = ref([]); +const tagFabring = ref(false); +const textFabring = ref(false); +const extracting = ref(false); // 防抖机制:防止按钮快速重复点击 const submitDebounceTimer = ref | null>(null); @@ -874,6 +1181,13 @@ const formData = reactive({ chunkSize: 512, chunkOverlap: 100, separators: ['\n\n', '\n', '。', '!', '?', ';', ';'] + }, + nodeExtract: { + enabled: false, + text: '', + tags: [] as string[], + nodes: [] as Node[], + relations: [] as Relation[] } }); @@ -995,8 +1309,29 @@ const canSubmit = computed(() => { vlmOk = true; } } + + let extractOk = true; + if (formData.nodeExtract.enabled) { + if (formData.nodeExtract.text === '') { + extractOk = false; + } + for (let i = 0; i < formData.nodeExtract.tags.length; i++) { + const tag = formData.nodeExtract.tags[i]; + if (tag == '') { + extractOk = false; + break; + } + } + + if (!readyNode.value){ + extractOk = false; + } + if (!readyRelation.value){ + extractOk = false; + } + } - return llmOk && embeddingOk && rerankOk && vlmOk; + return llmOk && embeddingOk && rerankOk && vlmOk && extractOk; }); const imageUpload = ref(null); @@ -1034,6 +1369,10 @@ const rules = { 'embedding.dimension': [ { required: true, message: '请输入Embedding维度', type: 'error' }, { validator: validateEmbeddingDimension, message: '维度必须为有效整数值,常见取值为768, 1024, 1536, 3584等', type: 'error' } + ], + 'nodeExtract.text': [ + { required: true, message: '请输入文本内容', type: 'error' }, + { min: 10, message: '文本内容至少需要10个字符', type: 'error' } ] }; @@ -1283,6 +1622,20 @@ const loadCurrentConfig = async () => { // 如果没有文档分割配置,确保使用默认的precision模式 selectedPreset.value = 'precision'; } + if (config.nodeExtract.enabled) { + formData.nodeExtract.enabled = true; + formData.nodeExtract.text = config.nodeExtract.text; + formData.nodeExtract.tags = config.nodeExtract.tags; + formData.nodeExtract.nodes = config.nodeExtract.nodes; + formData.nodeExtract.relations = config.nodeExtract.relations; + tagOptions.value = []; + for (const tag of config.nodeExtract.tags) { + if (tagOptions.value.find((item) => item.value === tag)) { + continue; + } + tagOptions.value.push({ label: tag, value: tag }); + } + } // 在配置加载完成后,检查模型状态 await checkModelsAfterLoading(config); @@ -2027,9 +2380,9 @@ const handleSubmit = async () => { } else { MessagePlugin.error(result.message || '操作失败'); } - } catch (error) { + } catch (error: any) { console.error('提交失败:', error); - MessagePlugin.error('操作失败,请检查网络连接'); + MessagePlugin.error(error.message || '操作失败,请检查网络连接'); } finally { submitting.value = false; @@ -2050,6 +2403,288 @@ const formatFileSize = (bytes: number): string => { return parseFloat((bytes / Math.pow(k, i)).toFixed(2)) + ' ' + sizes[i]; }; +const addTag = async (val: string) => { + val = val.trim(); + if (val === '') { + MessagePlugin.error('请输入有效的标签'); + return; + } + if (!tagOptions.value.find(item => item.value === val)){ + tagOptions.value.push({ label: val, value: val }); + } + if (!formData.nodeExtract.tags.includes(val)) { + formData.nodeExtract.tags.push(val); + }else { + MessagePlugin.error('该标签已存在'); + } + tagInput.value = ''; +} + +const clearTags = async () => { + formData.nodeExtract.tags = []; +} + +const defaultExtractExample = async () => { + formData.nodeExtract.tags = ['作者', '别名']; + formData.nodeExtract.text = `《红楼梦》,又名《石头记》,是清代作家曹雪芹创作的中国古典四大名著之一,被誉为中国封建社会的百科全书。该书前80回由曹雪芹所著,后40回一般认为是高鹗所续。小说以贾、史、王、薛四大家族的兴衰为背景,以贾宝玉、林黛玉和薛宝钗的爱情悲剧为主线,刻画了以贾宝玉和金陵十二钗为中心的正邪两赋、贤愚并出的高度复杂的人物群像。成书于乾隆年间(1743年前后),是中国文学史上现实主义的高峰,对后世影响深远。`; + formData.nodeExtract.nodes = [ + {name: '红楼梦', attributes: ['中国古典四大名著之一', '又名《石头记》', '被誉为中国封建社会的百科全书']}, + {name: '石头记', attributes: ['《红楼梦》的别名']}, + {name: '曹雪芹', attributes: ['清代作家', '《红楼梦》前 80 回的作者']}, + {name: '高鹗', attributes: ['一般认为是《红楼梦》后 40 回的续写者']} + ]; + formData.nodeExtract.relations = [ + {node1: '红楼梦', node2: '石头记', type: '别名'}, + {node1: '红楼梦', node2: '曹雪芹', type: '作者'}, + {node1: '红楼梦', node2: '高鹗', type: '作者'} + ]; + tagOptions.value = []; + tagOptions.value.push({ label: '作者', value: '作者' }); + tagOptions.value.push({ label: '别名', value: '别名' }); + popupVisibleNode1.value = Array(formData.nodeExtract.nodes.length).fill(false); + popupVisibleNode2.value = Array(formData.nodeExtract.nodes.length).fill(false); +} + +const clearExtractExample = async () => { + formData.nodeExtract.tags = []; + formData.nodeExtract.text = ''; + formData.nodeExtract.nodes = []; + formData.nodeExtract.relations = []; + tagOptions.value = [...tagOptionsDefault]; + popupVisibleNode1.value = []; + popupVisibleNode2.value = []; +} + +const addNode = async () =>{ + formData.nodeExtract.nodes.push({ + name: '', + attributes: [] + }); +} + +const removeNode = async (index: number) => { + formData.nodeExtract.nodes.splice(index, 1); +} + +const addAttribute = async (nodeIndex: number) => { + formData.nodeExtract.nodes[nodeIndex].attributes.push(''); +} + +const removeAttribute = async(nodeIndex: number, attrIndex: number) => { + formData.nodeExtract.nodes[nodeIndex].attributes.splice(attrIndex, 1); +} + +const onRelationNode1OptionClick = async (index: number, item: Node) => { + formData.nodeExtract.relations[index].node1 = item.name; + popupVisibleNode1.value[index] = false; +} + +const onRelationNode2OptionClick = async (index: number, item: Node) => { + formData.nodeExtract.relations[index].node2 = item.name; + popupVisibleNode2.value[index] = false; +} + +const relationOnClearNode1 = async (index: number) => { + formData.nodeExtract.relations[index].node1 = ''; +} + +const relationOnClearNode2 = async (index: number) => { + formData.nodeExtract.relations[index].node2 = ''; +} + +const onPopupVisibleNode1Change = async (index: number, val: boolean) => { + popupVisibleNode1.value[index] = val; +}; + +const onPopupVisibleNode2Change = async (index: number, val: boolean) => { + popupVisibleNode2.value[index] = val; +}; + +const addRelation = async () => { + formData.nodeExtract.relations.push({ + node1: '', + node2: '', + type: '' + }); + popupVisibleNode1.value.push(false); + popupVisibleNode2.value.push(false); +} + +const removeRelation = async (index: number) => { + formData.nodeExtract.relations.splice(index, 1); +} + +const onFocus = async () => {}; + +const canExtract = async (): Promise =>{ + if (formData.nodeExtract.text === '') { + MessagePlugin.error('请输入示例文本'); + return false; + } + if (formData.nodeExtract.tags.length === 0) { + MessagePlugin.error('请输入关系类型'); + return false; + } + for (let i = 0; i < formData.nodeExtract.tags.length; i++) { + if (formData.nodeExtract.tags[i] === '') { + MessagePlugin.error('请输入关系类型'); + return false; + } + } + if (!modelStatus.llm.available) { + MessagePlugin.error('请输入 LLM 大语言模型配置'); + return false; + } + return true; +} + +const readyNode = computed(() => { + for (let i = 0; i < formData.nodeExtract.nodes.length; i++) { + let node = formData.nodeExtract.nodes[i]; + if (node.name === '') { + return false; + } + if (node.attributes){ + for (let j = 0; j < node.attributes.length; j++) { + if (node.attributes[j] === '') { + return false; + } + } + } + } + return formData.nodeExtract.nodes.length > 0; +}) + +const readyRelation = computed(() => { + for (let i = 0; i < formData.nodeExtract.relations.length; i++) { + let relation = formData.nodeExtract.relations[i]; + if (relation.node1 == '' || relation.node2 == '' || relation.type == '' ) { + return false + } + } + return formData.nodeExtract.relations.length > 0; +}) + +// 处理提取 +const handleExtract = async () => { + if (extracting.value) return; + + try { + // 表单验证 + const isValid = await form.value?.validate(); + if (!isValid) { + MessagePlugin.error('请检查表单填写是否正确'); + return; + } + if (!canExtract()){ + return; + } + + extracting.value = true; + + const request: TextRelationExtractionRequest = { + text: formData.nodeExtract.text.trim(), + tags: formData.nodeExtract.tags, + llmConfig: { + source: formData.llm.source as 'local' | 'remote', + modelName: formData.llm.modelName, + baseUrl: formData.llm.baseUrl, + apiKey: formData.llm.apiKey, + }, + }; + + const result = await extractTextRelations(request); + if (result.nodes.length === 0 ) { + MessagePlugin.info('未提取有效节点'); + } else { + formData.nodeExtract.nodes = result.nodes; + } + if ( result.relations.length === 0) { + MessagePlugin.info('未提取有效关系'); + } else { + formData.nodeExtract.relations = result.relations; + } + } catch (error) { + console.error('文本内容关系提取失败:', error); + MessagePlugin.error('提取失败,请检查网络连接或文本内容格式'); + } finally { + extracting.value = false; + } +}; + +// 处理标签 +const handleFabriTag = async () => { + if (tagFabring.value) return; + + try { + // 表单验证 + const isValid = await form.value?.validate(); + if (!isValid) { + MessagePlugin.error('请检查表单填写是否正确'); + return; + } + + tagFabring.value = true; + + const request: FabriTagRequest = { + llmConfig: { + source: formData.llm.source as 'local' | 'remote', + modelName: formData.llm.modelName, + baseUrl: formData.llm.baseUrl, + apiKey: formData.llm.apiKey, + }, + }; + + const result = await fabriTag(request); + formData.nodeExtract.tags = result.tags; + tagOptions.value = []; + for (let i = 0; i < result.tags.length; i++) { + tagOptions.value.push({ label: result.tags[i], value: result.tags[i] }); + } + + } catch (error) { + console.error('随机生成标签:', error); + MessagePlugin.error('生成失败,请重试'); + } finally { + tagFabring.value = false; + } +}; + +// 处理示例文本 +const handleFabriText = async () => { + if (textFabring.value) return; + + try { + // 表单验证 + const isValid = await form.value?.validate(); + if (!isValid) { + MessagePlugin.error('请检查表单填写是否正确'); + return; + } + + textFabring.value = true; + + const request: FabriTextRequest = { + tags: formData.nodeExtract.tags, + llmConfig: { + source: formData.llm.source as 'local' | 'remote', + modelName: formData.llm.modelName, + baseUrl: formData.llm.baseUrl, + apiKey: formData.llm.apiKey, + }, + }; + + const result = await fabriText(request); + formData.nodeExtract.text = result.text; + } catch (error) { + console.error('生成示例文本失败:', error); + MessagePlugin.error('生成失败,请重试'); + } finally { + textFabring.value = false; + } +}; + + // 组件挂载时检查Ollama状态 onMounted(async () => { // 加载当前配置 @@ -2166,6 +2801,76 @@ onMounted(async () => { font-size: 20px; } } + + .add-tag-container { + display: flex; + align-items: center; /* 垂直居中 */ + justify-content: flex-start; /* 水平起始对齐 */ + gap: 8px; + } + .extract-button { + display: flex; + justify-content: center; + align-items: center; + gap: 12px; + text-align: center; + } + + .node-list { + display: flex; + flex-wrap: wrap; + gap: 12px; + } + + .node-header { + display: flex; + align-items: center; + justify-content: flex-start; + gap: 4px; + margin-bottom: 8px; + } + + .attribute-item { + display: flex; + align-items: center; + justify-content: flex-start; + margin-bottom: 4px; + } + + .relation-line { + display: flex; + align-items: center; + justify-content: flex-start; + gap: 4px; + } + + .relation-arrow { + font-size: 50px; + } + + .sample-text-form { + display: flex; + flex-direction: column; + width: 100%; + } + } + + .btn-tips-form { + display: flex; + align-items: center; + gap: 8px; + margin-bottom: 12px; + + .btn-tips { + display: flex; + align-items: center; + justify-content: center; + color: #fa8c16; + + .tip-icon { + margin-right: 6px; + } + } } .form-row { @@ -2385,7 +3090,7 @@ onMounted(async () => { } } - .rerank-config, .multimodal-config { + .rerank-config, .multimodal-config, .node-config { // margin-top: 20px; // padding: 20px; // background: #f9fcff; @@ -2834,4 +3539,29 @@ onMounted(async () => { } } } + +.select-input-node { + display: flex; + flex-direction: column; + padding: 0; + gap: 2px; +} + +.select-input-node > li { + display: block; + border-radius: 3px; + line-height: 22px; + cursor: pointer; + padding: 3px 8px; + color: var(--td-text-color-primary); + transition: background-color 0.2s linear; + white-space: nowrap; + word-wrap: normal; + overflow: hidden; + text-overflow: ellipsis; +} + +.select-input-node > li:hover { + background-color: var(--td-bg-color-container-hover); +} \ No newline at end of file diff --git a/go.mod b/go.mod index c2e7bb8..1675d54 100644 --- a/go.mod +++ b/go.mod @@ -14,11 +14,12 @@ require ( github.com/google/uuid v1.6.0 github.com/hibiken/asynq v0.25.1 github.com/minio/minio-go/v7 v7.0.90 + github.com/neo4j/neo4j-go-driver/v6 v6.0.0-alpha.1 github.com/ollama/ollama v0.11.4 github.com/panjf2000/ants/v2 v2.11.2 github.com/parquet-go/parquet-go v0.25.0 github.com/pgvector/pgvector-go v0.3.0 - github.com/redis/go-redis/v9 v9.7.3 + github.com/redis/go-redis/v9 v9.14.0 github.com/sashabaranov/go-openai v1.40.5 github.com/sirupsen/logrus v1.9.3 github.com/spf13/viper v1.20.1 @@ -35,7 +36,7 @@ require ( golang.org/x/crypto v0.42.0 golang.org/x/sync v0.17.0 google.golang.org/grpc v1.73.0 - google.golang.org/protobuf v1.36.6 + google.golang.org/protobuf v1.36.9 gorm.io/driver/postgres v1.5.11 gorm.io/gorm v1.25.12 ) @@ -92,7 +93,7 @@ require ( github.com/sagikazarmark/locafero v0.7.0 // indirect github.com/sourcegraph/conc v0.3.0 // indirect github.com/spf13/afero v1.12.0 // indirect - github.com/spf13/cast v1.7.1 // indirect + github.com/spf13/cast v1.10.0 // indirect github.com/spf13/pflag v1.0.6 // indirect github.com/subosito/gotenv v1.6.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect @@ -105,7 +106,7 @@ require ( golang.org/x/net v0.43.0 // indirect golang.org/x/sys v0.36.0 // indirect golang.org/x/text v0.29.0 // indirect - golang.org/x/time v0.11.0 // indirect + golang.org/x/time v0.13.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index 13d41ea..e8a3ae1 100644 --- a/go.sum +++ b/go.sum @@ -141,6 +141,8 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/mozillazg/go-httpheader v0.2.1 h1:geV7TrjbL8KXSyvghnFm+NyTux/hxwueTSrwhe88TQQ= github.com/mozillazg/go-httpheader v0.2.1/go.mod h1:jJ8xECTlalr6ValeXYdOF8fFUISeBAdw6E61aqQma60= +github.com/neo4j/neo4j-go-driver/v6 v6.0.0-alpha.1 h1:nV3ZdYJTi73jel0mm3dpWumNY3i3nwyo25y69SPGwyg= +github.com/neo4j/neo4j-go-driver/v6 v6.0.0-alpha.1/go.mod h1:hzSTfNfM31p1uRSzL1F/BAYOgaiTarE6OAQBajfsm+I= github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= github.com/ollama/ollama v0.11.4 h1:6xLYLEPTKtw6N20qQecyEL/rrBktPO4o5U05cnvkSmI= @@ -157,8 +159,8 @@ github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ github.com/pierrec/lz4/v4 v4.1.21/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM= -github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA= +github.com/redis/go-redis/v9 v9.14.0 h1:u4tNCjXOyzfgeLN+vAZaW1xUooqWDqVEsZN0U01jfAE= +github.com/redis/go-redis/v9 v9.14.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= @@ -178,8 +180,8 @@ github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9yS github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= github.com/spf13/afero v1.12.0 h1:UcOPyRBYczmFn6yvphxkn9ZEOY65cpwGKb5mL36mrqs= github.com/spf13/afero v1.12.0/go.mod h1:ZTlWwG4/ahT8W7T0WQ5uYmjI9duaLQGy3Q2OAl4sk/4= -github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y= -github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= +github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY= +github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o= github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/viper v1.20.1 h1:ZMi+z/lvLyPSCoNtFCpqjy0S4kPbirhpTMwl8BkW9X4= @@ -269,8 +271,8 @@ golang.org/x/term v0.35.0 h1:bZBVKBudEyhRcajGcNc3jIfWPqV4y/Kt2XcoigOWtDQ= golang.org/x/term v0.35.0/go.mod h1:TPGtkTLesOwf2DE8CgVYiZinHAOuy5AYUYT1lENIZnA= golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk= golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4= -golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0= -golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= +golang.org/x/time v0.13.0 h1:eUlYslOIt32DgYD6utsuUeHs4d7AsEYLuIAdg7FlYgI= +golang.org/x/time v0.13.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822 h1:oWVWY3NzT7KJppx2UKhKmzPq4SRe0LdCijVRwvGeikY= google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822/go.mod h1:h3c4v36UTKzUiuaOKQ6gr3S+0hovBtUrXzTG/i3+XEc= @@ -278,8 +280,8 @@ google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 h1: google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A= google.golang.org/grpc v1.73.0 h1:VIWSmpI2MegBtTuFt5/JWy2oXxtjJ/e89Z70ImfD2ok= google.golang.org/grpc v1.73.0/go.mod h1:50sbHOUqWoCQGI8V2HQLJM0B+LMlIUjNSZmow7EVBQc= -google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= -google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= +google.golang.org/protobuf v1.36.9 h1:w2gp2mA27hUeUzj9Ex9FBjsBm40zfaDtEWow293U7Iw= +google.golang.org/protobuf v1.36.9/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= diff --git a/internal/application/repository/retriever/neo4j/repository.go b/internal/application/repository/retriever/neo4j/repository.go new file mode 100644 index 0000000..028086e --- /dev/null +++ b/internal/application/repository/retriever/neo4j/repository.go @@ -0,0 +1,231 @@ +package neo4j + +import ( + "context" + "fmt" + "strings" + + "github.com/Tencent/WeKnora/internal/logger" + "github.com/Tencent/WeKnora/internal/types" + "github.com/Tencent/WeKnora/internal/types/interfaces" + "github.com/neo4j/neo4j-go-driver/v6/neo4j" +) + +type Neo4jRepository struct { + driver neo4j.Driver + nodePrefix string +} + +func NewNeo4jRepository(driver neo4j.Driver) interfaces.RetrieveGraphRepository { + return &Neo4jRepository{driver: driver, nodePrefix: "ENTITY"} +} + +func _remove_hyphen(s string) string { + return strings.ReplaceAll(s, "-", "_") +} + +func (n *Neo4jRepository) Labels(namespace types.NameSpace) []string { + res := make([]string, 0) + for _, label := range namespace.Labels() { + res = append(res, n.nodePrefix+_remove_hyphen(label)) + } + return res +} + +func (n *Neo4jRepository) Label(namespace types.NameSpace) string { + labels := n.Labels(namespace) + return strings.Join(labels, ":") +} + +// AddGraph implements interfaces.RetrieveGraphRepository. +func (n *Neo4jRepository) AddGraph(ctx context.Context, namespace types.NameSpace, graphs []*types.GraphData) error { + if n.driver == nil { + logger.Warnf(ctx, "NOT SUPPORT RETRIEVE GRAPH") + return nil + } + for _, graph := range graphs { + if err := n.addGraph(ctx, namespace, graph); err != nil { + return err + } + } + return nil +} + +func (n *Neo4jRepository) addGraph(ctx context.Context, namespace types.NameSpace, graph *types.GraphData) error { + session := n.driver.NewSession(ctx, neo4j.SessionConfig{AccessMode: neo4j.AccessModeWrite}) + defer session.Close(ctx) + + _, err := session.ExecuteWrite(ctx, func(tx neo4j.ManagedTransaction) (interface{}, error) { + node_import_query := ` + UNWIND $data AS row + CALL apoc.merge.node(row.labels, {name: row.name, kg: row.knowledge_id}, row.props, {}) YIELD node + SET node.chunks = apoc.coll.union(node.chunks, row.chunks) + RETURN distinct 'done' AS result + ` + nodeData := []map[string]interface{}{} + for _, node := range graph.Node { + nodeData = append(nodeData, map[string]interface{}{ + "name": node.Name, + "knowledge_id": namespace.Knowledge, + "props": map[string][]string{"attributes": node.Attributes}, + "chunks": node.Chunks, + "labels": n.Labels(namespace), + }) + } + if _, err := tx.Run(ctx, node_import_query, map[string]interface{}{"data": nodeData}); err != nil { + return nil, fmt.Errorf("failed to create nodes: %v", err) + } + + rel_import_query := ` + UNWIND $data AS row + CALL apoc.merge.node(row.source_labels, {name: row.source, kg: row.knowledge_id}, {}, {}) YIELD node as source + CALL apoc.merge.node(row.target_labels, {name: row.target, kg: row.knowledge_id}, {}, {}) YIELD node as target + CALL apoc.merge.relationship(source, row.type, {}, row.attributes, target) YIELD rel + RETURN distinct 'done' + ` + relData := []map[string]interface{}{} + for _, rel := range graph.Relation { + relData = append(relData, map[string]interface{}{ + "source": rel.Node1, + "target": rel.Node2, + "knowledge_id": namespace.Knowledge, + "type": rel.Type, + "source_labels": n.Labels(namespace), + "target_labels": n.Labels(namespace), + }) + } + if _, err := tx.Run(ctx, rel_import_query, map[string]interface{}{"data": relData}); err != nil { + return nil, fmt.Errorf("failed to create relationships: %v", err) + } + return nil, nil + }) + if err != nil { + logger.Errorf(ctx, "failed to add graph: %v", err) + return err + } + return nil +} + +// DelGraph implements interfaces.RetrieveGraphRepository. +func (n *Neo4jRepository) DelGraph(ctx context.Context, namespaces []types.NameSpace) error { + if n.driver == nil { + logger.Warnf(ctx, "NOT SUPPORT RETRIEVE GRAPH") + return nil + } + session := n.driver.NewSession(ctx, neo4j.SessionConfig{AccessMode: neo4j.AccessModeWrite}) + defer session.Close(ctx) + + result, err := session.ExecuteWrite(ctx, func(tx neo4j.ManagedTransaction) (interface{}, error) { + for _, namespace := range namespaces { + labelExpr := n.Label(namespace) + + deleteRelsQuery := ` + CALL apoc.periodic.iterate( + "MATCH (n:` + labelExpr + ` {kg: $knowledge_id})-[r]-(m:` + labelExpr + ` {kg: $knowledge_id}) RETURN r", + "DELETE r", + {batchSize: 1000, parallel: true, params: {knowledge_id: $knowledge_id}} + ) YIELD batches, total + RETURN total + ` + if _, err := tx.Run(ctx, deleteRelsQuery, map[string]interface{}{"knowledge_id": namespace.Knowledge}); err != nil { + return nil, fmt.Errorf("failed to delete relationships: %v", err) + } + + deleteNodesQuery := ` + CALL apoc.periodic.iterate( + "MATCH (n:` + labelExpr + ` {kg: $knowledge_id}) RETURN n", + "DELETE n", + {batchSize: 1000, parallel: true, params: {knowledge_id: $knowledge_id}} + ) YIELD batches, total + RETURN total + ` + if _, err := tx.Run(ctx, deleteNodesQuery, map[string]interface{}{"knowledge_id": namespace.Knowledge}); err != nil { + return nil, fmt.Errorf("failed to delete nodes: %v", err) + } + } + return nil, nil + }) + if err != nil { + return err + } + logger.Infof(ctx, "delete graph result: %v", result) + return nil +} + +func (n *Neo4jRepository) SearchNode(ctx context.Context, namespace types.NameSpace, nodes []string) (*types.GraphData, error) { + if n.driver == nil { + logger.Warnf(ctx, "NOT SUPPORT RETRIEVE GRAPH") + return nil, nil + } + session := n.driver.NewSession(ctx, neo4j.SessionConfig{AccessMode: neo4j.AccessModeRead}) + defer session.Close(ctx) + + result, err := session.ExecuteRead(ctx, func(tx neo4j.ManagedTransaction) (interface{}, error) { + labelExpr := n.Label(namespace) + query := ` + MATCH (n:` + labelExpr + `)-[r]-(m:` + labelExpr + `) + WHERE ANY(nodeText IN $nodes WHERE n.name CONTAINS nodeText) + RETURN n, r, m + ` + params := map[string]interface{}{"nodes": nodes} + result, err := tx.Run(ctx, query, params) + if err != nil { + return nil, fmt.Errorf("failed to run query: %v", err) + } + + graphData := &types.GraphData{} + nodeSeen := make(map[string]bool) + for result.Next(ctx) { + record := result.Record() + node, _ := record.Get("n") + rel, _ := record.Get("r") + targetNode, _ := record.Get("m") + + nodeData := node.(neo4j.Node) + targetNodeData := targetNode.(neo4j.Node) + + // Convert node to types.Node + for _, n := range []neo4j.Node{nodeData, targetNodeData} { + nameStr := n.Props["name"].(string) + if _, ok := nodeSeen[nameStr]; !ok { + nodeSeen[nameStr] = true + graphData.Node = append(graphData.Node, &types.GraphNode{ + Name: nameStr, + Chunks: listI2listS(n.Props["chunks"].([]interface{})), + Attributes: listI2listS(n.Props["attributes"].([]interface{})), + }) + } + } + + // Convert relationship to types.Relation + relData := rel.(neo4j.Relationship) + graphData.Relation = append(graphData.Relation, &types.GraphRelation{ + Node1: nodeData.Props["name"].(string), + Node2: targetNodeData.Props["name"].(string), + Type: relData.Type, + }) + } + return graphData, nil + }) + if err != nil { + logger.Errorf(ctx, "search node failed: %v", err) + return nil, err + } + return result.(*types.GraphData), nil +} + +func listI2listS(list []any) []string { + result := make([]string, len(list)) + for i, v := range list { + result[i] = fmt.Sprintf("%v", v) + } + return result +} + +func mapI2mapS(prop map[string]any) map[string]string { + attributes := make(map[string]string) + for k, v := range prop { + attributes[k] = fmt.Sprintf("%v", v) + } + return attributes +} diff --git a/internal/application/service/chat_pipline/extract_entity.go b/internal/application/service/chat_pipline/extract_entity.go new file mode 100644 index 0000000..170a965 --- /dev/null +++ b/internal/application/service/chat_pipline/extract_entity.go @@ -0,0 +1,499 @@ +package chatpipline + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "os" + "regexp" + "strings" + + "github.com/Tencent/WeKnora/internal/config" + "github.com/Tencent/WeKnora/internal/logger" + "github.com/Tencent/WeKnora/internal/models/chat" + "github.com/Tencent/WeKnora/internal/types" + "github.com/Tencent/WeKnora/internal/types/interfaces" +) + +// PluginExtractEntity is a plugin for extracting entities from user queries +// It uses historical dialog context and large language models to identify key entities in the user's original query +type PluginExtractEntity struct { + modelService interfaces.ModelService // Model service for calling large language models + template *types.PromptTemplateStructured // Template for generating prompts + knowledgeBaseRepo interfaces.KnowledgeBaseRepository +} + +// NewPluginRewrite creates a new query rewriting plugin instance +// Also registers the plugin with the event manager +func NewPluginExtractEntity( + eventManager *EventManager, + modelService interfaces.ModelService, + knowledgeBaseRepo interfaces.KnowledgeBaseRepository, + config *config.Config, +) *PluginExtractEntity { + res := &PluginExtractEntity{ + modelService: modelService, + template: config.ExtractManager.ExtractEntity, + knowledgeBaseRepo: knowledgeBaseRepo, + } + eventManager.Register(res) + return res +} + +// ActivationEvents returns the list of event types this plugin responds to +// This plugin only responds to REWRITE_QUERY events +func (p *PluginExtractEntity) ActivationEvents() []types.EventType { + return []types.EventType{types.REWRITE_QUERY} +} + +// OnEvent processes triggered events +// When receiving a REWRITE_QUERY event, it rewrites the user query using conversation history and the language model +func (p *PluginExtractEntity) OnEvent(ctx context.Context, + eventType types.EventType, chatManage *types.ChatManage, next func() *PluginError, +) *PluginError { + if strings.ToLower(os.Getenv("NEO4J_ENABLE")) != "true" { + logger.Debugf(ctx, "skipping extract entity, neo4j is disabled") + return next() + } + + query := chatManage.Query + + model, err := p.modelService.GetChatModel(ctx, chatManage.ChatModelID) + if err != nil { + logger.Errorf(ctx, "Failed to get model, session_id: %s, error: %v", chatManage.SessionID, err) + return next() + } + + kb, err := p.knowledgeBaseRepo.GetKnowledgeBaseByID(ctx, chatManage.KnowledgeBaseID) + if err != nil { + logger.Errorf(ctx, "failed to get knowledge base: %v", err) + return next() + } + if kb.ExtractConfig == nil { + logger.Warnf(ctx, "failed to get extract config") + return next() + } + + template := &types.PromptTemplateStructured{ + Description: p.template.Description, + Examples: p.template.Examples, + } + extractor := NewExtractor(model, template) + graph, err := extractor.Extract(ctx, query) + if err != nil { + logger.Errorf(ctx, "Failed to extract entities, session_id: %s, error: %v", chatManage.SessionID, err) + return next() + } + nodes := []string{} + for _, node := range graph.Node { + nodes = append(nodes, node.Name) + } + logger.Debugf(ctx, "extracted node: %v", nodes) + chatManage.Entity = nodes + return next() +} + +type Extractor struct { + chat chat.Chat + formater *Formater + template *types.PromptTemplateStructured + chatOpt *chat.ChatOptions +} + +func NewExtractor( + chatModel chat.Chat, + template *types.PromptTemplateStructured, +) Extractor { + think := false + return Extractor{ + chat: chatModel, + formater: NewFormater(), + template: template, + chatOpt: &chat.ChatOptions{ + Temperature: 0.3, + MaxTokens: 4096, + Thinking: &think, + }, + } +} + +func (e *Extractor) Extract(ctx context.Context, content string) (*types.GraphData, error) { + generator := NewQAPromptGenerator(e.formater, e.template) + + // logger.Debugf(ctx, "chat system: %s", generator.System(ctx)) + // logger.Debugf(ctx, "chat user: %s", generator.User(ctx, content)) + + chatResponse, err := e.chat.Chat(ctx, generator.Render(ctx, content), e.chatOpt) + if err != nil { + logger.Errorf(ctx, "failed to chat: %v", err) + return nil, err + } + + graph, err := e.formater.ParseGraph(ctx, chatResponse.Content) + if err != nil { + logger.Errorf(ctx, "failed to parse graph: %v", err) + return nil, err + } + // e.RemoveUnknownRelation(ctx, graph) + return graph, nil +} + +func (e *Extractor) RemoveUnknownRelation(ctx context.Context, graph *types.GraphData) { + relationType := make(map[string]bool) + for _, tag := range e.template.Tags { + relationType[tag] = true + } + + relationNew := make([]*types.GraphRelation, 0) + for _, relation := range graph.Relation { + if _, ok := relationType[relation.Type]; ok { + relationNew = append(relationNew, relation) + } else { + logger.Infof(ctx, "Unknown relation type %s with %v, ignore it", relation.Type, e.template.Tags) + } + } + graph.Relation = relationNew +} + +type QAPromptGenerator struct { + Formater *Formater + Template *types.PromptTemplateStructured + ExamplesHeading string + QuestionHeading string + QuestionPrefix string + AnswerPrefix string +} + +func NewQAPromptGenerator(formater *Formater, template *types.PromptTemplateStructured) *QAPromptGenerator { + return &QAPromptGenerator{ + Formater: formater, + Template: template, + ExamplesHeading: "# Examples", + QuestionHeading: "# Question", + QuestionPrefix: "Q: ", + AnswerPrefix: "A: ", + } +} + +func (qa *QAPromptGenerator) System(ctx context.Context) string { + promptLines := []string{} + + if len(qa.Template.Tags) == 0 { + promptLines = append(promptLines, qa.Template.Description) + } else { + tags, _ := json.Marshal(qa.Template.Tags) + promptLines = append(promptLines, fmt.Sprintf(qa.Template.Description, string(tags))) + } + if len(qa.Template.Examples) > 0 { + promptLines = append(promptLines, qa.ExamplesHeading) + for _, example := range qa.Template.Examples { + // Question + promptLines = append(promptLines, fmt.Sprintf("%s%s", qa.QuestionPrefix, strings.TrimSpace(example.Text))) + + // Answer + answer, err := qa.Formater.formatExtraction(example.Node, example.Relation) + if err != nil { + return "" + } + promptLines = append(promptLines, fmt.Sprintf("%s%s", qa.AnswerPrefix, answer)) + + // new line + promptLines = append(promptLines, "") + } + } + return strings.Join(promptLines, "\n") +} + +func (qa *QAPromptGenerator) User(ctx context.Context, question string) string { + promptLines := []string{} + promptLines = append(promptLines, qa.QuestionHeading) + promptLines = append(promptLines, fmt.Sprintf("%s%s", qa.QuestionPrefix, question)) + promptLines = append(promptLines, qa.AnswerPrefix) + return strings.Join(promptLines, "\n") +} + +func (qa *QAPromptGenerator) Render(ctx context.Context, question string) []chat.Message { + return []chat.Message{ + { + Role: "system", + Content: qa.System(ctx), + }, + { + Role: "user", + Content: qa.User(ctx, question), + }, + } +} + +type FormatType string + +const ( + FormatTypeJSON FormatType = "json" + FormatTypeYAML FormatType = "yaml" +) + +const ( + _FENCE_START = "```" + _LANGUAGE_TAG = `(?P[A-Za-z0-9_+-]+)?` + _FENCE_NEWLINE = `(?:\s*\n)?` + _FENCE_BODY = `(?P[\s\S]*?)` + _FENCE_END = "```" +) + +var _FENCE_RE = regexp.MustCompile( + _FENCE_START + _LANGUAGE_TAG + _FENCE_NEWLINE + _FENCE_BODY + _FENCE_END, +) + +type Formater struct { + attributeSuffix string + formatType FormatType + useFences bool + nodePrefix string + + relationSource string + relationTarget string + relationPrefix string +} + +func NewFormater() *Formater { + return &Formater{ + attributeSuffix: "_attributes", + formatType: FormatTypeJSON, + useFences: true, + nodePrefix: "entity", + relationSource: "entity1", + relationTarget: "entity2", + relationPrefix: "relation", + } +} + +func (f *Formater) formatExtraction(nodes []*types.GraphNode, relations []*types.GraphRelation) (string, error) { + items := make([]map[string]interface{}, 0) + for _, node := range nodes { + item := map[string]interface{}{ + f.nodePrefix: node.Name, + } + if len(node.Attributes) > 0 { + item[fmt.Sprintf("%s%s", f.nodePrefix, f.attributeSuffix)] = node.Attributes + } + items = append(items, item) + } + for _, relation := range relations { + item := map[string]interface{}{ + f.relationSource: relation.Node1, + f.relationTarget: relation.Node2, + f.relationPrefix: relation.Type, + } + items = append(items, item) + } + formatted := "" + switch f.formatType { + default: + formattedBytes, err := json.MarshalIndent(items, "", " ") + if err != nil { + return "", err + } + formatted = string(formattedBytes) + } + if f.useFences { + formatted = f.addFences(formatted) + } + return formatted, nil +} + +func (f *Formater) parseOutput(ctx context.Context, text string) ([]map[string]interface{}, error) { + if text == "" { + return nil, errors.New("Empty or invalid input string.") + } + content := f.extractContent(ctx, text) + // logger.Debugf(ctx, "Extracted content: %s", content) + if content == "" { + return nil, errors.New("Empty or invalid input string.") + } + + var parsed interface{} + var err error + if f.formatType == FormatTypeJSON { + err = json.Unmarshal([]byte(content), &parsed) + } + if err != nil { + return nil, fmt.Errorf("Failed to parse %s content: %s", strings.ToUpper(string(f.formatType)), err.Error()) + } + if parsed == nil { + return nil, fmt.Errorf("Content must be a list of extractions or a dict.") + } + + var items []interface{} + if parsedMap, ok := parsed.(map[string]interface{}); ok { + items = []interface{}{parsedMap} + } else if parsedList, ok := parsed.([]interface{}); ok { + items = parsedList + } else { + return nil, fmt.Errorf("Expected list or dict, got %T", parsed) + } + + itemsList := make([]map[string]interface{}, 0) + for _, item := range items { + if itemMap, ok := item.(map[string]interface{}); ok { + itemsList = append(itemsList, itemMap) + } else { + return nil, fmt.Errorf("Each item in the sequence must be a mapping.") + } + } + return itemsList, nil +} + +func (f *Formater) ParseGraph(ctx context.Context, text string) (*types.GraphData, error) { + matchData, err := f.parseOutput(ctx, text) + if err != nil { + return nil, err + } + if len(matchData) == 0 { + logger.Debugf(ctx, "Received empty extraction data.") + return &types.GraphData{}, nil + } + // mm, _ := json.Marshal(matchData) + // logger.Debugf(ctx, "Parsed graph data: %s", string(mm)) + + var nodes []*types.GraphNode + var relations []*types.GraphRelation + + for _, group := range matchData { + switch { + case group[f.nodePrefix] != nil: + attributes := make([]string, 0) + attributesKey := f.nodePrefix + f.attributeSuffix + if attr, ok := group[attributesKey].([]interface{}); ok { + for _, v := range attr { + attributes = append(attributes, fmt.Sprintf("%v", v)) + } + } + nodes = append(nodes, &types.GraphNode{ + Name: fmt.Sprintf("%v", group[f.nodePrefix]), + Attributes: attributes, + }) + case group[f.relationSource] != nil && group[f.relationTarget] != nil: + relations = append(relations, &types.GraphRelation{ + Node1: fmt.Sprintf("%v", group[f.relationSource]), + Node2: fmt.Sprintf("%v", group[f.relationTarget]), + Type: fmt.Sprintf("%v", group[f.relationPrefix]), + }) + default: + logger.Warnf(ctx, "Unsupported graph group: %v", group) + continue + } + } + graph := &types.GraphData{ + Node: nodes, + Relation: relations, + } + f.rebuildGraph(ctx, graph) + return graph, nil +} + +func (f *Formater) rebuildGraph(ctx context.Context, graph *types.GraphData) { + nodeMap := make(map[string]*types.GraphNode) + nodes := make([]*types.GraphNode, 0, len(graph.Node)) + for _, node := range graph.Node { + if prenode, ok := nodeMap[node.Name]; ok { + logger.Infof(ctx, "Duplicate node ID: %s, merge attribute", node.Name) + // 修复panic:检查Attributes是否为nil + if node.Attributes == nil { + node.Attributes = make([]string, 0) + } + if prenode.Attributes != nil { + for _, attr := range prenode.Attributes { + node.Attributes = append(node.Attributes, attr) + } + } + continue + } + nodeMap[node.Name] = node + nodes = append(nodes, node) + } + + relations := make([]*types.GraphRelation, 0, len(graph.Relation)) + for _, relation := range graph.Relation { + if relation.Node1 == relation.Node2 { + logger.Infof(ctx, "Duplicate relation, ignore it") + continue + } + + if _, ok := nodeMap[relation.Node1]; !ok { + node := &types.GraphNode{Name: relation.Node1} + nodes = append(nodes, node) + nodeMap[relation.Node1] = node + logger.Infof(ctx, "Add unknown source node ID: %s", relation.Node1) + } + if _, ok := nodeMap[relation.Node2]; !ok { + node := &types.GraphNode{Name: relation.Node2} + nodes = append(nodes, node) + nodeMap[relation.Node2] = node + logger.Infof(ctx, "Add unknown target node ID: %s", relation.Node2) + } + + relations = append(relations, relation) + } + *graph = types.GraphData{ + Node: nodes, + Relation: relations, + } +} + +func (f *Formater) extractContent(ctx context.Context, text string) string { + if !f.useFences { + return strings.TrimSpace(text) + } + validTags := map[FormatType]map[string]struct{}{ + FormatTypeYAML: {"yaml": {}, "yml": {}}, + FormatTypeJSON: {"json": {}}, + } + matches := _FENCE_RE.FindAllStringSubmatch(text, -1) + var candidates []string + for _, match := range matches { + lang := match[1] + body := match[2] + if f.isValidLanguageTag(lang, validTags) { + candidates = append(candidates, body) + } + } + switch { + case len(candidates) == 1: + return strings.TrimSpace(candidates[0]) + + case len(candidates) > 1: + logger.Warnf(ctx, "multiple candidates found: %d", len(candidates)) + return strings.TrimSpace(candidates[0]) + + case len(matches) == 1: + logger.Debugf(ctx, "no candidate found, use first match without language tag: %s", matches[0][1]) + return strings.TrimSpace(matches[0][2]) + + case len(matches) > 1: + logger.Warnf(ctx, "multiple matches found: %d", len(matches)) + return strings.TrimSpace(matches[0][2]) + + default: + logger.Warnf(ctx, "no match found") + return strings.TrimSpace(text) + } +} + +func (f *Formater) addFences(content string) string { + content = strings.TrimSpace(content) + return fmt.Sprintf("```%s\n%s\n```", f.formatType, content) +} + +func (f *Formater) isValidLanguageTag(lang string, validTags map[FormatType]map[string]struct{}) bool { + if lang == "" { + return true + } + tag := strings.TrimSpace(strings.ToLower(lang)) + validSet, ok := validTags[f.formatType] + if !ok { + return false + } + _, exists := validSet[tag] + return exists +} diff --git a/internal/application/service/chat_pipline/search.go b/internal/application/service/chat_pipline/search.go index 92a019d..a5b2904 100644 --- a/internal/application/service/chat_pipline/search.go +++ b/internal/application/service/chat_pipline/search.go @@ -77,6 +77,7 @@ func (p *PluginSearch) OnEvent(ctx context.Context, } chatManage.SearchResult = append(chatManage.SearchResult, searchResults...) } + // remove duplicate results chatManage.SearchResult = removeDuplicateResults(chatManage.SearchResult) diff --git a/internal/application/service/chat_pipline/search_entity.go b/internal/application/service/chat_pipline/search_entity.go new file mode 100644 index 0000000..a059174 --- /dev/null +++ b/internal/application/service/chat_pipline/search_entity.go @@ -0,0 +1,136 @@ +package chatpipline + +import ( + "context" + + "github.com/Tencent/WeKnora/internal/logger" + "github.com/Tencent/WeKnora/internal/types" + "github.com/Tencent/WeKnora/internal/types/interfaces" +) + +// PluginSearch implements search functionality for chat pipeline +type PluginSearchEntity struct { + graphRepo interfaces.RetrieveGraphRepository + chunkRepo interfaces.ChunkRepository + knowledgeRepo interfaces.KnowledgeRepository +} + +func NewPluginSearchEntity( + eventManager *EventManager, + graphRepository interfaces.RetrieveGraphRepository, + chunkRepository interfaces.ChunkRepository, + knowledgeRepository interfaces.KnowledgeRepository, +) *PluginSearchEntity { + res := &PluginSearchEntity{ + graphRepo: graphRepository, + chunkRepo: chunkRepository, + knowledgeRepo: knowledgeRepository, + } + eventManager.Register(res) + return res +} + +// ActivationEvents returns the event types this plugin handles +func (p *PluginSearchEntity) ActivationEvents() []types.EventType { + return []types.EventType{types.ENTITY_SEARCH} +} + +// OnEvent handles search events in the chat pipeline +func (p *PluginSearchEntity) OnEvent(ctx context.Context, + eventType types.EventType, chatManage *types.ChatManage, next func() *PluginError, +) *PluginError { + entity := chatManage.Entity + if len(entity) == 0 { + logger.Infof(ctx, "No entity found") + return next() + } + + graph, err := p.graphRepo.SearchNode(ctx, types.NameSpace{KnowledgeBase: chatManage.KnowledgeBaseID}, entity) + if err != nil { + logger.Errorf(ctx, "Failed to search node, session_id: %s, error: %v", chatManage.SessionID, err) + return next() + } + chatManage.GraphResult = graph + logger.Infof(ctx, "search entity result count: %d", len(graph.Node)) + // graphStr, _ := json.Marshal(graph) + // logger.Debugf(ctx, "search entity result: %s", string(graphStr)) + + chunkIDs := filterSeenChunk(ctx, graph, chatManage.SearchResult) + if len(chunkIDs) == 0 { + logger.Infof(ctx, "No new chunk found") + return next() + } + chunks, err := p.chunkRepo.ListChunksByID(ctx, ctx.Value(types.TenantIDContextKey).(uint), chunkIDs) + if err != nil { + logger.Errorf(ctx, "Failed to list chunks, session_id: %s, error: %v", chatManage.SessionID, err) + return next() + } + knowledgeIDs := []string{} + for _, chunk := range chunks { + knowledgeIDs = append(knowledgeIDs, chunk.KnowledgeID) + } + knowledges, err := p.knowledgeRepo.GetKnowledgeBatch(ctx, ctx.Value(types.TenantIDContextKey).(uint), knowledgeIDs) + if err != nil { + logger.Errorf(ctx, "Failed to list knowledge, session_id: %s, error: %v", chatManage.SessionID, err) + return next() + } + + knowledgeMap := map[string]*types.Knowledge{} + for _, knowledge := range knowledges { + knowledgeMap[knowledge.ID] = knowledge + } + for _, chunk := range chunks { + searchResult := chunk2SearchResult(chunk, knowledgeMap[chunk.KnowledgeID]) + chatManage.SearchResult = append(chatManage.SearchResult, searchResult) + } + // remove duplicate results + chatManage.SearchResult = removeDuplicateResults(chatManage.SearchResult) + if len(chatManage.SearchResult) == 0 { + logger.Infof(ctx, "No new search result, session_id: %s", chatManage.SessionID) + return ErrSearchNothing + } + logger.Infof(ctx, "search entity result count: %d, session_id: %s", len(chatManage.SearchResult), chatManage.SessionID) + return next() +} + +func filterSeenChunk(ctx context.Context, graph *types.GraphData, searchResult []*types.SearchResult) []string { + seen := map[string]bool{} + for _, chunk := range searchResult { + seen[chunk.ID] = true + } + logger.Infof(ctx, "filterSeenChunk: seen count: %d", len(seen)) + + chunkIDs := []string{} + for _, node := range graph.Node { + for _, chunkID := range node.Chunks { + if seen[chunkID] { + continue + } + seen[chunkID] = true + chunkIDs = append(chunkIDs, chunkID) + } + } + logger.Infof(ctx, "filterSeenChunk: new chunkIDs count: %d", len(chunkIDs)) + return chunkIDs +} + +func chunk2SearchResult(chunk *types.Chunk, knowledge *types.Knowledge) *types.SearchResult { + return &types.SearchResult{ + ID: chunk.ID, + Content: chunk.Content, + KnowledgeID: chunk.KnowledgeID, + ChunkIndex: chunk.ChunkIndex, + KnowledgeTitle: knowledge.Title, + StartAt: chunk.StartAt, + EndAt: chunk.EndAt, + Seq: chunk.ChunkIndex, + Score: 1.0, + MatchType: types.MatchTypeGraph, + Metadata: knowledge.GetMetadata(), + ChunkType: string(chunk.ChunkType), + ParentChunkID: chunk.ParentChunkID, + ImageInfo: chunk.ImageInfo, + KnowledgeFilename: knowledge.FileName, + KnowledgeSource: knowledge.Source, + } +} diff --git a/internal/application/service/extract.go b/internal/application/service/extract.go new file mode 100644 index 0000000..06f3db3 --- /dev/null +++ b/internal/application/service/extract.go @@ -0,0 +1,137 @@ +package service + +import ( + "context" + "encoding/json" + "fmt" + "os" + "strings" + + chatpipline "github.com/Tencent/WeKnora/internal/application/service/chat_pipline" + "github.com/Tencent/WeKnora/internal/config" + "github.com/Tencent/WeKnora/internal/logger" + "github.com/Tencent/WeKnora/internal/types" + "github.com/Tencent/WeKnora/internal/types/interfaces" + "github.com/google/uuid" + "github.com/hibiken/asynq" +) + +func NewChunkExtractTask(ctx context.Context, client *asynq.Client, tenantID uint, chunkID string, modelID string) error { + if strings.ToLower(os.Getenv("NEO4J_ENABLE")) != "true" { + logger.Debugf(ctx, "NOT SUPPORT RETRIEVE GRAPH") + return nil + } + payload, err := json.Marshal(types.ExtractChunkPayload{ + TenantID: tenantID, + ChunkID: chunkID, + ModelID: modelID, + }) + if err != nil { + return err + } + task := asynq.NewTask(types.TypeChunkExtract, payload, asynq.MaxRetry(3)) + info, err := client.Enqueue(task) + if err != nil { + logger.Errorf(ctx, "failed to enqueue task: %v", err) + return fmt.Errorf("failed to enqueue task: %v", err) + } + logger.Infof(ctx, "enqueued task: id=%s queue=%s chunk=%s", info.ID, info.Queue, chunkID) + return nil +} + +type ChunkExtractService struct { + template *types.PromptTemplateStructured + modelService interfaces.ModelService + knowledgeBaseRepo interfaces.KnowledgeBaseRepository + chunkRepo interfaces.ChunkRepository + graphEngine interfaces.RetrieveGraphRepository +} + +func NewChunkExtractService( + config *config.Config, + modelService interfaces.ModelService, + knowledgeBaseRepo interfaces.KnowledgeBaseRepository, + chunkRepo interfaces.ChunkRepository, + graphEngine interfaces.RetrieveGraphRepository, +) interfaces.Extracter { + generator := chatpipline.NewQAPromptGenerator(chatpipline.NewFormater(), config.ExtractManager.ExtractGraph) + ctx := context.Background() + logger.Debugf(ctx, "chunk extract system prompt: %s", generator.System(ctx)) + logger.Debugf(ctx, "chunk extract user prompt: %s", generator.User(ctx, "demo")) + return &ChunkExtractService{ + template: config.ExtractManager.ExtractGraph, + modelService: modelService, + knowledgeBaseRepo: knowledgeBaseRepo, + chunkRepo: chunkRepo, + graphEngine: graphEngine, + } +} + +func (s *ChunkExtractService) Extract(ctx context.Context, t *asynq.Task) error { + var p types.ExtractChunkPayload + if err := json.Unmarshal(t.Payload(), &p); err != nil { + logger.Errorf(ctx, "failed to unmarshal task payload: %v", err) + return err + } + ctx = logger.WithRequestID(ctx, uuid.New().String()) + ctx = logger.WithField(ctx, "extract", p.ChunkID) + ctx = context.WithValue(ctx, types.TenantIDContextKey, p.TenantID) + + chunk, err := s.chunkRepo.GetChunkByID(ctx, p.TenantID, p.ChunkID) + if err != nil { + logger.Errorf(ctx, "failed to get chunk: %v", err) + return err + } + kb, err := s.knowledgeBaseRepo.GetKnowledgeBaseByID(ctx, chunk.KnowledgeBaseID) + if err != nil { + logger.Errorf(ctx, "failed to get knowledge base: %v", err) + return err + } + if kb.ExtractConfig == nil { + logger.Warnf(ctx, "failed to get extract config") + return err + } + + chatModel, err := s.modelService.GetChatModel(ctx, p.ModelID) + if err != nil { + logger.Errorf(ctx, "failed to get chat model: %v", err) + return err + } + + template := &types.PromptTemplateStructured{ + Description: s.template.Description, + Tags: kb.ExtractConfig.Tags, + Examples: []types.GraphData{ + { + Text: kb.ExtractConfig.Text, + Node: kb.ExtractConfig.Nodes, + Relation: kb.ExtractConfig.Relations, + }, + }, + } + extractor := chatpipline.NewExtractor(chatModel, template) + graph, err := extractor.Extract(ctx, chunk.Content) + if err != nil { + return err + } + + chunk, err = s.chunkRepo.GetChunkByID(ctx, p.TenantID, p.ChunkID) + if err != nil { + logger.Warnf(ctx, "graph ignore chunk %s: %v", p.ChunkID, err) + return nil + } + + for _, node := range graph.Node { + node.Chunks = []string{chunk.ID} + } + if err = s.graphEngine.AddGraph(ctx, + types.NameSpace{KnowledgeBase: chunk.KnowledgeBaseID, Knowledge: chunk.KnowledgeID}, + []*types.GraphData{graph}, + ); err != nil { + logger.Errorf(ctx, "failed to add graph: %v", err) + return err + } + // gg, _ := json.Marshal(graph) + // logger.Infof(ctx, "extracted graph: %s", string(gg)) + return nil +} diff --git a/internal/application/service/knowledge.go b/internal/application/service/knowledge.go index 198b7eb..414ecc5 100644 --- a/internal/application/service/knowledge.go +++ b/internal/application/service/knowledge.go @@ -29,6 +29,7 @@ import ( "github.com/Tencent/WeKnora/services/docreader/src/client" "github.com/Tencent/WeKnora/services/docreader/src/proto" "github.com/google/uuid" + "github.com/hibiken/asynq" "go.opentelemetry.io/otel/attribute" "golang.org/x/sync/errgroup" ) @@ -61,6 +62,8 @@ type knowledgeService struct { chunkRepo interfaces.ChunkRepository fileSvc interfaces.FileService modelService interfaces.ModelService + task *asynq.Client + graphEngine interfaces.RetrieveGraphRepository } // NewKnowledgeService creates a new knowledge service instance @@ -74,6 +77,8 @@ func NewKnowledgeService( chunkRepo interfaces.ChunkRepository, fileSvc interfaces.FileService, modelService interfaces.ModelService, + task *asynq.Client, + graphEngine interfaces.RetrieveGraphRepository, ) (interfaces.KnowledgeService, error) { return &knowledgeService{ config: config, @@ -85,6 +90,8 @@ func NewKnowledgeService( chunkRepo: chunkRepo, fileSvc: fileSvc, modelService: modelService, + task: task, + graphEngine: graphEngine, }, nil } @@ -488,6 +495,16 @@ func (s *knowledgeService) DeleteKnowledge(ctx context.Context, id string) error return nil }) + // Delete the knowledge graph + wg.Go(func() error { + namespace := types.NameSpace{KnowledgeBase: knowledge.KnowledgeBaseID, Knowledge: knowledge.ID} + if err := s.graphEngine.DelGraph(ctx, []types.NameSpace{namespace}); err != nil { + logger.GetLogger(ctx).WithField("error", err).Errorf("DeleteKnowledge delete knowledge graph failed") + return err + } + return nil + }) + if err = wg.Wait(); err != nil { return err } @@ -561,6 +578,19 @@ func (s *knowledgeService) DeleteKnowledgeList(ctx context.Context, ids []string return nil }) + // Delete the knowledge graph + wg.Go(func() error { + namespaces := []types.NameSpace{} + for _, knowledge := range knowledgeList { + namespaces = append(namespaces, types.NameSpace{KnowledgeBase: knowledge.KnowledgeBaseID, Knowledge: knowledge.ID}) + } + if err := s.graphEngine.DelGraph(ctx, namespaces); err != nil { + logger.GetLogger(ctx).WithField("error", err).Errorf("DeleteKnowledge delete knowledge graph failed") + return err + } + return nil + }) + if err = wg.Wait(); err != nil { return err } @@ -1158,6 +1188,15 @@ func (s *knowledgeService) processChunks(ctx context.Context, } logger.GetLogger(ctx).Infof("processChunks batch index successfully, with %d index", len(indexInfoList)) + logger.Infof(ctx, "processChunks create relationship rag task") + for _, chunk := range textChunks { + err := NewChunkExtractTask(ctx, s.task, chunk.TenantID, chunk.ID, kb.SummaryModelID) + if err != nil { + logger.GetLogger(ctx).WithField("error", err).Errorf("processChunks create chunk extract task failed") + span.RecordError(err) + } + } + // Update knowledge status to completed knowledge.ParseStatus = "completed" knowledge.EnableStatus = "enabled" diff --git a/internal/application/service/knowledgebase.go b/internal/application/service/knowledgebase.go index 19e7a7f..cd991e3 100644 --- a/internal/application/service/knowledgebase.go +++ b/internal/application/service/knowledgebase.go @@ -2,12 +2,11 @@ package service import ( "context" + "encoding/json" "errors" "slices" "time" - "encoding/json" - "github.com/Tencent/WeKnora/internal/application/service/retriever" "github.com/Tencent/WeKnora/internal/common" "github.com/Tencent/WeKnora/internal/logger" @@ -376,8 +375,8 @@ func (s *knowledgeBaseService) HybridSearch(ctx context.Context, // processSearchResults handles the processing of search results, optimizing database queries func (s *knowledgeBaseService) processSearchResults(ctx context.Context, - chunks []*types.IndexWithScore) ([]*types.SearchResult, error) { - + chunks []*types.IndexWithScore, +) ([]*types.SearchResult, error) { if len(chunks) == 0 { return nil, nil } @@ -527,8 +526,8 @@ func (s *knowledgeBaseService) collectRelatedChunkIDs(chunk *types.Chunk, proces func (s *knowledgeBaseService) buildSearchResult(chunk *types.Chunk, knowledge *types.Knowledge, score float64, - matchType types.MatchType) *types.SearchResult { - + matchType types.MatchType, +) *types.SearchResult { return &types.SearchResult{ ID: chunk.ID, Content: chunk.Content, diff --git a/internal/common/asyncq.go b/internal/common/asyncq.go deleted file mode 100644 index 9316f89..0000000 --- a/internal/common/asyncq.go +++ /dev/null @@ -1,72 +0,0 @@ -package common - -import ( - "log" - - "github.com/Tencent/WeKnora/internal/config" - "github.com/hibiken/asynq" -) - -// client is the global asyncq client instance -var client *asynq.Client - -// InitAsyncq initializes the asyncq client with configuration -// It creates a new client and starts the server in a goroutine -func InitAsyncq(config *config.Config) error { - cfg := config.Asynq - client = asynq.NewClient(asynq.RedisClientOpt{ - Addr: cfg.Addr, - Username: cfg.Username, - Password: cfg.Password, - ReadTimeout: cfg.ReadTimeout, - WriteTimeout: cfg.WriteTimeout, - }) - go run(cfg) - return nil -} - -// GetAsyncqClient returns the global asyncq client instance -func GetAsyncqClient() *asynq.Client { - return client -} - -// handleFunc stores registered task handlers -var handleFunc = map[string]asynq.HandlerFunc{} - -// RegisterHandlerFunc registers a handler function for a specific task type -func RegisterHandlerFunc(taskType string, handlerFunc asynq.HandlerFunc) { - handleFunc[taskType] = handlerFunc -} - -// run starts the asyncq server with the given configuration -// It creates a new server, sets up handlers, and runs the server -func run(config *config.AsynqConfig) { - srv := asynq.NewServer( - asynq.RedisClientOpt{ - Addr: config.Addr, - Username: config.Username, - Password: config.Password, - ReadTimeout: config.ReadTimeout, - WriteTimeout: config.WriteTimeout, - }, - asynq.Config{ - Concurrency: config.Concurrency, - Queues: map[string]int{ - "critical": 6, // Highest priority queue - "default": 3, // Default priority queue - "low": 1, // Lowest priority queue - }, - }, - ) - - // Create a new mux and register all handlers - mux := asynq.NewServeMux() - for typ, handler := range handleFunc { - mux.HandleFunc(typ, handler) - } - - // Start the server - if err := srv.Run(mux); err != nil { - log.Fatalf("could not run server: %v", err) - } -} diff --git a/internal/config/config.go b/internal/config/config.go index afbae03..8072813 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -7,6 +7,7 @@ import ( "strings" "time" + "github.com/Tencent/WeKnora/internal/types" "github.com/go-viper/mapstructure/v2" "github.com/spf13/viper" ) @@ -18,10 +19,10 @@ type Config struct { KnowledgeBase *KnowledgeBaseConfig `yaml:"knowledge_base" json:"knowledge_base"` Tenant *TenantConfig `yaml:"tenant" json:"tenant"` Models []ModelConfig `yaml:"models" json:"models"` - Asynq *AsynqConfig `yaml:"asynq" json:"asynq"` VectorDatabase *VectorDatabaseConfig `yaml:"vector_database" json:"vector_database"` DocReader *DocReaderConfig `yaml:"docreader" json:"docreader"` StreamManager *StreamManagerConfig `yaml:"stream_manager" json:"stream_manager"` + ExtractManager *ExtractManagerConfig `yaml:"extract" json:"extract"` } type DocReaderConfig struct { @@ -109,15 +110,6 @@ type ModelConfig struct { Parameters map[string]interface{} `yaml:"parameters" json:"parameters"` } -type AsynqConfig struct { - Addr string `yaml:"addr" json:"addr"` - Username string `yaml:"username" json:"username"` - Password string `yaml:"password" json:"password"` - ReadTimeout time.Duration `yaml:"read_timeout" json:"read_timeout"` - WriteTimeout time.Duration `yaml:"write_timeout" json:"write_timeout"` - Concurrency int `yaml:"concurrency" json:"concurrency"` -} - // StreamManagerConfig 流管理器配置 type StreamManagerConfig struct { Type string `yaml:"type" json:"type"` // 类型: "memory" 或 "redis" @@ -134,6 +126,18 @@ type RedisConfig struct { TTL time.Duration `yaml:"ttl" json:"ttl"` // 过期时间(小时) } +// ExtractManagerConfig 抽取管理器配置 +type ExtractManagerConfig struct { + ExtractGraph *types.PromptTemplateStructured `yaml:"extract_graph" json:"extract_graph"` + ExtractEntity *types.PromptTemplateStructured `yaml:"extract_entity" json:"extract_entity"` + FabriText *FebriText `yaml:"fabri_text" json:"fabri_text"` +} + +type FebriText struct { + WithTag string `yaml:"with_tag" json:"with_tag"` + WithNoTag string `yaml:"with_no_tag" json:"with_no_tag"` +} + // LoadConfig 从配置文件加载配置 func LoadConfig() (*Config, error) { // 设置配置文件名和路径 diff --git a/internal/container/container.go b/internal/container/container.go index 3a47bf3..4a686e6 100644 --- a/internal/container/container.go +++ b/internal/container/container.go @@ -14,6 +14,7 @@ import ( esv7 "github.com/elastic/go-elasticsearch/v7" "github.com/elastic/go-elasticsearch/v8" + "github.com/neo4j/neo4j-go-driver/v6/neo4j" "github.com/panjf2000/ants/v2" "go.uber.org/dig" "gorm.io/driver/postgres" @@ -22,6 +23,7 @@ import ( "github.com/Tencent/WeKnora/internal/application/repository" elasticsearchRepoV7 "github.com/Tencent/WeKnora/internal/application/repository/retriever/elasticsearch/v7" elasticsearchRepoV8 "github.com/Tencent/WeKnora/internal/application/repository/retriever/elasticsearch/v8" + neo4jRepo "github.com/Tencent/WeKnora/internal/application/repository/retriever/neo4j" postgresRepo "github.com/Tencent/WeKnora/internal/application/repository/retriever/postgres" "github.com/Tencent/WeKnora/internal/application/service" chatpipline "github.com/Tencent/WeKnora/internal/application/service/chat_pipline" @@ -68,6 +70,7 @@ func BuildContainer(container *dig.Container) *dig.Container { // External service clients must(container.Provide(initDocReaderClient)) must(container.Provide(initOllamaService)) + must(container.Provide(initNeo4jClient)) must(container.Provide(stream.NewStreamManager)) // Data repositories layer @@ -80,6 +83,7 @@ func BuildContainer(container *dig.Container) *dig.Container { must(container.Provide(repository.NewModelRepository)) must(container.Provide(repository.NewUserRepository)) must(container.Provide(repository.NewAuthTokenRepository)) + must(container.Provide(neo4jRepo.NewNeo4jRepository)) // Business service layer must(container.Provide(service.NewTenantService)) @@ -93,6 +97,7 @@ func BuildContainer(container *dig.Container) *dig.Container { must(container.Provide(service.NewDatasetService)) must(container.Provide(service.NewEvaluationService)) must(container.Provide(service.NewUserService)) + must(container.Provide(service.NewChunkExtractService)) // Chat pipeline components for processing chat requests must(container.Provide(chatpipline.NewEventManager)) @@ -107,6 +112,8 @@ func BuildContainer(container *dig.Container) *dig.Container { must(container.Invoke(chatpipline.NewPluginFilterTopK)) must(container.Invoke(chatpipline.NewPluginPreprocess)) must(container.Invoke(chatpipline.NewPluginRewrite)) + must(container.Invoke(chatpipline.NewPluginExtractEntity)) + must(container.Invoke(chatpipline.NewPluginSearchEntity)) // HTTP handlers layer must(container.Provide(handler.NewTenantHandler)) @@ -123,6 +130,9 @@ func BuildContainer(container *dig.Container) *dig.Container { // Router configuration must(container.Provide(router.NewRouter)) + must(container.Provide(router.NewAsyncqClient)) + must(container.Provide(router.NewAsynqServer)) + must(container.Invoke(router.RunAsynqServer)) return container } @@ -184,6 +194,7 @@ func initDatabase(cfg *config.Config) (*gorm.DB, error) { err = db.AutoMigrate( &types.User{}, &types.AuthToken{}, + &types.KnowledgeBase{}, ) if err != nil { return nil, fmt.Errorf("failed to auto-migrate database tables: %v", err) @@ -385,3 +396,23 @@ func initOllamaService() (*ollama.OllamaService, error) { // Get Ollama service from existing factory function return ollama.GetOllamaService() } + +func initNeo4jClient() (neo4j.Driver, error) { + if strings.ToLower(os.Getenv("NEO4J_ENABLE")) != "true" { + logger.Debugf(context.Background(), "NOT SUPPORT RETRIEVE GRAPH") + return nil, nil + } + uri := os.Getenv("NEO4J_URI") + username := os.Getenv("NEO4J_USERNAME") + password := os.Getenv("NEO4J_PASSWORD") + + driver, err := neo4j.NewDriver(uri, neo4j.BasicAuth(username, password, "")) + if err != nil { + return nil, err + } + err = driver.VerifyAuthentication(context.Background(), nil) + if err != nil { + return nil, err + } + return driver, nil +} diff --git a/internal/handler/initialization.go b/internal/handler/initialization.go index 4749c44..e82a0d9 100644 --- a/internal/handler/initialization.go +++ b/internal/handler/initialization.go @@ -5,14 +5,15 @@ import ( "encoding/json" "fmt" "io" + "math/rand" "net/http" "os" + "strconv" "strings" "sync" "time" - "strconv" - + chatpipline "github.com/Tencent/WeKnora/internal/application/service/chat_pipline" "github.com/Tencent/WeKnora/internal/config" "github.com/Tencent/WeKnora/internal/errors" "github.com/Tencent/WeKnora/internal/logger" @@ -133,6 +134,21 @@ type InitializationRequest struct { ChunkOverlap int `json:"chunkOverlap" binding:"min=0"` Separators []string `json:"separators" binding:"required,min=1"` } `json:"documentSplitting" binding:"required"` + + NodeExtract struct { + Enabled bool `json:"enabled"` + Text string `json:"text"` + Tags []string `json:"tags"` + Nodes []struct { + Name string `json:"name"` + Attributes []string `json:"attributes"` + } `json:"nodes"` + Relations []struct { + Node1 string `json:"node1"` + Node2 string `json:"node2"` + Type string `json:"type"` + } `json:"relations"` + } `json:"nodeExtract"` } // InitializeByKB 根据知识库ID执行配置更新 @@ -207,6 +223,25 @@ func (h *InitializationHandler) InitializeByKB(c *gin.Context) { } } + // 验证Node Extractor配置(如果启用) + if strings.ToLower(os.Getenv("NEO4J_ENABLE")) != "true" && req.NodeExtract.Enabled { + logger.Error(ctx, "Node Extractor configuration incomplete") + c.Error(errors.NewBadRequestError("请正确配置环境变量NEO4J_ENABLE")) + return + } + if req.NodeExtract.Enabled { + if req.NodeExtract.Text == "" || len(req.NodeExtract.Tags) == 0 { + logger.Error(ctx, "Node Extractor configuration incomplete") + c.Error(errors.NewBadRequestError("Node Extractor配置不完整")) + return + } + if len(req.NodeExtract.Nodes) == 0 || len(req.NodeExtract.Relations) == 0 { + logger.Error(ctx, "Node Extractor configuration incomplete") + c.Error(errors.NewBadRequestError("请先提取实体和关系")) + return + } + } + // 处理模型创建/更新 modelsToProcess := []struct { modelType types.ModelType @@ -406,6 +441,29 @@ func (h *InitializationHandler) InitializeByKB(c *gin.Context) { kb.StorageConfig = types.StorageConfig{} } + if req.NodeExtract.Enabled { + kb.ExtractConfig = &types.ExtractConfig{ + Text: req.NodeExtract.Text, + Tags: req.NodeExtract.Tags, + Nodes: make([]*types.GraphNode, 0), + Relations: make([]*types.GraphRelation, 0), + } + for _, rnode := range req.NodeExtract.Nodes { + node := &types.GraphNode{ + Name: rnode.Name, + Attributes: rnode.Attributes, + } + kb.ExtractConfig.Nodes = append(kb.ExtractConfig.Nodes, node) + } + for _, relation := range req.NodeExtract.Relations { + kb.ExtractConfig.Relations = append(kb.ExtractConfig.Relations, &types.GraphRelation{ + Node1: relation.Node1, + Node2: relation.Node2, + Type: relation.Type, + }) + } + } + err = h.kbRepository.UpdateKnowledgeBase(ctx, kb) if err != nil { logger.ErrorWithFields(ctx, err, map[string]interface{}{"kbId": kbIdStr}) @@ -710,7 +768,6 @@ func (h *InitializationHandler) downloadModelAsync(ctx context.Context, err := h.pullModelWithProgress(ctx, modelName, func(progress float64, message string) { h.updateTaskStatus(taskID, "downloading", progress, message) }) - if err != nil { logger.ErrorWithFields(ctx, err, map[string]interface{}{ "model_name": modelName, @@ -777,7 +834,6 @@ func (h *InitializationHandler) pullModelWithProgress(ctx context.Context, ) return nil }) - if err != nil { return fmt.Errorf("failed to pull model: %w", err) } @@ -971,6 +1027,20 @@ func buildConfigResponse(models []*types.Model, } } + if kb.ExtractConfig != nil { + config["nodeExtract"] = map[string]interface{}{ + "enabled": true, + "text": kb.ExtractConfig.Text, + "tags": kb.ExtractConfig.Tags, + "nodes": kb.ExtractConfig.Nodes, + "relations": kb.ExtractConfig.Relations, + } + } else { + config["nodeExtract"] = map[string]interface{}{ + "enabled": false, + } + } + return config } @@ -1148,8 +1218,8 @@ func (h *InitializationHandler) checkRemoteModelConnection(ctx context.Context, // checkRerankModelConnection 检查Rerank模型连接和功能的内部方法 func (h *InitializationHandler) checkRerankModelConnection(ctx context.Context, - modelName, baseURL, apiKey string) (bool, string) { - + modelName, baseURL, apiKey string, +) (bool, string) { // 创建Reranker配置 config := &rerank.RerankerConfig{ APIKey: apiKey, @@ -1492,3 +1562,214 @@ func (h *InitializationHandler) testMultimodalWithDocReader( return result, nil } + +// TextRelationExtractionRequest 文本关系提取请求结构 +type TextRelationExtractionRequest struct { + Text string `json:"text" binding:"required"` + Tags []string `json:"tags" binding:"required"` + LLMConfig LLMConfig `json:"llmConfig"` +} + +type LLMConfig struct { + Source string `json:"source"` + ModelName string `json:"modelName"` + BaseUrl string `json:"baseUrl"` + ApiKey string `json:"apiKey"` +} + +// TextRelationExtractionResponse 文本关系提取响应结构 +type TextRelationExtractionResponse struct { + Nodes []*types.GraphNode `json:"nodes"` + Relations []*types.GraphRelation `json:"relations"` +} + +func (h *InitializationHandler) ExtractTextRelations(c *gin.Context) { + ctx := c.Request.Context() + + var req TextRelationExtractionRequest + if err := c.ShouldBindJSON(&req); err != nil { + logger.Error(ctx, "文本关系提取请求参数错误") + c.Error(errors.NewBadRequestError("文本关系提取请求参数错误")) + return + } + + // 验证文本内容 + if len(req.Text) == 0 { + c.Error(errors.NewBadRequestError("文本内容不能为空")) + return + } + + if len(req.Text) > 5000 { + c.Error(errors.NewBadRequestError("文本内容长度不能超过5000字符")) + return + } + + // 验证标签 + if len(req.Tags) == 0 { + c.Error(errors.NewBadRequestError("至少需要选择一个关系标签")) + return + } + + // 调用模型服务进行文本关系提取 + result, err := h.extractRelationsFromText(ctx, req.Text, req.Tags, req.LLMConfig) + if err != nil { + logger.Error(ctx, "文本关系提取失败", err) + c.JSON(http.StatusOK, gin.H{ + "success": true, + "data": gin.H{ + "success": false, + "message": err.Error(), + }, + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "data": result, + }) +} + +// extractRelationsFromText 从文本中提取关系 +func (h *InitializationHandler) extractRelationsFromText(ctx context.Context, text string, tags []string, llm LLMConfig) (*TextRelationExtractionResponse, error) { + chatModel, err := chat.NewChat(&chat.ChatConfig{ + ModelID: "initialization", + APIKey: llm.ApiKey, + BaseURL: llm.BaseUrl, + ModelName: llm.ModelName, + Source: types.ModelSource(llm.Source), + }) + if err != nil { + logger.Error(ctx, "初始化模型服务失败", err) + return nil, err + } + + template := &types.PromptTemplateStructured{ + Description: h.config.ExtractManager.ExtractGraph.Description, + Tags: tags, + Examples: h.config.ExtractManager.ExtractGraph.Examples, + } + + extractor := chatpipline.NewExtractor(chatModel, template) + graph, err := extractor.Extract(ctx, text) + if err != nil { + logger.Error(ctx, "文本关系提取失败", err) + return nil, err + } + extractor.RemoveUnknownRelation(ctx, graph) + + result := &TextRelationExtractionResponse{ + Nodes: graph.Node, + Relations: graph.Relation, + } + + return result, nil +} + +type FabriTextRequest struct { + Tags []string `json:"tags"` + LLMConfig LLMConfig `json:"llmConfig"` +} + +type FabriTextResponse struct { + Text string `json:"text"` +} + +func (h *InitializationHandler) FabriText(c *gin.Context) { + ctx := c.Request.Context() + + var req FabriTextRequest + if err := c.ShouldBindJSON(&req); err != nil { + logger.Error(ctx, "生成示例文本请求参数错误") + c.Error(errors.NewBadRequestError("生成示例文本请求参数错误")) + return + } + + result, err := h.fabriText(ctx, req.Tags, req.LLMConfig) + if err != nil { + logger.Error(ctx, "生成示例文本失败", err) + c.JSON(http.StatusOK, gin.H{ + "success": true, + "data": gin.H{ + "success": false, + "message": err.Error(), + }, + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "data": FabriTextResponse{Text: result}, + }) +} + +func (h *InitializationHandler) fabriText(ctx context.Context, tags []string, llm LLMConfig) (string, error) { + chatModel, err := chat.NewChat(&chat.ChatConfig{ + ModelID: "initialization", + APIKey: llm.ApiKey, + BaseURL: llm.BaseUrl, + ModelName: llm.ModelName, + Source: types.ModelSource(llm.Source), + }) + if err != nil { + logger.Error(ctx, "初始化模型服务失败", err) + return "", err + } + + content := h.config.ExtractManager.FabriText.WithNoTag + if len(tags) > 0 { + tagStr, _ := json.Marshal(tags) + content = fmt.Sprintf(h.config.ExtractManager.FabriText.WithTag, string(tagStr)) + } + + think := false + result, err := chatModel.Chat(ctx, []chat.Message{ + {Role: "user", Content: content}, + }, &chat.ChatOptions{ + Temperature: 0.3, + MaxTokens: 4096, + Thinking: &think, + }) + if err != nil { + logger.Error(ctx, "生成示例文本失败", err) + return "", err + } + return result.Content, nil +} + +type FabriTagRequest struct { + LLMConfig LLMConfig `json:"llmConfig"` +} + +type FabriTagResponse struct { + Tags []string `json:"tags"` +} + +var tagOptions = []string{ + "内容", "文化", "人物", "事件", "时间", "地点", "作品", "作者", "关系", "属性", +} + +func (h *InitializationHandler) FabriTag(c *gin.Context) { + tagRandom := RandomSelect(tagOptions, rand.Intn(len(tagOptions)-1)+1) + c.JSON(http.StatusOK, gin.H{ + "success": true, + "data": FabriTagResponse{Tags: tagRandom}, + }) +} + +func RandomSelect(strs []string, n int) []string { + if n <= 0 { + return []string{} + } + result := make([]string, len(strs)) + copy(result, strs) + rand.Shuffle(len(result), func(i, j int) { + result[i], result[j] = result[j], result[i] + }) + + if n > len(strs) { + n = len(strs) + } + return result[:n] +} diff --git a/internal/handler/session.go b/internal/handler/session.go index 870a024..374a614 100644 --- a/internal/handler/session.go +++ b/internal/handler/session.go @@ -83,7 +83,7 @@ type CreateSessionRequest struct { func (h *SessionHandler) CreateSession(c *gin.Context) { ctx := c.Request.Context() - logger.Infof(ctx, "Start creating session, config: %+v", h.config.Conversation) + // logger.Infof(ctx, "Start creating session, config: %+v", h.config.Conversation) // Parse and validate the request body var request CreateSessionRequest diff --git a/internal/router/router.go b/internal/router/router.go index 9d23542..38785c8 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -267,6 +267,10 @@ func RegisterInitializationRoutes(r *gin.RouterGroup, handler *handler.Initializ r.POST("/initialization/embedding/test", handler.TestEmbeddingModel) r.POST("/initialization/rerank/check", handler.CheckRerankModel) r.POST("/initialization/multimodal/test", handler.TestMultimodalFunction) + + r.POST("/initialization/extract/text-relation", handler.ExtractTextRelations) + r.POST("/initialization/extract/fabri-tag", handler.FabriTag) + r.POST("/initialization/extract/fabri-text", handler.FabriText) } // RegisterSystemRoutes registers system information routes diff --git a/internal/router/task.go b/internal/router/task.go new file mode 100644 index 0000000..547c4f1 --- /dev/null +++ b/internal/router/task.go @@ -0,0 +1,66 @@ +package router + +import ( + "log" + "os" + "time" + + "github.com/Tencent/WeKnora/internal/types" + "github.com/Tencent/WeKnora/internal/types/interfaces" + "github.com/hibiken/asynq" + "go.uber.org/dig" +) + +type AsynqTaskParams struct { + dig.In + + Server *asynq.Server + Extracter interfaces.Extracter +} + +func getAsynqRedisClientOpt() *asynq.RedisClientOpt { + opt := &asynq.RedisClientOpt{ + Addr: os.Getenv("REDIS_ADDR"), + Password: os.Getenv("REDIS_PASSWORD"), + ReadTimeout: 100 * time.Millisecond, + WriteTimeout: 200 * time.Millisecond, + DB: 0, + } + return opt +} + +func NewAsyncqClient() *asynq.Client { + opt := getAsynqRedisClientOpt() + client := asynq.NewClient(opt) + return client +} + +func NewAsynqServer() *asynq.Server { + opt := getAsynqRedisClientOpt() + srv := asynq.NewServer( + opt, + asynq.Config{ + Queues: map[string]int{ + "critical": 6, // Highest priority queue + "default": 3, // Default priority queue + "low": 1, // Lowest priority queue + }, + }, + ) + return srv +} + +func RunAsynqServer(params AsynqTaskParams) *asynq.ServeMux { + // Create a new mux and register all handlers + mux := asynq.NewServeMux() + + mux.HandleFunc(types.TypeChunkExtract, params.Extracter.Extract) + + go func() { + // Start the server + if err := params.Server.Run(mux); err != nil { + log.Fatalf("could not run server: %v", err) + } + }() + return mux +} diff --git a/internal/types/chat_manage.go b/internal/types/chat_manage.go index 917cbdb..fff243d 100644 --- a/internal/types/chat_manage.go +++ b/internal/types/chat_manage.go @@ -28,6 +28,8 @@ type ChatManage struct { SearchResult []*SearchResult `json:"-"` // Results from search phase RerankResult []*SearchResult `json:"-"` // Results after reranking MergeResult []*SearchResult `json:"-"` // Final merged results after all processing + Entity []string `json:"-"` // List of identified entities + GraphResult *GraphData `json:"-"` // Graph data from search phase UserContent string `json:"-"` // Processed user content ChatResponse *ChatResponse `json:"-"` // Final response from chat model ResponseChan <-chan StreamResponse `json:"-"` // Channel for streaming responses @@ -75,6 +77,7 @@ const ( PREPROCESS_QUERY EventType = "preprocess_query" // Query preprocessing stage REWRITE_QUERY EventType = "rewrite_query" // Query rewriting for better retrieval CHUNK_SEARCH EventType = "chunk_search" // Search for relevant chunks + ENTITY_SEARCH EventType = "entity_search" // Search for relevant entities CHUNK_RERANK EventType = "chunk_rerank" // Rerank search results CHUNK_MERGE EventType = "chunk_merge" // Merge similar chunks INTO_CHAT_MESSAGE EventType = "into_chat_message" // Convert chunks into chat messages @@ -104,6 +107,7 @@ var Pipline = map[string][]EventType{ REWRITE_QUERY, PREPROCESS_QUERY, CHUNK_SEARCH, + ENTITY_SEARCH, CHUNK_RERANK, CHUNK_MERGE, FILTER_TOP_K, diff --git a/internal/types/embedding.go b/internal/types/embedding.go index b21cd78..19fc810 100644 --- a/internal/types/embedding.go +++ b/internal/types/embedding.go @@ -19,6 +19,7 @@ const ( MatchTypeHistory MatchTypeParentChunk // 父Chunk匹配类型 MatchTypeRelationChunk // 关系Chunk匹配类型 + MatchTypeGraph ) // IndexInfo contains information about indexed content diff --git a/internal/types/extract_graph.go b/internal/types/extract_graph.go new file mode 100644 index 0000000..2bffdaf --- /dev/null +++ b/internal/types/extract_graph.go @@ -0,0 +1,51 @@ +package types + +const ( + TypeChunkExtract = "chunk:extract" +) + +type ExtractChunkPayload struct { + TenantID uint `json:"tenant_id"` + ChunkID string `json:"chunk_id"` + ModelID string `json:"model_id"` +} + +type PromptTemplateStructured struct { + Description string `json:"description"` + Tags []string `json:"tags"` + Examples []GraphData `json:"examples"` +} + +type GraphNode struct { + Name string `json:"name,omitempty"` + Chunks []string `json:"chunks,omitempty"` + Attributes []string `json:"attributes,omitempty"` +} + +type GraphRelation struct { + Node1 string `json:"node1,omitempty"` + Node2 string `json:"node2,omitempty"` + Type string `json:"type,omitempty"` +} + +type GraphData struct { + Text string `json:"text,omitempty"` + Node []*GraphNode `json:"node,omitempty"` + Relation []*GraphRelation `json:"relation,omitempty"` +} + +type NameSpace struct { + KnowledgeBase string `json:"knowledge_base"` + Knowledge string `json:"knowledge"` +} + +func (n NameSpace) Labels() []string { + res := make([]string, 0) + if n.KnowledgeBase != "" { + res = append(res, n.KnowledgeBase) + } + if n.Knowledge != "" { + res = append(res, n.Knowledge) + } + return res +} diff --git a/internal/types/interfaces/extracter.go b/internal/types/interfaces/extracter.go new file mode 100644 index 0000000..c764ad1 --- /dev/null +++ b/internal/types/interfaces/extracter.go @@ -0,0 +1,11 @@ +package interfaces + +import ( + "context" + + "github.com/hibiken/asynq" +) + +type Extracter interface { + Extract(ctx context.Context, t *asynq.Task) error +} diff --git a/internal/types/interfaces/retriever_graph.go b/internal/types/interfaces/retriever_graph.go new file mode 100644 index 0000000..04f939d --- /dev/null +++ b/internal/types/interfaces/retriever_graph.go @@ -0,0 +1,13 @@ +package interfaces + +import ( + "context" + + "github.com/Tencent/WeKnora/internal/types" +) + +type RetrieveGraphRepository interface { + AddGraph(ctx context.Context, namespace types.NameSpace, graphs []*types.GraphData) error + DelGraph(ctx context.Context, namespace []types.NameSpace) error + SearchNode(ctx context.Context, namespace types.NameSpace, nodes []string) (*types.GraphData, error) +} diff --git a/internal/types/knowledgebase.go b/internal/types/knowledgebase.go index 5a1523b..c2a42ec 100644 --- a/internal/types/knowledgebase.go +++ b/internal/types/knowledgebase.go @@ -38,6 +38,8 @@ type KnowledgeBase struct { VLMConfig VLMConfig `yaml:"vlm_config" json:"vlm_config" gorm:"type:json"` // Storage config StorageConfig StorageConfig `yaml:"cos_config" json:"cos_config" gorm:"column:cos_config;type:json"` + // Extract config + ExtractConfig *ExtractConfig `yaml:"extract_config" json:"extract_config" gorm:"column:extract_config;type:json"` // Creation time of the knowledge base CreatedAt time.Time `yaml:"created_at" json:"created_at"` // Last updated time of the knowledge base @@ -167,3 +169,27 @@ func (c *VLMConfig) Scan(value interface{}) error { } return json.Unmarshal(b, c) } + +type ExtractConfig struct { + Text string `yaml:"text" json:"text"` + Tags []string `yaml:"tags" json:"tags"` + Nodes []*GraphNode `yaml:"nodes" json:"nodes"` + Relations []*GraphRelation `yaml:"relations" json:"relations"` +} + +// Value implements the driver.Valuer interface, used to convert ExtractConfig to database value +func (e ExtractConfig) Value() (driver.Value, error) { + return json.Marshal(e) +} + +// Scan implements the sql.Scanner interface, used to convert database value to ExtractConfig +func (e *ExtractConfig) Scan(value interface{}) error { + if value == nil { + return nil + } + b, ok := value.([]byte) + if !ok { + return nil + } + return json.Unmarshal(b, e) +} diff --git a/migrations/mysql/00-init-db.sql b/migrations/mysql/00-init-db.sql index 000e6fc..e0863ab 100644 --- a/migrations/mysql/00-init-db.sql +++ b/migrations/mysql/00-init-db.sql @@ -51,6 +51,7 @@ CREATE TABLE knowledge_bases ( vlm_model_id VARCHAR(64) NOT NULL, cos_config JSON NOT NULL, vlm_config JSON NOT NULL, + extract_config JSON NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, deleted_at TIMESTAMP NULL DEFAULT NULL diff --git a/migrations/paradedb/00-init-db.sql b/migrations/paradedb/00-init-db.sql index 1354ba1..403f188 100644 --- a/migrations/paradedb/00-init-db.sql +++ b/migrations/paradedb/00-init-db.sql @@ -62,6 +62,7 @@ CREATE TABLE IF NOT EXISTS knowledge_bases ( vlm_model_id VARCHAR(64) NOT NULL, cos_config JSONB NOT NULL DEFAULT '{}', vlm_config JSONB NOT NULL DEFAULT '{}', + extract_config JSONB NULL DEFAULT NULL, created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, deleted_at TIMESTAMP WITH TIME ZONE