package chatpipline import ( "context" "encoding/json" "errors" "fmt" "os" "regexp" "strings" "github.com/Tencent/WeKnora/internal/config" "github.com/Tencent/WeKnora/internal/logger" "github.com/Tencent/WeKnora/internal/models/chat" "github.com/Tencent/WeKnora/internal/types" "github.com/Tencent/WeKnora/internal/types/interfaces" ) // PluginExtractEntity is a plugin for extracting entities from user queries // It uses historical dialog context and large language models to identify key entities in the user's original query type PluginExtractEntity struct { modelService interfaces.ModelService // Model service for calling large language models template *types.PromptTemplateStructured // Template for generating prompts knowledgeBaseRepo interfaces.KnowledgeBaseRepository } // NewPluginRewrite creates a new query rewriting plugin instance // Also registers the plugin with the event manager func NewPluginExtractEntity( eventManager *EventManager, modelService interfaces.ModelService, knowledgeBaseRepo interfaces.KnowledgeBaseRepository, config *config.Config, ) *PluginExtractEntity { res := &PluginExtractEntity{ modelService: modelService, template: config.ExtractManager.ExtractEntity, knowledgeBaseRepo: knowledgeBaseRepo, } eventManager.Register(res) return res } // ActivationEvents returns the list of event types this plugin responds to // This plugin only responds to REWRITE_QUERY events func (p *PluginExtractEntity) ActivationEvents() []types.EventType { return []types.EventType{types.REWRITE_QUERY} } // OnEvent processes triggered events // When receiving a REWRITE_QUERY event, it rewrites the user query using conversation history and the language model func (p *PluginExtractEntity) OnEvent(ctx context.Context, eventType types.EventType, chatManage *types.ChatManage, next func() *PluginError, ) *PluginError { if strings.ToLower(os.Getenv("NEO4J_ENABLE")) != "true" { logger.Debugf(ctx, "skipping extract entity, neo4j is disabled") return next() } query := chatManage.Query model, err := p.modelService.GetChatModel(ctx, chatManage.ChatModelID) if err != nil { logger.Errorf(ctx, "Failed to get model, session_id: %s, error: %v", chatManage.SessionID, err) return next() } kb, err := p.knowledgeBaseRepo.GetKnowledgeBaseByID(ctx, chatManage.KnowledgeBaseID) if err != nil { logger.Errorf(ctx, "failed to get knowledge base: %v", err) return next() } if kb.ExtractConfig == nil { logger.Warnf(ctx, "failed to get extract config") return next() } template := &types.PromptTemplateStructured{ Description: p.template.Description, Examples: p.template.Examples, } extractor := NewExtractor(model, template) graph, err := extractor.Extract(ctx, query) if err != nil { logger.Errorf(ctx, "Failed to extract entities, session_id: %s, error: %v", chatManage.SessionID, err) return next() } nodes := []string{} for _, node := range graph.Node { nodes = append(nodes, node.Name) } logger.Debugf(ctx, "extracted node: %v", nodes) chatManage.Entity = nodes return next() } type Extractor struct { chat chat.Chat formater *Formater template *types.PromptTemplateStructured chatOpt *chat.ChatOptions } func NewExtractor( chatModel chat.Chat, template *types.PromptTemplateStructured, ) Extractor { think := false return Extractor{ chat: chatModel, formater: NewFormater(), template: template, chatOpt: &chat.ChatOptions{ Temperature: 0.3, MaxTokens: 4096, Thinking: &think, }, } } func (e *Extractor) Extract(ctx context.Context, content string) (*types.GraphData, error) { generator := NewQAPromptGenerator(e.formater, e.template) // logger.Debugf(ctx, "chat system: %s", generator.System(ctx)) // logger.Debugf(ctx, "chat user: %s", generator.User(ctx, content)) chatResponse, err := e.chat.Chat(ctx, generator.Render(ctx, content), e.chatOpt) if err != nil { logger.Errorf(ctx, "failed to chat: %v", err) return nil, err } graph, err := e.formater.ParseGraph(ctx, chatResponse.Content) if err != nil { logger.Errorf(ctx, "failed to parse graph: %v", err) return nil, err } // e.RemoveUnknownRelation(ctx, graph) return graph, nil } func (e *Extractor) RemoveUnknownRelation(ctx context.Context, graph *types.GraphData) { relationType := make(map[string]bool) for _, tag := range e.template.Tags { relationType[tag] = true } relationNew := make([]*types.GraphRelation, 0) for _, relation := range graph.Relation { if _, ok := relationType[relation.Type]; ok { relationNew = append(relationNew, relation) } else { logger.Infof(ctx, "Unknown relation type %s with %v, ignore it", relation.Type, e.template.Tags) } } graph.Relation = relationNew } type QAPromptGenerator struct { Formater *Formater Template *types.PromptTemplateStructured ExamplesHeading string QuestionHeading string QuestionPrefix string AnswerPrefix string } func NewQAPromptGenerator(formater *Formater, template *types.PromptTemplateStructured) *QAPromptGenerator { return &QAPromptGenerator{ Formater: formater, Template: template, ExamplesHeading: "# Examples", QuestionHeading: "# Question", QuestionPrefix: "Q: ", AnswerPrefix: "A: ", } } func (qa *QAPromptGenerator) System(ctx context.Context) string { promptLines := []string{} if len(qa.Template.Tags) == 0 { promptLines = append(promptLines, qa.Template.Description) } else { tags, _ := json.Marshal(qa.Template.Tags) promptLines = append(promptLines, fmt.Sprintf(qa.Template.Description, string(tags))) } if len(qa.Template.Examples) > 0 { promptLines = append(promptLines, qa.ExamplesHeading) for _, example := range qa.Template.Examples { // Question promptLines = append(promptLines, fmt.Sprintf("%s%s", qa.QuestionPrefix, strings.TrimSpace(example.Text))) // Answer answer, err := qa.Formater.formatExtraction(example.Node, example.Relation) if err != nil { return "" } promptLines = append(promptLines, fmt.Sprintf("%s%s", qa.AnswerPrefix, answer)) // new line promptLines = append(promptLines, "") } } return strings.Join(promptLines, "\n") } func (qa *QAPromptGenerator) User(ctx context.Context, question string) string { promptLines := []string{} promptLines = append(promptLines, qa.QuestionHeading) promptLines = append(promptLines, fmt.Sprintf("%s%s", qa.QuestionPrefix, question)) promptLines = append(promptLines, qa.AnswerPrefix) return strings.Join(promptLines, "\n") } func (qa *QAPromptGenerator) Render(ctx context.Context, question string) []chat.Message { return []chat.Message{ { Role: "system", Content: qa.System(ctx), }, { Role: "user", Content: qa.User(ctx, question), }, } } type FormatType string const ( FormatTypeJSON FormatType = "json" FormatTypeYAML FormatType = "yaml" ) const ( _FENCE_START = "```" _LANGUAGE_TAG = `(?P[A-Za-z0-9_+-]+)?` _FENCE_NEWLINE = `(?:\s*\n)?` _FENCE_BODY = `(?P[\s\S]*?)` _FENCE_END = "```" ) var _FENCE_RE = regexp.MustCompile( _FENCE_START + _LANGUAGE_TAG + _FENCE_NEWLINE + _FENCE_BODY + _FENCE_END, ) type Formater struct { attributeSuffix string formatType FormatType useFences bool nodePrefix string relationSource string relationTarget string relationPrefix string } func NewFormater() *Formater { return &Formater{ attributeSuffix: "_attributes", formatType: FormatTypeJSON, useFences: true, nodePrefix: "entity", relationSource: "entity1", relationTarget: "entity2", relationPrefix: "relation", } } func (f *Formater) formatExtraction(nodes []*types.GraphNode, relations []*types.GraphRelation) (string, error) { items := make([]map[string]interface{}, 0) for _, node := range nodes { item := map[string]interface{}{ f.nodePrefix: node.Name, } if len(node.Attributes) > 0 { item[fmt.Sprintf("%s%s", f.nodePrefix, f.attributeSuffix)] = node.Attributes } items = append(items, item) } for _, relation := range relations { item := map[string]interface{}{ f.relationSource: relation.Node1, f.relationTarget: relation.Node2, f.relationPrefix: relation.Type, } items = append(items, item) } formatted := "" switch f.formatType { default: formattedBytes, err := json.MarshalIndent(items, "", " ") if err != nil { return "", err } formatted = string(formattedBytes) } if f.useFences { formatted = f.addFences(formatted) } return formatted, nil } func (f *Formater) parseOutput(ctx context.Context, text string) ([]map[string]interface{}, error) { if text == "" { return nil, errors.New("Empty or invalid input string.") } content := f.extractContent(ctx, text) // logger.Debugf(ctx, "Extracted content: %s", content) if content == "" { return nil, errors.New("Empty or invalid input string.") } var parsed interface{} var err error if f.formatType == FormatTypeJSON { err = json.Unmarshal([]byte(content), &parsed) } if err != nil { return nil, fmt.Errorf("Failed to parse %s content: %s", strings.ToUpper(string(f.formatType)), err.Error()) } if parsed == nil { return nil, fmt.Errorf("Content must be a list of extractions or a dict.") } var items []interface{} if parsedMap, ok := parsed.(map[string]interface{}); ok { items = []interface{}{parsedMap} } else if parsedList, ok := parsed.([]interface{}); ok { items = parsedList } else { return nil, fmt.Errorf("Expected list or dict, got %T", parsed) } itemsList := make([]map[string]interface{}, 0) for _, item := range items { if itemMap, ok := item.(map[string]interface{}); ok { itemsList = append(itemsList, itemMap) } else { return nil, fmt.Errorf("Each item in the sequence must be a mapping.") } } return itemsList, nil } func (f *Formater) ParseGraph(ctx context.Context, text string) (*types.GraphData, error) { matchData, err := f.parseOutput(ctx, text) if err != nil { return nil, err } if len(matchData) == 0 { logger.Debugf(ctx, "Received empty extraction data.") return &types.GraphData{}, nil } // mm, _ := json.Marshal(matchData) // logger.Debugf(ctx, "Parsed graph data: %s", string(mm)) var nodes []*types.GraphNode var relations []*types.GraphRelation for _, group := range matchData { switch { case group[f.nodePrefix] != nil: attributes := make([]string, 0) attributesKey := f.nodePrefix + f.attributeSuffix if attr, ok := group[attributesKey].([]interface{}); ok { for _, v := range attr { attributes = append(attributes, fmt.Sprintf("%v", v)) } } nodes = append(nodes, &types.GraphNode{ Name: fmt.Sprintf("%v", group[f.nodePrefix]), Attributes: attributes, }) case group[f.relationSource] != nil && group[f.relationTarget] != nil: relations = append(relations, &types.GraphRelation{ Node1: fmt.Sprintf("%v", group[f.relationSource]), Node2: fmt.Sprintf("%v", group[f.relationTarget]), Type: fmt.Sprintf("%v", group[f.relationPrefix]), }) default: logger.Warnf(ctx, "Unsupported graph group: %v", group) continue } } graph := &types.GraphData{ Node: nodes, Relation: relations, } f.rebuildGraph(ctx, graph) return graph, nil } func (f *Formater) rebuildGraph(ctx context.Context, graph *types.GraphData) { nodeMap := make(map[string]*types.GraphNode) nodes := make([]*types.GraphNode, 0, len(graph.Node)) for _, node := range graph.Node { if prenode, ok := nodeMap[node.Name]; ok { logger.Infof(ctx, "Duplicate node ID: %s, merge attribute", node.Name) // 修复panic:检查Attributes是否为nil if node.Attributes == nil { node.Attributes = make([]string, 0) } if prenode.Attributes != nil { for _, attr := range prenode.Attributes { node.Attributes = append(node.Attributes, attr) } } continue } nodeMap[node.Name] = node nodes = append(nodes, node) } relations := make([]*types.GraphRelation, 0, len(graph.Relation)) for _, relation := range graph.Relation { if relation.Node1 == relation.Node2 { logger.Infof(ctx, "Duplicate relation, ignore it") continue } if _, ok := nodeMap[relation.Node1]; !ok { node := &types.GraphNode{Name: relation.Node1} nodes = append(nodes, node) nodeMap[relation.Node1] = node logger.Infof(ctx, "Add unknown source node ID: %s", relation.Node1) } if _, ok := nodeMap[relation.Node2]; !ok { node := &types.GraphNode{Name: relation.Node2} nodes = append(nodes, node) nodeMap[relation.Node2] = node logger.Infof(ctx, "Add unknown target node ID: %s", relation.Node2) } relations = append(relations, relation) } *graph = types.GraphData{ Node: nodes, Relation: relations, } } func (f *Formater) extractContent(ctx context.Context, text string) string { if !f.useFences { return strings.TrimSpace(text) } validTags := map[FormatType]map[string]struct{}{ FormatTypeYAML: {"yaml": {}, "yml": {}}, FormatTypeJSON: {"json": {}}, } matches := _FENCE_RE.FindAllStringSubmatch(text, -1) var candidates []string for _, match := range matches { lang := match[1] body := match[2] if f.isValidLanguageTag(lang, validTags) { candidates = append(candidates, body) } } switch { case len(candidates) == 1: return strings.TrimSpace(candidates[0]) case len(candidates) > 1: logger.Warnf(ctx, "multiple candidates found: %d", len(candidates)) return strings.TrimSpace(candidates[0]) case len(matches) == 1: logger.Debugf(ctx, "no candidate found, use first match without language tag: %s", matches[0][1]) return strings.TrimSpace(matches[0][2]) case len(matches) > 1: logger.Warnf(ctx, "multiple matches found: %d", len(matches)) return strings.TrimSpace(matches[0][2]) default: logger.Warnf(ctx, "no match found") return strings.TrimSpace(text) } } func (f *Formater) addFences(content string) string { content = strings.TrimSpace(content) return fmt.Sprintf("```%s\n%s\n```", f.formatType, content) } func (f *Formater) isValidLanguageTag(lang string, validTags map[FormatType]map[string]struct{}) bool { if lang == "" { return true } tag := strings.TrimSpace(strings.ToLower(lang)) validSet, ok := validTags[f.formatType] if !ok { return false } _, exists := validSet[tag] return exists }