Files
WeKnora/internal/application/repository/retriever/neo4j/repository.go
2025-10-16 17:48:21 +08:00

232 lines
7.3 KiB
Go

package neo4j
import (
"context"
"fmt"
"strings"
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
"github.com/neo4j/neo4j-go-driver/v6/neo4j"
)
type Neo4jRepository struct {
driver neo4j.Driver
nodePrefix string
}
func NewNeo4jRepository(driver neo4j.Driver) interfaces.RetrieveGraphRepository {
return &Neo4jRepository{driver: driver, nodePrefix: "ENTITY"}
}
func _remove_hyphen(s string) string {
return strings.ReplaceAll(s, "-", "_")
}
func (n *Neo4jRepository) Labels(namespace types.NameSpace) []string {
res := make([]string, 0)
for _, label := range namespace.Labels() {
res = append(res, n.nodePrefix+_remove_hyphen(label))
}
return res
}
func (n *Neo4jRepository) Label(namespace types.NameSpace) string {
labels := n.Labels(namespace)
return strings.Join(labels, ":")
}
// AddGraph implements interfaces.RetrieveGraphRepository.
func (n *Neo4jRepository) AddGraph(ctx context.Context, namespace types.NameSpace, graphs []*types.GraphData) error {
if n.driver == nil {
logger.Warnf(ctx, "NOT SUPPORT RETRIEVE GRAPH")
return nil
}
for _, graph := range graphs {
if err := n.addGraph(ctx, namespace, graph); err != nil {
return err
}
}
return nil
}
func (n *Neo4jRepository) addGraph(ctx context.Context, namespace types.NameSpace, graph *types.GraphData) error {
session := n.driver.NewSession(ctx, neo4j.SessionConfig{AccessMode: neo4j.AccessModeWrite})
defer session.Close(ctx)
_, err := session.ExecuteWrite(ctx, func(tx neo4j.ManagedTransaction) (interface{}, error) {
node_import_query := `
UNWIND $data AS row
CALL apoc.merge.node(row.labels, {name: row.name, kg: row.knowledge_id}, row.props, {}) YIELD node
SET node.chunks = apoc.coll.union(node.chunks, row.chunks)
RETURN distinct 'done' AS result
`
nodeData := []map[string]interface{}{}
for _, node := range graph.Node {
nodeData = append(nodeData, map[string]interface{}{
"name": node.Name,
"knowledge_id": namespace.Knowledge,
"props": map[string][]string{"attributes": node.Attributes},
"chunks": node.Chunks,
"labels": n.Labels(namespace),
})
}
if _, err := tx.Run(ctx, node_import_query, map[string]interface{}{"data": nodeData}); err != nil {
return nil, fmt.Errorf("failed to create nodes: %v", err)
}
rel_import_query := `
UNWIND $data AS row
CALL apoc.merge.node(row.source_labels, {name: row.source, kg: row.knowledge_id}, {}, {}) YIELD node as source
CALL apoc.merge.node(row.target_labels, {name: row.target, kg: row.knowledge_id}, {}, {}) YIELD node as target
CALL apoc.merge.relationship(source, row.type, {}, row.attributes, target) YIELD rel
RETURN distinct 'done'
`
relData := []map[string]interface{}{}
for _, rel := range graph.Relation {
relData = append(relData, map[string]interface{}{
"source": rel.Node1,
"target": rel.Node2,
"knowledge_id": namespace.Knowledge,
"type": rel.Type,
"source_labels": n.Labels(namespace),
"target_labels": n.Labels(namespace),
})
}
if _, err := tx.Run(ctx, rel_import_query, map[string]interface{}{"data": relData}); err != nil {
return nil, fmt.Errorf("failed to create relationships: %v", err)
}
return nil, nil
})
if err != nil {
logger.Errorf(ctx, "failed to add graph: %v", err)
return err
}
return nil
}
// DelGraph implements interfaces.RetrieveGraphRepository.
func (n *Neo4jRepository) DelGraph(ctx context.Context, namespaces []types.NameSpace) error {
if n.driver == nil {
logger.Warnf(ctx, "NOT SUPPORT RETRIEVE GRAPH")
return nil
}
session := n.driver.NewSession(ctx, neo4j.SessionConfig{AccessMode: neo4j.AccessModeWrite})
defer session.Close(ctx)
result, err := session.ExecuteWrite(ctx, func(tx neo4j.ManagedTransaction) (interface{}, error) {
for _, namespace := range namespaces {
labelExpr := n.Label(namespace)
deleteRelsQuery := `
CALL apoc.periodic.iterate(
"MATCH (n:` + labelExpr + ` {kg: $knowledge_id})-[r]-(m:` + labelExpr + ` {kg: $knowledge_id}) RETURN r",
"DELETE r",
{batchSize: 1000, parallel: true, params: {knowledge_id: $knowledge_id}}
) YIELD batches, total
RETURN total
`
if _, err := tx.Run(ctx, deleteRelsQuery, map[string]interface{}{"knowledge_id": namespace.Knowledge}); err != nil {
return nil, fmt.Errorf("failed to delete relationships: %v", err)
}
deleteNodesQuery := `
CALL apoc.periodic.iterate(
"MATCH (n:` + labelExpr + ` {kg: $knowledge_id}) RETURN n",
"DELETE n",
{batchSize: 1000, parallel: true, params: {knowledge_id: $knowledge_id}}
) YIELD batches, total
RETURN total
`
if _, err := tx.Run(ctx, deleteNodesQuery, map[string]interface{}{"knowledge_id": namespace.Knowledge}); err != nil {
return nil, fmt.Errorf("failed to delete nodes: %v", err)
}
}
return nil, nil
})
if err != nil {
return err
}
logger.Infof(ctx, "delete graph result: %v", result)
return nil
}
func (n *Neo4jRepository) SearchNode(ctx context.Context, namespace types.NameSpace, nodes []string) (*types.GraphData, error) {
if n.driver == nil {
logger.Warnf(ctx, "NOT SUPPORT RETRIEVE GRAPH")
return nil, nil
}
session := n.driver.NewSession(ctx, neo4j.SessionConfig{AccessMode: neo4j.AccessModeRead})
defer session.Close(ctx)
result, err := session.ExecuteRead(ctx, func(tx neo4j.ManagedTransaction) (interface{}, error) {
labelExpr := n.Label(namespace)
query := `
MATCH (n:` + labelExpr + `)-[r]-(m:` + labelExpr + `)
WHERE ANY(nodeText IN $nodes WHERE n.name CONTAINS nodeText)
RETURN n, r, m
`
params := map[string]interface{}{"nodes": nodes}
result, err := tx.Run(ctx, query, params)
if err != nil {
return nil, fmt.Errorf("failed to run query: %v", err)
}
graphData := &types.GraphData{}
nodeSeen := make(map[string]bool)
for result.Next(ctx) {
record := result.Record()
node, _ := record.Get("n")
rel, _ := record.Get("r")
targetNode, _ := record.Get("m")
nodeData := node.(neo4j.Node)
targetNodeData := targetNode.(neo4j.Node)
// Convert node to types.Node
for _, n := range []neo4j.Node{nodeData, targetNodeData} {
nameStr := n.Props["name"].(string)
if _, ok := nodeSeen[nameStr]; !ok {
nodeSeen[nameStr] = true
graphData.Node = append(graphData.Node, &types.GraphNode{
Name: nameStr,
Chunks: listI2listS(n.Props["chunks"].([]interface{})),
Attributes: listI2listS(n.Props["attributes"].([]interface{})),
})
}
}
// Convert relationship to types.Relation
relData := rel.(neo4j.Relationship)
graphData.Relation = append(graphData.Relation, &types.GraphRelation{
Node1: nodeData.Props["name"].(string),
Node2: targetNodeData.Props["name"].(string),
Type: relData.Type,
})
}
return graphData, nil
})
if err != nil {
logger.Errorf(ctx, "search node failed: %v", err)
return nil, err
}
return result.(*types.GraphData), nil
}
func listI2listS(list []any) []string {
result := make([]string, len(list))
for i, v := range list {
result[i] = fmt.Sprintf("%v", v)
}
return result
}
func mapI2mapS(prop map[string]any) map[string]string {
attributes := make(map[string]string)
for k, v := range prop {
attributes[k] = fmt.Sprintf("%v", v)
}
return attributes
}