mirror of
https://github.com/Tencent/WeKnora.git
synced 2025-11-25 03:15:00 +08:00
232 lines
7.3 KiB
Go
232 lines
7.3 KiB
Go
package neo4j
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/Tencent/WeKnora/internal/logger"
|
|
"github.com/Tencent/WeKnora/internal/types"
|
|
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
|
"github.com/neo4j/neo4j-go-driver/v6/neo4j"
|
|
)
|
|
|
|
type Neo4jRepository struct {
|
|
driver neo4j.Driver
|
|
nodePrefix string
|
|
}
|
|
|
|
func NewNeo4jRepository(driver neo4j.Driver) interfaces.RetrieveGraphRepository {
|
|
return &Neo4jRepository{driver: driver, nodePrefix: "ENTITY"}
|
|
}
|
|
|
|
func _remove_hyphen(s string) string {
|
|
return strings.ReplaceAll(s, "-", "_")
|
|
}
|
|
|
|
func (n *Neo4jRepository) Labels(namespace types.NameSpace) []string {
|
|
res := make([]string, 0)
|
|
for _, label := range namespace.Labels() {
|
|
res = append(res, n.nodePrefix+_remove_hyphen(label))
|
|
}
|
|
return res
|
|
}
|
|
|
|
func (n *Neo4jRepository) Label(namespace types.NameSpace) string {
|
|
labels := n.Labels(namespace)
|
|
return strings.Join(labels, ":")
|
|
}
|
|
|
|
// AddGraph implements interfaces.RetrieveGraphRepository.
|
|
func (n *Neo4jRepository) AddGraph(ctx context.Context, namespace types.NameSpace, graphs []*types.GraphData) error {
|
|
if n.driver == nil {
|
|
logger.Warnf(ctx, "NOT SUPPORT RETRIEVE GRAPH")
|
|
return nil
|
|
}
|
|
for _, graph := range graphs {
|
|
if err := n.addGraph(ctx, namespace, graph); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (n *Neo4jRepository) addGraph(ctx context.Context, namespace types.NameSpace, graph *types.GraphData) error {
|
|
session := n.driver.NewSession(ctx, neo4j.SessionConfig{AccessMode: neo4j.AccessModeWrite})
|
|
defer session.Close(ctx)
|
|
|
|
_, err := session.ExecuteWrite(ctx, func(tx neo4j.ManagedTransaction) (interface{}, error) {
|
|
node_import_query := `
|
|
UNWIND $data AS row
|
|
CALL apoc.merge.node(row.labels, {name: row.name, kg: row.knowledge_id}, row.props, {}) YIELD node
|
|
SET node.chunks = apoc.coll.union(node.chunks, row.chunks)
|
|
RETURN distinct 'done' AS result
|
|
`
|
|
nodeData := []map[string]interface{}{}
|
|
for _, node := range graph.Node {
|
|
nodeData = append(nodeData, map[string]interface{}{
|
|
"name": node.Name,
|
|
"knowledge_id": namespace.Knowledge,
|
|
"props": map[string][]string{"attributes": node.Attributes},
|
|
"chunks": node.Chunks,
|
|
"labels": n.Labels(namespace),
|
|
})
|
|
}
|
|
if _, err := tx.Run(ctx, node_import_query, map[string]interface{}{"data": nodeData}); err != nil {
|
|
return nil, fmt.Errorf("failed to create nodes: %v", err)
|
|
}
|
|
|
|
rel_import_query := `
|
|
UNWIND $data AS row
|
|
CALL apoc.merge.node(row.source_labels, {name: row.source, kg: row.knowledge_id}, {}, {}) YIELD node as source
|
|
CALL apoc.merge.node(row.target_labels, {name: row.target, kg: row.knowledge_id}, {}, {}) YIELD node as target
|
|
CALL apoc.merge.relationship(source, row.type, {}, row.attributes, target) YIELD rel
|
|
RETURN distinct 'done'
|
|
`
|
|
relData := []map[string]interface{}{}
|
|
for _, rel := range graph.Relation {
|
|
relData = append(relData, map[string]interface{}{
|
|
"source": rel.Node1,
|
|
"target": rel.Node2,
|
|
"knowledge_id": namespace.Knowledge,
|
|
"type": rel.Type,
|
|
"source_labels": n.Labels(namespace),
|
|
"target_labels": n.Labels(namespace),
|
|
})
|
|
}
|
|
if _, err := tx.Run(ctx, rel_import_query, map[string]interface{}{"data": relData}); err != nil {
|
|
return nil, fmt.Errorf("failed to create relationships: %v", err)
|
|
}
|
|
return nil, nil
|
|
})
|
|
if err != nil {
|
|
logger.Errorf(ctx, "failed to add graph: %v", err)
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// DelGraph implements interfaces.RetrieveGraphRepository.
|
|
func (n *Neo4jRepository) DelGraph(ctx context.Context, namespaces []types.NameSpace) error {
|
|
if n.driver == nil {
|
|
logger.Warnf(ctx, "NOT SUPPORT RETRIEVE GRAPH")
|
|
return nil
|
|
}
|
|
session := n.driver.NewSession(ctx, neo4j.SessionConfig{AccessMode: neo4j.AccessModeWrite})
|
|
defer session.Close(ctx)
|
|
|
|
result, err := session.ExecuteWrite(ctx, func(tx neo4j.ManagedTransaction) (interface{}, error) {
|
|
for _, namespace := range namespaces {
|
|
labelExpr := n.Label(namespace)
|
|
|
|
deleteRelsQuery := `
|
|
CALL apoc.periodic.iterate(
|
|
"MATCH (n:` + labelExpr + ` {kg: $knowledge_id})-[r]-(m:` + labelExpr + ` {kg: $knowledge_id}) RETURN r",
|
|
"DELETE r",
|
|
{batchSize: 1000, parallel: true, params: {knowledge_id: $knowledge_id}}
|
|
) YIELD batches, total
|
|
RETURN total
|
|
`
|
|
if _, err := tx.Run(ctx, deleteRelsQuery, map[string]interface{}{"knowledge_id": namespace.Knowledge}); err != nil {
|
|
return nil, fmt.Errorf("failed to delete relationships: %v", err)
|
|
}
|
|
|
|
deleteNodesQuery := `
|
|
CALL apoc.periodic.iterate(
|
|
"MATCH (n:` + labelExpr + ` {kg: $knowledge_id}) RETURN n",
|
|
"DELETE n",
|
|
{batchSize: 1000, parallel: true, params: {knowledge_id: $knowledge_id}}
|
|
) YIELD batches, total
|
|
RETURN total
|
|
`
|
|
if _, err := tx.Run(ctx, deleteNodesQuery, map[string]interface{}{"knowledge_id": namespace.Knowledge}); err != nil {
|
|
return nil, fmt.Errorf("failed to delete nodes: %v", err)
|
|
}
|
|
}
|
|
return nil, nil
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
logger.Infof(ctx, "delete graph result: %v", result)
|
|
return nil
|
|
}
|
|
|
|
func (n *Neo4jRepository) SearchNode(ctx context.Context, namespace types.NameSpace, nodes []string) (*types.GraphData, error) {
|
|
if n.driver == nil {
|
|
logger.Warnf(ctx, "NOT SUPPORT RETRIEVE GRAPH")
|
|
return nil, nil
|
|
}
|
|
session := n.driver.NewSession(ctx, neo4j.SessionConfig{AccessMode: neo4j.AccessModeRead})
|
|
defer session.Close(ctx)
|
|
|
|
result, err := session.ExecuteRead(ctx, func(tx neo4j.ManagedTransaction) (interface{}, error) {
|
|
labelExpr := n.Label(namespace)
|
|
query := `
|
|
MATCH (n:` + labelExpr + `)-[r]-(m:` + labelExpr + `)
|
|
WHERE ANY(nodeText IN $nodes WHERE n.name CONTAINS nodeText)
|
|
RETURN n, r, m
|
|
`
|
|
params := map[string]interface{}{"nodes": nodes}
|
|
result, err := tx.Run(ctx, query, params)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to run query: %v", err)
|
|
}
|
|
|
|
graphData := &types.GraphData{}
|
|
nodeSeen := make(map[string]bool)
|
|
for result.Next(ctx) {
|
|
record := result.Record()
|
|
node, _ := record.Get("n")
|
|
rel, _ := record.Get("r")
|
|
targetNode, _ := record.Get("m")
|
|
|
|
nodeData := node.(neo4j.Node)
|
|
targetNodeData := targetNode.(neo4j.Node)
|
|
|
|
// Convert node to types.Node
|
|
for _, n := range []neo4j.Node{nodeData, targetNodeData} {
|
|
nameStr := n.Props["name"].(string)
|
|
if _, ok := nodeSeen[nameStr]; !ok {
|
|
nodeSeen[nameStr] = true
|
|
graphData.Node = append(graphData.Node, &types.GraphNode{
|
|
Name: nameStr,
|
|
Chunks: listI2listS(n.Props["chunks"].([]interface{})),
|
|
Attributes: listI2listS(n.Props["attributes"].([]interface{})),
|
|
})
|
|
}
|
|
}
|
|
|
|
// Convert relationship to types.Relation
|
|
relData := rel.(neo4j.Relationship)
|
|
graphData.Relation = append(graphData.Relation, &types.GraphRelation{
|
|
Node1: nodeData.Props["name"].(string),
|
|
Node2: targetNodeData.Props["name"].(string),
|
|
Type: relData.Type,
|
|
})
|
|
}
|
|
return graphData, nil
|
|
})
|
|
if err != nil {
|
|
logger.Errorf(ctx, "search node failed: %v", err)
|
|
return nil, err
|
|
}
|
|
return result.(*types.GraphData), nil
|
|
}
|
|
|
|
func listI2listS(list []any) []string {
|
|
result := make([]string, len(list))
|
|
for i, v := range list {
|
|
result[i] = fmt.Sprintf("%v", v)
|
|
}
|
|
return result
|
|
}
|
|
|
|
func mapI2mapS(prop map[string]any) map[string]string {
|
|
attributes := make(map[string]string)
|
|
for k, v := range prop {
|
|
attributes[k] = fmt.Sprintf("%v", v)
|
|
}
|
|
return attributes
|
|
}
|