42 Commits

Author SHA1 Message Date
qzero
ca704fa054 feat(MCP): MCP服务端添加使用文件创建知识的功能 2025-10-20 11:03:55 +08:00
begoniezhao
02b78a5908 feat: 新增异步任务提取服务 2025-10-16 17:48:21 +08:00
begoniezhao
de96a52d54 chore: 将 PostgreSQL 镜像版本更新为 v0.18.9-pg17 2025-10-15 11:48:49 +08:00
begoniezhao
f24cd817cb feat: 新增 URL 重复错误类型及 409 冲突处理 2025-09-22 17:29:15 +08:00
begoniezhao
4824e41361 chore: 移除构建镜像前的版本信息准备步骤 2025-09-19 12:29:03 +08:00
begoniezhao
bfd4fffbe3 chore: upspeed build docker 2025-09-19 12:29:03 +08:00
begoniezhao
a2902de6ce chore(build): Optimize Docker build configuration and process 2025-09-18 18:22:38 +08:00
begoniezhao
a5c3623a02 chore(build): Optimize Docker build configuration and process 2025-09-18 17:34:14 +08:00
begoniezhao
8f723b38fb chore(build): Optimize Docker build configuration and process, adjust task names 2025-09-18 17:34:14 +08:00
begoniezhao
7973128f4c chore(build): Optimize Makefile and Dockerfile.app build configuration 2025-09-18 14:46:10 +08:00
wizardchen
8ed050b8ec fix(ui): Fix Drag Upload File failed in knowlegebase 2025-09-18 13:17:54 +08:00
wizardchen
4ccbd2a127 fix(ui): Fix Drag Upload File Knowlegebase status check 2025-09-17 22:18:31 +08:00
wizardchen
512910584b fix(ui): Fix Ollama Model Download Progress 2025-09-17 21:35:33 +08:00
wizardchen
cd7e02e54a docs: Update CHANGELOG 2025-09-17 20:37:33 +08:00
wizardchen
c9b1f43ed7 chore: release v0.1.4 2025-09-17 20:30:02 +08:00
wizardchen
76fc64a807 fix(frontend): Fix Login Direct Page 2025-09-17 20:29:34 +08:00
wizardchen
947899ff10 fix(app): Update App LLM Model Check logic 2025-09-17 19:11:26 +08:00
wizardchen
5e0a99b127 fix: Get version script 2025-09-17 16:40:48 +08:00
lyingbug
b04566be32 Support multi knowledgebases operation 2025-09-17 16:36:21 +08:00
wizardchen
0157eb25bd merge main 2025-09-17 16:14:29 +08:00
wizardchen
91e65d6445 feat(ui): Support multi knowledgebases operation 2025-09-17 16:02:08 +08:00
begoniezhao
c589a911dc feat: Added multi-data source search engine configuration and optimization logic 2025-09-17 10:21:37 +08:00
wizardchen
66aec78960 chore(ui): Update Setting page 2025-09-16 20:30:09 +08:00
wizardchen
76fbfdf8ac chore(ui): Update Setting page 2025-09-16 20:18:47 +08:00
wizardchen
4137a63852 feat(ui): Add tenant info 2025-09-16 15:46:18 +08:00
wizardchen
d28f805707 fix(ui): Fix CSP error 2025-09-16 14:58:22 +08:00
v_wnxinfeng
2e395864b9 feat: 修复下载文件内容错误问题 2025-09-16 14:33:21 +08:00
wizardchen
4005aa3ded fix(ui): fix xss in thinking 2025-09-16 13:18:58 +08:00
wizardchen
5e22f96d37 chore: release v0.1.3
🔒 Security Features:
- Added login authentication functionality
- Fixed XSS vulnerabilities
- Enhanced security utilities and API key protection

🐛 Bug Fixes:
- Fixed OCR AVX support issues
- Improved Docker binary downloads
- Enhanced COS file service initialization

📚 Documentation:
- Added security notices to all README files
- Updated deployment recommendations

🚀 Features:
- Comprehensive user management system
- Enhanced authentication flow
- Improved logging and configuration
2025-09-16 11:11:32 +08:00
wizardchen
2237e1ee55 chore: release v0.1.3
- Add login authentication functionality
- Update security notices in all README files
- Update version badges and package.json
- Add deployment security recommendations
2025-09-16 11:08:43 +08:00
lyingbug
b11df52cfb Merge pull request #301 from lyingbug/login_page
feat: Add Login Page
2025-09-16 10:24:49 +08:00
lyingbug
c3744866fd Merge branch 'main' into login_page 2025-09-16 10:24:05 +08:00
begoniezhao
c2d52a9374 feat: Modify COS file service initialization parameters and URL processing logic 2025-09-16 10:15:50 +08:00
wizardchen
81bd2e6c2c feat: Add Login Page 2025-09-16 02:47:39 +08:00
wizardchen
0908f9c487 fix(ui): Fix xss attact 2025-09-15 20:02:25 +08:00
wizardchen
1aac37d3fd chore(docs): Fix Docs Spell 2025-09-15 19:15:15 +08:00
wizardchen
cd249df8c8 fix(ui): Ignore showing APIKEY for security 2025-09-15 14:52:00 +08:00
wizardchen
092b30af3e fix(docreader): Download binary by target arch in docker 2025-09-12 20:21:30 +08:00
wizardchen
74c121f7fb feat: Adjust App & Docreader log output 2025-09-11 23:14:23 +08:00
wizardchen
78088057fb fix: frontend depends app health 2025-09-11 14:22:29 +08:00
wizardchen
bff0e742fa fix: try fix ocr avx not support 2025-09-11 13:21:21 +08:00
wizardchen
6598baab2e chore: bump version to v0.1.2 2025-09-10 20:28:54 +08:00
97 changed files with 9597 additions and 3830 deletions

View File

@@ -121,6 +121,18 @@ COS_ENABLE_OLD_DOMAIN=true
# 如果解析网络连接使用Web代理需要配置以下参数
# WEB_PROXY=your_web_proxy
# Neo4j 开关
# NEO4J_ENABLE=false
# Neo4j的访问地址
# NEO4J_URI=neo4j://neo4j:7687
# Neo4j的用户名和密码
# NEO4J_USERNAME=neo4j
# Neo4j的密码
# NEO4J_PASSWORD=password
##############################################################
###### 注意: 以下配置不再生效已在Web“配置初始化”阶段完成 #########

View File

@@ -49,15 +49,7 @@ body:
请按照以下步骤收集相关日志:
**1. 应用模块日志:**
```bash
docker exec -it WeKnora-app tail -f /var/log/WeKnora.log
```
**2. 文档解析模块日志:**
```bash
docker exec -it WeKnora-docreader tail -f /var/log/docreader.log
```
docker compose logs -f --tail=1000 app docreader postgres
请重现问题并收集相关日志,然后粘贴到下面的日志字段中。
@@ -68,8 +60,7 @@ body:
description: 请按照上面的指南收集并粘贴相关日志
placeholder: |
请粘贴从以下命令收集的日志:
- docker exec -it WeKnora-app tail -f /var/log/WeKnora.log
- docker exec -it WeKnora-docreader tail -f /var/log/docreader.log
docker compose logs -f --tail=1000 app docreader postgres
render: shell
- type: input

View File

@@ -68,14 +68,8 @@ body:
如果问题涉及错误或需要调试,请收集相关日志:
**应用模块日志:**
```bash
docker exec -it WeKnora-app tail -f /var/log/WeKnora.log
```
**文档解析模块日志:**
```bash
docker exec -it WeKnora-docreader tail -f /var/log/docreader.log
docker compose logs -f --tail=1000 app docreader postgres
```
- type: textarea

View File

@@ -1,6 +1,8 @@
name: Build and Push Docker Image
on:
push:
tags:
- "v*"
branches:
- main
@@ -9,51 +11,201 @@ concurrency:
cancel-in-progress: false
jobs:
build-app:
build-ui:
runs-on: ubuntu-latest
strategy:
matrix:
include:
- service_name: ui
file: frontend/Dockerfile
context: ./frontend
platform: linux/amd64,linux/arm64
- service_name: app
file: docker/Dockerfile.app
context: .
platform: linux/amd64,linux/arm64
- service_name: docreader
file: docker/Dockerfile.docreader
context: .
platform: linux/amd64,linux/arm64
steps:
- name: Checkout code
uses: actions/checkout@v3
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v2
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_PASSWORD }}
- name: Read VERSION file
run: echo "VERSION=$(cat VERSION)" >> $GITHUB_ENV
- name: Docker meta
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ secrets.DOCKERHUB_USERNAME }}/weknora-ui
- name: Build ui Image
uses: docker/build-push-action@v3
with:
push: true
platforms: linux/amd64,linux/arm64
file: frontend/Dockerfile
context: ./frontend
labels: ${{ steps.meta.outputs.labels }}
tags: ${{ steps.meta.outputs.tags }}
cache-from: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/weknora-ui:cache
cache-to: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/weknora-ui:cache,mode=max
build-docreader:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_PASSWORD }}
- name: Docker meta
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ secrets.DOCKERHUB_USERNAME }}/weknora-docreader
- name: Build docreader Image
uses: docker/build-push-action@v3
with:
push: true
platforms: linux/amd64,linux/arm64
file: docker/Dockerfile.docreader
context: .
labels: ${{ steps.meta.outputs.labels }}
tags: ${{ steps.meta.outputs.tags }}
cache-from: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/weknora-docreader:cache
cache-to: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/weknora-docreader:cache,mode=max
build-app:
strategy:
matrix:
include:
- arch: amd64
platform: linux/amd64
runs: ubuntu-latest
- arch: arm64
platform: linux/arm64
runs: ubuntu-24.04-arm
runs-on: ${{ matrix.runs }}
steps:
- name: Checkout code
uses: actions/checkout@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
id: setup-buildx
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_PASSWORD }}
- name: Docker meta
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ secrets.DOCKERHUB_USERNAME }}/weknora-app
- name: Prepare version info
id: version
run: |
# 使用统一的版本管理脚本
eval "$(./scripts/get_version.sh env)"
echo "version=$VERSION" >> $GITHUB_OUTPUT
echo "commit_id=$COMMIT_ID" >> $GITHUB_OUTPUT
echo "build_time=$BUILD_TIME" >> $GITHUB_OUTPUT
echo "go_version=$GO_VERSION" >> $GITHUB_OUTPUT
platform=${{ matrix.platform }}
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
# 显示版本信息
./scripts/get_version.sh info
- name: Build ${{ matrix.service_name }} Image
- name: Build Cache for Docker
uses: actions/cache@v4
id: cache
with:
path: go-pkg-mod
key: ${{ env.PLATFORM_PAIR }}-go-build-cache-${{ hashFiles('**/go.sum') }}
- name: Inject go-build-cache
uses: reproducible-containers/buildkit-cache-dance@v3
with:
builder: ${{ steps.setup-buildx.outputs.name }}
cache-map: |
{
"go-pkg-mod": "/go/pkg/mod"
}
skip-extraction: ${{ steps.cache.outputs.cache-hit }}
- name: Build app Image
id: build
uses: docker/build-push-action@v3
with:
push: true
platforms: ${{ matrix.platform }}
file: ${{ matrix.file }}
context: ${{ matrix.context }}
tags: |
${{ secrets.DOCKERHUB_USERNAME }}/weknora-${{ matrix.service_name }}:latest
${{ secrets.DOCKERHUB_USERNAME }}/weknora-${{ matrix.service_name }}:${{ env.VERSION }}
cache-from: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/weknora-${{ matrix.service_name }}:cache
cache-to: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/weknora-${{ matrix.service_name }}:cache,mode=max
file: docker/Dockerfile.app
context: .
build-args: |
${{ format('VERSION_ARG={0}', steps.version.outputs.version) }}
${{ format('COMMIT_ID_ARG={0}', steps.version.outputs.commit_id) }}
${{ format('BUILD_TIME_ARG={0}', steps.version.outputs.build_time) }}
${{ format('GO_VERSION_ARG={0}', steps.version.outputs.go_version) }}
labels: ${{ steps.meta.outputs.labels }}
tags: ${{ secrets.DOCKERHUB_USERNAME }}/weknora-app
cache-from: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/weknora-app:cache-${{ env.PLATFORM_PAIR }}
cache-to: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/weknora-app:cache-${{ env.PLATFORM_PAIR }},mode=max
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
- name: Export digest
run: |
mkdir -p ${{ runner.temp }}/digests
digest="${{ steps.build.outputs.digest }}"
touch "${{ runner.temp }}/digests/${digest#sha256:}"
- name: Upload digest
uses: actions/upload-artifact@v4
with:
name: digests-${{ env.PLATFORM_PAIR }}
path: ${{ runner.temp }}/digests/*
if-no-files-found: error
retention-days: 1
merge:
runs-on: ubuntu-latest
needs:
- build-app
steps:
- name: Download digests
uses: actions/download-artifact@v4
with:
path: ${{ runner.temp }}/digests
pattern: digests-*
merge-multiple: true
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_PASSWORD }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Docker meta
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ secrets.DOCKERHUB_USERNAME }}/weknora-app
- name: Create manifest list and push
working-directory: ${{ runner.temp }}/digests
run: |
docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
$(printf '${{ secrets.DOCKERHUB_USERNAME }}/weknora-app@sha256:%s ' *)
- name: Inspect image
run: |
docker buildx imagetools inspect ${{ secrets.DOCKERHUB_USERNAME }}/weknora-app:${{ steps.meta.outputs.version }}

View File

@@ -2,6 +2,89 @@
All notable changes to this project will be documented in this file.
## [0.1.4] - 2025-09-17
### 🚀 Major Features
- **NEW**: Multi-knowledgebases operation support
- Added comprehensive multi-knowledgebase management functionality
- Implemented multi-data source search engine configuration and optimization logic
- Enhanced knowledge base switching and management in UI
- **NEW**: Enhanced tenant information management
- Added dedicated tenant information page
- Improved user and tenant management capabilities
### 🎨 UI/UX Improvements
- **REDESIGNED**: Settings page with improved layout and functionality
- **ENHANCED**: Menu component with multi-knowledgebase support
- **IMPROVED**: Initialization configuration page structure
- **OPTIMIZED**: Login page and authentication flow
### 🔒 Security Fixes
- **FIXED**: XSS attack vulnerabilities in thinking component
- **FIXED**: Content Security Policy (CSP) errors
- **ENHANCED**: Frontend security measures and input sanitization
### 🐛 Bug Fixes
- **FIXED**: Login direct page navigation issues
- **FIXED**: App LLM model check logic
- **FIXED**: Version script functionality
- **FIXED**: File download content errors
- **IMPROVED**: Document content component display
### 🧹 Code Cleanup
- **REMOVED**: Test data functionality and related APIs
- **SIMPLIFIED**: Initialization configuration components
- **CLEANED**: Redundant UI components and unused code
## [0.1.3] - 2025-09-16
### 🔒 Security Features
- **NEW**: Added login authentication functionality to enhance system security
- Implemented user authentication and authorization mechanisms
- Added session management and access control
- Fixed XSS attack vulnerabilities in frontend components
### 📚 Documentation Updates
- Added security notices in all README files (English, Chinese, Japanese)
- Updated deployment recommendations emphasizing internal/private network deployment
- Enhanced security guidelines to prevent information leakage risks
- Fixed documentation spelling issues
### 🛡️ Security Improvements
- Hide API keys in UI for security purposes
- Enhanced input sanitization and XSS protection
- Added comprehensive security utilities
### 🐛 Bug Fixes
- Fixed OCR AVX support issues
- Improved frontend health check dependencies
- Enhanced Docker binary downloads for target architecture
- Fixed COS file service initialization parameters and URL processing logic
### 🚀 Features & Enhancements
- Improved application and docreader log output
- Enhanced frontend routing and authentication flow
- Added comprehensive user management system
- Improved initialization configuration handling
### 🛡️ Security Recommendations
- Deploy WeKnora services in internal/private network environments
- Avoid direct exposure to public internet
- Configure proper firewall rules and access controls
- Regular updates for security patches and improvements
## [0.1.2] - 2025-09-10
- Fixed health check implementation for docreader service
- Improved query handling for empty queries
- Enhanced knowledge base column value update methods
- Optimized logging throughout the application
- Added process parsing documentation for markdown files
- Fixed OCR model pre-fetching in Docker containers
- Resolved image parser concurrency errors
- Added support for modifying listening port configuration
## [0.1.0] - 2025-09-08
- Initial public release of WeKnora.
@@ -14,4 +97,7 @@ All notable changes to this project will be documented in this file.
- Docker Compose for quick startup and service orchestration.
- MCP server support for integrating with MCP-compatible clients.
[0.1.4]: https://github.com/Tencent/WeKnora/tree/v0.1.4
[0.1.3]: https://github.com/Tencent/WeKnora/tree/v0.1.3
[0.1.2]: https://github.com/Tencent/WeKnora/tree/v0.1.2
[0.1.0]: https://github.com/Tencent/WeKnora/tree/v0.1.0

View File

@@ -85,7 +85,15 @@ clean:
# Build Docker image
docker-build-app:
docker build --platform $(PLATFORM) -f docker/Dockerfile.app -t $(DOCKER_IMAGE):$(DOCKER_TAG) .
@echo "获取版本信息..."
@eval $$(./scripts/get_version.sh env); \
./scripts/get_version.sh info; \
docker build --platform $(PLATFORM) \
--build-arg VERSION_ARG="$$VERSION" \
--build-arg COMMIT_ID_ARG="$$COMMIT_ID" \
--build-arg BUILD_TIME_ARG="$$BUILD_TIME" \
--build-arg GO_VERSION_ARG="$$GO_VERSION" \
-f docker/Dockerfile.app -t $(DOCKER_IMAGE):$(DOCKER_TAG) .
# Build docreader Docker image
docker-build-docreader:
@@ -168,7 +176,12 @@ deps:
# Build for production
build-prod:
GOOS=linux go build -a -installsuffix cgo -ldflags="-w -s" -o $(BINARY_NAME) $(MAIN_PATH)
VERSION=$${VERSION:-unknown}; \
COMMIT_ID=$${COMMIT_ID:-unknown}; \
BUILD_TIME=$${BUILD_TIME:-unknown}; \
GO_VERSION=$${GO_VERSION:-unknown}; \
LDFLAGS="-X 'github.com/Tencent/WeKnora/internal/handler.Version=$$VERSION' -X 'github.com/Tencent/WeKnora/internal/handler.CommitID=$$COMMIT_ID' -X 'github.com/Tencent/WeKnora/internal/handler.BuildTime=$$BUILD_TIME' -X 'github.com/Tencent/WeKnora/internal/handler.GoVersion=$$GO_VERSION'"; \
go build -ldflags="-w -s $$LDFLAGS" -o $(BINARY_NAME) $(MAIN_PATH)
clean-db:
@echo "Cleaning database..."

View File

@@ -15,7 +15,7 @@
<img src="https://img.shields.io/badge/License-MIT-ffffff?labelColor=d4eaf7&color=2e6cc4" alt="License">
</a>
<a href="./CHANGELOG.md">
<img alt="Version" src="https://img.shields.io/badge/version-0.1.0-2e6cc4?labelColor=d4eaf7">
<img alt="Version" src="https://img.shields.io/badge/version-0.1.3-2e6cc4?labelColor=d4eaf7">
</a>
</p>
@@ -41,6 +41,15 @@ It adopts a modular architecture that combines multimodal preprocessing, semanti
**Website:** https://weknora.weixin.qq.com
## 🔒 Security Notice
**Important:** Starting from v0.1.3, WeKnora includes login authentication functionality to enhance system security. For production deployments, we strongly recommend:
- Deploy WeKnora services in internal/private network environments rather than public internet
- Avoid exposing the service directly to public networks to prevent potential information leakage
- Configure proper firewall rules and access controls for your deployment environment
- Regularly update to the latest version for security patches and improvements
## 🏗️ Architecture
![weknora-pipeline.png](./docs/images/pipeline.jpg)

View File

@@ -15,7 +15,7 @@
<img src="https://img.shields.io/badge/License-MIT-ffffff?labelColor=d4eaf7&color=2e6cc4" alt="License">
</a>
<a href="./CHANGELOG.md">
<img alt="版本" src="https://img.shields.io/badge/version-0.1.0-2e6cc4?labelColor=d4eaf7">
<img alt="版本" src="https://img.shields.io/badge/version-0.1.3-2e6cc4?labelColor=d4eaf7">
</a>
</p>
@@ -41,6 +41,15 @@
**官网:** https://weknora.weixin.qq.com
## 🔒 安全声明
**重要提示:** 从 v0.1.3 版本开始WeKnora 提供了登录鉴权功能,以增强系统安全性。在生产环境部署时,我们强烈建议:
- 将 WeKnora 服务部署在内网/私有网络环境中,而非公网环境
- 避免将服务直接暴露在公网上,以防止重要信息泄露风险
- 为部署环境配置适当的防火墙规则和访问控制
- 定期更新到最新版本以获取安全补丁和改进
## 🏗️ 架构设计
![weknora-pipelone.png](./docs/images/pipeline.jpg)

View File

@@ -15,7 +15,7 @@
<img src="https://img.shields.io/badge/License-MIT-ffffff?labelColor=d4eaf7&color=2e6cc4" alt="License">
</a>
<a href="./CHANGELOG.md">
<img alt="バージョン" src="https://img.shields.io/badge/version-0.1.0-2e6cc4?labelColor=d4eaf7">
<img alt="バージョン" src="https://img.shields.io/badge/version-0.1.3-2e6cc4?labelColor=d4eaf7">
</a>
</p>
@@ -41,6 +41,15 @@
**公式サイト:** https://weknora.weixin.qq.com
## 🔒 セキュリティ通知
**重要:** v0.1.3バージョンより、WeKnoraにはシステムセキュリティを強化するためのログイン認証機能が含まれています。本番環境でのデプロイメントにおいて、以下を強く推奨します
- WeKnoraサービスはパブリックインターネットではなく、内部/プライベートネットワーク環境にデプロイしてください
- 重要な情報漏洩を防ぐため、サービスを直接パブリックネットワークに公開することは避けてください
- デプロイメント環境に適切なファイアウォールルールとアクセス制御を設定してください
- セキュリティパッチと改善のため、定期的に最新バージョンに更新してください
## 🏗️ アーキテクチャ設計
![weknora-pipelone.png](./docs/images/pipeline.jpg)

View File

@@ -1 +1 @@
0.1.0
0.1.4

View File

@@ -74,6 +74,9 @@ type UpdateImageInfoRequest struct {
// ErrDuplicateFile is returned when attempting to create a knowledge entry with a file that already exists
var ErrDuplicateFile = errors.New("file already exists")
// ErrDuplicateURL is returned when attempting to create a knowledge entry with a URL that already exists
var ErrDuplicateURL = errors.New("URL already exists")
// CreateKnowledgeFromFile creates a knowledge entry from a local file path
func (c *Client) CreateKnowledgeFromFile(ctx context.Context,
knowledgeBaseID string, filePath string, metadata map[string]string, enableMultimodel *bool,
@@ -186,7 +189,12 @@ func (c *Client) CreateKnowledgeFromURL(ctx context.Context, knowledgeBaseID str
}
var response KnowledgeResponse
if err := parseResponse(resp, &response); err != nil {
if resp.StatusCode == http.StatusConflict {
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
return &response.Data, ErrDuplicateURL
} else if err := parseResponse(resp, &response); err != nil {
return nil, err
}

View File

@@ -9,7 +9,7 @@ conversation:
keyword_threshold: 0.3
embedding_top_k: 10
vector_threshold: 0.5
rerank_threshold: 0.7
rerank_threshold: 0.5
rerank_top_k: 5
fallback_strategy: "fixed"
fallback_response: "抱歉,我无法回答这个问题。"
@@ -534,3 +534,69 @@ knowledge_base:
split_markers: ["\n\n", "\n", "。"]
image_processing:
enable_multimodal: true
extract:
extract_graph:
description: |
请基于给定文本,按以下步骤完成信息提取任务,确保逻辑清晰、信息完整准确:
## 一、实体提取与属性补充
1. **提取核心实体**:通读文本,按逻辑顺序(如文本叙述顺序、实体关联紧密程度)提取所有与任务相关的核心实体。
2. **补充实体详细属性**:针对每个提取的实体,全面补充其在文本中明确提及的详细属性,确保无关键属性遗漏。
## 二、关系提取与验证
1. **明确关系类型**:仅从指定关系列表中选择对应类型,限定关系类型为: %s。
2. **提取有效关系**:基于已提取的实体及属性,识别文本中真实存在的关系,确保关系符合文本事实、无虚假关联。
3. **明确关系主体**:对每一组提取的关系,清晰标注两个关联主体,避免主体混淆。
4. **补充关联属性**:若文本中存在与该关系直接相关的补充信息,需将该信息作为关系的关联属性补充,进一步完善关系信息。
tags:
- "作者"
- "别名"
examples:
- text: |
《红楼梦》又名《石头记》是清代作家曹雪芹创作的中国古典四大名著之一被誉为中国封建社会的百科全书。该书前80回由曹雪芹所著后40回一般认为是高鹗所续。
小说以贾、史、王、薛四大家族的兴衰为背景,以贾宝玉、林黛玉和薛宝钗的爱情悲剧为主线,刻画了以贾宝玉和金陵十二钗为中心的正邪两赋、贤愚并出的高度复杂的人物群像。
成书于乾隆年间1743年前后是中国文学史上现实主义的高峰对后世影响深远。
node:
- name: "红楼梦"
attributes:
- "中国古典四大名著之一"
- "又名《石头记》"
- "被誉为中国封建社会的百科全书"
- name: "石头记"
attributes:
- "《红楼梦》的别名"
- name: "曹雪芹"
attributes:
- "清代作家"
- "《红楼梦》前 80 回的作者"
- name: "高鹗"
attributes:
- "一般认为是《红楼梦》后 40 回的续写者"
relation:
- node1: "红楼梦"
node2: "曹雪芹"
type: "作者"
- node1: "红楼梦"
node2: "高鹗"
type: "作者"
- node1: "红楼梦"
node2: "石头记"
type: "别名"
extract_entity:
description: |
请基于用户给的问题,按以下步骤处理关键信息提取任务:
1. 梳理逻辑关联:首先完整分析文本内容,明确其核心逻辑关系,并简要标注该核心逻辑类型;
2. 提取关键实体:围绕梳理出的逻辑关系,精准提取文本中的关键信息并归类为明确实体,确保不遗漏核心信息、不添加冗余内容;
3. 排序实体优先级:按实体与文本核心主题的关联紧密程度排序,优先呈现对理解文本主旨最重要的实体;
examples:
- text: "《红楼梦》,又名《石头记》,是清代作家曹雪芹创作的中国古典四大名著之一,被誉为中国封建社会的百科全书。"
node:
- name: "红楼梦"
- name: "曹雪芹"
- name: "中国古典四大名著"
fabri_text:
with_tag: |
请随机生成一段文本,要求内容与 %s 等相关,字数在 [50-200] 之间,并且尽量包含一些与这些标签相关的专业术语或典型元素,使文本更具针对性和相关性。
with_no_tag: |
请随机生成一段文本,内容请自由发挥,字数在 [50-200] 之间。

View File

@@ -7,21 +7,27 @@ services:
volumes:
- data-files:/data/files
- ./config/config.yaml:/app/config/config.yaml
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8080/health"]
interval: 30s
timeout: 10s
retries: 3
start_period: 60s
environment:
- COS_SECRET_ID=${COS_SECRET_ID}
- COS_SECRET_KEY=${COS_SECRET_KEY}
- COS_REGION=${COS_REGION}
- COS_BUCKET_NAME=${COS_BUCKET_NAME}
- COS_APP_ID=${COS_APP_ID}
- COS_PATH_PREFIX=${COS_PATH_PREFIX}
- COS_ENABLE_OLD_DOMAIN=${COS_ENABLE_OLD_DOMAIN}
- GIN_MODE=${GIN_MODE}
- COS_SECRET_ID=${COS_SECRET_ID:-}
- COS_SECRET_KEY=${COS_SECRET_KEY:-}
- COS_REGION=${COS_REGION:-}
- COS_BUCKET_NAME=${COS_BUCKET_NAME:-}
- COS_APP_ID=${COS_APP_ID:-}
- COS_PATH_PREFIX=${COS_PATH_PREFIX:-}
- COS_ENABLE_OLD_DOMAIN=${COS_ENABLE_OLD_DOMAIN:-}
- GIN_MODE=${GIN_MODE:-}
- DB_DRIVER=postgres
- DB_HOST=postgres
- DB_PORT=5432
- DB_USER=${DB_USER}
- DB_PASSWORD=${DB_PASSWORD}
- DB_NAME=${DB_NAME}
- DB_USER=${DB_USER:-}
- DB_PASSWORD=${DB_PASSWORD:-}
- DB_NAME=${DB_NAME:-}
- TZ=Asia/Shanghai
- OTEL_EXPORTER_OTLP_ENDPOINT=jaeger:4317
- OTEL_SERVICE_NAME=WeKnora
@@ -29,38 +35,42 @@ services:
- OTEL_METRICS_EXPORTER=none
- OTEL_LOGS_EXPORTER=none
- OTEL_PROPAGATORS=tracecontext,baggage
- RETRIEVE_DRIVER=${RETRIEVE_DRIVER}
- ELASTICSEARCH_ADDR=${ELASTICSEARCH_ADDR}
- ELASTICSEARCH_USERNAME=${ELASTICSEARCH_USERNAME}
- ELASTICSEARCH_PASSWORD=${ELASTICSEARCH_PASSWORD}
- ELASTICSEARCH_INDEX=${ELASTICSEARCH_INDEX}
- RETRIEVE_DRIVER=${RETRIEVE_DRIVER:-}
- ELASTICSEARCH_ADDR=${ELASTICSEARCH_ADDR:-}
- ELASTICSEARCH_USERNAME=${ELASTICSEARCH_USERNAME:-}
- ELASTICSEARCH_PASSWORD=${ELASTICSEARCH_PASSWORD:-}
- ELASTICSEARCH_INDEX=${ELASTICSEARCH_INDEX:-}
- DOCREADER_ADDR=docreader:50051
- STORAGE_TYPE=${STORAGE_TYPE}
- LOCAL_STORAGE_BASE_DIR=${LOCAL_STORAGE_BASE_DIR}
- STORAGE_TYPE=${STORAGE_TYPE:-}
- LOCAL_STORAGE_BASE_DIR=${LOCAL_STORAGE_BASE_DIR:-}
- MINIO_ENDPOINT=minio:9000
- MINIO_ACCESS_KEY_ID=${MINIO_ACCESS_KEY_ID:-minioadmin}
- MINIO_SECRET_ACCESS_KEY=${MINIO_SECRET_ACCESS_KEY:-minioadmin}
- MINIO_BUCKET_NAME=${MINIO_BUCKET_NAME}
- MINIO_BUCKET_NAME=${MINIO_BUCKET_NAME:-}
- OLLAMA_BASE_URL=${OLLAMA_BASE_URL:-http://host.docker.internal:11434}
- STREAM_MANAGER_TYPE=${STREAM_MANAGER_TYPE}
- STREAM_MANAGER_TYPE=${STREAM_MANAGER_TYPE:-}
- REDIS_ADDR=redis:6379
- REDIS_PASSWORD=${REDIS_PASSWORD}
- REDIS_DB=${REDIS_DB}
- REDIS_PREFIX=${REDIS_PREFIX}
- ENABLE_GRAPH_RAG=${ENABLE_GRAPH_RAG}
- TENANT_AES_KEY=${TENANT_AES_KEY}
- REDIS_PASSWORD=${REDIS_PASSWORD:-}
- REDIS_DB=${REDIS_DB:-}
- REDIS_PREFIX=${REDIS_PREFIX:-}
- ENABLE_GRAPH_RAG=${ENABLE_GRAPH_RAG:-}
- NEO4J_ENABLE=${NEO4J_ENABLE:-}
- NEO4J_URI=bolt://neo4j:7687
- NEO4J_USERNAME=${NEO4J_USERNAME:-neo4j}
- NEO4J_PASSWORD=${NEO4J_PASSWORD:-password}
- TENANT_AES_KEY=${TENANT_AES_KEY:-}
- CONCURRENCY_POOL_SIZE=${CONCURRENCY_POOL_SIZE:-5}
- INIT_LLM_MODEL_NAME=${INIT_LLM_MODEL_NAME}
- INIT_LLM_MODEL_BASE_URL=${INIT_LLM_MODEL_BASE_URL}
- INIT_LLM_MODEL_API_KEY=${INIT_LLM_MODEL_API_KEY}
- INIT_EMBEDDING_MODEL_NAME=${INIT_EMBEDDING_MODEL_NAME}
- INIT_EMBEDDING_MODEL_BASE_URL=${INIT_EMBEDDING_MODEL_BASE_URL}
- INIT_EMBEDDING_MODEL_API_KEY=${INIT_EMBEDDING_MODEL_API_KEY}
- INIT_EMBEDDING_MODEL_DIMENSION=${INIT_EMBEDDING_MODEL_DIMENSION}
- INIT_EMBEDDING_MODEL_ID=${INIT_EMBEDDING_MODEL_ID}
- INIT_RERANK_MODEL_NAME=${INIT_RERANK_MODEL_NAME}
- INIT_RERANK_MODEL_BASE_URL=${INIT_RERANK_MODEL_BASE_URL}
- INIT_RERANK_MODEL_API_KEY=${INIT_RERANK_MODEL_API_KEY}
- INIT_LLM_MODEL_NAME=${INIT_LLM_MODEL_NAME:-}
- INIT_LLM_MODEL_BASE_URL=${INIT_LLM_MODEL_BASE_URL:-}
- INIT_LLM_MODEL_API_KEY=${INIT_LLM_MODEL_API_KEY:-}
- INIT_EMBEDDING_MODEL_NAME=${INIT_EMBEDDING_MODEL_NAME:-}
- INIT_EMBEDDING_MODEL_BASE_URL=${INIT_EMBEDDING_MODEL_BASE_URL:-}
- INIT_EMBEDDING_MODEL_API_KEY=${INIT_EMBEDDING_MODEL_API_KEY:-}
- INIT_EMBEDDING_MODEL_DIMENSION=${INIT_EMBEDDING_MODEL_DIMENSION:-}
- INIT_EMBEDDING_MODEL_ID=${INIT_EMBEDDING_MODEL_ID:-}
- INIT_RERANK_MODEL_NAME=${INIT_RERANK_MODEL_NAME:-}
- INIT_RERANK_MODEL_BASE_URL=${INIT_RERANK_MODEL_BASE_URL:-}
- INIT_RERANK_MODEL_API_KEY=${INIT_RERANK_MODEL_API_KEY:-}
depends_on:
redis:
condition: service_started
@@ -70,6 +80,8 @@ services:
condition: service_started
docreader:
condition: service_healthy
neo4j:
condition: service_started
networks:
- WeKnora-network
restart: unless-stopped
@@ -102,7 +114,8 @@ services:
ports:
- "${FRONTEND_PORT:-80}:80"
depends_on:
- app
app:
condition: service_healthy
networks:
- WeKnora-network
restart: unless-stopped
@@ -113,24 +126,24 @@ services:
ports:
- "${DOCREADER_PORT:-50051}:50051"
environment:
- COS_SECRET_ID=${COS_SECRET_ID}
- COS_SECRET_KEY=${COS_SECRET_KEY}
- COS_REGION=${COS_REGION}
- COS_BUCKET_NAME=${COS_BUCKET_NAME}
- COS_APP_ID=${COS_APP_ID}
- COS_PATH_PREFIX=${COS_PATH_PREFIX}
- COS_ENABLE_OLD_DOMAIN=${COS_ENABLE_OLD_DOMAIN}
- VLM_MODEL_BASE_URL=${VLM_MODEL_BASE_URL}
- VLM_MODEL_NAME=${VLM_MODEL_NAME}
- VLM_MODEL_API_KEY=${VLM_MODEL_API_KEY}
- STORAGE_TYPE=${STORAGE_TYPE}
- COS_SECRET_ID=${COS_SECRET_ID:-}
- COS_SECRET_KEY=${COS_SECRET_KEY:-}
- COS_REGION=${COS_REGION:-}
- COS_BUCKET_NAME=${COS_BUCKET_NAME:-}
- COS_APP_ID=${COS_APP_ID:-}
- COS_PATH_PREFIX=${COS_PATH_PREFIX:-}
- COS_ENABLE_OLD_DOMAIN=${COS_ENABLE_OLD_DOMAIN:-}
- VLM_MODEL_BASE_URL=${VLM_MODEL_BASE_URL:-}
- VLM_MODEL_NAME=${VLM_MODEL_NAME:-}
- VLM_MODEL_API_KEY=${VLM_MODEL_API_KEY:-}
- STORAGE_TYPE=${STORAGE_TYPE:-}
- MINIO_PUBLIC_ENDPOINT=http://localhost:${MINIO_PORT:-9000}
- MINIO_ENDPOINT=minio:9000
- MINIO_ACCESS_KEY_ID=${MINIO_ACCESS_KEY_ID:-minioadmin}
- MINIO_SECRET_ACCESS_KEY=${MINIO_SECRET_ACCESS_KEY:-minioadmin}
- MINIO_BUCKET_NAME=${MINIO_BUCKET_NAME}
- MINIO_USE_SSL=${MINIO_USE_SSL}
- WEB_PROXY=${WEB_PROXY}
- MINIO_BUCKET_NAME=${MINIO_BUCKET_NAME:-}
- MINIO_USE_SSL=${MINIO_USE_SSL:-}
- WEB_PROXY=${WEB_PROXY:-}
healthcheck:
test: ["CMD", "grpc_health_probe", "-addr=:50051"]
interval: 30s
@@ -165,7 +178,7 @@ services:
restart: unless-stopped
# 修改的PostgreSQL配置
postgres:
image: paradedb/paradedb:latest
image: paradedb/paradedb:v0.18.9-pg17
container_name: WeKnora-postgres
ports:
- "${DB_PORT}:5432"
@@ -202,6 +215,24 @@ services:
networks:
- WeKnora-network
neo4j:
image: neo4j:latest
container_name: WeKnora-neo4j
volumes:
- neo4j-data:/data
environment:
- NEO4J_AUTH=${NEO4J_USERNAME:-neo4j}/${NEO4J_PASSWORD:-password}
- NEO4J_apoc_export_file_enabled=true
- NEO4J_apoc_import_file_enabled=true
- NEO4J_apoc_import_file_use__neo4j__config=true
- NEO4JLABS_PLUGINS=["apoc"]
ports:
- "7474:7474"
- "7687:7687"
restart: always
networks:
- WeKnora-network
networks:
WeKnora-network:
driver: bridge
@@ -212,3 +243,4 @@ volumes:
jaeger_data:
redis_data:
minio_data:
neo4j-data:

View File

@@ -3,10 +3,6 @@ FROM golang:1.24-alpine AS builder
WORKDIR /app
# Install dependencies
RUN sed -i 's/dl-cdn.alpinelinux.org/mirrors.tuna.tsinghua.edu.cn/g' /etc/apk/repositories && \
apk add --no-cache git build-base
# 通过构建参数接收敏感信息
ARG GOPRIVATE_ARG
ARG GOPROXY_ARG
@@ -17,19 +13,33 @@ ENV GOPRIVATE=${GOPRIVATE_ARG}
ENV GOPROXY=${GOPROXY_ARG}
ENV GOSUMDB=${GOSUMDB_ARG}
# Copy go mod and sum files
COPY go.mod go.sum ./
RUN go mod download
# Install dependencies
RUN sed -i 's/dl-cdn.alpinelinux.org/mirrors.tuna.tsinghua.edu.cn/g' /etc/apk/repositories && \
apk add --no-cache git build-base
ENV CGO_ENABLED=1
# Install migrate tool
RUN go install -tags 'postgres' github.com/golang-migrate/migrate/v4/cmd/migrate@latest
# Copy source code
# Copy go mod and sum files
COPY go.mod go.sum ./
RUN --mount=type=cache,target=/go/pkg/mod go mod download
COPY . .
# Build the application
RUN make build-prod
# Get version and commit info for build injection
ARG VERSION_ARG
ARG COMMIT_ID_ARG
ARG BUILD_TIME_ARG
ARG GO_VERSION_ARG
# Set build-time variables
ENV VERSION=${VERSION_ARG}
ENV COMMIT_ID=${COMMIT_ID_ARG}
ENV BUILD_TIME=${BUILD_TIME_ARG}
ENV GO_VERSION=${GO_VERSION_ARG}
# Build the application with version info
RUN --mount=type=cache,target=/go/pkg/mod make build-prod
RUN --mount=type=cache,target=/go/pkg/mod cp -r /go/pkg/mod/github.com/yanyiwu/ /app/yanyiwu/
# Final stage
FROM alpine:3.17
@@ -39,36 +49,31 @@ WORKDIR /app
# Install runtime dependencies
RUN sed -i 's/dl-cdn.alpinelinux.org/mirrors.tuna.tsinghua.edu.cn/g' /etc/apk/repositories && \
apk update && apk upgrade && \
apk add --no-cache build-base postgresql-client mysql-client ca-certificates tzdata sed curl bash supervisor vim wget
# Copy the binary from the builder stage
COPY --from=builder /app/WeKnora .
COPY --from=builder /app/config ./config
COPY --from=builder /app/scripts ./scripts
COPY --from=builder /app/migrations ./migrations
COPY --from=builder /app/dataset/samples ./dataset/samples
# Copy migrate tool from builder stage
COPY --from=builder /go/bin/migrate /usr/local/bin/
COPY --from=builder /go/pkg/mod/github.com/yanyiwu /go/pkg/mod/github.com/yanyiwu/
# Make scripts executable
RUN chmod +x ./scripts/*.sh
# Setup supervisor configuration
RUN mkdir -p /etc/supervisor.d/
COPY docker/config/supervisord.conf /etc/supervisor.d/supervisord.conf
# Expose ports
EXPOSE 8080
# Set environment variables
ENV CGO_ENABLED=1
apk add --no-cache build-base postgresql-client mysql-client ca-certificates tzdata sed curl bash vim wget
# Create a non-root user and switch to it
RUN mkdir -p /data/files && \
adduser -D -g '' appuser && \
chown -R appuser:appuser /app /data/files
# Run supervisor instead of direct application start
CMD ["supervisord", "-c", "/etc/supervisor.d/supervisord.conf"]
# Copy migrate tool from builder stage
COPY --from=builder /go/bin/migrate /usr/local/bin/
COPY --from=builder /app/yanyiwu/ /go/pkg/mod/github.com/yanyiwu/
# Copy the binary from the builder stage
COPY --from=builder /app/config ./config
COPY --from=builder /app/scripts ./scripts
COPY --from=builder /app/migrations ./migrations
COPY --from=builder /app/dataset/samples ./dataset/samples
COPY --from=builder /app/WeKnora .
# Make scripts executable
RUN chmod +x ./scripts/*.sh
# Expose ports
EXPOSE 8080
# Switch to non-root user and run the application directly
USER appuser
CMD ["./WeKnora"]

View File

@@ -26,22 +26,31 @@ RUN apt-get update && apt-get install -y \
&& rm -rf /var/lib/apt/lists/*
# 检查是否存在本地protoc安装包如果存在则离线安装否则在线安装,其他安装包按需求添加
ARG TARGETARCH
COPY packages/ /app/packages/
RUN echo "检查本地protoc安装包..." && \
if [ -f "/app/packages/protoc-3.19.4-linux-x86_64.zip" ]; then \
# 根据目标架构选择正确的protoc包名
case ${TARGETARCH} in \
"amd64") PROTOC_ARCH="x86_64" ;; \
"arm64") PROTOC_ARCH="aarch_64" ;; \
"arm") PROTOC_ARCH="arm" ;; \
*) echo "Unsupported architecture for protoc: ${TARGETARCH}" && exit 1 ;; \
esac && \
PROTOC_PACKAGE="protoc-3.19.4-linux-${PROTOC_ARCH}.zip" && \
if [ -f "/app/packages/${PROTOC_PACKAGE}" ]; then \
echo "发现本地protoc安装包将进行离线安装"; \
# 离线安装:使用本地包(精确路径避免歧义)
cp /app/packages/protoc-*.zip /app/ && \
unzip -o /app/protoc-*.zip -d /usr/local && \
cp /app/packages/${PROTOC_PACKAGE} /app/ && \
unzip -o /app/${PROTOC_PACKAGE} -d /usr/local && \
chmod +x /usr/local/bin/protoc && \
rm -f /app/protoc-*.zip; \
rm -f /app/${PROTOC_PACKAGE}; \
else \
echo "未发现本地protoc安装包将进行在线安装"; \
# 在线安装:从网络下载
curl -LO https://github.com/protocolbuffers/protobuf/releases/download/v3.19.4/protoc-3.19.4-linux-x86_64.zip && \
unzip -o protoc-3.19.4-linux-x86_64.zip -d /usr/local && \
curl -LO https://github.com/protocolbuffers/protobuf/releases/download/v3.19.4/${PROTOC_PACKAGE} && \
unzip -o ${PROTOC_PACKAGE} -d /usr/local && \
chmod +x /usr/local/bin/protoc && \
rm -f protoc-3.19.4-linux-x86_64.zip; \
rm -f ${PROTOC_PACKAGE}; \
fi
# 复制依赖文件
@@ -102,7 +111,6 @@ RUN apt-get update && apt-get install -y \
libgl1 \
libglib2.0-0 \
antiword \
supervisor \
vim \
tar \
dpkg \
@@ -118,30 +126,35 @@ RUN apt-get update && apt-get install -y \
&& rm -rf /var/lib/apt/lists/*
# 安装 grpc_health_probe
ARG TARGETARCH
RUN GRPC_HEALTH_PROBE_VERSION=v0.4.24 && \
wget -qO/bin/grpc_health_probe https://github.com/grpc-ecosystem/grpc-health-probe/releases/download/${GRPC_HEALTH_PROBE_VERSION}/grpc_health_probe-linux-amd64 && \
# 根据目标架构选择正确的二进制文件
case ${TARGETARCH} in \
"amd64") ARCH="amd64" ;; \
"arm64") ARCH="arm64" ;; \
"arm") ARCH="arm" ;; \
*) echo "Unsupported architecture: ${TARGETARCH}" && exit 1 ;; \
esac && \
wget -qO/bin/grpc_health_probe https://github.com/grpc-ecosystem/grpc-health-probe/releases/download/${GRPC_HEALTH_PROBE_VERSION}/grpc_health_probe-linux-${ARCH} && \
chmod +x /bin/grpc_health_probe
# 从构建阶段复制已安装的依赖和生成的代码
COPY --from=builder /usr/local/lib/python3.10/site-packages /usr/local/lib/python3.10/site-packages
COPY --from=builder /usr/local/bin /usr/local/bin
COPY --from=builder /root/.paddleocr /root/.paddleocr
COPY --from=builder /app/src /app/src
# 安装 Playwright 浏览器
RUN python -m playwright install webkit
RUN python -m playwright install-deps webkit
COPY --from=builder /app/src /app/src
# 设置 Python 路径
ENV PYTHONPATH=/app/src
RUN cd /app/src && python -m download_deps
# 创建supervisor配置
RUN mkdir -p /etc/supervisor/conf.d
COPY services/docreader/supervisord.conf /etc/supervisor/conf.d/supervisord.conf
# 暴露 gRPC 端口
EXPOSE 50051
# 使用supervisor启动服务
CMD ["supervisord", "-c", "/etc/supervisor/conf.d/supervisord.conf"]
# 直接运行 Python 服务(日志输出到 stdout/stderr
CMD ["python", "/app/src/server/server.py"]

View File

@@ -2,11 +2,7 @@
## 1. 如何查看日志?
```bash
# 查看 主服务 日志
docker exec -it WeKnora-app tail -f /var/log/WeKnora.log
# 查看 文档解析模块 日志
docker exec -it WeKnora-docreader tail -f /var/log/docreader.log
docker compose logs -f app docreader postgres
```
## 2. 如何启动和停止服务?

BIN
docs/images/pipeline.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 504 KiB

View File

@@ -2,6 +2,16 @@ server {
listen 80;
server_name localhost;
client_max_body_size 50M;
# 安全头配置
add_header X-Frame-Options "SAMEORIGIN" always;
add_header X-Content-Type-Options "nosniff" always;
add_header X-XSS-Protection "1; mode=block" always;
add_header Referrer-Policy "strict-origin-when-cross-origin" always;
# 错误日志配置
error_log /var/log/nginx/error.log warn;
access_log /var/log/nginx/access.log;
# 前端静态文件
location / {
@@ -18,6 +28,12 @@ server {
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
# 连接和重试配置
proxy_connect_timeout 30s; # 连接超时时间
proxy_next_upstream error timeout invalid_header http_500 http_502 http_503 http_504;
proxy_next_upstream_tries 3; # 重试次数
proxy_next_upstream_timeout 30s; # 重试超时时间
# SSE 相关配置
proxy_http_version 1.1; # 使用 HTTP/1.1
proxy_set_header Connection ""; # 禁用 Connection: close保持连接打开

View File

@@ -1,18 +1,21 @@
{
"name": "knowledage-base",
"version": "0.0.0",
"version": "0.1.3",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "knowledage-base",
"version": "0.0.0",
"version": "0.1.3",
"dependencies": {
"@microsoft/fetch-event-source": "^2.0.1",
"@types/dompurify": "^3.0.5",
"axios": "^1.8.4",
"dompurify": "^3.2.6",
"marked": "^5.1.2",
"pagefind": "^1.1.1",
"pinia": "^3.0.1",
"tdesign-icons-vue-next": "^0.4.1",
"tdesign-vue-next": "^1.11.5",
"vue": "^3.5.13",
"vue-router": "^4.5.0",
@@ -1274,6 +1277,15 @@
"dev": true,
"license": "MIT"
},
"node_modules/@types/dompurify": {
"version": "3.0.5",
"resolved": "https://mirrors.tencent.com/npm/@types/dompurify/-/dompurify-3.0.5.tgz",
"integrity": "sha512-1Wg0g3BtQF7sSb27fJQAKck1HECM6zV1EB66j8JH9i3LCjYabJa0FSdiSgsD5K/RbrsR0SiraKacLB+T8ZVYAg==",
"license": "MIT",
"dependencies": {
"@types/trusted-types": "*"
}
},
"node_modules/@types/eslint": {
"version": "9.6.1",
"resolved": "https://mirrors.tencent.com/npm/@types/eslint/-/eslint-9.6.1.tgz",
@@ -1346,6 +1358,12 @@
"resolved": "https://mirrors.tencent.com/npm/@types/tinycolor2/-/tinycolor2-1.4.6.tgz",
"integrity": "sha512-iEN8J0BoMnsWBqjVbWH/c0G0Hh7O21lpR2/+PrvAVgWdzL7eexIFm4JN/Wn10PTcmNdtS6U67r499mlWMXOxNw=="
},
"node_modules/@types/trusted-types": {
"version": "2.0.7",
"resolved": "https://mirrors.tencent.com/npm/@types/trusted-types/-/trusted-types-2.0.7.tgz",
"integrity": "sha512-ScaPdn1dQczgbl0QFTeTOmVHFULt394XJgOQNoyVhZ6r2vLnMLJfBPd53SB52T/3G36VI1/g2MZaX0cwDuXsfw==",
"license": "MIT"
},
"node_modules/@types/validator": {
"version": "13.15.2",
"resolved": "https://mirrors.tencent.com/npm/@types/validator/-/validator-13.15.2.tgz",
@@ -2121,6 +2139,15 @@
"node": ">=0.4.0"
}
},
"node_modules/dompurify": {
"version": "3.2.6",
"resolved": "https://mirrors.tencent.com/npm/dompurify/-/dompurify-3.2.6.tgz",
"integrity": "sha512-/2GogDQlohXPZe6D6NOgQvXLPSYBqIWMnZ8zzOhn09REE4eyAzb+Hed3jhoM9OkuaJ8P6ZGTTVWQKAi8ieIzfQ==",
"license": "(MPL-2.0 OR Apache-2.0)",
"optionalDependencies": {
"@types/trusted-types": "^2.0.7"
}
},
"node_modules/dunder-proto": {
"version": "1.0.1",
"resolved": "https://mirrors.tencent.com/npm/dunder-proto/-/dunder-proto-1.0.1.tgz",
@@ -3374,9 +3401,10 @@
}
},
"node_modules/tdesign-icons-vue-next": {
"version": "0.3.6",
"resolved": "https://mirrors.tencent.com/npm/tdesign-icons-vue-next/-/tdesign-icons-vue-next-0.3.6.tgz",
"integrity": "sha512-X9u90dBv8tPhfpguUyx+BzF8CU2ef2L4RXOO7MYOj1ufHCHwBXTF8L3GPfq6KZd/2u4vMLYAA8lGURn4PZZICw==",
"version": "0.4.1",
"resolved": "https://mirrors.tencent.com/npm/tdesign-icons-vue-next/-/tdesign-icons-vue-next-0.4.1.tgz",
"integrity": "sha512-uDPuTLRORnGcTyVGNoentNaK4V+ZcBmhYwcY3KqDaQQ5rrPeLMxu0ZVmgOEf0JtF2QZiqAxY7vodNEiLUdoRKA==",
"license": "MIT",
"dependencies": {
"@babel/runtime": "^7.16.3"
},
@@ -3410,6 +3438,18 @@
"vue": ">=3.1.0"
}
},
"node_modules/tdesign-vue-next/node_modules/tdesign-icons-vue-next": {
"version": "0.3.7",
"resolved": "https://mirrors.tencent.com/npm/tdesign-icons-vue-next/-/tdesign-icons-vue-next-0.3.7.tgz",
"integrity": "sha512-Q5ebVty/TCqhBa0l/17kkhjC0pBAOGvn7C35MAt1xS+johKVM9QEDOy9R6XEl332AiwQ37MwqioczqjYC30ckw==",
"license": "MIT",
"dependencies": {
"@babel/runtime": "^7.16.3"
},
"peerDependencies": {
"vue": "^3.0.0"
}
},
"node_modules/terser": {
"version": "5.43.1",
"resolved": "https://mirrors.tencent.com/npm/terser/-/terser-5.43.1.tgz",

View File

@@ -1,6 +1,6 @@
{
"name": "knowledage-base",
"version": "0.1.0",
"version": "0.1.3",
"private": true,
"type": "module",
"scripts": {
@@ -13,10 +13,13 @@
},
"dependencies": {
"@microsoft/fetch-event-source": "^2.0.1",
"@types/dompurify": "^3.0.5",
"axios": "^1.8.4",
"dompurify": "^3.2.6",
"marked": "^5.1.2",
"pagefind": "^1.1.1",
"pinia": "^3.0.1",
"tdesign-icons-vue-next": "^0.4.1",
"tdesign-vue-next": "^1.11.5",
"vue": "^3.5.13",
"vue-router": "^4.5.0",

View File

@@ -0,0 +1,239 @@
import { post, get, put } from '@/utils/request'
// 用户登录接口
export interface LoginRequest {
email: string
password: string
}
export interface LoginResponse {
success: boolean
message?: string
user?: {
id: string
username: string
email: string
avatar?: string
tenant_id: number
is_active: boolean
created_at: string
updated_at: string
}
tenant?: {
id: number
name: string
description: string
api_key: string
status: string
business: string
storage_quota: number
storage_used: number
created_at: string
updated_at: string
}
token?: string
refresh_token?: string
}
// 用户注册接口
export interface RegisterRequest {
username: string
email: string
password: string
}
export interface RegisterResponse {
success: boolean
message?: string
data?: {
user: {
id: string
username: string
email: string
}
tenant: {
id: string
name: string
api_key: string
}
}
}
// 用户信息接口
export interface UserInfo {
id: string
username: string
email: string
avatar?: string
tenant_id: string
created_at: string
updated_at: string
}
// 租户信息接口
export interface TenantInfo {
id: string
name: string
description?: string
api_key: string
status?: string
business?: string
owner_id: string
storage_quota?: number
storage_used?: number
created_at: string
updated_at: string
knowledge_bases?: KnowledgeBaseInfo[]
}
// 知识库信息接口
export interface KnowledgeBaseInfo {
id: string
name: string
description: string
tenant_id: string
created_at: string
updated_at: string
document_count?: number
chunk_count?: number
}
// 模型信息接口
export interface ModelInfo {
id: string
name: string
type: string
source: string
description?: string
is_default?: boolean
created_at: string
updated_at: string
}
/**
* 用户登录
*/
export async function login(data: LoginRequest): Promise<LoginResponse> {
try {
const response = await post('/api/v1/auth/login', data)
return response as unknown as LoginResponse
} catch (error: any) {
return {
success: false,
message: error.message || '登录失败'
}
}
}
/**
* 用户注册
*/
export async function register(data: RegisterRequest): Promise<RegisterResponse> {
try {
const response = await post('/api/v1/auth/register', data)
return response as unknown as RegisterResponse
} catch (error: any) {
return {
success: false,
message: error.message || '注册失败'
}
}
}
/**
* 获取当前用户信息
*/
export async function getCurrentUser(): Promise<{ success: boolean; data?: { user: UserInfo; tenant: TenantInfo }; message?: string }> {
try {
const response = await get('/api/v1/auth/me')
return response as unknown as { success: boolean; data?: { user: UserInfo; tenant: TenantInfo }; message?: string }
} catch (error: any) {
return {
success: false,
message: error.message || '获取用户信息失败'
}
}
}
/**
* 获取当前租户信息
*/
export async function getCurrentTenant(): Promise<{ success: boolean; data?: TenantInfo; message?: string }> {
try {
const response = await get('/api/v1/auth/tenant')
return response as unknown as { success: boolean; data?: TenantInfo; message?: string }
} catch (error: any) {
return {
success: false,
message: error.message || '获取租户信息失败'
}
}
}
/**
* 刷新Token
*/
export async function refreshToken(refreshToken: string): Promise<{ success: boolean; data?: { token: string; refreshToken: string }; message?: string }> {
try {
const response: any = await post('/api/v1/auth/refresh', { refreshToken })
if (response && response.success) {
if (response.access_token || response.refresh_token) {
return {
success: true,
data: {
token: response.access_token,
refreshToken: response.refresh_token,
}
}
}
}
// 其他情况直接返回原始消息
return {
success: false,
message: response?.message || '刷新Token失败'
}
} catch (error: any) {
return {
success: false,
message: error.message || '刷新Token失败'
}
}
}
/**
* 用户登出
*/
export async function logout(): Promise<{ success: boolean; message?: string }> {
try {
await post('/api/v1/auth/logout', {})
return {
success: true
}
} catch (error: any) {
return {
success: false,
message: error.message || '登出失败'
}
}
}
/**
* 验证Token有效性
*/
export async function validateToken(): Promise<{ success: boolean; valid?: boolean; message?: string }> {
try {
const response = await get('/api/v1/auth/validate')
return response as unknown as { success: boolean; valid?: boolean; message?: string }
} catch (error: any) {
return {
success: false,
valid: false,
message: error.message || 'Token验证失败'
}
}
}

View File

@@ -1,54 +1,24 @@
import { get, post, put, del, postChat } from "../../utils/request";
import { loadTestData } from "../test-data";
// 从localStorage获取设置
function getSettings() {
const settingsStr = localStorage.getItem("WeKnora_settings");
if (settingsStr) {
try {
const settings = JSON.parse(settingsStr);
if (settings.apiKey && settings.endpoint) {
return settings;
}
} catch (e) {
console.error("解析设置失败:", e);
}
}
return null;
}
// 根据是否有设置决定是否需要加载测试数据
async function ensureConfigured() {
const settings = getSettings();
// 如果没有设置APIKey和Endpoint则加载测试数据
if (!settings) {
await loadTestData();
}
}
export async function createSessions(data = {}) {
await ensureConfigured();
return post("/api/v1/sessions", data);
}
export async function getSessionsList(page: number, page_size: number) {
await ensureConfigured();
return get(`/api/v1/sessions?page=${page}&page_size=${page_size}`);
}
export async function generateSessionsTitle(session_id: string, data: any) {
await ensureConfigured();
return post(`/api/v1/sessions/${session_id}/generate_title`, data);
}
export async function knowledgeChat(data: { session_id: string; query: string; }) {
await ensureConfigured();
return postChat(`/api/v1/knowledge-chat/${data.session_id}`, { query: data.query });
}
export async function getMessageList(data: { session_id: string; limit: number, created_at: string }) {
await ensureConfigured();
if (data.created_at) {
return get(`/api/v1/messages/${data.session_id}/load?before_time=${encodeURIComponent(data.created_at)}&limit=${data.limit}`);
} else {
@@ -57,6 +27,5 @@ export async function getMessageList(data: { session_id: string; limit: number,
}
export async function delSession(session_id: string) {
await ensureConfigured();
return del(`/api/v1/sessions/${session_id}`);
}

View File

@@ -1,22 +1,8 @@
import { fetchEventSource } from '@microsoft/fetch-event-source'
import { ref, type Ref, onUnmounted, nextTick } from 'vue'
import { generateRandomString } from '@/utils/index';
import { getTestData } from '@/utils/request';
import { loadTestData } from '@/api/test-data';
// 从localStorage获取设置
function getSettings() {
const settingsStr = localStorage.getItem("WeKnora_settings");
if (settingsStr) {
try {
const settings = JSON.parse(settingsStr);
return settings;
} catch (e) {
console.error("解析设置失败:", e);
}
}
return null;
}
interface StreamOptions {
// 请求方法 (默认POST)
@@ -49,26 +35,15 @@ export function useStream() {
isStreaming.value = true;
isLoading.value = true;
// 获取设置信息
const settings = getSettings();
let apiUrl = '';
let apiKey = '';
// 如果有设置信息,优先使用设置信息
if (settings && settings.endpoint && settings.apiKey) {
apiUrl = settings.endpoint;
apiKey = settings.apiKey;
} else {
// 否则加载测试数据
await loadTestData();
const testData = getTestData();
if (!testData) {
error.value = "测试数据未初始化,无法进行聊天";
stopStream();
return;
}
apiUrl = import.meta.env.VITE_IS_DOCKER ? "" : "http://localhost:8080";
apiKey = testData.tenant.api_key;
// 获取API配置
const apiUrl = import.meta.env.VITE_IS_DOCKER ? "" : "http://localhost:8080";
// 获取JWT Token
const token = localStorage.getItem('weknora_token');
if (!token) {
error.value = "未找到登录令牌,请重新登录";
stopStream();
return;
}
try {
@@ -80,7 +55,7 @@ export function useStream() {
method: params.method,
headers: {
"Content-Type": "application/json",
"X-API-Key": apiKey,
"Authorization": `Bearer ${token}`,
"X-Request-ID": `${generateRandomString(12)}`,
},
body:

View File

@@ -19,6 +19,7 @@ export interface InitializationConfig {
modelName: string;
baseUrl: string;
apiKey?: string;
enabled: boolean;
};
multimodal: {
enabled: boolean;
@@ -49,6 +50,13 @@ export interface InitializationConfig {
};
// Frontend-only hint for storage selection UI
storageType?: 'cos' | 'minio';
nodeExtract: {
enabled: boolean,
text: string,
tags: string[],
nodes: Node[],
relations: Relation[]
}
}
// 下载任务状态类型
@@ -62,34 +70,18 @@ export interface DownloadTask {
endTime?: string;
}
// 系统初始化状态检查
export function checkInitializationStatus(): Promise<{ initialized: boolean }> {
// 根据知识库ID执行配置更新
export function initializeSystemByKB(kbId: string, config: InitializationConfig): Promise<any> {
return new Promise((resolve, reject) => {
get('/api/v1/initialization/status')
console.log('开始知识库配置更新...', kbId, config);
post(`/api/v1/initialization/initialize/${kbId}`, config)
.then((response: any) => {
resolve(response.data || { initialized: false });
})
.catch((error: any) => {
console.warn('检查初始化状态失败,假设需要初始化:', error);
resolve({ initialized: false });
});
});
}
// 执行系统初始化
export function initializeSystem(config: InitializationConfig): Promise<any> {
return new Promise((resolve, reject) => {
console.log('开始系统初始化...', config);
post('/api/v1/initialization/initialize', config)
.then((response: any) => {
console.log('系统初始化完成', response);
// 设置本地初始化状态标记
localStorage.setItem('system_initialized', 'true');
console.log('知识库配置更新完成', response);
resolve(response);
})
.catch((error: any) => {
console.error('系统初始化失败:', error);
reject(error);
console.error('知识库配置更新失败:', error);
reject(error.error || error);
});
});
}
@@ -178,15 +170,15 @@ export function listDownloadTasks(): Promise<DownloadTask[]> {
});
}
// 获取当前系统配置
export function getCurrentConfig(): Promise<InitializationConfig & { hasFiles: boolean }> {
export function getCurrentConfigByKB(kbId: string): Promise<InitializationConfig & { hasFiles: boolean }> {
return new Promise((resolve, reject) => {
get('/api/v1/initialization/config')
get(`/api/v1/initialization/config/${kbId}`)
.then((response: any) => {
resolve(response.data || {});
})
.catch((error: any) => {
console.error('获取当前配置失败:', error);
console.error('获取知识库配置失败:', error);
reject(error);
});
});
@@ -311,9 +303,17 @@ export function testMultimodalFunction(testData: {
formData.append('chunk_overlap', testData.chunk_overlap.toString());
formData.append('separators', JSON.stringify(testData.separators));
// 获取鉴权Token
const token = localStorage.getItem('weknora_token');
const headers: Record<string, string> = {};
if (token) {
headers['Authorization'] = `Bearer ${token}`;
}
// 使用原生fetch因为需要发送FormData
fetch('/api/v1/initialization/multimodal/test', {
method: 'POST',
headers,
body: formData
})
.then(response => response.json())
@@ -329,4 +329,93 @@ export function testMultimodalFunction(testData: {
reject(error);
});
});
}
}
// 文本内容关系提取接口
export interface TextRelationExtractionRequest {
text: string;
tags: string[];
llmConfig: LLMConfig;
}
export interface Node {
name: string;
attributes: string[];
}
export interface Relation {
node1: string;
node2: string;
type: string;
}
export interface LLMConfig {
source: 'local' | 'remote';
modelName: string;
baseUrl: string;
apiKey: string;
}
export interface TextRelationExtractionResponse {
nodes: Node[];
relations: Relation[];
}
// 文本内容关系提取
export function extractTextRelations(request: TextRelationExtractionRequest): Promise<TextRelationExtractionResponse> {
return new Promise((resolve, reject) => {
post('/api/v1/initialization/extract/text-relation', request)
.then((response: any) => {
resolve(response.data || { nodes: [], relations: [] });
})
.catch((error: any) => {
console.error('文本内容关系提取失败:', error);
reject(error);
});
});
}
export interface FabriTextRequest {
tags: string[];
llmConfig: LLMConfig;
}
export interface FabriTextResponse {
text: string;
}
// 文本内容生成
export function fabriText(request: FabriTextRequest): Promise<FabriTextResponse> {
return new Promise((resolve, reject) => {
post('/api/v1/initialization/extract/fabri-text', request)
.then((response: any) => {
resolve(response.data || { text: '' });
})
.catch((error: any) => {
console.error('文本内容生成失败:', error);
reject(error);
});
});
}
export interface FabriTagRequest {
llmConfig: LLMConfig;
}
export interface FabriTagResponse {
tags: string[];
}
// 文本内容生成
export function fabriTag(request: FabriTagRequest): Promise<FabriTagResponse> {
return new Promise((resolve, reject) => {
post('/api/v1/initialization/extract/fabri-tag', request)
.then((response: any) => {
resolve(response.data || { tags: [] as string[] });
})
.catch((error: any) => {
console.error('标签生成失败:', error);
reject(error);
});
});
}

View File

@@ -1,62 +1,55 @@
import { get, post, put, del, postUpload, getDown, getTestData } from "../../utils/request";
import { loadTestData } from "../test-data";
import { get, post, put, del, postUpload, getDown } from "../../utils/request";
// 获取知识库ID优先从设置中获取
async function getKnowledgeBaseID() {
// 从localStorage获取设置中的知识库ID
const settingsStr = localStorage.getItem("WeKnora_settings");
let knowledgeBaseId = "";
if (settingsStr) {
try {
const settings = JSON.parse(settingsStr);
if (settings.knowledgeBaseId) {
return settings.knowledgeBaseId;
}
} catch (e) {
console.error("解析设置失败:", e);
}
}
// 如果设置中没有知识库ID则使用测试数据
await loadTestData();
const testData = getTestData();
if (!testData || testData.knowledge_bases.length === 0) {
console.error("测试数据未初始化或不包含知识库");
throw new Error("测试数据未初始化或不包含知识库");
}
return testData.knowledge_bases[0].id;
// 知识库管理 API列表、创建、获取、更新、删除、复制
export function listKnowledgeBases() {
return get(`/api/v1/knowledge-bases`);
}
export async function uploadKnowledgeBase(data = {}) {
const kbId = await getKnowledgeBaseID();
export function createKnowledgeBase(data: { name: string; description?: string; chunking_config?: any }) {
return post(`/api/v1/knowledge-bases`, data);
}
export function getKnowledgeBaseById(id: string) {
return get(`/api/v1/knowledge-bases/${id}`);
}
export function updateKnowledgeBase(id: string, data: { name: string; description?: string; config: any }) {
return put(`/api/v1/knowledge-bases/${id}` , data);
}
export function deleteKnowledgeBase(id: string) {
return del(`/api/v1/knowledge-bases/${id}`);
}
export function copyKnowledgeBase(data: { source_id: string; target_id?: string }) {
return post(`/api/v1/knowledge-bases/copy`, data);
}
// 知识文件 API基于具体知识库
export function uploadKnowledgeFile(kbId: string, data = {}) {
return postUpload(`/api/v1/knowledge-bases/${kbId}/knowledge/file`, data);
}
export async function getKnowledgeBase({page, page_size}) {
const kbId = await getKnowledgeBaseID();
return get(
`/api/v1/knowledge-bases/${kbId}/knowledge?page=${page}&page_size=${page_size}`
);
export function listKnowledgeFiles(kbId: string, { page, page_size }: { page: number; page_size: number }) {
return get(`/api/v1/knowledge-bases/${kbId}/knowledge?page=${page}&page_size=${page_size}`);
}
export function getKnowledgeDetails(id: any) {
export function getKnowledgeDetails(id: string) {
return get(`/api/v1/knowledge/${id}`);
}
export function delKnowledgeDetails(id: any) {
export function delKnowledgeDetails(id: string) {
return del(`/api/v1/knowledge/${id}`);
}
export function downKnowledgeDetails(id: any) {
export function downKnowledgeDetails(id: string) {
return getDown(`/api/v1/knowledge/${id}/download`);
}
export function batchQueryKnowledge(ids: any) {
return get(`/api/v1/knowledge/batch?${ids}`);
export function batchQueryKnowledge(idsQueryString: string) {
return get(`/api/v1/knowledge/batch?${idsQueryString}`);
}
export function getKnowledgeDetailsCon(id: any, page) {
export function getKnowledgeDetailsCon(id: string, page: number) {
return get(`/api/v1/chunks/${id}?page=${page}&page_size=25`);
}

View File

@@ -0,0 +1,12 @@
import { get } from '@/utils/request'
export interface SystemInfo {
version: string
commit_id?: string
build_time?: string
go_version?: string
}
export function getSystemInfo(): Promise<{ data: SystemInfo }> {
return get('/api/v1/system/info')
}

View File

@@ -1,55 +0,0 @@
import { get, setTestData } from '../../utils/request';
export interface TestDataResponse {
success: boolean;
data: {
tenant: {
id: number;
name: string;
api_key: string;
};
knowledge_bases: Array<{
id: string;
name: string;
description: string;
}>;
}
}
// 是否已加载测试数据
let isTestDataLoaded = false;
/**
* 加载测试数据
* 在API调用前调用此函数以确保测试数据已加载
* @returns Promise<boolean> 是否成功加载
*/
export async function loadTestData(): Promise<boolean> {
// 如果已经加载过,直接返回
if (isTestDataLoaded) {
return true;
}
try {
console.log('开始加载测试数据...');
const response = await get('/api/v1/test-data');
console.log('测试数据', response);
if (response && response.data) {
// 设置测试数据
setTestData({
tenant: response.data.tenant,
knowledge_bases: response.data.knowledge_bases
});
isTestDataLoaded = true;
console.log('测试数据加载成功');
return true;
} else {
console.warn('测试数据响应为空');
return false;
}
} catch (error) {
console.error('加载测试数据失败:', error);
return false;
}
}

View File

@@ -0,0 +1,6 @@
<svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24" fill="none">
<path d="M10 3H6a2 2 0 0 0-2 2v14a2 2 0 0 0 2 2h4" stroke="#000" stroke-opacity="0.6" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M17 16l4-4-4-4" stroke="#000" stroke-opacity="0.6" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M21 12H10" stroke="#000" stroke-opacity="0.6" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
</svg>

After

Width:  |  Height:  |  Size: 509 B

View File

@@ -0,0 +1,4 @@
<svg width="20" height="20" viewBox="0 0 20 20" fill="none" xmlns="http://www.w3.org/2000/svg">
<circle cx="10" cy="6" r="3" stroke="#07C05F" stroke-width="1.5" fill="none"/>
<path d="M4 16c0-3.314 2.686-6 6-6s6 2.686 6 6" stroke="#07C05F" stroke-width="1.5" fill="none"/>
</svg>

After

Width:  |  Height:  |  Size: 284 B

View File

@@ -0,0 +1,4 @@
<svg width="20" height="20" viewBox="0 0 20 20" fill="none" xmlns="http://www.w3.org/2000/svg">
<circle cx="10" cy="6" r="3" stroke="currentColor" stroke-width="1.5" fill="none"/>
<path d="M4 16c0-3.314 2.686-6 6-6s6 2.686 6 6" stroke="currentColor" stroke-width="1.5" fill="none"/>
</svg>

After

Width:  |  Height:  |  Size: 294 B

View File

@@ -4,6 +4,8 @@ import { onMounted, ref, nextTick, onUnmounted, onUpdated, watch } from "vue";
import { downKnowledgeDetails } from "@/api/knowledge-base/index";
import { MessagePlugin } from "tdesign-vue-next";
import picturePreview from '@/components/picture-preview.vue';
import { sanitizeHTML, safeMarkdownToHTML, createSafeImage, isValidImageURL } from '@/utils/security';
marked.use({
mangle: false,
headerIds: false,
@@ -37,10 +39,16 @@ const checkImage = (url) => {
});
};
renderer.image = function (href, title, text) {
// 自定义HTML结构图片展示带标题
// 安全地处理图片链接
if (!isValidImageURL(href)) {
return `<p>无效的图片链接</p>`;
}
// 使用安全的图片创建函数
const safeImage = createSafeImage(href, text || '', title || '');
return `<figure>
<img class="markdown-image" src="${href}" alt="${title}" title="${text}">
<figcaption style="text-align: left;">${text}</figcaption>
${safeImage}
<figcaption style="text-align: left;">${text || ''}</figcaption>
</figure>`;
};
const props = defineProps(["visible", "details"]);
@@ -66,14 +74,23 @@ watch(() => props.details.md, (newVal) => {
deep: true
})
// 处理 Markdown 中的图片
// 安全地处理 Markdown 内容
const processMarkdown = (markdownText) => {
// 自定义渲染器处理图片
if (!markdownText || typeof markdownText !== 'string') {
return '';
}
// 首先对 Markdown 内容进行安全处理
const safeMarkdown = safeMarkdownToHTML(markdownText);
// 使用安全的渲染器
marked.use({ renderer });
let html = marked.parse(markdownText);
const parser = new DOMParser();
const doc = parser.parseFromString(html, 'text/html');
return doc.body.innerHTML;
let html = marked.parse(safeMarkdown);
// 使用 DOMPurify 进行最终的安全清理
const sanitizedHTML = sanitizeHTML(html);
return sanitizedHTML;
};
const closePreImg = () => {
reviewImg.value = false
@@ -87,15 +104,19 @@ const downloadFile = () => {
downKnowledgeDetails(props.details.id)
.then((result) => {
if (result) {
if (url.value) {
URL.revokeObjectURL(url.value);
}
url.value = URL.createObjectURL(result);
down.value.click();
// const link = document.createElement("a");
// link.style.display = "none";
// link.setAttribute("href", url);
// link.setAttribute("download", props.details.title);
// link.click();
// document.body.removeChild(link);
window.URL.revokeObjectURL(url);
const link = document.createElement("a");
link.style.display = "none";
link.setAttribute("href", url.value);
link.setAttribute("download", props.details.title);
link.click();
nextTick(() => {
document.body.removeChild(link);
URL.revokeObjectURL(url.value);
})
}
})
.catch((err) => {

View File

@@ -1,68 +1,132 @@
<template>
<div class="aside_box">
<div class="logo_box">
<div class="logo_box" @click="router.push('/platform/knowledge-bases')" style="cursor: pointer;">
<img class="logo" src="@/assets/img/weknora.png" alt="">
</div>
<div class="menu_box" v-for="(item, index) in menuArr" :key="index">
<div @click="gotopage(item.path)"
@mouseenter="mouseenteMenu(item.path)" @mouseleave="mouseleaveMenu(item.path)"
:class="['menu_item', item.childrenPath && item.childrenPath == currentpath ? 'menu_item_c_active' : item.path == currentpath ? 'menu_item_active' : '']">
<div class="menu_item-box">
<div class="menu_icon">
<img class="icon" :src="getImgSrc(item.icon == 'zhishiku' ? knowledgeIcon : item.icon == 'setting' ? settingIcon : prefixIcon)" alt="">
<!-- 上半部分知识库和对话 -->
<div class="menu_top">
<div class="menu_box" :class="{ 'has-submenu': item.children }" v-for="(item, index) in topMenuItems" :key="index">
<div @click="handleMenuClick(item.path)"
@mouseenter="mouseenteMenu(item.path)" @mouseleave="mouseleaveMenu(item.path)"
:class="['menu_item', item.childrenPath && item.childrenPath == currentpath ? 'menu_item_c_active' : isMenuItemActive(item.path) ? 'menu_item_active' : '']">
<div class="menu_item-box">
<div class="menu_icon">
<img class="icon" :src="getImgSrc(item.icon == 'zhishiku' ? knowledgeIcon : item.icon == 'logout' ? logoutIcon : item.icon == 'tenant' ? tenantIcon : prefixIcon)" alt="">
</div>
<span class="menu_title" :title="item.path === 'knowledge-bases' && kbMenuItem ? kbMenuItem.title : item.title">{{ item.path === 'knowledge-bases' && kbMenuItem ? kbMenuItem.title : item.title }}</span>
<!-- 知识库切换下拉箭头 -->
<div v-if="item.path === 'knowledge-bases' && isInKnowledgeBase"
class="kb-dropdown-icon"
:class="{
'rotate-180': showKbDropdown,
'active': isMenuItemActive(item.path)
}"
@click.stop="toggleKbDropdown">
<svg width="12" height="12" viewBox="0 0 12 12" fill="currentColor">
<path d="M2.5 4.5L6 8L9.5 4.5H2.5Z"/>
</svg>
</div>
</div>
<span class="menu_title">{{ item.title }}</span>
<!-- 知识库切换下拉菜单 -->
<div v-if="item.path === 'knowledge-bases' && showKbDropdown && isInKnowledgeBase"
class="kb-dropdown-menu">
<div v-for="kb in initializedKnowledgeBases"
:key="kb.id"
class="kb-dropdown-item"
:class="{ 'active': kb.name === currentKbName }"
@click.stop="switchKnowledgeBase(kb.id)">
{{ kb.name }}
</div>
</div>
<t-popup overlayInnerClassName="upload-popup" class="placement top center" content="上传知识"
placement="top" show-arrow destroy-on-close>
<div class="upload-file-wrap" @click.stop="uploadFile" variant="outline"
v-if="item.path === 'knowledge-bases' && $route.name === 'knowledgeBaseDetail'">
<img class="upload-file-icon" :class="[item.path == currentpath ? 'active-upload' : '']"
:src="getImgSrc(fileAddIcon)" alt="">
</div>
</t-popup>
</div>
<t-popup overlayInnerClassName="upload-popup" class="placement top center" content="上传知识"
placement="top" show-arrow destroy-on-close>
<div class="upload-file-wrap" @click="uploadFile" variant="outline"
v-if="item.path == 'knowledgeBase'">
<img class="upload-file-icon" :class="[item.path == currentpath ? 'active-upload' : '']"
:src="getImgSrc(fileAddIcon)" alt="">
</div>
</t-popup>
</div>
<div ref="submenuscrollContainer" @scroll="handleScroll" class="submenu" v-if="item.children">
<div class="submenu_item_p" v-for="(subitem, subindex) in item.children" :key="subindex"
@click="gotopage(subitem.path)">
<div :class="['submenu_item', currentSecondpath == subitem.path ? 'submenu_item_active' : '']"
@mouseenter="mouseenteBotDownr(subindex)" @mouseleave="mouseleaveBotDown">
<i v-if="currentSecondpath == subitem.path" class="dot"></i>
<span class="submenu_title"
:style="currentSecondpath == subitem.path ? 'margin-left:14px;max-width:160px;' : 'margin-left:18px;max-width:173px;'">
{{ subitem.title }}
</span>
<t-popup v-model:visible="subitem.isMore" @overlay-click="delCard(subindex, subitem)"
@visible-change="onVisibleChange" overlayClassName="del-menu-popup" trigger="click"
destroy-on-close placement="top-left">
<div v-if="(activeSubmenu == subindex) || (currentSecondpath == subitem.path) || subitem.isMore"
@click.stop="openMore(subindex)" variant="outline" class="menu-more-wrap">
<t-icon name="ellipsis" class="menu-more" />
</div>
<template #content>
<span class="del_submenu">删除记录</span>
</template>
</t-popup>
<div ref="submenuscrollContainer" @scroll="handleScroll" class="submenu" v-if="item.children">
<div class="submenu_item_p" v-for="(subitem, subindex) in item.children" :key="subindex"
@click="gotopage(subitem.path)">
<div :class="['submenu_item', currentSecondpath == subitem.path ? 'submenu_item_active' : '']"
@mouseenter="mouseenteBotDownr(subindex)" @mouseleave="mouseleaveBotDown">
<i v-if="currentSecondpath == subitem.path" class="dot"></i>
<span class="submenu_title"
:style="currentSecondpath == subitem.path ? 'margin-left:14px;max-width:160px;' : 'margin-left:18px;max-width:173px;'">
{{ subitem.title }}
</span>
<t-popup v-model:visible="subitem.isMore" @overlay-click="delCard(subindex, subitem)"
@visible-change="onVisibleChange" overlayClassName="del-menu-popup" trigger="click"
destroy-on-close placement="top-left">
<div v-if="(activeSubmenu == subindex) || (currentSecondpath == subitem.path) || subitem.isMore"
@click.stop="openMore(subindex)" variant="outline" class="menu-more-wrap">
<t-icon name="ellipsis" class="menu-more" />
</div>
<template #content>
<span class="del_submenu">删除记录</span>
</template>
</t-popup>
</div>
</div>
</div>
</div>
</div>
<!-- 下半部分账户信息系统设置退出登录 -->
<div class="menu_bottom">
<div class="menu_box" v-for="(item, index) in bottomMenuItems" :key="'bottom-' + index">
<div v-if="item.path === 'logout'">
<t-popconfirm
content="确定要退出登录吗?"
@confirm="handleLogout"
placement="top"
:show-arrow="true"
>
<div @mouseenter="mouseenteMenu(item.path)" @mouseleave="mouseleaveMenu(item.path)"
:class="['menu_item', 'logout-item']">
<div class="menu_item-box">
<div class="menu_icon">
<img class="icon" :src="getImgSrc(logoutIcon)" alt="">
</div>
<span class="menu_title">{{ item.title }}</span>
</div>
</div>
</t-popconfirm>
</div>
<div v-else @click="handleMenuClick(item.path)"
@mouseenter="mouseenteMenu(item.path)" @mouseleave="mouseleaveMenu(item.path)"
:class="['menu_item', item.childrenPath && item.childrenPath == currentpath ? 'menu_item_c_active' : (item.path == currentpath) ? 'menu_item_active' : '']">
<div class="menu_item-box">
<div class="menu_icon">
<img class="icon" :src="getImgSrc(item.icon == 'zhishiku' ? knowledgeIcon : item.icon == 'tenant' ? tenantIcon : prefixIcon)" alt="">
</div>
<span class="menu_title">{{ item.path === 'knowledge-bases' && kbMenuItem ? kbMenuItem.title : item.title }}</span>
</div>
</div>
</div>
</div>
<input type="file" @change="upload" style="display: none" ref="uploadInput"
accept=".pdf,.docx,.doc,.txt,.md,.jpg,.jpeg,.png" />
</div>
</template>
<script setup>
<script setup lang="ts">
import { storeToRefs } from 'pinia';
import { onMounted, watch, computed, ref, reactive } from 'vue';
import { onMounted, watch, computed, ref, reactive, nextTick } from 'vue';
import { useRoute, useRouter } from 'vue-router';
import { getSessionsList, delSession } from "@/api/chat/index";
import { getKnowledgeBaseById, listKnowledgeBases, uploadKnowledgeFile } from '@/api/knowledge-base';
import { kbFileTypeVerification } from '@/utils/index';
import { useMenuStore } from '@/stores/menu';
import useKnowledgeBase from '@/hooks/useKnowledgeBase';
import { useAuthStore } from '@/stores/auth';
import { MessagePlugin } from "tdesign-vue-next";
let { requestMethod } = useKnowledgeBase()
let uploadInput = ref();
const usemenuStore = useMenuStore();
const authStore = useAuthStore();
const route = useRoute();
const router = useRouter();
const currentpath = ref('');
@@ -74,39 +138,206 @@ const submenuscrollContainer = ref(null);
// 计算总页数
const totalPages = computed(() => Math.ceil(total.value / page_size.value));
const hasMore = computed(() => currentPage.value < totalPages.value);
type MenuItem = { title: string; icon: string; path: string; childrenPath?: string; children?: any[] };
const { menuArr } = storeToRefs(usemenuStore);
let activeSubmenu = ref(-1);
let activeSubmenu = ref<number>(-1);
// 是否处于知识库详情页
const isInKnowledgeBase = computed<boolean>(() => {
return route.name === 'knowledgeBaseDetail' ||
route.name === 'kbCreatChat' ||
route.name === 'chat' ||
route.name === 'knowledgeBaseSettings';
});
// 统一的菜单项激活状态判断
const isMenuItemActive = (itemPath: string): boolean => {
const currentRoute = route.name;
switch (itemPath) {
case 'knowledge-bases':
return currentRoute === 'knowledgeBaseList' ||
currentRoute === 'knowledgeBaseDetail' ||
currentRoute === 'knowledgeBaseSettings';
case 'creatChat':
return currentRoute === 'kbCreatChat';
case 'tenant':
return currentRoute === 'tenant';
default:
return itemPath === currentpath.value;
}
};
// 统一的图标激活状态判断
const getIconActiveState = (itemPath: string) => {
const currentRoute = route.name;
return {
isKbActive: itemPath === 'knowledge-bases' && (
currentRoute === 'knowledgeBaseList' ||
currentRoute === 'knowledgeBaseDetail' ||
currentRoute === 'knowledgeBaseSettings'
),
isCreatChatActive: itemPath === 'creatChat' && currentRoute === 'kbCreatChat',
isTenantActive: itemPath === 'tenant' && currentRoute === 'tenant',
isChatActive: itemPath === 'chat' && currentRoute === 'chat'
};
};
// 分离上下两部分菜单
const topMenuItems = computed<MenuItem[]>(() => {
return (menuArr.value as unknown as MenuItem[]).filter((item: MenuItem) =>
item.path === 'knowledge-bases' || (isInKnowledgeBase.value && item.path === 'creatChat')
);
});
const bottomMenuItems = computed<MenuItem[]>(() => {
return (menuArr.value as unknown as MenuItem[]).filter((item: MenuItem) => {
if (item.path === 'knowledge-bases' || item.path === 'creatChat') {
return false;
}
return true;
});
});
// 当前知识库名称和列表
const currentKbName = ref<string>('')
const allKnowledgeBases = ref<Array<{ id: string; name: string; embedding_model_id?: string; summary_model_id?: string }>>([])
const showKbDropdown = ref<boolean>(false)
// 过滤已初始化的知识库
const initializedKnowledgeBases = computed(() => {
return allKnowledgeBases.value.filter(kb =>
kb.embedding_model_id && kb.embedding_model_id !== '' &&
kb.summary_model_id && kb.summary_model_id !== ''
)
})
// 动态更新知识库菜单项标题
const kbMenuItem = computed(() => {
const kbItem = topMenuItems.value.find(item => item.path === 'knowledge-bases')
if (kbItem && isInKnowledgeBase.value && currentKbName.value) {
return { ...kbItem, title: currentKbName.value }
}
return kbItem
})
const loading = ref(false)
const uploadFile = () => {
const uploadFile = async () => {
// 获取当前知识库ID
const currentKbId = await getCurrentKbId();
// 检查当前知识库的初始化状态
if (currentKbId) {
try {
const kbResponse = await getKnowledgeBaseById(currentKbId);
const kb = kbResponse.data;
// 检查知识库是否已初始化(有 EmbeddingModelID 和 SummaryModelID
if (!kb.embedding_model_id || kb.embedding_model_id === '' ||
!kb.summary_model_id || kb.summary_model_id === '') {
MessagePlugin.warning("该知识库尚未完成初始化配置,请先前往设置页面配置模型信息后再上传文件");
return;
}
} catch (error) {
console.error('获取知识库信息失败:', error);
MessagePlugin.error("获取知识库信息失败,无法上传文件");
return;
}
}
uploadInput.value.click()
}
const upload = (e) => {
requestMethod(e.target.files[0], uploadInput)
const upload = async (e: any) => {
const file = e.target.files[0];
if (!file) return;
// 文件类型验证
if (kbFileTypeVerification(file)) {
return;
}
// 获取当前知识库ID
const currentKbId = (route.params as any)?.kbId as string;
if (!currentKbId) {
MessagePlugin.error("缺少知识库ID");
return;
}
try {
const result = await uploadKnowledgeFile(currentKbId, { file });
const responseData = result as any;
console.log('上传API返回结果:', responseData);
// 如果没有抛出异常,就认为上传成功,先触发刷新事件
console.log('文件上传完成发送事件通知页面刷新知识库ID:', currentKbId);
window.dispatchEvent(new CustomEvent('knowledgeFileUploaded', {
detail: { kbId: currentKbId }
}));
// 然后处理UI消息
// 判断上传是否成功 - 检查多种可能的成功标识
const isSuccess = responseData.success || responseData.code === 200 || responseData.status === 'success' || (!responseData.error && responseData);
if (isSuccess) {
MessagePlugin.info("上传成功!");
} else {
// 改进错误信息提取逻辑
let errorMessage = "上传失败!";
if (responseData.error && responseData.error.message) {
errorMessage = responseData.error.message;
} else if (responseData.message) {
errorMessage = responseData.message;
}
if (responseData.code === 'duplicate_file' || (responseData.error && responseData.error.code === 'duplicate_file')) {
errorMessage = "文件已存在";
}
MessagePlugin.error(errorMessage);
}
} catch (err: any) {
let errorMessage = "上传失败!";
if (err.code === 'duplicate_file') {
errorMessage = "文件已存在";
} else if (err.error && err.error.message) {
errorMessage = err.error.message;
} else if (err.message) {
errorMessage = err.message;
}
MessagePlugin.error(errorMessage);
} finally {
uploadInput.value.value = "";
}
}
const mouseenteBotDownr = (val) => {
const mouseenteBotDownr = (val: number) => {
activeSubmenu.value = val;
}
const mouseleaveBotDown = () => {
activeSubmenu.value = -1;
}
const onVisibleChange = (e) => {
const onVisibleChange = (_e: any) => {
}
const delCard = (index, item) => {
delSession(item.id).then(res => {
if (res && res.success) {
menuArr.value[1].children.splice(index, 1);
const delCard = (index: number, item: any) => {
delSession(item.id).then((res: any) => {
if (res && (res as any).success) {
(menuArr.value as any[])[1]?.children?.splice(index, 1);
if (item.id == route.params.chatid) {
router.push('/platform/creatChat');
// 删除当前会话后,跳转到当前知识库的创建聊天页面
const kbId = route.params.kbId;
if (kbId) {
router.push(`/platform/knowledge-bases/${kbId}/creatChat`);
} else {
router.push('/platform/knowledge-bases');
}
}
} else {
MessagePlugin.error("删除失败,请稍后再试!");
}
})
}
const debounce = (fn, delay) => {
let timer
return (...args) => {
const debounce = (fn: (...args: any[]) => void, delay: number) => {
let timer: ReturnType<typeof setTimeout>
return (...args: any[]) => {
clearTimeout(timer)
timer = setTimeout(() => fn(...args), delay)
}
@@ -124,80 +355,221 @@ const checkScrollBottom = () => {
}
}
const handleScroll = debounce(checkScrollBottom, 200)
const getMessageList = () => {
const getMessageList = async () => {
// 仅在知识库内部显示对话列表
if (!isInKnowledgeBase.value) {
usemenuStore.clearMenuArr();
currentKbName.value = '';
return;
}
let kbId = (route.params as any)?.kbId as string
// 新的路由格式:/platform/chat/:kbId/:chatid直接从路由参数获取知识库ID
if (!kbId) {
usemenuStore.clearMenuArr();
currentKbName.value = '';
return;
}
// 获取知识库名称和所有知识库列表
try {
const [kbRes, allKbRes]: any[] = await Promise.all([
getKnowledgeBaseById(kbId),
listKnowledgeBases()
])
if (kbRes?.data?.name) {
currentKbName.value = kbRes.data.name
}
if (allKbRes?.data) {
allKnowledgeBases.value = allKbRes.data
}
} catch {}
if (loading.value) return;
loading.value = true;
usemenuStore.clearMenuArr();
getSessionsList(currentPage.value, page_size.value).then(res => {
getSessionsList(currentPage.value, page_size.value).then((res: any) => {
if (res.data && res.data.length) {
res.data.forEach(item => {
let obj = { title: item.title ? item.title : "新会话", path: `chat/${item.id}`, id: item.id, isMore: false, isNoTitle: item.title ? false : true }
// 过滤出当前知识库的会话
const filtered = res.data.filter((s: any) => s.knowledge_base_id === kbId)
filtered.forEach((item: any) => {
let obj = { title: item.title ? item.title : "新会话", path: `chat/${kbId}/${item.id}`, id: item.id, isMore: false, isNoTitle: item.title ? false : true }
usemenuStore.updatemenuArr(obj)
});
loading.value = false;
}
if (res.total) {
total.value = res.total;
if ((res as any).total) {
total.value = (res as any).total;
}
})
}
const openMore = (e) => { }
const openMore = (_e: any) => { }
onMounted(() => {
currentpath.value = route.name;
if (route.params.chatid) {
currentSecondpath.value = `${route.name}/${route.params.chatid}`;
const routeName = typeof route.name === 'string' ? route.name : (route.name ? String(route.name) : '')
currentpath.value = routeName;
if (route.params.chatid && route.params.kbId) {
currentSecondpath.value = `chat/${route.params.kbId}/${route.params.chatid}`;
}
getMessageList();
});
watch([() => route.name, () => route.params], (newvalue) => {
currentpath.value = newvalue[0];
if (newvalue[1].chatid) {
currentSecondpath.value = `${newvalue[0]}/${newvalue[1].chatid}`;
const nameStr = typeof newvalue[0] === 'string' ? (newvalue[0] as string) : (newvalue[0] ? String(newvalue[0]) : '')
currentpath.value = nameStr;
if (newvalue[1].chatid && newvalue[1].kbId) {
currentSecondpath.value = `chat/${newvalue[1].kbId}/${newvalue[1].chatid}`;
} else {
currentSecondpath.value = "";
}
// 路由变化时刷新对话列表(仅在知识库内部)
getMessageList();
// 路由变化时更新图标状态
getIcon(nameStr);
});
let fileAddIcon = ref('file-add-green.svg');
let knowledgeIcon = ref('zhishiku-green.svg');
let prefixIcon = ref('prefixIcon.svg');
let settingIcon = ref('setting.svg');
let logoutIcon = ref('logout.svg');
let tenantIcon = ref('user.svg'); // 使用专门的用户图标
let pathPrefix = ref(route.name)
const getIcon = (path) => {
fileAddIcon.value = path == 'knowledgeBase' ? 'file-add-green.svg' : 'file-add.svg';
knowledgeIcon.value = path == 'knowledgeBase' ? 'zhishiku-green.svg' : 'zhishiku.svg';
prefixIcon.value = path == 'creatChat' ? 'prefixIcon-green.svg' : path == 'knowledgeBase' ? 'prefixIcon-grey.svg' : 'prefixIcon.svg';
settingIcon.value = path == 'settings' ? 'setting-green.svg' : 'setting.svg';
const getIcon = (path: string) => {
// 根据当前路由状态更新所有图标
const kbActiveState = getIconActiveState('knowledge-bases');
const creatChatActiveState = getIconActiveState('creatChat');
const tenantActiveState = getIconActiveState('tenant');
// 上传图标:只在知识库相关页面显示绿色
fileAddIcon.value = kbActiveState.isKbActive ? 'file-add-green.svg' : 'file-add.svg';
// 知识库图标:只在知识库页面显示绿色
knowledgeIcon.value = kbActiveState.isKbActive ? 'zhishiku-green.svg' : 'zhishiku.svg';
// 对话图标:只在对话创建页面显示绿色,在知识库页面显示灰色,其他情况显示默认
prefixIcon.value = creatChatActiveState.isCreatChatActive ? 'prefixIcon-green.svg' :
kbActiveState.isKbActive ? 'prefixIcon-grey.svg' :
'prefixIcon.svg';
// 租户图标:只在租户页面显示绿色
tenantIcon.value = tenantActiveState.isTenantActive ? 'user-green.svg' : 'user.svg';
// 退出图标:始终显示默认
logoutIcon.value = 'logout.svg';
}
getIcon(route.name)
const gotopage = (path) => {
pathPrefix.value = path;
// 如果是系统设置,跳转到初始化配置页面
if (path === 'settings') {
router.push('/initialization');
getIcon(typeof route.name === 'string' ? route.name as string : (route.name ? String(route.name) : ''))
const handleMenuClick = async (path: string) => {
if (path === 'knowledge-bases') {
// 知识库菜单项:如果在知识库内部,跳转到当前知识库文件页;否则跳转到知识库列表
const kbId = await getCurrentKbId()
if (kbId) {
router.push(`/platform/knowledge-bases/${kbId}`)
} else {
router.push('/platform/knowledge-bases')
}
} else {
router.push(`/platform/${path}`);
gotopage(path)
}
}
// 处理退出登录确认
const handleLogout = () => {
gotopage('logout')
}
const getCurrentKbId = async (): Promise<string | null> => {
let kbId = (route.params as any)?.kbId as string
// 新的路由格式:/platform/chat/:kbId/:chatid直接从路由参数获取
if (!kbId && route.name === 'chat' && (route.params as any)?.kbId) {
kbId = (route.params as any).kbId
}
return kbId || null
}
const gotopage = async (path: string) => {
pathPrefix.value = path;
// 处理退出登录
if (path === 'logout') {
authStore.logout();
router.push('/login');
return;
} else {
if (path === 'creatChat') {
const kbId = await getCurrentKbId()
if (kbId) {
router.push(`/platform/knowledge-bases/${kbId}/creatChat`)
} else {
router.push(`/platform/knowledge-bases`)
}
} else {
router.push(`/platform/${path}`);
}
}
getIcon(path)
}
const getImgSrc = (url) => {
const getImgSrc = (url: string) => {
return new URL(`/src/assets/img/${url}`, import.meta.url).href;
}
const mouseenteMenu = (path) => {
if (pathPrefix.value != 'knowledgeBase' && pathPrefix.value != 'creatChat' && path != 'knowledgeBase') {
const mouseenteMenu = (path: string) => {
if (pathPrefix.value != 'knowledge-bases' && pathPrefix.value != 'creatChat' && path != 'knowledge-bases') {
prefixIcon.value = 'prefixIcon-grey.svg';
}
}
const mouseleaveMenu = (path) => {
if (pathPrefix.value != 'knowledgeBase' && pathPrefix.value != 'creatChat' && path != 'knowledgeBase') {
getIcon(route.name)
const mouseleaveMenu = (path: string) => {
if (pathPrefix.value != 'knowledge-bases' && pathPrefix.value != 'creatChat' && path != 'knowledge-bases') {
const nameStr = typeof route.name === 'string' ? route.name as string : (route.name ? String(route.name) : '')
getIcon(nameStr)
}
}
// 知识库下拉相关方法
const toggleKbDropdown = (event?: Event) => {
if (event) {
event.stopPropagation()
}
showKbDropdown.value = !showKbDropdown.value
}
const switchKnowledgeBase = (kbId: string, event?: Event) => {
if (event) {
event.stopPropagation()
}
showKbDropdown.value = false
const currentRoute = route.name
// 路由跳转
if (currentRoute === 'knowledgeBaseDetail') {
router.push(`/platform/knowledge-bases/${kbId}`)
} else if (currentRoute === 'kbCreatChat') {
router.push(`/platform/knowledge-bases/${kbId}/creatChat`)
} else if (currentRoute === 'knowledgeBaseSettings') {
router.push(`/platform/knowledge-bases/${kbId}/settings`)
} else {
router.push(`/platform/knowledge-bases/${kbId}`)
}
// 刷新右侧内容 - 通过触发页面重新加载或发送事件
nextTick(() => {
// 发送全局事件通知页面刷新知识库内容
window.dispatchEvent(new CustomEvent('knowledgeBaseChanged', {
detail: { kbId }
}))
})
}
// 点击外部关闭下拉菜单
const handleClickOutside = () => {
showKbDropdown.value = false
}
onMounted(() => {
document.addEventListener('click', handleClickOutside)
})
watch(() => route.params.kbId, () => {
showKbDropdown.value = false
})
</script>
<style lang="less" scoped>
.del_submenu {
@@ -210,6 +582,10 @@ const mouseleaveMenu = (path) => {
padding: 8px;
background: #fff;
box-sizing: border-box;
height: 100vh;
overflow: hidden;
display: flex;
flex-direction: column;
.logo_box {
height: 80px;
@@ -239,9 +615,28 @@ const mouseleaveMenu = (path) => {
line-height: 21.7px;
}
.menu_top {
flex: 1;
display: flex;
flex-direction: column;
overflow: hidden;
min-height: 0;
}
.menu_bottom {
flex-shrink: 0;
display: flex;
flex-direction: column;
}
.menu_box {
display: flex;
flex-direction: column;
&.has-submenu {
flex: 1;
min-height: 0;
}
}
@@ -341,18 +736,21 @@ const mouseleaveMenu = (path) => {
font-style: normal;
font-weight: 600;
line-height: 22px;
overflow: hidden;
white-space: nowrap;
max-width: 120px;
flex: 1;
}
.submenu {
font-family: "PingFang SC";
font-size: 14px;
font-style: normal;
font-family: "PingFang SC";
font-size: 14px;
font-style: normal;
overflow-y: scroll;
overflow-y: auto;
scrollbar-width: none;
height: calc(98vh - 276px);
flex: 1;
min-height: 0;
margin-left: 4px;
}
.submenu_item_p {
@@ -427,6 +825,92 @@ const mouseleaveMenu = (path) => {
}
}
}
/* 知识库下拉菜单样式 */
.kb-dropdown-icon {
margin-left: auto;
color: #666;
transition: transform 0.3s ease, color 0.2s ease;
cursor: pointer;
display: flex;
align-items: center;
justify-content: center;
width: 16px;
height: 16px;
&.rotate-180 {
transform: rotate(180deg);
}
&:hover {
color: #07c05f;
}
&.active {
color: #07c05f;
}
&.active:hover {
color: #05a04f;
}
svg {
width: 12px;
height: 12px;
transition: inherit;
}
}
.kb-dropdown-menu {
position: absolute;
top: 100%;
left: 0;
right: 0;
background: #fff;
border: 1px solid #e5e7eb;
border-radius: 6px;
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
z-index: 1000;
max-height: 200px;
overflow-y: auto;
}
.kb-dropdown-item {
padding: 8px 16px;
cursor: pointer;
transition: background-color 0.2s ease;
font-size: 14px;
color: #333;
&:hover {
background-color: #f5f5f5;
}
&.active {
background-color: #07c05f1a;
color: #07c05f;
font-weight: 500;
}
&:first-child {
border-radius: 6px 6px 0 0;
}
&:last-child {
border-radius: 0 0 6px 6px;
}
}
.menu_item-box {
display: flex;
align-items: center;
width: 100%;
position: relative;
}
.menu_box {
position: relative;
}
</style>
<style lang="less">
.upload-popup {
@@ -456,4 +940,48 @@ const mouseleaveMenu = (path) => {
}
}
// 退出登录确认框样式
:deep(.t-popconfirm) {
.t-popconfirm__content {
background: #fff;
border: 1px solid #e7e7e7;
border-radius: 6px;
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.15);
padding: 12px 16px;
font-size: 14px;
color: #333;
max-width: 200px;
}
.t-popconfirm__arrow {
border-bottom-color: #e7e7e7;
}
.t-popconfirm__arrow::after {
border-bottom-color: #fff;
}
.t-popconfirm__buttons {
margin-top: 8px;
display: flex;
justify-content: flex-end;
gap: 8px;
}
.t-button--variant-outline {
border-color: #d9d9d9;
color: #666;
}
.t-button--theme-danger {
background-color: #ff4d4f;
border-color: #ff4d4f;
}
.t-button--theme-danger:hover {
background-color: #ff7875;
border-color: #ff7875;
}
}
</style>

View File

@@ -1,52 +1,54 @@
import { ref, reactive, onMounted } from "vue";
import { ref, reactive } from "vue";
import { storeToRefs } from "pinia";
import { formatStringDate, kbFileTypeVerification } from "../utils/index";
import { MessagePlugin } from "tdesign-vue-next";
import {
uploadKnowledgeBase,
getKnowledgeBase,
uploadKnowledgeFile,
listKnowledgeFiles,
getKnowledgeDetails,
delKnowledgeDetails,
getKnowledgeDetailsCon,
} from "@/api/knowledge-base/index";
import { knowledgeStore } from "@/stores/knowledge";
import { useRoute } from 'vue-router';
const usemenuStore = knowledgeStore();
export default function () {
export default function (knowledgeBaseId?: string) {
const route = useRoute();
const { cardList, total } = storeToRefs(usemenuStore);
let moreIndex = ref(-1);
const details = reactive({
title: "",
time: "",
md: [],
md: [] as any[],
id: "",
total: 0
});
const getKnowled = (query = { page: 1, page_size: 35 }) => {
getKnowledgeBase(query)
const getKnowled = (query = { page: 1, page_size: 35 }, kbId?: string) => {
const targetKbId = kbId || knowledgeBaseId;
if (!targetKbId) return;
listKnowledgeFiles(targetKbId, query)
.then((result: any) => {
let { data, total: totalResult } = result;
let cardList_ = data.map((item) => {
item["file_name"] = item.file_name.substring(
0,
item.file_name.lastIndexOf(".")
);
return {
...item,
updated_at: formatStringDate(new Date(item.updated_at)),
isMore: false,
file_type: item.file_type.toLocaleUpperCase(),
};
});
if (query.page == 1) {
const { data, total: totalResult } = result;
const cardList_ = data.map((item: any) => ({
...item,
file_name: item.file_name.substring(0, item.file_name.lastIndexOf(".")),
updated_at: formatStringDate(new Date(item.updated_at)),
isMore: false,
file_type: item.file_type.toLocaleUpperCase(),
}));
if (query.page === 1) {
cardList.value = cardList_;
} else {
cardList.value.push(...cardList_);
}
total.value = totalResult;
})
.catch((err) => {});
.catch(() => {});
};
const delKnowledge = (index: number, item) => {
const delKnowledge = (index: number, item: any) => {
cardList.value[index].isMore = false;
moreIndex.value = -1;
delKnowledgeDetails(item.id)
@@ -58,7 +60,7 @@ export default function () {
MessagePlugin.error("知识删除失败!");
}
})
.catch((err) => {
.catch(() => {
MessagePlugin.error("知识删除失败!");
});
};
@@ -70,56 +72,48 @@ export default function () {
moreIndex.value = -1;
}
};
const requestMethod = (file: any, uploadInput) => {
if (file instanceof File && uploadInput) {
if (kbFileTypeVerification(file)) {
return;
}
uploadKnowledgeBase({ file })
.then((result: any) => {
if (result.success) {
MessagePlugin.info("上传成功!");
getKnowled();
} else {
// 改进错误信息提取逻辑
let errorMessage = "上传失败!";
// 优先从 error 对象中获取错误信息
if (result.error && result.error.message) {
errorMessage = result.error.message;
} else if (result.message) {
errorMessage = result.message;
}
// 检查错误码,如果是重复文件则显示特定提示
if (result.code === 'duplicate_file' || (result.error && result.error.code === 'duplicate_file')) {
errorMessage = "文件已存在";
}
MessagePlugin.error(errorMessage);
}
uploadInput.value.value = "";
})
.catch((err: any) => {
// 改进 catch 中的错误处理
let errorMessage = "上传失败!";
if (err.code === 'duplicate_file') {
errorMessage = "文件已存在";
} else if (err.error && err.error.message) {
errorMessage = err.error.message;
} else if (err.message) {
errorMessage = err.message;
}
MessagePlugin.error(errorMessage);
uploadInput.value.value = "";
});
} else {
MessagePlugin.error("file文件类型错误");
const requestMethod = (file: any, uploadInput: any) => {
if (!(file instanceof File) || !uploadInput) {
MessagePlugin.error("文件类型错误!");
return;
}
if (kbFileTypeVerification(file)) {
return;
}
// 获取当前知识库ID
let currentKbId: string | undefined = (route.params as any)?.kbId as string;
if (!currentKbId && typeof window !== 'undefined') {
const match = window.location.pathname.match(/knowledge-bases\/([^/]+)/);
if (match?.[1]) currentKbId = match[1];
}
if (!currentKbId) {
currentKbId = knowledgeBaseId;
}
if (!currentKbId) {
MessagePlugin.error("缺少知识库ID");
return;
}
uploadKnowledgeFile(currentKbId, { file })
.then((result: any) => {
if (result.success) {
MessagePlugin.info("上传成功!");
getKnowled({ page: 1, page_size: 35 }, currentKbId);
} else {
const errorMessage = result.error?.message || result.message || "上传失败!";
MessagePlugin.error(result.code === 'duplicate_file' ? "文件已存在" : errorMessage);
}
uploadInput.value.value = "";
})
.catch((err: any) => {
const errorMessage = err.error?.message || err.message || "上传失败!";
MessagePlugin.error(err.code === 'duplicate_file' ? "文件已存在" : errorMessage);
uploadInput.value.value = "";
});
};
const getCardDetails = (item) => {
const getCardDetails = (item: any) => {
Object.assign(details, {
title: "",
time: "",
@@ -129,7 +123,7 @@ export default function () {
getKnowledgeDetails(item.id)
.then((result: any) => {
if (result.success && result.data) {
let { data } = result;
const { data } = result;
Object.assign(details, {
title: data.file_name,
time: formatStringDate(new Date(data.updated_at)),
@@ -137,15 +131,16 @@ export default function () {
});
}
})
.catch((err) => {});
getfDetails(item.id, 1);
.catch(() => {});
getfDetails(item.id, 1);
};
const getfDetails = (id, page) => {
const getfDetails = (id: string, page: number) => {
getKnowledgeDetailsCon(id, page)
.then((result: any) => {
if (result.success && result.data) {
let { data, total: totalResult } = result;
if (page == 1) {
const { data, total: totalResult } = result;
if (page === 1) {
details.md = data;
} else {
details.md.push(...data);
@@ -153,7 +148,7 @@ export default function () {
details.total = totalResult;
}
})
.catch((err) => {});
.catch(() => {});
};
return {
cardList,

View File

@@ -1,89 +1,117 @@
import { createRouter, createWebHistory } from 'vue-router'
import { checkInitializationStatus } from '@/api/initialization'
import { listKnowledgeBases } from '@/api/knowledge-base'
import { useAuthStore } from '@/stores/auth'
import { validateToken } from '@/api/auth'
const router = createRouter({
history: createWebHistory(import.meta.env.BASE_URL),
routes: [
{
path: "/",
redirect: "/platform",
redirect: "/platform/knowledge-bases",
},
{
path: "/initialization",
name: "initialization",
component: () => import("../views/initialization/InitializationConfig.vue"),
meta: { requiresInit: false } // 初始化页面不需要检查初始化状态
path: "/login",
name: "login",
component: () => import("../views/auth/Login.vue"),
meta: { requiresAuth: false, requiresInit: false }
},
{
path: "/knowledgeBase",
name: "home",
component: () => import("../views/knowledge/KnowledgeBase.vue"),
meta: { requiresInit: true }
meta: { requiresInit: true, requiresAuth: true }
},
{
path: "/platform",
name: "Platform",
redirect: "/platform/knowledgeBase",
redirect: "/platform/knowledge-bases",
component: () => import("../views/platform/index.vue"),
meta: { requiresInit: true },
meta: { requiresInit: true, requiresAuth: true },
children: [
{
path: "knowledgeBase",
name: "knowledgeBase",
path: "tenant",
name: "tenant",
component: () => import("../views/tenant/TenantInfo.vue"),
meta: { requiresInit: true, requiresAuth: true }
},
{
path: "knowledge-bases",
name: "knowledgeBaseList",
component: () => import("../views/knowledge/KnowledgeBaseList.vue"),
meta: { requiresInit: true, requiresAuth: true }
},
{
path: "knowledge-bases/:kbId",
name: "knowledgeBaseDetail",
component: () => import("../views/knowledge/KnowledgeBase.vue"),
meta: { requiresInit: true }
meta: { requiresInit: true, requiresAuth: true }
},
{
path: "creatChat",
name: "creatChat",
path: "knowledge-bases/:kbId/creatChat",
name: "kbCreatChat",
component: () => import("../views/creatChat/creatChat.vue"),
meta: { requiresInit: true }
meta: { requiresInit: true, requiresAuth: true }
},
{
path: "chat/:chatid",
path: "knowledge-bases/:kbId/settings",
name: "knowledgeBaseSettings",
component: () => import("../views/initialization/InitializationContent.vue"),
props: { isKbSettings: true },
meta: { requiresInit: true, requiresAuth: true }
},
{
path: "chat/:kbId/:chatid",
name: "chat",
component: () => import("../views/chat/index.vue"),
meta: { requiresInit: true }
},
{
path: "settings",
name: "settings",
component: () => import("../views/settings/Settings.vue"),
meta: { requiresInit: true }
meta: { requiresInit: true, requiresAuth: true }
},
],
},
],
});
// 路由守卫:检查系统初始化状态
// 路由守卫:检查认证状态和系统初始化状态
router.beforeEach(async (to, from, next) => {
// 如果访问的是初始化页面,直接放行
if (to.meta.requiresInit === false) {
next();
return;
}
1
try {
// 检查系统是否已初始化
const { initialized } = await checkInitializationStatus();
if (initialized) {
// 系统已初始化,记录到本地存储并正常跳转
localStorage.setItem('system_initialized', 'true');
next();
} else {
// 系统未初始化,跳转到初始化页面
console.log('系统未初始化,跳转到初始化页面');
next('/initialization');
const authStore = useAuthStore()
// 如果访问的是登录页面或初始化页面,直接放行
if (to.meta.requiresAuth === false || to.meta.requiresInit === false) {
// 如果已登录用户访问登录页面,重定向到知识库列表页面
if (to.path === '/login' && authStore.isLoggedIn) {
next('/platform/knowledge-bases')
return
}
} catch (error) {
console.error('检查初始化状态失败:', error);
// 如果检查失败,默认认为需要初始化
next('/initialization');
next()
return
}
// 检查用户认证状态
if (to.meta.requiresAuth !== false) {
if (!authStore.isLoggedIn) {
// 未登录,跳转到登录页面
next('/login')
return
}
// 验证Token有效性
// try {
// const { valid } = await validateToken()
// if (!valid) {
// // Token无效清空认证信息并跳转到登录页面
// authStore.logout()
// next('/login')
// return
// }
// } catch (error) {
// console.error('Token验证失败:', error)
// authStore.logout()
// next('/login')
// return
// }
}
next()
});
export default router

169
frontend/src/stores/auth.ts Normal file
View File

@@ -0,0 +1,169 @@
import { defineStore } from 'pinia'
import { ref, computed } from 'vue'
import type { UserInfo, TenantInfo, KnowledgeBaseInfo } from '@/api/auth'
export const useAuthStore = defineStore('auth', () => {
// 状态
const user = ref<UserInfo | null>(null)
const tenant = ref<TenantInfo | null>(null)
const token = ref<string>('')
const refreshToken = ref<string>('')
const knowledgeBases = ref<KnowledgeBaseInfo[]>([])
const currentKnowledgeBase = ref<KnowledgeBaseInfo | null>(null)
// 计算属性
const isLoggedIn = computed(() => {
return !!token.value && !!user.value
})
const hasValidTenant = computed(() => {
return !!tenant.value && !!tenant.value.api_key
})
const currentTenantId = computed(() => {
return tenant.value?.id || ''
})
const currentUserId = computed(() => {
return user.value?.id || ''
})
// 操作方法
const setUser = (userData: UserInfo) => {
user.value = userData
// 保存到localStorage
localStorage.setItem('weknora_user', JSON.stringify(userData))
}
const setTenant = (tenantData: TenantInfo) => {
tenant.value = tenantData
// 保存到localStorage
localStorage.setItem('weknora_tenant', JSON.stringify(tenantData))
}
const setToken = (tokenValue: string) => {
token.value = tokenValue
localStorage.setItem('weknora_token', tokenValue)
}
const setRefreshToken = (refreshTokenValue: string) => {
refreshToken.value = refreshTokenValue
localStorage.setItem('weknora_refresh_token', refreshTokenValue)
}
const setKnowledgeBases = (kbList: KnowledgeBaseInfo[]) => {
// 确保输入是数组
knowledgeBases.value = Array.isArray(kbList) ? kbList : []
localStorage.setItem('weknora_knowledge_bases', JSON.stringify(knowledgeBases.value))
}
const setCurrentKnowledgeBase = (kb: KnowledgeBaseInfo | null) => {
currentKnowledgeBase.value = kb
if (kb) {
localStorage.setItem('weknora_current_kb', JSON.stringify(kb))
} else {
localStorage.removeItem('weknora_current_kb')
}
}
const logout = () => {
// 清空状态
user.value = null
tenant.value = null
token.value = ''
refreshToken.value = ''
knowledgeBases.value = []
currentKnowledgeBase.value = null
// 清空localStorage
localStorage.removeItem('weknora_user')
localStorage.removeItem('weknora_tenant')
localStorage.removeItem('weknora_token')
localStorage.removeItem('weknora_refresh_token')
localStorage.removeItem('weknora_knowledge_bases')
localStorage.removeItem('weknora_current_kb')
}
const initFromStorage = () => {
// 从localStorage恢复状态
const storedUser = localStorage.getItem('weknora_user')
const storedTenant = localStorage.getItem('weknora_tenant')
const storedToken = localStorage.getItem('weknora_token')
const storedRefreshToken = localStorage.getItem('weknora_refresh_token')
const storedKnowledgeBases = localStorage.getItem('weknora_knowledge_bases')
const storedCurrentKb = localStorage.getItem('weknora_current_kb')
if (storedUser) {
try {
user.value = JSON.parse(storedUser)
} catch (e) {
console.error('解析用户信息失败:', e)
}
}
if (storedTenant) {
try {
tenant.value = JSON.parse(storedTenant)
} catch (e) {
console.error('解析租户信息失败:', e)
}
}
if (storedToken) {
token.value = storedToken
}
if (storedRefreshToken) {
refreshToken.value = storedRefreshToken
}
if (storedKnowledgeBases) {
try {
const parsed = JSON.parse(storedKnowledgeBases)
knowledgeBases.value = Array.isArray(parsed) ? parsed : []
} catch (e) {
console.error('解析知识库列表失败:', e)
knowledgeBases.value = []
}
}
if (storedCurrentKb) {
try {
currentKnowledgeBase.value = JSON.parse(storedCurrentKb)
} catch (e) {
console.error('解析当前知识库失败:', e)
}
}
}
// 初始化时从localStorage恢复状态
initFromStorage()
return {
// 状态
user,
tenant,
token,
refreshToken,
knowledgeBases,
currentKnowledgeBase,
// 计算属性
isLoggedIn,
hasValidTenant,
currentTenantId,
currentUserId,
// 方法
setUser,
setTenant,
setToken,
setRefreshToken,
setKnowledgeBases,
setCurrentKnowledgeBase,
logout,
initFromStorage
}
})

View File

@@ -4,8 +4,8 @@ import { defineStore } from "pinia";
export const knowledgeStore = defineStore("knowledge", {
state: () => ({
cardList: ref([]),
total: ref(0),
cardList: ref<any[]>([]),
total: ref<number>(0),
}),
actions: {},
});

View File

@@ -5,7 +5,7 @@ import { defineStore } from 'pinia';
export const useMenuStore = defineStore('menuStore', {
state: () => ({
menuArr: reactive([
{ title: '知识库', icon: 'zhishiku', path: 'knowledgeBase' },
{ title: '知识库', icon: 'zhishiku', path: 'knowledge-bases' },
{
title: '对话',
icon: 'prefixIcon',
@@ -13,7 +13,8 @@ export const useMenuStore = defineStore('menuStore', {
childrenPath: 'chat',
children: reactive<object[]>([]),
},
{ title: '系统设置', icon: 'setting', path: 'settings' }
{ title: '系统信息', icon: 'tenant', path: 'tenant' },
{ title: '退出登录', icon: 'logout', path: 'logout' }
]),
isFirstSession: false,
firstQuery: ''
@@ -30,7 +31,7 @@ export const useMenuStore = defineStore('menuStore', {
this.menuArr[1].children?.unshift(item)
},
updatasessionTitle(session_id: string, title: string) {
this.menuArr[1].children?.forEach(item => {
this.menuArr[1].children?.forEach((item: any) => {
if (item.id == session_id) {
item.title = title;
item.isNoTitle = false;

View File

@@ -10,19 +10,19 @@ export function generateRandomString(length: number) {
return result;
}
export function formatStringDate(date) {
export function formatStringDate(date: any) {
let data = new Date(date);
let year = data.getFullYear();
let month = data.getMonth() + 1;
let day = data.getDate();
let hour = data.getHours();
let minute = data.getMinutes();
let second = data.getSeconds();
let month = String(data.getMonth() + 1).padStart(2, '0');
let day = String(data.getDate()).padStart(2, '0');
let hour = String(data.getHours()).padStart(2, '0');
let minute = String(data.getMinutes()).padStart(2, '0');
let second = String(data.getSeconds()).padStart(2, '0');
return (
year + "-" + month + "-" + day + " " + hour + ":" + minute + ":" + second
);
}
export function kbFileTypeVerification(file) {
export function kbFileTypeVerification(file: any) {
let validTypes = ["pdf", "txt", "md", "docx", "doc", "jpg", "jpeg", "png"];
let type = file.name.substring(file.name.lastIndexOf(".") + 1);
if (!validTypes.includes(type)) {

View File

@@ -2,40 +2,9 @@
import axios from "axios";
import { generateRandomString } from "./index";
// 从localStorage获取设置
function getSettings() {
const settingsStr = localStorage.getItem("WeKnora_settings");
if (settingsStr) {
try {
return JSON.parse(settingsStr);
} catch (e) {
console.error("解析设置失败:", e);
}
}
return {
endpoint: import.meta.env.VITE_IS_DOCKER ? "" : "http://localhost:8080",
apiKey: "",
knowledgeBaseId: "",
};
}
// API基础URL
const BASE_URL = import.meta.env.VITE_IS_DOCKER ? "" : "http://localhost:8080";
// API基础URL优先使用设置中的endpoint
const settings = getSettings();
const BASE_URL = settings.endpoint;
// 测试数据
let testData: {
tenant: {
id: number;
name: string;
api_key: string;
};
knowledge_bases: Array<{
id: string;
name: string;
description: string;
}>;
} | null = null;
// 创建Axios实例
const instance = axios.create({
@@ -47,44 +16,41 @@ const instance = axios.create({
},
});
// 设置测试数据
export function setTestData(data: typeof testData) {
testData = data;
if (data) {
// 优先使用设置中的ApiKey如果没有则使用测试数据中的
const apiKey = settings.apiKey || (data?.tenant?.api_key || "");
if (apiKey) {
instance.defaults.headers["X-API-Key"] = apiKey;
}
}
}
// 获取测试数据
export function getTestData() {
return testData;
}
instance.interceptors.request.use(
(config) => {
// 每次请求前检查是否有更新的设置
const currentSettings = getSettings();
// 更新BaseURL (如果有变化)
if (currentSettings.endpoint && config.baseURL !== currentSettings.endpoint) {
config.baseURL = currentSettings.endpoint;
}
// 更新API Key (如果有)
if (currentSettings.apiKey) {
config.headers["X-API-Key"] = currentSettings.apiKey;
// 添加JWT token认证
const token = localStorage.getItem('weknora_token');
if (token) {
config.headers["Authorization"] = `Bearer ${token}`;
}
config.headers["X-Request-ID"] = `${generateRandomString(12)}`;
return config;
},
(error) => {}
(error) => {
return Promise.reject(error);
}
);
// Token刷新标志防止多个请求同时刷新token
let isRefreshing = false;
let failedQueue: Array<{ resolve: Function; reject: Function }> = [];
let hasRedirectedOn401 = false;
// 处理队列中的请求
const processQueue = (error: any, token: string | null = null) => {
failedQueue.forEach(({ resolve, reject }) => {
if (error) {
reject(error);
} else {
resolve(token);
}
});
failedQueue = [];
};
instance.interceptors.response.use(
(response) => {
// 根据业务状态码处理逻辑
@@ -95,12 +61,98 @@ instance.interceptors.response.use(
return Promise.reject(data);
}
},
(error: any) => {
async (error: any) => {
const originalRequest = error.config;
if (!error.response) {
return Promise.reject({ message: "网络错误,请检查您的网络连接" });
}
const { data } = error.response;
return Promise.reject(data);
// 如果是登录接口的401直接返回错误以便页面展示toast不做跳转
if (error.response.status === 401 && originalRequest?.url?.includes('/auth/login')) {
const { status, data } = error.response;
return Promise.reject({ status, message: (typeof data === 'object' ? data?.message : data) || '用户名或密码错误' });
}
// 如果是401错误且不是刷新token的请求尝试刷新token
if (error.response.status === 401 && !originalRequest._retry && !originalRequest.url?.includes('/auth/refresh')) {
if (isRefreshing) {
// 如果正在刷新token将请求加入队列
return new Promise((resolve, reject) => {
failedQueue.push({ resolve, reject });
}).then(token => {
originalRequest.headers['Authorization'] = 'Bearer ' + token;
return instance(originalRequest);
}).catch(err => {
return Promise.reject(err);
});
}
originalRequest._retry = true;
isRefreshing = true;
const refreshToken = localStorage.getItem('weknora_refresh_token');
if (refreshToken) {
try {
// 动态导入refresh token API
const { refreshToken: refreshTokenAPI } = await import('../api/auth/index');
const response = await refreshTokenAPI(refreshToken);
if (response.success && response.data) {
const { token, refreshToken: newRefreshToken } = response.data;
// 更新localStorage中的token
localStorage.setItem('weknora_token', token);
localStorage.setItem('weknora_refresh_token', newRefreshToken);
// 更新请求头
originalRequest.headers['Authorization'] = 'Bearer ' + token;
// 处理队列中的请求
processQueue(null, token);
return instance(originalRequest);
} else {
throw new Error(response.message || 'Token刷新失败');
}
} catch (refreshError) {
// 刷新失败清除所有token并跳转到登录页
localStorage.removeItem('weknora_token');
localStorage.removeItem('weknora_refresh_token');
localStorage.removeItem('weknora_user');
localStorage.removeItem('weknora_tenant');
processQueue(refreshError, null);
// 跳转到登录页
if (!hasRedirectedOn401 && typeof window !== 'undefined') {
hasRedirectedOn401 = true;
window.location.href = '/login';
}
return Promise.reject(refreshError);
} finally {
isRefreshing = false;
}
} else {
// 没有refresh token直接跳转到登录页
localStorage.removeItem('weknora_token');
localStorage.removeItem('weknora_user');
localStorage.removeItem('weknora_tenant');
if (!hasRedirectedOn401 && typeof window !== 'undefined') {
hasRedirectedOn401 = true;
window.location.href = '/login';
}
return Promise.reject({ message: '请重新登录' });
}
}
const { status, data } = error.response;
// 将HTTP状态码一并抛出方便上层判断401等场景
return Promise.reject({ status, ...(typeof data === 'object' ? data : { message: data }) });
}
);

View File

@@ -0,0 +1,207 @@
/**
* 安全工具类 - 防止 XSS 攻击
*/
import DOMPurify from 'dompurify';
// 配置 DOMPurify 的安全策略
const DOMPurifyConfig = {
// 允许的标签
ALLOWED_TAGS: [
'p', 'br', 'strong', 'em', 'u', 's', 'del', 'ins',
'h1', 'h2', 'h3', 'h4', 'h5', 'h6',
'ul', 'ol', 'li', 'blockquote', 'pre', 'code',
'a', 'img', 'table', 'thead', 'tbody', 'tr', 'th', 'td',
'div', 'span', 'figure', 'figcaption', 'think'
],
// 允许的属性
ALLOWED_ATTR: [
'href', 'title', 'alt', 'src', 'class', 'id', 'style',
'target', 'rel', 'width', 'height'
],
// 允许的协议
ALLOWED_URI_REGEXP: /^(?:(?:(?:f|ht)tps?|mailto|tel|callto|cid|xmpp):|[^a-z]|[a-z+.\-]+(?:[^a-z+.\-:]|$))/i,
// 禁止的标签和属性
FORBID_TAGS: ['script', 'object', 'embed', 'form', 'input', 'button'],
FORBID_ATTR: ['onerror', 'onload', 'onclick', 'onmouseover', 'onfocus', 'onblur'],
// 其他安全配置
KEEP_CONTENT: true,
RETURN_DOM: false,
RETURN_DOM_FRAGMENT: false,
RETURN_DOM_IMPORT: false,
SANITIZE_DOM: true,
SANITIZE_NAMED_PROPS: true,
WHOLE_DOCUMENT: false,
// 自定义钩子函数
HOOKS: {
// 在清理前处理
beforeSanitizeElements: (currentNode: Element) => {
// 移除所有 script 标签
if (currentNode.tagName === 'SCRIPT') {
currentNode.remove();
return null;
}
// 移除所有事件处理器
const eventAttrs = ['onclick', 'onload', 'onerror', 'onmouseover', 'onfocus', 'onblur'];
eventAttrs.forEach(attr => {
if (currentNode.hasAttribute(attr)) {
currentNode.removeAttribute(attr);
}
});
},
// 在清理后处理
afterSanitizeElements: (currentNode: Element) => {
// 确保所有链接都有 rel="noopener noreferrer"
if (currentNode.tagName === 'A') {
const href = currentNode.getAttribute('href');
if (href && href.startsWith('http')) {
currentNode.setAttribute('rel', 'noopener noreferrer');
currentNode.setAttribute('target', '_blank');
}
}
// 确保所有图片都有 alt 属性
if (currentNode.tagName === 'IMG') {
if (!currentNode.getAttribute('alt')) {
currentNode.setAttribute('alt', '');
}
}
}
}
};
/**
* 安全地清理 HTML 内容
* @param html 需要清理的 HTML 字符串
* @returns 清理后的安全 HTML 字符串
*/
export function sanitizeHTML(html: string): string {
if (!html || typeof html !== 'string') {
return '';
}
try {
return DOMPurify.sanitize(html, DOMPurifyConfig);
} catch (error) {
console.error('HTML sanitization failed:', error);
// 如果清理失败,返回转义的纯文本
return escapeHTML(html);
}
}
/**
* 转义 HTML 特殊字符
* @param text 需要转义的文本
* @returns 转义后的文本
*/
export function escapeHTML(text: string): string {
if (!text || typeof text !== 'string') {
return '';
}
const map: { [key: string]: string } = {
'&': '&amp;',
'<': '&lt;',
'>': '&gt;',
'"': '&quot;',
"'": '&#x27;',
'/': '&#x2F;',
'`': '&#x60;',
'=': '&#x3D;'
};
return text.replace(/[&<>"'`=\/]/g, (s) => map[s]);
}
/**
* 验证 URL 是否安全
* @param url 需要验证的 URL
* @returns 是否为安全 URL
*/
export function isValidURL(url: string): boolean {
if (!url || typeof url !== 'string') {
return false;
}
try {
const urlObj = new URL(url);
// 只允许 http 和 https 协议
return ['http:', 'https:'].includes(urlObj.protocol);
} catch {
return false;
}
}
/**
* 安全地处理 Markdown 内容
* @param markdown Markdown 文本
* @returns 安全的 HTML 字符串
*/
export function safeMarkdownToHTML(markdown: string): string {
if (!markdown || typeof markdown !== 'string') {
return '';
}
// 首先转义可能的 HTML 标签
const escapedMarkdown = markdown
.replace(/<script\b[^<]*(?:(?!<\/script>)<[^<]*)*<\/script>/gi, '')
.replace(/<iframe\b[^<]*(?:(?!<\/iframe>)<[^<]*)*<\/iframe>/gi, '')
.replace(/<object\b[^<]*(?:(?!<\/object>)<[^<]*)*<\/object>/gi, '')
.replace(/<embed\b[^<]*(?:(?!<\/embed>)<[^<]*)*<\/embed>/gi, '');
return escapedMarkdown;
}
/**
* 清理用户输入
* @param input 用户输入
* @returns 清理后的安全输入
*/
export function sanitizeUserInput(input: string): string {
if (!input || typeof input !== 'string') {
return '';
}
// 移除控制字符
let cleaned = input.replace(/[\x00-\x1F\x7F-\x9F]/g, '');
// 限制长度
if (cleaned.length > 10000) {
cleaned = cleaned.substring(0, 10000);
}
return cleaned.trim();
}
/**
* 验证图片 URL 是否安全
* @param url 图片 URL
* @returns 是否为安全的图片 URL
*/
export function isValidImageURL(url: string): boolean {
if (!isValidURL(url)) {
return false;
}
// 检查是否为图片文件
const imageExtensions = /\.(jpg|jpeg|png|gif|webp|svg|bmp|ico)(\?.*)?$/i;
return imageExtensions.test(url);
}
/**
* 创建安全的图片元素
* @param src 图片源
* @param alt 替代文本
* @param title 标题
* @returns 安全的图片 HTML
*/
export function createSafeImage(src: string, alt: string = '', title: string = ''): string {
if (!isValidImageURL(src)) {
return '';
}
const safeSrc = escapeHTML(src);
const safeAlt = escapeHTML(alt);
const safeTitle = escapeHTML(title);
return `<img src="${safeSrc}" alt="${safeAlt}" title="${safeTitle}" class="markdown-image" style="max-width: 100%; height: auto;">`;
}

View File

@@ -0,0 +1,553 @@
<template>
<div class="login-container">
<!-- 登录表单 -->
<div class="login-card" v-if="!isRegisterMode">
<!-- 系统Logo和标题 -->
<div class="login-header">
<div class="logo">
<img src="@/assets/img/weknora.png" alt="WeKnora" class="logo-img" />
</div>
<p class="login-subtitle">基于大模型的文档理解与语义检索框架</p>
</div>
<div class="login-form">
<t-form
ref="formRef"
:data="formData"
:rules="formRules"
@submit="handleLogin"
layout="vertical"
>
<t-form-item label="邮箱" name="email">
<t-input
v-model="formData.email"
placeholder="请输入邮箱地址"
type="email"
size="large"
:disabled="loading"
/>
</t-form-item>
<t-form-item label="密码" name="password">
<t-input
v-model="formData.password"
placeholder="请输入密码8-32位包含字母和数字"
type="password"
size="large"
:disabled="loading"
@keydown.enter="handleLogin"
/>
</t-form-item>
<t-button
type="submit"
theme="primary"
size="large"
block
:loading="loading"
class="login-button"
>
{{ loading ? '登录中...' : '登录' }}
</t-button>
</t-form>
<!-- 注册链接 -->
<div class="register-link">
<span>还没有账号</span>
<a href="#" @click.prevent="toggleMode" class="register-btn">
立即注册
</a>
</div>
</div>
</div>
<!-- 注册表单 -->
<div class="register-card" v-if="isRegisterMode">
<div class="login-header">
<h1 class="login-title">创建账号</h1>
<p class="login-subtitle">注册后系统将为您创建专属租户</p>
</div>
<div class="login-form">
<t-form
ref="registerFormRef"
:data="registerData"
:rules="registerRules"
@submit="handleRegister"
layout="vertical"
>
<t-form-item label="用户名" name="username">
<t-input
v-model="registerData.username"
placeholder="请输入用户名"
size="large"
:disabled="loading"
/>
</t-form-item>
<t-form-item label="邮箱" name="email">
<t-input
v-model="registerData.email"
placeholder="请输入邮箱地址"
type="email"
size="large"
:disabled="loading"
/>
</t-form-item>
<t-form-item label="密码" name="password">
<t-input
v-model="registerData.password"
placeholder="请输入密码8-32位包含字母和数字"
type="password"
size="large"
:disabled="loading"
/>
</t-form-item>
<t-form-item label="确认密码" name="confirmPassword">
<t-input
v-model="registerData.confirmPassword"
placeholder="请再次输入密码"
type="password"
size="large"
:disabled="loading"
@keydown.enter="handleRegister"
/>
</t-form-item>
<t-button
type="submit"
theme="primary"
size="large"
block
:loading="loading"
class="login-button"
>
{{ loading ? '注册中...' : '注册' }}
</t-button>
</t-form>
<!-- 返回登录 -->
<div class="register-link">
<span>已有账号</span>
<a href="#" @click.prevent="toggleMode" class="register-btn">
返回登录
</a>
</div>
</div>
</div>
</div>
</template>
<script setup lang="ts">
import { ref, reactive, computed, nextTick, onMounted } from 'vue'
import { useRouter } from 'vue-router'
import { MessagePlugin } from 'tdesign-vue-next'
import { login, register } from '@/api/auth'
import { useAuthStore } from '@/stores/auth'
const router = useRouter()
const authStore = useAuthStore()
// 表单引用
const formRef = ref()
const registerFormRef = ref()
// 状态管理
const loading = ref(false)
const isRegisterMode = ref(false)
// 登录表单数据
const formData = reactive<{[key: string]: any}>({
email: '',
password: '',
})
// 注册表单数据
const registerData = reactive<{[key: string]: any}>({
username: '',
email: '',
password: '',
confirmPassword: ''
})
// 登录表单验证规则
const formRules = {
email: [
{ required: true, message: '请输入邮箱地址', type: 'error' },
{ email: true, message: '请输入正确的邮箱格式', type: 'error' }
],
password: [
{ required: true, message: '请输入密码', type: 'error' },
{ min: 8, message: '密码至少8位', type: 'error' },
{ max: 32, message: '密码不能超过32位', type: 'error' },
{ pattern: /[a-zA-Z]/, message: '密码必须包含字母', type: 'error' },
{ pattern: /\d/, message: '密码必须包含数字', type: 'error' }
]
}
// 注册表单验证规则
const registerRules = {
username: [
{ required: true, message: '请输入用户名', type: 'error' },
{ min: 2, message: '用户名至少2位', type: 'error' },
{ max: 20, message: '用户名不能超过20位', type: 'error' },
{
pattern: /^[a-zA-Z0-9_\u4e00-\u9fa5]+$/,
message: '用户名只能包含字母、数字、下划线和中文',
type: 'error'
}
],
email: [
{ required: true, message: '请输入邮箱地址', type: 'error' },
{ email: true, message: '请输入正确的邮箱格式', type: 'error' }
],
password: [
{ required: true, message: '请输入密码', type: 'error' },
{ min: 8, message: '密码至少8位', type: 'error' },
{ max: 32, message: '密码不能超过32位', type: 'error' },
{ pattern: /[a-zA-Z]/, message: '密码必须包含字母', type: 'error' },
{ pattern: /\d/, message: '密码必须包含数字', type: 'error' }
],
confirmPassword: [
{ required: true, message: '请确认密码', type: 'error' },
{
validator: (val: string) => val === registerData.password,
message: '两次输入的密码不一致',
type: 'error'
}
]
}
// 切换登录/注册模式
const toggleMode = () => {
isRegisterMode.value = !isRegisterMode.value
Object.keys(registerData).forEach(key => {
(registerData as any)[key] = ''
})
}
// 处理登录
const handleLogin = async () => {
try {
const valid = await formRef.value?.validate()
if (!valid) return
loading.value = true
const response = await login({
email: formData.email,
password: formData.password,
})
if (response.success) {
// 保存用户信息和token
if (response.user && response.tenant && response.token) {
authStore.setUser({
id: response.user.id || '',
username: response.user.username || '',
email: response.user.email || '',
avatar: response.user.avatar,
tenant_id: String(response.tenant.id) || '',
created_at: response.user.created_at || new Date().toISOString(),
updated_at: response.user.updated_at || new Date().toISOString()
})
authStore.setToken(response.token)
if (response.refresh_token) {
authStore.setRefreshToken(response.refresh_token)
}
authStore.setTenant({
id: String(response.tenant.id) || '',
name: response.tenant.name || '',
api_key: response.tenant.api_key || '',
owner_id: response.user.id || '',
created_at: response.tenant.created_at || new Date().toISOString(),
updated_at: response.tenant.updated_at || new Date().toISOString()
})
}
MessagePlugin.success('登录成功!')
// 等待状态更新完成后再跳转
await nextTick()
router.replace('/platform/knowledge-bases')
} else {
MessagePlugin.error(response.message || '登录失败,请检查邮箱或密码')
}
} catch (error: any) {
console.error('登录错误:', error)
MessagePlugin.error(error.message || '登录失败,请稍后重试')
} finally {
loading.value = false
}
}
// 处理注册
const handleRegister = async () => {
try {
const valid = await registerFormRef.value?.validate()
if (!valid) return
loading.value = true
const response = await register({
username: registerData.username,
email: registerData.email,
password: registerData.password
})
if (response.success) {
MessagePlugin.success('注册成功!系统已为您创建专属租户,请登录使用')
// 切换到登录模式并填入邮箱
isRegisterMode.value = false
formData.email = registerData.email
// 清空注册表单
Object.keys(registerData).forEach(key => {
(registerData as any)[key] = ''
})
} else {
MessagePlugin.error(response.message || '注册失败')
}
} catch (error: any) {
console.error('注册错误:', error)
MessagePlugin.error(error.message || '注册失败,请稍后重试')
} finally {
loading.value = false
}
}
// 处理忘记密码
const handleForgotPassword = () => {
MessagePlugin.info('忘记密码功能暂未开放,请联系管理员')
}
// 检查是否已登录
onMounted(() => {
if (authStore.isLoggedIn) {
router.replace('/platform/tenant/knowledge-bases')
}
})
</script>
<style lang="less" scoped>
.login-container {
min-height: 100vh;
display: flex;
align-items: center;
justify-content: center;
background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
padding: 20px;
box-sizing: border-box;
}
.login-card,
.register-card {
width: 100%;
max-width: 440px;
background: #fff;
border-radius: 14px;
box-shadow: 0 10px 16px 0 #0000000f, 0 20px 24px -2px #0000001a;
padding: 40px;
box-sizing: border-box;
animation: fadeInUp .28s ease-out both;
}
.login-header {
text-align: center;
margin-bottom: 32px;
.logo {
margin-bottom: 16px;
.logo-img {
width: 180px;
height: auto;
border-radius: 12px;
}
}
.login-title {
font-size: 28px;
font-weight: 600;
color: #000000e6;
margin: 0 0 8px 0;
font-family: "PingFang SC";
}
.login-subtitle {
font-size: 16px;
color: #0000008c;
margin: 0;
font-family: "PingFang SC";
}
}
.login-form {
:deep(.t-form-item__label) {
font-size: 14px;
color: #000000e6;
font-weight: 500;
margin-bottom: 8px;
font-family: "PingFang SC";
display: block;
text-align: left;
}
:deep(.t-input) {
border: 1px solid #E7E7E7;
border-radius: 8px;
background: #fff;
&:focus-within {
border-color: #07C05F;
box-shadow: 0 0 0 2px rgba(7, 192, 95, 0.1);
}
&:hover {
border-color: #07C05F;
}
.t-input__inner {
border: none !important;
box-shadow: none !important;
outline: none !important;
background: transparent;
font-size: 16px;
font-family: "PingFang SC";
&:focus {
border: none !important;
box-shadow: none !important;
outline: none !important;
}
}
.t-input__wrap {
border: none !important;
box-shadow: none !important;
}
}
:deep(.t-form-item) {
margin-bottom: 20px;
&:last-child {
margin-bottom: 0;
}
}
:deep(.t-form-item__control) {
width: 100%;
}
}
.login-options {
display: flex;
justify-content: space-between;
align-items: center;
margin: 16px 0 24px 0;
width: 100%;
:deep(.t-checkbox) {
display: flex;
align-items: center;
.t-checkbox__input {
margin-right: 8px;
}
}
:deep(.t-checkbox__label) {
font-size: 14px;
color: #00000099;
font-family: "PingFang SC";
line-height: 1.4;
margin-left: 0;
}
.forgot-password {
font-size: 14px;
color: #07C05F;
text-decoration: none;
font-family: "PingFang SC";
line-height: 1.4;
&:hover {
text-decoration: underline;
}
}
}
.login-button {
height: 48px;
border-radius: 8px;
font-size: 16px;
font-weight: 500;
font-family: "PingFang SC";
margin: 16px 0 8px 0;
:deep(.t-button) {
background-color: #07C05F;
border-color: #07C05F;
&:hover {
background-color: #06a855;
border-color: #06a855;
}
}
}
.register-link {
text-align: center;
font-size: 14px;
color: #00000099;
font-family: "PingFang SC";
.register-btn {
color: #07C05F;
text-decoration: none;
margin-left: 4px;
&:hover {
text-decoration: underline;
}
}
}
// 响应式设计
@media (max-width: 480px) {
.login-container {
padding: 16px;
}
.login-card,
.register-card {
padding: 28px;
}
.login-header {
margin-bottom: 24px;
.login-title {
font-size: 24px;
}
}
}
@keyframes fadeInUp {
from {
opacity: 0;
transform: translate3d(0, 6px, 0);
}
to {
opacity: 1;
transform: translate3d(0, 0, 0);
}
}
</style>

View File

@@ -23,6 +23,8 @@ import { marked } from 'marked';
import docInfo from './docInfo.vue';
import deepThink from './deepThink.vue';
import picturePreview from '@/components/picture-preview.vue';
import { sanitizeHTML, safeMarkdownToHTML, createSafeImage, isValidImageURL } from '@/utils/security';
marked.use({
mangle: false,
headerIds: false,
@@ -89,36 +91,36 @@ const checkImage = (url) => {
img.src = url;
});
};
// 处理 Markdown 中的图片
// 安全地处理 Markdown 内容
const processMarkdown = (markdownText) => {
// 自定义渲染器处理图片
if (!markdownText || typeof markdownText !== 'string') {
return '';
}
// 首先对 Markdown 内容进行安全处理
const safeMarkdown = safeMarkdownToHTML(markdownText);
// 自定义安全的渲染器处理图片
const renderer = {
image(href, title, text) {
return `<img src="${href}" alt="${text}" title="${title || ''}" class="markdown-image" style="max-width: 708px;height: 230px;">`;
// 验证图片 URL 是否安全
if (!isValidImageURL(href)) {
return `<p>无效的图片链接</p>`;
}
// 使用安全的图片创建函数
return createSafeImage(href, text || '', title || '');
}
};
marked.use({ renderer });
// 第一次渲染
let html = marked.parse(markdownText);
// 安全地渲染 Markdown
let html = marked.parse(safeMarkdown);
// 创建虚拟 DOM 来操作
const parser = new DOMParser();
const doc = parser.parseFromString(html, 'text/html');
// 检查所有图片
// const images = doc.querySelectorAll('img');
// images.forEach(async item => {
// const isValid = await checkImage(item.src);
// if (!isValid) {
// item.remove();
// }
// });
// if (props.isFirstEnter) {
// emit('scroll-bottom')
// }
return doc.body.innerHTML;
// 使用 DOMPurify 进行最终的安全清理
const sanitizedHTML = sanitizeHTML(html);
return sanitizedHTML;
};
const handleImg = async (newVal) => {
let index = newVal.lastIndexOf('![');

View File

@@ -19,7 +19,7 @@
</div>
</template>
<div class="content">
<span v-html="deepSession.thinkContent.replace(/\n/g, '<br/>')"></span>
<span v-html="safeProcessThinkContent(deepSession.thinkContent)"></span>
</div>
</t-collapse-panel>
@@ -29,6 +29,7 @@
</template>
<script setup>
import { onMounted, watch, computed, ref, reactive, defineProps } from 'vue';
import { sanitizeHTML } from '@/utils/security';
const isFold = ref(true)
const props = defineProps({
@@ -51,8 +52,20 @@ const showHide = () => {
}
const handlePanelChange = (val) => {
isFold.value = !val.length ? true : false;
}
// 安全地处理思考内容防止XSS攻击
const safeProcessThinkContent = (content) => {
if (!content || typeof content !== 'string') return '';
// 先处理换行符
const contentWithBreaks = content.replace(/\n/g, '<br/>');
// 使用DOMPurify进行安全清理允许基本的文本格式化标签
const cleanContent = sanitizeHTML(contentWithBreaks);
return cleanContent;
};
</script>
<style lang="less" scoped>
.deep-think {

View File

@@ -15,7 +15,7 @@
trigger="click">
<template #content>
<div class="doc_content">
<div v-html="item.content.replace(/\n/g, '<br/>')"></div>
<div v-html="safeProcessContent(item.content)"></div>
</div>
</template>
<span class="doc">
@@ -28,6 +28,7 @@
</template>
<script setup>
import { onMounted, defineProps, computed, ref, reactive } from "vue";
import { sanitizeHTML } from '@/utils/security';
const props = defineProps({
// 必填项
content: {
@@ -44,6 +45,14 @@ const referBoxSwitch = () => {
showReferBox.value = !showReferBox.value;
};
// 安全地处理内容
const safeProcessContent = (content) => {
if (!content) return '';
// 先进行安全清理,然后处理换行
const sanitized = sanitizeHTML(content);
return sanitized.replace(/\n/g, '<br/>');
};
</script>
<style lang="less" scoped>
.refer {

View File

@@ -38,6 +38,7 @@ const { output, onChunk, isStreaming, isLoading, error, startStream, stopStream
const route = useRoute();
const router = useRouter();
const session_id = ref(route.params.chatid);
const knowledge_base_id = ref(route.params.kbId);
const created_at = ref('');
const limit = ref(20);
const messagesList = reactive([]);
@@ -57,6 +58,7 @@ watch([() => route.params], (newvalue) => {
}
messagesList.splice(0);
session_id.value = newvalue[0].chatid;
knowledge_base_id.value = newvalue[0].kbId;
checkmenuTitle(session_id.value)
let data = {
session_id: session_id.value,
@@ -154,7 +156,14 @@ const sendMsg = async (value) => {
loading.value = true;
messagesList.push({ content: value, role: 'user' });
scrollToBottom();
await startStream({ session_id: session_id.value, query: value, method: 'POST', url: '/api/v1/knowledge-chat' });
await startStream({
session_id: session_id.value,
knowledge_base_id: knowledge_base_id.value,
query: value,
method: 'POST',
url: '/api/v1/knowledge-chat'
});
}
// 处理流式数据

View File

@@ -1,5 +1,5 @@
<template>
<div v-show="cardList.length" class="dialogue-wrap">
<div class="dialogue-wrap">
<div class="dialogue-answers">
<div class="dialogue-title">
<span>基于知识库内容问答</span>
@@ -7,7 +7,17 @@
<InputField @send-msg="sendMsg"></InputField>
</div>
</div>
<EmptyKnowledge v-show="!cardList.length"></EmptyKnowledge>
<t-dialog v-model:visible="selectVisible" header="选择知识库" :confirmBtn="{ content: '开始对话', theme: 'primary' }" :onConfirm="confirmSelect" :onCancel="() => selectVisible = false">
<t-form :data="{ kb: selectedKbId }">
<t-form-item label="知识库">
<t-select v-model="selectedKbId" :loading="kbLoading" placeholder="请选择知识库">
<t-option v-for="kb in kbList" :key="kb.id" :value="kb.id" :label="kb.name" />
</t-select>
</t-form-item>
</t-form>
</t-dialog>
</template>
<script setup lang="ts">
import { ref, onUnmounted, watch } from 'vue';
@@ -17,55 +27,52 @@ import { getSessionsList, createSessions, generateSessionsTitle } from "@/api/ch
import { useMenuStore } from '@/stores/menu';
import { useRoute, useRouter } from 'vue-router';
import useKnowledgeBase from '@/hooks/useKnowledgeBase';
import { getTestData } from '@/utils/request';
import { listKnowledgeBases } from '@/api/knowledge-base';
let { cardList } = useKnowledgeBase()
const router = useRouter();
const route = useRoute();
const usemenuStore = useMenuStore();
const sendMsg = (value: string) => {
createNewSession(value);
}
const selectVisible = ref(false)
const selectedKbId = ref<string>('')
const kbList = ref<Array<{ id: string; name: string }>>([])
const kbLoading = ref(false)
const ensureKbId = async (): Promise<string | null> => {
// 1) 优先使用当前路由上下文(如果来自某个知识库详情页)
const routeKb = (route.params as any)?.kbId as string
if (routeKb) return routeKb
// 3) 弹窗选择知识库(从接口拉取)
kbLoading.value = true
try {
const res: any = await listKnowledgeBases()
kbList.value = res?.data || []
if (kbList.value.length === 0) return null
selectedKbId.value = kbList.value[0].id
selectVisible.value = true
return null
} finally {
kbLoading.value = false
}
}
async function createNewSession(value: string) {
// 从localStorage获取设置中的知识库ID
const settingsStr = localStorage.getItem("WeKnora_settings");
let knowledgeBaseId = "";
if (settingsStr) {
try {
const settings = JSON.parse(settingsStr);
if (settings.knowledgeBaseId) {
knowledgeBaseId = settings.knowledgeBaseId;
createSessions({ knowledge_base_id: knowledgeBaseId }).then(res => {
if (res.data && res.data.id) {
getTitle(res.data.id, value);
} else {
// 错误处理
console.error("创建会话失败");
}
}).catch(error => {
console.error("创建会话出错:", error);
});
return;
}
} catch (e) {
console.error("解析设置失败:", e);
}
}
// 如果设置中没有知识库ID则使用测试数据
const testData = getTestData();
if (!testData || testData.knowledge_bases.length === 0) {
console.error("测试数据未初始化或不包含知识库");
return;
let knowledgeBaseId = await ensureKbId()
if (!knowledgeBaseId) {
// 等待用户在弹窗中选择
pendingValue.value = value
return
}
// 使用第一个知识库ID
knowledgeBaseId = testData.knowledge_bases[0].id;
createSessions({ knowledge_base_id: knowledgeBaseId }).then(res => {
createSessions({ knowledge_base_id: knowledgeBaseId }).then(async res => {
if (res.data && res.data.id) {
getTitle(res.data.id, value)
await getTitle(res.data.id, value)
} else {
// 错误处理
console.error("创建会话失败");
@@ -75,12 +82,33 @@ async function createNewSession(value: string) {
})
}
const getTitle = (session_id: string, value: string) => {
let obj = { title: '新会话', path: `chat/${session_id}`, id: session_id, isMore: false, isNoTitle: true }
const pendingValue = ref<string>('')
const confirmSelect = async () => {
if (!selectedKbId.value) return
const value = pendingValue.value
pendingValue.value = ''
selectVisible.value = false
createSessions({ knowledge_base_id: selectedKbId.value }).then(async res => {
if (res.data && res.data.id) {
await getTitle(res.data.id, value, selectedKbId.value)
} else {
console.error('创建会话失败')
}
}).catch((e:any) => console.error('创建会话出错:', e))
}
const getTitle = async (session_id: string, value: string, kbId?: string) => {
const finalKbId = kbId || await ensureKbId();
if (!finalKbId) {
console.error('无法获取知识库ID');
return;
}
let obj = { title: '新会话', path: `chat/${finalKbId}/${session_id}`, id: session_id, isMore: false, isNoTitle: true }
usemenuStore.updataMenuChildren(obj);
usemenuStore.changeIsFirstSession(true);
usemenuStore.changeFirstQuery(value);
router.push(`/platform/chat/${session_id}`);
router.push(`/platform/chat/${finalKbId}/${session_id}`);
}
</script>

View File

@@ -1,5 +1,5 @@
<script setup lang="ts">
import { ref, onMounted, watch, reactive } from "vue";
import { ref, onMounted, onUnmounted, watch, reactive, computed } from "vue";
import DocContent from "@/components/doc-content.vue";
import InputField from "@/components/Input-field.vue";
import useKnowledgeBase from '@/hooks/useKnowledgeBase';
@@ -7,17 +7,21 @@ import { useRoute, useRouter } from 'vue-router';
import EmptyKnowledge from '@/components/empty-knowledge.vue';
import { getSessionsList, createSessions, generateSessionsTitle } from "@/api/chat/index";
import { useMenuStore } from '@/stores/menu';
import { getTestData } from '@/utils/request';
import { MessagePlugin } from 'tdesign-vue-next';
const usemenuStore = useMenuStore();
const router = useRouter();
import {
batchQueryKnowledge,
listKnowledgeFiles,
} from "@/api/knowledge-base/index";
let { cardList, total, moreIndex, details, getKnowled, delKnowledge, openMore, onVisibleChange, getCardDetails, getfDetails } = useKnowledgeBase()
import { formatStringDate } from "@/utils/index";
const route = useRoute();
const kbId = computed(() => (route.params as any).kbId as string || '');
let { cardList, total, moreIndex, details, getKnowled, delKnowledge, openMore, onVisibleChange, getCardDetails, getfDetails } = useKnowledgeBase(kbId.value)
let isCardDetails = ref(false);
let timeout = null;
let timeout: ReturnType<typeof setInterval> | null = null;
let delDialog = ref(false)
let knowledge = ref({})
let knowledge = ref<KnowledgeCard>({ id: '', parse_status: '' })
let knowledgeIndex = ref(-1)
let knowledgeScroll = ref()
let page = 1;
@@ -29,29 +33,93 @@ const getPageSize = () => {
pageSize = Math.max(35, itemsInView);
}
getPageSize()
// 直接调用 API 获取知识库文件列表
const loadKnowledgeFiles = async (kbIdValue: string) => {
if (!kbIdValue) return;
try {
const result = await listKnowledgeFiles(kbIdValue, { page: 1, page_size: pageSize });
// 由于响应拦截器已经返回了 data所以 result 就是响应的 data 部分
// 按照 useKnowledgeBase hook 中的方式处理
const { data, total: totalResult } = result as any;
if (!data || !Array.isArray(data)) {
console.error('Invalid data format. Expected array, got:', typeof data, data);
return;
}
const cardList_ = data.map((item: any) => {
item["file_name"] = item.file_name.substring(
0,
item.file_name.lastIndexOf(".")
);
return {
...item,
updated_at: formatStringDate(new Date(item.updated_at)),
isMore: false,
file_type: item.file_type.toLocaleUpperCase(),
};
});
cardList.value = cardList_ as any[];
total.value = totalResult;
} catch (err) {
console.error('Failed to load knowledge files:', err);
}
};
// 监听路由参数变化,重新获取知识库内容
watch(() => kbId.value, (newKbId, oldKbId) => {
if (newKbId && newKbId !== oldKbId) {
loadKnowledgeFiles(newKbId);
}
}, { immediate: false });
// 监听文件上传事件
const handleFileUploaded = (event: CustomEvent) => {
const uploadedKbId = event.detail.kbId;
console.log('接收到文件上传事件上传的知识库ID:', uploadedKbId, '当前知识库ID:', kbId.value);
if (uploadedKbId && uploadedKbId === kbId.value) {
console.log('匹配当前知识库,开始刷新文件列表');
// 如果上传的文件属于当前知识库,使用 loadKnowledgeFiles 刷新文件列表
loadKnowledgeFiles(uploadedKbId);
}
};
onMounted(() => {
getKnowled({ page: 1, page_size: pageSize });
// 监听文件上传事件
window.addEventListener('knowledgeFileUploaded', handleFileUploaded as EventListener);
});
onUnmounted(() => {
window.removeEventListener('knowledgeFileUploaded', handleFileUploaded as EventListener);
});
watch(() => cardList.value, (newValue) => {
let analyzeList = [];
analyzeList = newValue.filter(item => {
return item.parse_status == 'pending' || item.parse_status == 'processing';
})
clearInterval(timeout);
timeout = null;
if (timeout !== null) {
clearInterval(timeout);
timeout = null;
}
if (analyzeList.length) {
updateStatus(analyzeList)
}
}, { deep: true })
const updateStatus = (analyzeList) => {
type KnowledgeCard = { id: string; parse_status: string; description?: string; file_name?: string; updated_at?: string; file_type?: string; isMore?: boolean };
const updateStatus = (analyzeList: KnowledgeCard[]) => {
let query = ``;
for (let i = 0; i < analyzeList.length; i++) {
query += `ids=${analyzeList[i].id}&`;
}
timeout = setInterval(() => {
batchQueryKnowledge(query).then((result) => {
batchQueryKnowledge(query).then((result: any) => {
if (result.success && result.data) {
result.data.forEach(item => {
(result.data as KnowledgeCard[]).forEach((item: KnowledgeCard) => {
if (item.parse_status == 'failed' || item.parse_status == 'completed') {
let index = cardList.value.findIndex(card => card.id == item.id);
if (index != -1) {
@@ -70,12 +138,12 @@ const updateStatus = (analyzeList) => {
const closeDoc = () => {
isCardDetails.value = false;
};
const openCardDetails = (item) => {
const openCardDetails = (item: KnowledgeCard) => {
isCardDetails.value = true;
getCardDetails(item);
};
const delCard = (index, item) => {
const delCard = (index: number, item: KnowledgeCard) => {
knowledgeIndex.value = index;
knowledge.value = item;
delDialog.value = true;
@@ -94,7 +162,7 @@ const handleScroll = () => {
}
}
};
const getDoc = (page) => {
const getDoc = (page: number) => {
getfDetails(details.id, page)
};
@@ -108,51 +176,36 @@ const sendMsg = (value: string) => {
};
const getTitle = (session_id: string, value: string) => {
let obj = { title: '新会话', path: `chat/${session_id}`, id: session_id, isMore: false, isNoTitle: true };
let obj = { title: '新会话', path: `chat/${kbId.value}/${session_id}`, id: session_id, isMore: false, isNoTitle: true };
usemenuStore.updataMenuChildren(obj);
usemenuStore.changeIsFirstSession(true);
usemenuStore.changeFirstQuery(value);
router.push(`/platform/chat/${session_id}`);
router.push(`/platform/chat/${kbId.value}/${session_id}`);
};
async function createNewSession(value: string): Promise<void> {
// 从localStorage获取设置中的知识库ID
const settingsStr = localStorage.getItem("WeKnora_settings");
let knowledgeBaseId = "";
// 优先使用当前页面的知识库ID
let sessionKbId = kbId.value;
if (settingsStr) {
try {
const settings = JSON.parse(settingsStr);
if (settings.knowledgeBaseId) {
knowledgeBaseId = settings.knowledgeBaseId;
createSessions({ knowledge_base_id: knowledgeBaseId }).then(res => {
if (res.data && res.data.id) {
getTitle(res.data.id, value);
} else {
// 错误处理
console.error("创建会话失败");
}
}).catch(error => {
console.error("创建会话出错:", error);
});
return;
// 如果当前页面没有知识库ID尝试从localStorage获取设置中的知识库ID
if (!sessionKbId) {
const settingsStr = localStorage.getItem("WeKnora_settings");
if (settingsStr) {
try {
const settings = JSON.parse(settingsStr);
sessionKbId = settings.knowledgeBaseId;
} catch (e) {
console.error("解析设置失败:", e);
}
} catch (e) {
console.error("解析设置失败:", e);
}
}
// 如果设置中没有知识库ID则使用测试数据
const testData = getTestData();
if (!testData || !testData.knowledge_bases || testData.knowledge_bases.length === 0) {
console.error("测试数据未初始化或不包含知识库");
if (!sessionKbId) {
MessagePlugin.warning("请先选择一个知识库");
return;
}
// 使用第一个知识库ID
knowledgeBaseId = testData.knowledge_bases[0].id;
createSessions({ knowledge_base_id: knowledgeBaseId }).then(res => {
createSessions({ knowledge_base_id: sessionKbId }).then(res => {
if (res.data && res.data.id) {
getTitle(res.data.id, value);
} else {

View File

@@ -0,0 +1,234 @@
<template>
<div class="kb-list-container">
<div class="header">
<h2>知识库</h2>
<t-button theme="primary" @click="openCreate">新建知识库</t-button>
</div>
<!-- 未初始化知识库提示 -->
<div v-if="hasUninitializedKbs" class="warning-banner">
<t-icon name="info-circle" size="16px" />
<span>部分知识库尚未初始化需要先在设置中配置模型信息才能添加知识文档</span>
</div>
<t-table :data="kbs" :columns="columns" row-key="id" size="medium" hover>
<template #status="{ row }">
<div class="status-cell">
<t-tag
:theme="isInitialized(row) ? 'success' : 'warning'"
size="small"
>
{{ isInitialized(row) ? '已初始化' : '未初始化' }}
</t-tag>
<t-tooltip
v-if="!isInitialized(row)"
content="需要先在设置中配置模型信息才能添加知识"
placement="top"
>
<span class="warning-icon"></span>
</t-tooltip>
</div>
</template>
<template #description="{ row }">
<div class="description-text">{{ row.description || '暂无描述' }}</div>
</template>
<template #op="{ row }">
<t-space size="small">
<t-button
size="small"
@click="goDetail(row.id)"
:disabled="!isInitialized(row)"
:theme="isInitialized(row) ? 'primary' : 'default'"
:variant="isInitialized(row) ? 'base' : 'outline'"
:title="!isInitialized(row) ? '请先在设置中配置模型信息' : ''"
>
文档
</t-button>
<t-button size="small" variant="outline" @click="goSettings(row.id)">设置</t-button>
<t-popconfirm content="确认删除该知识库?" @confirm="remove(row.id)">
<t-button size="small" theme="danger" variant="text">删除</t-button>
</t-popconfirm>
</t-space>
</template>
</t-table>
<t-dialog v-model:visible="createVisible" header="新建知识库" :footer="false">
<t-form :data="createForm" @submit="create">
<t-form-item label="名称" name="name" :rules="[{ required: true, message: '请输入名称' }]">
<t-input v-model="createForm.name" />
</t-form-item>
<t-form-item label="描述" name="description">
<t-textarea v-model="createForm.description" />
</t-form-item>
<t-form-item>
<t-space>
<t-button theme="primary" type="submit" :loading="creating">创建</t-button>
<t-button variant="outline" @click="createVisible = false">取消</t-button>
</t-space>
</t-form-item>
</t-form>
</t-dialog>
</div>
</template>
<script setup lang="ts">
import { onMounted, reactive, ref, computed } from 'vue'
import { useRouter } from 'vue-router'
import { MessagePlugin } from 'tdesign-vue-next'
import { listKnowledgeBases, createKnowledgeBase, deleteKnowledgeBase } from '@/api/knowledge-base'
import { formatStringDate } from '@/utils/index'
const router = useRouter()
interface KB {
id: string;
name: string;
description?: string;
updated_at?: string;
embedding_model_id?: string;
summary_model_id?: string;
}
const kbs = ref<KB[]>([])
const loading = ref(false)
const columns = [
{ colKey: 'name', title: '名称' },
{ colKey: 'description', title: '描述', cell: 'description', width: 300 },
{ colKey: 'status', title: '状态', cell: 'status', width: 100 },
{ colKey: 'updated_at', title: '更新时间' },
{ colKey: 'op', title: '操作', cell: 'op', width: 220 },
]
const fetchList = () => {
loading.value = true
listKnowledgeBases().then((res: any) => {
const data = res.data || []
// 格式化时间
kbs.value = data.map((kb: KB) => ({
...kb,
updated_at: kb.updated_at ? formatStringDate(new Date(kb.updated_at)) : ''
}))
}).finally(() => loading.value = false)
}
onMounted(fetchList)
const createVisible = ref(false)
const creating = ref(false)
const createForm = reactive({ name: '', description: '' })
const openCreate = () => {
createForm.name = ''
createForm.description = ''
createVisible.value = true
}
const create = () => {
if (!createForm.name) return
creating.value = true
const chunking_config = {
chunk_size: 512,
chunk_overlap: 100,
separators: ['.', '?', '!', '。', '', ''],
enable_multimodal: false
}
createKnowledgeBase({ name: createForm.name, description: createForm.description, chunking_config }).then((res: any) => {
if (res.success) {
MessagePlugin.success('创建成功')
createVisible.value = false
fetchList()
} else {
MessagePlugin.error(res.message || '创建失败')
}
}).catch((e: any) => {
MessagePlugin.error(e?.message || '创建失败')
}).finally(() => creating.value = false)
}
const remove = (id: string) => {
deleteKnowledgeBase(id).then((res: any) => {
if (res.success) {
MessagePlugin.success('已删除')
fetchList()
} else {
MessagePlugin.error(res.message || '删除失败')
}
}).catch((e: any) => MessagePlugin.error(e?.message || '删除失败'))
}
const isInitialized = (kb: KB) => {
return !!(kb.embedding_model_id && kb.embedding_model_id !== '' &&
kb.summary_model_id && kb.summary_model_id !== '')
}
// 计算是否有未初始化的知识库
const hasUninitializedKbs = computed(() => {
return kbs.value.some(kb => !isInitialized(kb))
})
const goDetail = (id: string) => {
router.push(`/platform/knowledge-bases/${id}`)
}
const goSettings = (id: string) => {
router.push(`/platform/knowledge-bases/${id}/settings`)
}
</script>
<style scoped lang="less">
.kb-list-container {
padding: 20px;
background: #fff;
margin: 0 20px 0 20px;
height: calc(100vh);
overflow-y: auto;
box-sizing: border-box;
flex: 1;
}
.header {
display: flex;
align-items: center;
justify-content: space-between;
margin-bottom: 16px;
h2 { margin: 0; font-size: 20px; font-weight: 600; }
}
.warning-banner {
display: flex;
align-items: center;
gap: 8px;
padding: 12px 16px;
margin-bottom: 16px;
background: #fff7e6;
border: 1px solid #ffd591;
border-radius: 6px;
color: #d46b08;
font-size: 14px;
.t-icon {
color: #d46b08;
flex-shrink: 0;
}
}
.status-cell {
display: flex;
align-items: center;
gap: 8px;
.warning-icon {
color: #ff8800;
cursor: pointer;
font-size: 16px;
font-weight: bold;
transition: color 0.2s;
&:hover {
color: #d46b08;
}
}
}
.description-cell {
.description-text {
color: #000000e6;
font-size: 14px;
}
}
</style>

View File

@@ -1,5 +1,5 @@
<template>
<div class="main" ref="dropzone" @dragover="dragover" @drop="drop" @dragstart="dragstart">
<div class="main" ref="dropzone">
<Menu></Menu>
<RouterView />
<div class="upload-mask" v-show="ismask">
@@ -10,42 +10,104 @@
</template>
<script setup lang="ts">
import Menu from '@/components/menu.vue'
import { ref } from 'vue';
import { useRouter } from 'vue-router'
import { storeToRefs } from "pinia";
import { knowledgeStore } from "@/stores/knowledge";
const usemenuStore = knowledgeStore();
import { ref, onMounted, onUnmounted } from 'vue';
import { useRoute } from 'vue-router'
import useKnowledgeBase from '@/hooks/useKnowledgeBase'
import UploadMask from '@/components/upload-mask.vue'
import { getKnowledgeBaseById } from '@/api/knowledge-base/index'
import { MessagePlugin } from 'tdesign-vue-next'
let { requestMethod } = useKnowledgeBase()
const router = useRouter();
const route = useRoute();
let ismask = ref(false)
let dropzone = ref();
let uploadInput = ref();
const dragover = (event) => {
event.preventDefault();
ismask.value = true;
if (((window.innerWidth - event.clientX) < 50) || ((window.innerHeight - event.clientY) < 50) || event.clientX < 50 || event.clientY < 50) {
ismask.value = false
}
// 获取当前知识库ID
const getCurrentKbId = (): string | null => {
return (route.params as any)?.kbId as string || null
}
const drop = (event) => {
event.preventDefault();
ismask.value = false
const DataTransferItemList = event.dataTransfer.items;
for (const dataTransferItem of DataTransferItemList) {
const fileEntry = dataTransferItem.webkitGetAsEntry();
if (fileEntry) {
fileEntry.file((file: file) => {
requestMethod(file, uploadInput)
router.push('/platform/knowledgeBase?upload=true')
})
// 检查知识库初始化状态
const checkKnowledgeBaseInitialization = async (): Promise<boolean> => {
const currentKbId = getCurrentKbId();
if (!currentKbId) {
MessagePlugin.error("缺少知识库ID");
return false;
}
try {
const kbResponse = await getKnowledgeBaseById(currentKbId);
const kb = kbResponse.data;
if (!kb.embedding_model_id || !kb.summary_model_id) {
MessagePlugin.warning("该知识库尚未完成初始化配置,请先前往设置页面配置模型信息后再上传文件");
return false;
}
return true;
} catch (error) {
MessagePlugin.error("获取知识库信息失败,无法上传文件");
return false;
}
}
const dragstart = (event) => {
// 全局拖拽事件处理
const handleGlobalDragEnter = (event: DragEvent) => {
event.preventDefault();
if (event.dataTransfer) {
event.dataTransfer.effectAllowed = 'all';
}
ismask.value = true;
}
const handleGlobalDragOver = (event: DragEvent) => {
event.preventDefault();
if (event.dataTransfer) {
event.dataTransfer.dropEffect = 'copy';
}
ismask.value = true;
}
const handleGlobalDrop = async (event: DragEvent) => {
event.preventDefault();
ismask.value = false;
const DataTransferFiles = event.dataTransfer?.files ? Array.from(event.dataTransfer.files) : [];
const DataTransferItemList = event.dataTransfer?.items ? Array.from(event.dataTransfer.items) : [];
const isInitialized = await checkKnowledgeBaseInitialization();
if (!isInitialized) {
return;
}
if (DataTransferFiles.length > 0) {
DataTransferFiles.forEach(file => requestMethod(file, uploadInput));
} else if (DataTransferItemList.length > 0) {
DataTransferItemList.forEach(dataTransferItem => {
const fileEntry = dataTransferItem.webkitGetAsEntry() as FileSystemFileEntry | null;
if (fileEntry) {
fileEntry.file((file: File) => requestMethod(file, uploadInput));
}
});
} else {
MessagePlugin.warning('请拖拽文件而不是文本或链接');
}
}
// 组件挂载时添加全局事件监听器
onMounted(() => {
document.addEventListener('dragenter', handleGlobalDragEnter, true);
document.addEventListener('dragover', handleGlobalDragOver, true);
document.addEventListener('drop', handleGlobalDrop, true);
});
// 组件卸载时移除全局事件监听器
onUnmounted(() => {
document.removeEventListener('dragenter', handleGlobalDragEnter, true);
document.removeEventListener('dragover', handleGlobalDragOver, true);
document.removeEventListener('drop', handleGlobalDrop, true);
});
</script>
<style lang="less">
.main {

View File

@@ -0,0 +1,211 @@
<template>
<div class="system-settings-container">
<!-- 页面标题区域 -->
<div class="settings-header">
<h2>{{ isKbSettings ? '知识库设置' : '系统设置' }}</h2>
<p class="settings-subtitle">{{ isKbSettings ? '配置该知识库的模型与文档切分参数' : '管理和更新系统模型与服务配置' }}</p>
</div>
<!-- 配置内容 -->
<div class="settings-content">
<!-- 系统设置使用初始化配置 -->
<InitializationContent v-if="!isKbSettings" />
<!-- 知识库设置基础信息与文档切分配置 -->
<div v-else>
<t-form :data="kbForm" @submit="saveKb">
<div class="config-section">
<h3><span class="section-icon"></span>基础信息</h3>
<t-form-item label="名称" name="name" :rules="[{ required: true, message: '请输入名称' }]">
<t-input v-model="kbForm.name" />
</t-form-item>
<t-form-item label="描述" name="description">
<t-textarea v-model="kbForm.description" />
</t-form-item>
</div>
<div class="config-section">
<h3><span class="section-icon">📄</span>文档切分</h3>
<t-row :gutter="16">
<t-col :span="6">
<t-form-item label="Chunk Size" name="chunkSize">
<t-input-number v-model="kbForm.config.chunking_config.chunk_size" :min="1" />
</t-form-item>
</t-col>
<t-col :span="6">
<t-form-item label="Chunk Overlap" name="chunkOverlap">
<t-input-number v-model="kbForm.config.chunking_config.chunk_overlap" :min="0" />
</t-form-item>
</t-col>
</t-row>
</div>
<div class="submit-section">
<t-button theme="primary" type="submit" :loading="saving">保存</t-button>
</div>
</t-form>
</div>
</div>
</div>
</template>
<script setup lang="ts">
import { defineAsyncComponent, onMounted, reactive, ref } from 'vue'
import { useRoute, useRouter } from 'vue-router'
import { MessagePlugin } from 'tdesign-vue-next'
import { getKnowledgeBaseById, updateKnowledgeBase } from '@/api/knowledge-base'
// 异步加载初始化配置组件
const InitializationContent = defineAsyncComponent(() => import('../initialization/InitializationContent.vue'))
const route = useRoute()
const router = useRouter()
const isKbSettings = ref<boolean>(false)
interface KbForm {
name: string
description?: string
config: { chunking_config: { chunk_size: number; chunk_overlap: number } }
}
const kbForm = reactive<KbForm>({
name: '',
description: '',
config: { chunking_config: { chunk_size: 512, chunk_overlap: 64 } }
})
const saving = ref(false)
const loadKb = () => {
const kbId = (route.params as any).kbId as string
if (!kbId) return
getKnowledgeBaseById(kbId).then((res: any) => {
if (res?.data) {
kbForm.name = res.data.name
kbForm.description = res.data.description
const cc = res.data.chunking_config || {}
kbForm.config.chunking_config.chunk_size = cc.chunk_size ?? 512
kbForm.config.chunking_config.chunk_overlap = cc.chunk_overlap ?? 64
}
})
}
onMounted(() => {
isKbSettings.value = route.name === 'knowledgeBaseSettings'
if (isKbSettings.value) loadKb()
})
const saveKb = () => {
const kbId = (route.params as any).kbId as string
if (!kbId) return
saving.value = true
updateKnowledgeBase(kbId, { name: kbForm.name, description: kbForm.description, config: { chunking_config: { chunk_size: kbForm.config.chunking_config.chunk_size, chunk_overlap: kbForm.config.chunking_config.chunk_overlap, separators: [], enable_multimodal: false }, image_processing_config: { model_id: '' } } })
.then((res: any) => {
if (res.success) {
MessagePlugin.success('保存成功')
} else {
MessagePlugin.error(res.message || '保存失败')
}
})
.catch((e: any) => MessagePlugin.error(e?.message || '保存失败'))
.finally(() => saving.value = false)
}
</script>
<style lang="less" scoped>
.system-settings-container {
padding: 20px;
background-color: #fff;
margin: 0 20px 0 20px;
height: calc(100vh);
overflow-y: auto;
box-sizing: border-box;
flex: 1;
}
.settings-header {
margin-bottom: 20px;
border-bottom: 1px solid #f0f0f0;
padding-bottom: 16px;
h2 {
font-size: 20px;
font-weight: 600;
color: #000000;
margin: 0 0 8px 0;
}
.settings-subtitle {
font-size: 14px;
color: #666666;
margin: 0;
}
}
.settings-content {
margin-top: 0;
}
/* 响应式设计 */
@media (max-width: 768px) {
.system-settings-container {
padding: 16px;
margin: 10px;
height: calc(100vh - 20px);
}
.settings-header h2 {
font-size: 18px;
}
}
/* 覆盖TDesign组件样式与账户信息页面保持一致 */
:deep(.t-card) {
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1);
border: 1px solid #e5e7eb;
}
/* 调整InitializationContent内部样式使每个配置区域显示为独立卡片 */
:deep(.config-section) {
background: #fff;
border: 1px solid #e5e7eb;
border-radius: 8px;
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1);
padding: 20px;
margin-bottom: 20px;
&:last-child {
margin-bottom: 0;
}
h3 {
font-size: 16px;
font-weight: 600;
color: #000000;
margin: 0 0 16px 0;
display: flex;
align-items: center;
padding: 0;
background: none;
border-left: none;
border-radius: 0;
border-bottom: 1px solid #f0f0f0;
padding-bottom: 12px;
.section-icon {
margin-right: 8px;
color: #07c05f;
font-size: 18px;
}
}
}
:deep(.ollama-summary-card) {
background: #fff;
border: 1px solid #e5e7eb;
border-radius: 8px;
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1);
padding: 20px;
margin-bottom: 20px;
}
:deep(.submit-section) {
margin-top: 20px;
text-align: center;
}
</style>

View File

@@ -0,0 +1,535 @@
<template>
<div class="tenant-info-container">
<div class="tenant-header">
<h2>系统信息</h2>
<p class="tenant-subtitle">查看系统版本信息和用户账户配置</p>
</div>
<div class="tenant-content" v-if="!loading && !error">
<!-- 系统信息卡片 -->
<t-card class="info-card" :bordered="false">
<template #header>
<div class="card-title">系统信息</div>
</template>
<div class="info-content">
<t-descriptions :column="1" layout="vertical">
<t-descriptions-item label="版本号">
{{ systemInfo?.version || '未知' }}
<span v-if="systemInfo?.commit_id" class="commit-info">
({{ systemInfo.commit_id }})
</span>
</t-descriptions-item>
<t-descriptions-item label="构建时间" v-if="systemInfo?.build_time">
{{ systemInfo.build_time }}
</t-descriptions-item>
<t-descriptions-item label="Go版本" v-if="systemInfo?.go_version">
{{ systemInfo.go_version }}
</t-descriptions-item>
</t-descriptions>
</div>
</t-card>
<!-- 用户信息卡片 -->
<t-card class="info-card" :bordered="false">
<template #header>
<div class="card-title">用户信息</div>
</template>
<div class="info-content">
<t-descriptions :column="1" layout="vertical">
<t-descriptions-item label="用户 ID">
{{ userInfo?.id }}
</t-descriptions-item>
<t-descriptions-item label="用户名">
{{ userInfo?.username }}
</t-descriptions-item>
<t-descriptions-item label="邮箱">
{{ userInfo?.email }}
</t-descriptions-item>
<t-descriptions-item label="创建时间">
{{ formatDate(userInfo?.created_at) }}
</t-descriptions-item>
</t-descriptions>
</div>
</t-card>
<!-- 租户信息卡片 -->
<t-card class="info-card" :bordered="false">
<template #header>
<div class="card-title">租户信息</div>
</template>
<div class="info-content">
<t-descriptions :column="1" layout="vertical">
<t-descriptions-item label="租户 ID">
{{ tenantInfo?.id }}
</t-descriptions-item>
<t-descriptions-item label="租户名称">
{{ tenantInfo?.name }}
</t-descriptions-item>
<t-descriptions-item label="描述">
{{ tenantInfo?.description || '暂无描述' }}
</t-descriptions-item>
<t-descriptions-item label="业务">
{{ tenantInfo?.business || '暂无' }}
</t-descriptions-item>
<t-descriptions-item label="状态">
<t-tag
:theme="getStatusTheme(tenantInfo?.status)"
variant="light"
>
{{ getStatusText(tenantInfo?.status) }}
</t-tag>
</t-descriptions-item>
<t-descriptions-item label="创建时间">
{{ formatDate(tenantInfo?.created_at) }}
</t-descriptions-item>
</t-descriptions>
</div>
</t-card>
<!-- API Key 卡片 -->
<t-card class="info-card" :bordered="false">
<template #header>
<div class="card-header-with-actions">
<div class="card-title">API Key</div>
</div>
</template>
<div class="api-key-content">
<t-input
v-model="displayApiKey"
readonly
class="api-key-input"
:type="showApiKey ? 'text' : 'password'"
/>
<t-alert theme="warning" :close="false" class="api-warning">
<template #icon>
<t-icon name="error-circle" />
</template>
请妥善保管您的 API Key不要在公共场所或代码仓库中暴露
</t-alert>
</div>
</t-card>
<!-- 存储信息卡片 -->
<t-card
class="info-card"
:bordered="false"
v-if="tenantInfo?.storage_quota !== undefined"
>
<template #header>
<div class="card-title">存储信息</div>
</template>
<div class="storage-content">
<t-descriptions :column="1" layout="vertical">
<t-descriptions-item label="存储配额">
{{ formatBytes(tenantInfo.storage_quota) }}
</t-descriptions-item>
<t-descriptions-item label="已使用">
{{ formatBytes(tenantInfo.storage_used || 0) }}
</t-descriptions-item>
<t-descriptions-item label="使用率">
<div class="usage-info">
<span class="usage-text">{{ getUsagePercentage() }}%</span>
<t-progress
:percentage="getUsagePercentage()"
:show-info="false"
size="medium"
:theme="getUsagePercentage() > 80 ? 'warning' : 'success'"
/>
</div>
</t-descriptions-item>
</t-descriptions>
</div>
</t-card>
<!-- API 开发文档卡片 -->
<t-card class="info-card" :bordered="false">
<template #header>
<div class="card-title">API 开发文档</div>
</template>
<div class="doc-content">
<p class="doc-description">使用您的 API Key 开始开发查看完整的 API 文档和示例代码</p>
<t-space class="doc-actions">
<t-button
theme="primary"
@click="openApiDoc"
>
<template #icon>
<t-icon name="link" />
</template>
查看 API 文档
</t-button>
</t-space>
</div>
</t-card>
</div>
<!-- 加载状态 -->
<div v-if="loading" class="loading-container">
<t-loading size="large" />
<p class="loading-text">正在加载账户信息...</p>
</div>
<!-- 错误状态 -->
<div v-if="error" class="error-container">
<t-result theme="error" title="加载失败" :description="error">
<template #extra>
<t-button theme="primary" @click="loadTenantInfo">重试</t-button>
</template>
</t-result>
</div>
</div>
</template>
<script setup lang="ts">
import { ref, onMounted, computed } from 'vue'
import { getCurrentUser, type TenantInfo, type UserInfo } from '@/api/auth'
import { getSystemInfo, type SystemInfo } from '@/api/system'
// 响应式数据
const tenantInfo = ref<TenantInfo | null>(null)
const userInfo = ref<UserInfo | null>(null)
const systemInfo = ref<SystemInfo | null>(null)
const loading = ref(true)
const error = ref('')
const showApiKey = ref(false)
const showApiExample = ref(false)
// API 基础 URL
const apiBaseUrl = window.location.origin
// 计算属性
const displayApiKey = computed(() => {
if (!tenantInfo.value?.api_key) return ''
return tenantInfo.value.api_key
})
// API示例代码
const apiExampleCode = computed(() => {
return `curl -X GET "${apiBaseUrl}/api/v1/tenants/${tenantInfo.value?.id}" \\
-H "Content-Type: application/json" \\
-H "X-API-Key: ${tenantInfo.value?.api_key}"`
})
// 方法
const loadTenantInfo = async () => {
try {
loading.value = true
error.value = ''
// 并行获取用户信息和系统信息
const [userResponse, systemResponse] = await Promise.all([
getCurrentUser(),
getSystemInfo().catch(() => ({ data: null })) // 系统信息获取失败不影响页面显示
])
if (userResponse.success && userResponse.data) {
userInfo.value = userResponse.data.user
tenantInfo.value = userResponse.data.tenant
} else {
error.value = userResponse.message || '获取用户信息失败'
}
if (systemResponse.data) {
systemInfo.value = systemResponse.data
}
} catch (err: any) {
error.value = err.message || '网络错误,请稍后重试'
} finally {
loading.value = false
}
}
const toggleApiKeyVisibility = () => {
showApiKey.value = !showApiKey.value
}
const copyApiKey = async () => {
if (!tenantInfo.value?.api_key) return
try {
await navigator.clipboard.writeText(tenantInfo.value.api_key)
// 使用TDesign的消息组件
import('tdesign-vue-next').then(({ MessagePlugin }) => {
MessagePlugin.success('API Key 已复制到剪贴板')
})
} catch (err) {
// 降级到传统方式
const textArea = document.createElement('textarea')
textArea.value = tenantInfo.value.api_key
document.body.appendChild(textArea)
textArea.select()
document.execCommand('copy')
document.body.removeChild(textArea)
import('tdesign-vue-next').then(({ MessagePlugin }) => {
MessagePlugin.success('API Key 已复制到剪贴板')
})
}
}
const openApiDoc = () => {
window.open('https://github.com/Tencent/WeKnora/blob/main/docs/API.md', '_blank')
}
const getStatusText = (status: string | undefined) => {
switch (status) {
case 'active':
return '活跃'
case 'inactive':
return '未激活'
case 'suspended':
return '已暂停'
default:
return '未知'
}
}
const getStatusTheme = (status: string | undefined) => {
switch (status) {
case 'active':
return 'success'
case 'inactive':
return 'warning'
case 'suspended':
return 'danger'
default:
return 'default'
}
}
const formatDate = (dateStr: string | undefined) => {
if (!dateStr) return '未知'
try {
const date = new Date(dateStr)
return date.toLocaleString('zh-CN', {
year: 'numeric',
month: '2-digit',
day: '2-digit',
hour: '2-digit',
minute: '2-digit',
second: '2-digit'
})
} catch {
return '格式错误'
}
}
const formatBytes = (bytes: number) => {
if (bytes === 0) return '0 B'
const k = 1024
const sizes = ['B', 'KB', 'MB', 'GB', 'TB']
const i = Math.floor(Math.log(bytes) / Math.log(k))
return parseFloat((bytes / Math.pow(k, i)).toFixed(2)) + ' ' + sizes[i]
}
const getUsagePercentage = () => {
if (!tenantInfo.value?.storage_quota || tenantInfo.value.storage_quota === 0) {
return 0
}
const used = tenantInfo.value.storage_used || 0
const percentage = (used / tenantInfo.value.storage_quota) * 100
return Math.min(Math.round(percentage * 100) / 100, 100) // 保留两位小数最大100%
}
// 生命周期
onMounted(() => {
loadTenantInfo()
})
</script>
<style lang="less" scoped>
.tenant-info-container {
padding: 20px;
background-color: #fff;
margin: 0 20px 0 20px;
height: calc(100vh);
overflow-y: auto;
box-sizing: border-box;
flex: 1;
}
.tenant-header {
margin-bottom: 20px;
border-bottom: 1px solid #f0f0f0;
padding-bottom: 16px;
h2 {
font-size: 20px;
font-weight: 600;
color: #000000;
margin: 0 0 8px 0;
}
.tenant-subtitle {
font-size: 14px;
color: #666666;
margin: 0;
}
}
.tenant-content {
display: grid;
gap: 20px;
grid-template-columns: 1fr;
}
.info-card {
margin-bottom: 20px;
.card-title {
font-size: 16px;
font-weight: 600;
color: #07C05F;
}
.card-header-with-actions {
display: flex;
justify-content: space-between;
align-items: center;
}
}
.info-content,
.api-key-content,
.storage-content,
.doc-content {
margin-top: 0;
}
.api-key-input {
margin-bottom: 16px;
font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', monospace;
}
.api-warning {
margin-top: 16px;
}
.usage-info {
display: flex;
flex-direction: column;
gap: 8px;
.usage-text {
font-weight: 500;
color: #000000;
}
}
.doc-description {
margin-bottom: 16px;
color: #666666;
font-size: 14px;
}
.doc-actions {
margin-bottom: 20px;
}
.api-example {
margin-top: 20px;
padding: 16px;
background-color: #f8f9fa;
border-radius: 6px;
.example-header h4 {
margin: 0 0 16px 0;
font-size: 16px;
font-weight: 600;
color: #000000;
}
.code-textarea {
margin-bottom: 16px;
font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', monospace;
}
.example-note {
margin-top: 16px;
}
}
.loading-container {
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
padding: 60px 24px;
text-align: center;
.loading-text {
margin-top: 16px;
color: #666666;
font-size: 14px;
}
}
.error-container {
padding: 40px;
text-align: center;
}
/* 响应式设计 */
@media (max-width: 768px) {
.tenant-info-container {
padding: 16px;
margin: 10px;
height: calc(100vh - 20px);
}
.tenant-header h2 {
font-size: 18px;
}
.card-header-with-actions {
flex-direction: column;
align-items: flex-start !important;
gap: 12px;
}
.commit-info {
color: #666;
font-size: 12px;
margin-left: 8px;
}
.doc-actions {
:deep(.t-space) {
flex-direction: column;
width: 100%;
.t-button {
width: 100%;
}
}
}
}
/* 覆盖TDesign组件样式 */
:deep(.t-card) {
border: 1px solid #e5e7eb;
}
:deep(.t-descriptions-item__label) {
font-weight: 500;
color: #374151;
}
:deep(.t-descriptions-item__content) {
color: #000000;
}
:deep(.t-input__inner) {
font-family: inherit;
}
:deep(.code-textarea .t-textarea__inner) {
font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', monospace;
font-size: 13px;
line-height: 1.4;
}
</style>

20
go.mod
View File

@@ -10,14 +10,16 @@ require (
github.com/gin-contrib/cors v1.7.5
github.com/gin-gonic/gin v1.10.0
github.com/go-viper/mapstructure/v2 v2.2.1
github.com/golang-jwt/jwt/v5 v5.3.0
github.com/google/uuid v1.6.0
github.com/hibiken/asynq v0.25.1
github.com/minio/minio-go/v7 v7.0.90
github.com/neo4j/neo4j-go-driver/v6 v6.0.0-alpha.1
github.com/ollama/ollama v0.11.4
github.com/panjf2000/ants/v2 v2.11.2
github.com/parquet-go/parquet-go v0.25.0
github.com/pgvector/pgvector-go v0.3.0
github.com/redis/go-redis/v9 v9.7.3
github.com/redis/go-redis/v9 v9.14.0
github.com/sashabaranov/go-openai v1.40.5
github.com/sirupsen/logrus v1.9.3
github.com/spf13/viper v1.20.1
@@ -31,9 +33,10 @@ require (
go.opentelemetry.io/otel/sdk v1.37.0
go.opentelemetry.io/otel/trace v1.37.0
go.uber.org/dig v1.18.1
golang.org/x/sync v0.15.0
golang.org/x/crypto v0.42.0
golang.org/x/sync v0.17.0
google.golang.org/grpc v1.73.0
google.golang.org/protobuf v1.36.6
google.golang.org/protobuf v1.36.9
gorm.io/driver/postgres v1.5.11
gorm.io/gorm v1.25.12
)
@@ -90,7 +93,7 @@ require (
github.com/sagikazarmark/locafero v0.7.0 // indirect
github.com/sourcegraph/conc v0.3.0 // indirect
github.com/spf13/afero v1.12.0 // indirect
github.com/spf13/cast v1.7.1 // indirect
github.com/spf13/cast v1.10.0 // indirect
github.com/spf13/pflag v1.0.6 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
@@ -100,11 +103,10 @@ require (
go.opentelemetry.io/proto/otlp v1.7.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/arch v0.15.0 // indirect
golang.org/x/crypto v0.39.0 // indirect
golang.org/x/net v0.41.0 // indirect
golang.org/x/sys v0.33.0 // indirect
golang.org/x/text v0.26.0 // indirect
golang.org/x/time v0.11.0 // indirect
golang.org/x/net v0.43.0 // indirect
golang.org/x/sys v0.36.0 // indirect
golang.org/x/text v0.29.0 // indirect
golang.org/x/time v0.13.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect

44
go.sum
View File

@@ -70,6 +70,8 @@ github.com/go-viper/mapstructure/v2 v2.2.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlnd
github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
@@ -139,6 +141,8 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/mozillazg/go-httpheader v0.2.1 h1:geV7TrjbL8KXSyvghnFm+NyTux/hxwueTSrwhe88TQQ=
github.com/mozillazg/go-httpheader v0.2.1/go.mod h1:jJ8xECTlalr6ValeXYdOF8fFUISeBAdw6E61aqQma60=
github.com/neo4j/neo4j-go-driver/v6 v6.0.0-alpha.1 h1:nV3ZdYJTi73jel0mm3dpWumNY3i3nwyo25y69SPGwyg=
github.com/neo4j/neo4j-go-driver/v6 v6.0.0-alpha.1/go.mod h1:hzSTfNfM31p1uRSzL1F/BAYOgaiTarE6OAQBajfsm+I=
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
github.com/ollama/ollama v0.11.4 h1:6xLYLEPTKtw6N20qQecyEL/rrBktPO4o5U05cnvkSmI=
@@ -155,8 +159,8 @@ github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ
github.com/pierrec/lz4/v4 v4.1.21/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM=
github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA=
github.com/redis/go-redis/v9 v9.14.0 h1:u4tNCjXOyzfgeLN+vAZaW1xUooqWDqVEsZN0U01jfAE=
github.com/redis/go-redis/v9 v9.14.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
@@ -176,8 +180,8 @@ github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9yS
github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0=
github.com/spf13/afero v1.12.0 h1:UcOPyRBYczmFn6yvphxkn9ZEOY65cpwGKb5mL36mrqs=
github.com/spf13/afero v1.12.0/go.mod h1:ZTlWwG4/ahT8W7T0WQ5uYmjI9duaLQGy3Q2OAl4sk/4=
github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y=
github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY=
github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo=
github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o=
github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.20.1 h1:ZMi+z/lvLyPSCoNtFCpqjy0S4kPbirhpTMwl8BkW9X4=
@@ -253,22 +257,22 @@ go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
golang.org/x/arch v0.15.0 h1:QtOrQd0bTUnhNVNndMpLHNWrDmYzZ2KDqSrEymqInZw=
golang.org/x/arch v0.15.0/go.mod h1:JmwW7aLIoRUKgaTzhkiEFxvcEiQGyOg9BMonBJUS7EE=
golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM=
golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U=
golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw=
golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA=
golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8=
golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI=
golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8=
golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE=
golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg=
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/term v0.32.0 h1:DR4lr0TjUs3epypdhTOkMmuF5CDFJ/8pOnbzMZPQ7bg=
golang.org/x/term v0.32.0/go.mod h1:uZG1FhGx848Sqfsq4/DlJr3xGGsYMu/L5GW4abiaEPQ=
golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M=
golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA=
golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0=
golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k=
golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.35.0 h1:bZBVKBudEyhRcajGcNc3jIfWPqV4y/Kt2XcoigOWtDQ=
golang.org/x/term v0.35.0/go.mod h1:TPGtkTLesOwf2DE8CgVYiZinHAOuy5AYUYT1lENIZnA=
golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk=
golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4=
golang.org/x/time v0.13.0 h1:eUlYslOIt32DgYD6utsuUeHs4d7AsEYLuIAdg7FlYgI=
golang.org/x/time v0.13.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822 h1:oWVWY3NzT7KJppx2UKhKmzPq4SRe0LdCijVRwvGeikY=
google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822/go.mod h1:h3c4v36UTKzUiuaOKQ6gr3S+0hovBtUrXzTG/i3+XEc=
@@ -276,8 +280,8 @@ google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 h1:
google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A=
google.golang.org/grpc v1.73.0 h1:VIWSmpI2MegBtTuFt5/JWy2oXxtjJ/e89Z70ImfD2ok=
google.golang.org/grpc v1.73.0/go.mod h1:50sbHOUqWoCQGI8V2HQLJM0B+LMlIUjNSZmow7EVBQc=
google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY=
google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY=
google.golang.org/protobuf v1.36.9 h1:w2gp2mA27hUeUzj9Ex9FBjsBm40zfaDtEWow293U7Iw=
google.golang.org/protobuf v1.36.9/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=

View File

@@ -0,0 +1,231 @@
package neo4j
import (
"context"
"fmt"
"strings"
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
"github.com/neo4j/neo4j-go-driver/v6/neo4j"
)
type Neo4jRepository struct {
driver neo4j.Driver
nodePrefix string
}
func NewNeo4jRepository(driver neo4j.Driver) interfaces.RetrieveGraphRepository {
return &Neo4jRepository{driver: driver, nodePrefix: "ENTITY"}
}
func _remove_hyphen(s string) string {
return strings.ReplaceAll(s, "-", "_")
}
func (n *Neo4jRepository) Labels(namespace types.NameSpace) []string {
res := make([]string, 0)
for _, label := range namespace.Labels() {
res = append(res, n.nodePrefix+_remove_hyphen(label))
}
return res
}
func (n *Neo4jRepository) Label(namespace types.NameSpace) string {
labels := n.Labels(namespace)
return strings.Join(labels, ":")
}
// AddGraph implements interfaces.RetrieveGraphRepository.
func (n *Neo4jRepository) AddGraph(ctx context.Context, namespace types.NameSpace, graphs []*types.GraphData) error {
if n.driver == nil {
logger.Warnf(ctx, "NOT SUPPORT RETRIEVE GRAPH")
return nil
}
for _, graph := range graphs {
if err := n.addGraph(ctx, namespace, graph); err != nil {
return err
}
}
return nil
}
func (n *Neo4jRepository) addGraph(ctx context.Context, namespace types.NameSpace, graph *types.GraphData) error {
session := n.driver.NewSession(ctx, neo4j.SessionConfig{AccessMode: neo4j.AccessModeWrite})
defer session.Close(ctx)
_, err := session.ExecuteWrite(ctx, func(tx neo4j.ManagedTransaction) (interface{}, error) {
node_import_query := `
UNWIND $data AS row
CALL apoc.merge.node(row.labels, {name: row.name, kg: row.knowledge_id}, row.props, {}) YIELD node
SET node.chunks = apoc.coll.union(node.chunks, row.chunks)
RETURN distinct 'done' AS result
`
nodeData := []map[string]interface{}{}
for _, node := range graph.Node {
nodeData = append(nodeData, map[string]interface{}{
"name": node.Name,
"knowledge_id": namespace.Knowledge,
"props": map[string][]string{"attributes": node.Attributes},
"chunks": node.Chunks,
"labels": n.Labels(namespace),
})
}
if _, err := tx.Run(ctx, node_import_query, map[string]interface{}{"data": nodeData}); err != nil {
return nil, fmt.Errorf("failed to create nodes: %v", err)
}
rel_import_query := `
UNWIND $data AS row
CALL apoc.merge.node(row.source_labels, {name: row.source, kg: row.knowledge_id}, {}, {}) YIELD node as source
CALL apoc.merge.node(row.target_labels, {name: row.target, kg: row.knowledge_id}, {}, {}) YIELD node as target
CALL apoc.merge.relationship(source, row.type, {}, row.attributes, target) YIELD rel
RETURN distinct 'done'
`
relData := []map[string]interface{}{}
for _, rel := range graph.Relation {
relData = append(relData, map[string]interface{}{
"source": rel.Node1,
"target": rel.Node2,
"knowledge_id": namespace.Knowledge,
"type": rel.Type,
"source_labels": n.Labels(namespace),
"target_labels": n.Labels(namespace),
})
}
if _, err := tx.Run(ctx, rel_import_query, map[string]interface{}{"data": relData}); err != nil {
return nil, fmt.Errorf("failed to create relationships: %v", err)
}
return nil, nil
})
if err != nil {
logger.Errorf(ctx, "failed to add graph: %v", err)
return err
}
return nil
}
// DelGraph implements interfaces.RetrieveGraphRepository.
func (n *Neo4jRepository) DelGraph(ctx context.Context, namespaces []types.NameSpace) error {
if n.driver == nil {
logger.Warnf(ctx, "NOT SUPPORT RETRIEVE GRAPH")
return nil
}
session := n.driver.NewSession(ctx, neo4j.SessionConfig{AccessMode: neo4j.AccessModeWrite})
defer session.Close(ctx)
result, err := session.ExecuteWrite(ctx, func(tx neo4j.ManagedTransaction) (interface{}, error) {
for _, namespace := range namespaces {
labelExpr := n.Label(namespace)
deleteRelsQuery := `
CALL apoc.periodic.iterate(
"MATCH (n:` + labelExpr + ` {kg: $knowledge_id})-[r]-(m:` + labelExpr + ` {kg: $knowledge_id}) RETURN r",
"DELETE r",
{batchSize: 1000, parallel: true, params: {knowledge_id: $knowledge_id}}
) YIELD batches, total
RETURN total
`
if _, err := tx.Run(ctx, deleteRelsQuery, map[string]interface{}{"knowledge_id": namespace.Knowledge}); err != nil {
return nil, fmt.Errorf("failed to delete relationships: %v", err)
}
deleteNodesQuery := `
CALL apoc.periodic.iterate(
"MATCH (n:` + labelExpr + ` {kg: $knowledge_id}) RETURN n",
"DELETE n",
{batchSize: 1000, parallel: true, params: {knowledge_id: $knowledge_id}}
) YIELD batches, total
RETURN total
`
if _, err := tx.Run(ctx, deleteNodesQuery, map[string]interface{}{"knowledge_id": namespace.Knowledge}); err != nil {
return nil, fmt.Errorf("failed to delete nodes: %v", err)
}
}
return nil, nil
})
if err != nil {
return err
}
logger.Infof(ctx, "delete graph result: %v", result)
return nil
}
func (n *Neo4jRepository) SearchNode(ctx context.Context, namespace types.NameSpace, nodes []string) (*types.GraphData, error) {
if n.driver == nil {
logger.Warnf(ctx, "NOT SUPPORT RETRIEVE GRAPH")
return nil, nil
}
session := n.driver.NewSession(ctx, neo4j.SessionConfig{AccessMode: neo4j.AccessModeRead})
defer session.Close(ctx)
result, err := session.ExecuteRead(ctx, func(tx neo4j.ManagedTransaction) (interface{}, error) {
labelExpr := n.Label(namespace)
query := `
MATCH (n:` + labelExpr + `)-[r]-(m:` + labelExpr + `)
WHERE ANY(nodeText IN $nodes WHERE n.name CONTAINS nodeText)
RETURN n, r, m
`
params := map[string]interface{}{"nodes": nodes}
result, err := tx.Run(ctx, query, params)
if err != nil {
return nil, fmt.Errorf("failed to run query: %v", err)
}
graphData := &types.GraphData{}
nodeSeen := make(map[string]bool)
for result.Next(ctx) {
record := result.Record()
node, _ := record.Get("n")
rel, _ := record.Get("r")
targetNode, _ := record.Get("m")
nodeData := node.(neo4j.Node)
targetNodeData := targetNode.(neo4j.Node)
// Convert node to types.Node
for _, n := range []neo4j.Node{nodeData, targetNodeData} {
nameStr := n.Props["name"].(string)
if _, ok := nodeSeen[nameStr]; !ok {
nodeSeen[nameStr] = true
graphData.Node = append(graphData.Node, &types.GraphNode{
Name: nameStr,
Chunks: listI2listS(n.Props["chunks"].([]interface{})),
Attributes: listI2listS(n.Props["attributes"].([]interface{})),
})
}
}
// Convert relationship to types.Relation
relData := rel.(neo4j.Relationship)
graphData.Relation = append(graphData.Relation, &types.GraphRelation{
Node1: nodeData.Props["name"].(string),
Node2: targetNodeData.Props["name"].(string),
Type: relData.Type,
})
}
return graphData, nil
})
if err != nil {
logger.Errorf(ctx, "search node failed: %v", err)
return nil, err
}
return result.(*types.GraphData), nil
}
func listI2listS(list []any) []string {
result := make([]string, len(list))
for i, v := range list {
result[i] = fmt.Sprintf("%v", v)
}
return result
}
func mapI2mapS(prop map[string]any) map[string]string {
attributes := make(map[string]string)
for k, v := range prop {
attributes[k] = fmt.Sprintf("%v", v)
}
return attributes
}

View File

@@ -0,0 +1,154 @@
package repository
import (
"context"
"errors"
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
"gorm.io/gorm"
)
var (
ErrUserNotFound = errors.New("user not found")
ErrUserAlreadyExists = errors.New("user already exists")
ErrTokenNotFound = errors.New("token not found")
)
// userRepository implements user repository interface
type userRepository struct {
db *gorm.DB
}
// NewUserRepository creates a new user repository
func NewUserRepository(db *gorm.DB) interfaces.UserRepository {
return &userRepository{db: db}
}
// CreateUser creates a user
func (r *userRepository) CreateUser(ctx context.Context, user *types.User) error {
logger.Infof(ctx, "Creating user in database: %s", user.Email)
return r.db.WithContext(ctx).Create(user).Error
}
// GetUserByID gets a user by ID
func (r *userRepository) GetUserByID(ctx context.Context, id string) (*types.User, error) {
var user types.User
if err := r.db.WithContext(ctx).Where("id = ?", id).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, err
}
return &user, nil
}
// GetUserByEmail gets a user by email
func (r *userRepository) GetUserByEmail(ctx context.Context, email string) (*types.User, error) {
var user types.User
if err := r.db.WithContext(ctx).Where("email = ?", email).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, err
}
return &user, nil
}
// GetUserByUsername gets a user by username
func (r *userRepository) GetUserByUsername(ctx context.Context, username string) (*types.User, error) {
var user types.User
if err := r.db.WithContext(ctx).Where("username = ?", username).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, err
}
return &user, nil
}
// UpdateUser updates a user
func (r *userRepository) UpdateUser(ctx context.Context, user *types.User) error {
return r.db.WithContext(ctx).Save(user).Error
}
// DeleteUser deletes a user
func (r *userRepository) DeleteUser(ctx context.Context, id string) error {
return r.db.WithContext(ctx).Where("id = ?", id).Delete(&types.User{}).Error
}
// ListUsers lists users with pagination
func (r *userRepository) ListUsers(ctx context.Context, offset, limit int) ([]*types.User, error) {
var users []*types.User
query := r.db.WithContext(ctx).Order("created_at DESC")
if limit > 0 {
query = query.Limit(limit)
}
if offset > 0 {
query = query.Offset(offset)
}
if err := query.Find(&users).Error; err != nil {
return nil, err
}
return users, nil
}
// authTokenRepository implements auth token repository interface
type authTokenRepository struct {
db *gorm.DB
}
// NewAuthTokenRepository creates a new auth token repository
func NewAuthTokenRepository(db *gorm.DB) interfaces.AuthTokenRepository {
return &authTokenRepository{db: db}
}
// CreateToken creates an auth token
func (r *authTokenRepository) CreateToken(ctx context.Context, token *types.AuthToken) error {
return r.db.WithContext(ctx).Create(token).Error
}
// GetTokenByValue gets a token by its value
func (r *authTokenRepository) GetTokenByValue(ctx context.Context, tokenValue string) (*types.AuthToken, error) {
var token types.AuthToken
if err := r.db.WithContext(ctx).Where("token = ?", tokenValue).First(&token).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrTokenNotFound
}
return nil, err
}
return &token, nil
}
// GetTokensByUserID gets all tokens for a user
func (r *authTokenRepository) GetTokensByUserID(ctx context.Context, userID string) ([]*types.AuthToken, error) {
var tokens []*types.AuthToken
if err := r.db.WithContext(ctx).Where("user_id = ?", userID).Find(&tokens).Error; err != nil {
return nil, err
}
return tokens, nil
}
// UpdateToken updates a token
func (r *authTokenRepository) UpdateToken(ctx context.Context, token *types.AuthToken) error {
return r.db.WithContext(ctx).Save(token).Error
}
// DeleteToken deletes a token
func (r *authTokenRepository) DeleteToken(ctx context.Context, id string) error {
return r.db.WithContext(ctx).Where("id = ?", id).Delete(&types.AuthToken{}).Error
}
// DeleteExpiredTokens deletes all expired tokens
func (r *authTokenRepository) DeleteExpiredTokens(ctx context.Context) error {
return r.db.WithContext(ctx).Where("expires_at < NOW()").Delete(&types.AuthToken{}).Error
}
// RevokeTokensByUserID revokes all tokens for a user
func (r *authTokenRepository) RevokeTokensByUserID(ctx context.Context, userID string) error {
return r.db.WithContext(ctx).Model(&types.AuthToken{}).Where("user_id = ?", userID).Update("is_revoked", true).Error
}

View File

@@ -0,0 +1,499 @@
package chatpipline
import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"regexp"
"strings"
"github.com/Tencent/WeKnora/internal/config"
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/models/chat"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
)
// PluginExtractEntity is a plugin for extracting entities from user queries
// It uses historical dialog context and large language models to identify key entities in the user's original query
type PluginExtractEntity struct {
modelService interfaces.ModelService // Model service for calling large language models
template *types.PromptTemplateStructured // Template for generating prompts
knowledgeBaseRepo interfaces.KnowledgeBaseRepository
}
// NewPluginRewrite creates a new query rewriting plugin instance
// Also registers the plugin with the event manager
func NewPluginExtractEntity(
eventManager *EventManager,
modelService interfaces.ModelService,
knowledgeBaseRepo interfaces.KnowledgeBaseRepository,
config *config.Config,
) *PluginExtractEntity {
res := &PluginExtractEntity{
modelService: modelService,
template: config.ExtractManager.ExtractEntity,
knowledgeBaseRepo: knowledgeBaseRepo,
}
eventManager.Register(res)
return res
}
// ActivationEvents returns the list of event types this plugin responds to
// This plugin only responds to REWRITE_QUERY events
func (p *PluginExtractEntity) ActivationEvents() []types.EventType {
return []types.EventType{types.REWRITE_QUERY}
}
// OnEvent processes triggered events
// When receiving a REWRITE_QUERY event, it rewrites the user query using conversation history and the language model
func (p *PluginExtractEntity) OnEvent(ctx context.Context,
eventType types.EventType, chatManage *types.ChatManage, next func() *PluginError,
) *PluginError {
if strings.ToLower(os.Getenv("NEO4J_ENABLE")) != "true" {
logger.Debugf(ctx, "skipping extract entity, neo4j is disabled")
return next()
}
query := chatManage.Query
model, err := p.modelService.GetChatModel(ctx, chatManage.ChatModelID)
if err != nil {
logger.Errorf(ctx, "Failed to get model, session_id: %s, error: %v", chatManage.SessionID, err)
return next()
}
kb, err := p.knowledgeBaseRepo.GetKnowledgeBaseByID(ctx, chatManage.KnowledgeBaseID)
if err != nil {
logger.Errorf(ctx, "failed to get knowledge base: %v", err)
return next()
}
if kb.ExtractConfig == nil {
logger.Warnf(ctx, "failed to get extract config")
return next()
}
template := &types.PromptTemplateStructured{
Description: p.template.Description,
Examples: p.template.Examples,
}
extractor := NewExtractor(model, template)
graph, err := extractor.Extract(ctx, query)
if err != nil {
logger.Errorf(ctx, "Failed to extract entities, session_id: %s, error: %v", chatManage.SessionID, err)
return next()
}
nodes := []string{}
for _, node := range graph.Node {
nodes = append(nodes, node.Name)
}
logger.Debugf(ctx, "extracted node: %v", nodes)
chatManage.Entity = nodes
return next()
}
type Extractor struct {
chat chat.Chat
formater *Formater
template *types.PromptTemplateStructured
chatOpt *chat.ChatOptions
}
func NewExtractor(
chatModel chat.Chat,
template *types.PromptTemplateStructured,
) Extractor {
think := false
return Extractor{
chat: chatModel,
formater: NewFormater(),
template: template,
chatOpt: &chat.ChatOptions{
Temperature: 0.3,
MaxTokens: 4096,
Thinking: &think,
},
}
}
func (e *Extractor) Extract(ctx context.Context, content string) (*types.GraphData, error) {
generator := NewQAPromptGenerator(e.formater, e.template)
// logger.Debugf(ctx, "chat system: %s", generator.System(ctx))
// logger.Debugf(ctx, "chat user: %s", generator.User(ctx, content))
chatResponse, err := e.chat.Chat(ctx, generator.Render(ctx, content), e.chatOpt)
if err != nil {
logger.Errorf(ctx, "failed to chat: %v", err)
return nil, err
}
graph, err := e.formater.ParseGraph(ctx, chatResponse.Content)
if err != nil {
logger.Errorf(ctx, "failed to parse graph: %v", err)
return nil, err
}
// e.RemoveUnknownRelation(ctx, graph)
return graph, nil
}
func (e *Extractor) RemoveUnknownRelation(ctx context.Context, graph *types.GraphData) {
relationType := make(map[string]bool)
for _, tag := range e.template.Tags {
relationType[tag] = true
}
relationNew := make([]*types.GraphRelation, 0)
for _, relation := range graph.Relation {
if _, ok := relationType[relation.Type]; ok {
relationNew = append(relationNew, relation)
} else {
logger.Infof(ctx, "Unknown relation type %s with %v, ignore it", relation.Type, e.template.Tags)
}
}
graph.Relation = relationNew
}
type QAPromptGenerator struct {
Formater *Formater
Template *types.PromptTemplateStructured
ExamplesHeading string
QuestionHeading string
QuestionPrefix string
AnswerPrefix string
}
func NewQAPromptGenerator(formater *Formater, template *types.PromptTemplateStructured) *QAPromptGenerator {
return &QAPromptGenerator{
Formater: formater,
Template: template,
ExamplesHeading: "# Examples",
QuestionHeading: "# Question",
QuestionPrefix: "Q: ",
AnswerPrefix: "A: ",
}
}
func (qa *QAPromptGenerator) System(ctx context.Context) string {
promptLines := []string{}
if len(qa.Template.Tags) == 0 {
promptLines = append(promptLines, qa.Template.Description)
} else {
tags, _ := json.Marshal(qa.Template.Tags)
promptLines = append(promptLines, fmt.Sprintf(qa.Template.Description, string(tags)))
}
if len(qa.Template.Examples) > 0 {
promptLines = append(promptLines, qa.ExamplesHeading)
for _, example := range qa.Template.Examples {
// Question
promptLines = append(promptLines, fmt.Sprintf("%s%s", qa.QuestionPrefix, strings.TrimSpace(example.Text)))
// Answer
answer, err := qa.Formater.formatExtraction(example.Node, example.Relation)
if err != nil {
return ""
}
promptLines = append(promptLines, fmt.Sprintf("%s%s", qa.AnswerPrefix, answer))
// new line
promptLines = append(promptLines, "")
}
}
return strings.Join(promptLines, "\n")
}
func (qa *QAPromptGenerator) User(ctx context.Context, question string) string {
promptLines := []string{}
promptLines = append(promptLines, qa.QuestionHeading)
promptLines = append(promptLines, fmt.Sprintf("%s%s", qa.QuestionPrefix, question))
promptLines = append(promptLines, qa.AnswerPrefix)
return strings.Join(promptLines, "\n")
}
func (qa *QAPromptGenerator) Render(ctx context.Context, question string) []chat.Message {
return []chat.Message{
{
Role: "system",
Content: qa.System(ctx),
},
{
Role: "user",
Content: qa.User(ctx, question),
},
}
}
type FormatType string
const (
FormatTypeJSON FormatType = "json"
FormatTypeYAML FormatType = "yaml"
)
const (
_FENCE_START = "```"
_LANGUAGE_TAG = `(?P<lang>[A-Za-z0-9_+-]+)?`
_FENCE_NEWLINE = `(?:\s*\n)?`
_FENCE_BODY = `(?P<body>[\s\S]*?)`
_FENCE_END = "```"
)
var _FENCE_RE = regexp.MustCompile(
_FENCE_START + _LANGUAGE_TAG + _FENCE_NEWLINE + _FENCE_BODY + _FENCE_END,
)
type Formater struct {
attributeSuffix string
formatType FormatType
useFences bool
nodePrefix string
relationSource string
relationTarget string
relationPrefix string
}
func NewFormater() *Formater {
return &Formater{
attributeSuffix: "_attributes",
formatType: FormatTypeJSON,
useFences: true,
nodePrefix: "entity",
relationSource: "entity1",
relationTarget: "entity2",
relationPrefix: "relation",
}
}
func (f *Formater) formatExtraction(nodes []*types.GraphNode, relations []*types.GraphRelation) (string, error) {
items := make([]map[string]interface{}, 0)
for _, node := range nodes {
item := map[string]interface{}{
f.nodePrefix: node.Name,
}
if len(node.Attributes) > 0 {
item[fmt.Sprintf("%s%s", f.nodePrefix, f.attributeSuffix)] = node.Attributes
}
items = append(items, item)
}
for _, relation := range relations {
item := map[string]interface{}{
f.relationSource: relation.Node1,
f.relationTarget: relation.Node2,
f.relationPrefix: relation.Type,
}
items = append(items, item)
}
formatted := ""
switch f.formatType {
default:
formattedBytes, err := json.MarshalIndent(items, "", " ")
if err != nil {
return "", err
}
formatted = string(formattedBytes)
}
if f.useFences {
formatted = f.addFences(formatted)
}
return formatted, nil
}
func (f *Formater) parseOutput(ctx context.Context, text string) ([]map[string]interface{}, error) {
if text == "" {
return nil, errors.New("Empty or invalid input string.")
}
content := f.extractContent(ctx, text)
// logger.Debugf(ctx, "Extracted content: %s", content)
if content == "" {
return nil, errors.New("Empty or invalid input string.")
}
var parsed interface{}
var err error
if f.formatType == FormatTypeJSON {
err = json.Unmarshal([]byte(content), &parsed)
}
if err != nil {
return nil, fmt.Errorf("Failed to parse %s content: %s", strings.ToUpper(string(f.formatType)), err.Error())
}
if parsed == nil {
return nil, fmt.Errorf("Content must be a list of extractions or a dict.")
}
var items []interface{}
if parsedMap, ok := parsed.(map[string]interface{}); ok {
items = []interface{}{parsedMap}
} else if parsedList, ok := parsed.([]interface{}); ok {
items = parsedList
} else {
return nil, fmt.Errorf("Expected list or dict, got %T", parsed)
}
itemsList := make([]map[string]interface{}, 0)
for _, item := range items {
if itemMap, ok := item.(map[string]interface{}); ok {
itemsList = append(itemsList, itemMap)
} else {
return nil, fmt.Errorf("Each item in the sequence must be a mapping.")
}
}
return itemsList, nil
}
func (f *Formater) ParseGraph(ctx context.Context, text string) (*types.GraphData, error) {
matchData, err := f.parseOutput(ctx, text)
if err != nil {
return nil, err
}
if len(matchData) == 0 {
logger.Debugf(ctx, "Received empty extraction data.")
return &types.GraphData{}, nil
}
// mm, _ := json.Marshal(matchData)
// logger.Debugf(ctx, "Parsed graph data: %s", string(mm))
var nodes []*types.GraphNode
var relations []*types.GraphRelation
for _, group := range matchData {
switch {
case group[f.nodePrefix] != nil:
attributes := make([]string, 0)
attributesKey := f.nodePrefix + f.attributeSuffix
if attr, ok := group[attributesKey].([]interface{}); ok {
for _, v := range attr {
attributes = append(attributes, fmt.Sprintf("%v", v))
}
}
nodes = append(nodes, &types.GraphNode{
Name: fmt.Sprintf("%v", group[f.nodePrefix]),
Attributes: attributes,
})
case group[f.relationSource] != nil && group[f.relationTarget] != nil:
relations = append(relations, &types.GraphRelation{
Node1: fmt.Sprintf("%v", group[f.relationSource]),
Node2: fmt.Sprintf("%v", group[f.relationTarget]),
Type: fmt.Sprintf("%v", group[f.relationPrefix]),
})
default:
logger.Warnf(ctx, "Unsupported graph group: %v", group)
continue
}
}
graph := &types.GraphData{
Node: nodes,
Relation: relations,
}
f.rebuildGraph(ctx, graph)
return graph, nil
}
func (f *Formater) rebuildGraph(ctx context.Context, graph *types.GraphData) {
nodeMap := make(map[string]*types.GraphNode)
nodes := make([]*types.GraphNode, 0, len(graph.Node))
for _, node := range graph.Node {
if prenode, ok := nodeMap[node.Name]; ok {
logger.Infof(ctx, "Duplicate node ID: %s, merge attribute", node.Name)
// 修复panic检查Attributes是否为nil
if node.Attributes == nil {
node.Attributes = make([]string, 0)
}
if prenode.Attributes != nil {
for _, attr := range prenode.Attributes {
node.Attributes = append(node.Attributes, attr)
}
}
continue
}
nodeMap[node.Name] = node
nodes = append(nodes, node)
}
relations := make([]*types.GraphRelation, 0, len(graph.Relation))
for _, relation := range graph.Relation {
if relation.Node1 == relation.Node2 {
logger.Infof(ctx, "Duplicate relation, ignore it")
continue
}
if _, ok := nodeMap[relation.Node1]; !ok {
node := &types.GraphNode{Name: relation.Node1}
nodes = append(nodes, node)
nodeMap[relation.Node1] = node
logger.Infof(ctx, "Add unknown source node ID: %s", relation.Node1)
}
if _, ok := nodeMap[relation.Node2]; !ok {
node := &types.GraphNode{Name: relation.Node2}
nodes = append(nodes, node)
nodeMap[relation.Node2] = node
logger.Infof(ctx, "Add unknown target node ID: %s", relation.Node2)
}
relations = append(relations, relation)
}
*graph = types.GraphData{
Node: nodes,
Relation: relations,
}
}
func (f *Formater) extractContent(ctx context.Context, text string) string {
if !f.useFences {
return strings.TrimSpace(text)
}
validTags := map[FormatType]map[string]struct{}{
FormatTypeYAML: {"yaml": {}, "yml": {}},
FormatTypeJSON: {"json": {}},
}
matches := _FENCE_RE.FindAllStringSubmatch(text, -1)
var candidates []string
for _, match := range matches {
lang := match[1]
body := match[2]
if f.isValidLanguageTag(lang, validTags) {
candidates = append(candidates, body)
}
}
switch {
case len(candidates) == 1:
return strings.TrimSpace(candidates[0])
case len(candidates) > 1:
logger.Warnf(ctx, "multiple candidates found: %d", len(candidates))
return strings.TrimSpace(candidates[0])
case len(matches) == 1:
logger.Debugf(ctx, "no candidate found, use first match without language tag: %s", matches[0][1])
return strings.TrimSpace(matches[0][2])
case len(matches) > 1:
logger.Warnf(ctx, "multiple matches found: %d", len(matches))
return strings.TrimSpace(matches[0][2])
default:
logger.Warnf(ctx, "no match found")
return strings.TrimSpace(text)
}
}
func (f *Formater) addFences(content string) string {
content = strings.TrimSpace(content)
return fmt.Sprintf("```%s\n%s\n```", f.formatType, content)
}
func (f *Formater) isValidLanguageTag(lang string, validTags map[FormatType]map[string]struct{}) bool {
if lang == "" {
return true
}
tag := strings.TrimSpace(strings.ToLower(lang))
validSet, ok := validTags[f.formatType]
if !ok {
return false
}
_, exists := validSet[tag]
return exists
}

View File

@@ -12,6 +12,7 @@ import (
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/types"
secutils "github.com/Tencent/WeKnora/internal/utils"
)
// PluginIntoChatMessage handles the transformation of search results into chat messages
@@ -50,9 +51,16 @@ func (p *PluginIntoChatMessage) OnEvent(ctx context.Context,
weekdayName := []string{"星期日", "星期一", "星期二", "星期三", "星期四", "星期五", "星期六"}
var userContent bytes.Buffer
// 验证用户查询的安全性
safeQuery, isValid := secutils.ValidateInput(chatManage.Query)
if !isValid {
logger.Errorf(ctx, "Invalid user query: %s", chatManage.Query)
return ErrTemplateExecute.WithError(fmt.Errorf("用户查询包含非法内容"))
}
// Execute template with context data
err = tmpl.Execute(&userContent, map[string]interface{}{
"Query": chatManage.Query, // User's original query
"Query": safeQuery, // User's original query
"Contexts": passages, // Extracted passages from search results
"CurrentTime": time.Now().Format("2006-01-02 15:04:05"), // Formatted current time
"CurrentWeek": weekdayName[time.Now().Weekday()], // Current weekday in Chinese

View File

@@ -77,6 +77,7 @@ func (p *PluginSearch) OnEvent(ctx context.Context,
}
chatManage.SearchResult = append(chatManage.SearchResult, searchResults...)
}
// remove duplicate results
chatManage.SearchResult = removeDuplicateResults(chatManage.SearchResult)

View File

@@ -0,0 +1,136 @@
package chatpipline
import (
"context"
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
)
// PluginSearch implements search functionality for chat pipeline
type PluginSearchEntity struct {
graphRepo interfaces.RetrieveGraphRepository
chunkRepo interfaces.ChunkRepository
knowledgeRepo interfaces.KnowledgeRepository
}
func NewPluginSearchEntity(
eventManager *EventManager,
graphRepository interfaces.RetrieveGraphRepository,
chunkRepository interfaces.ChunkRepository,
knowledgeRepository interfaces.KnowledgeRepository,
) *PluginSearchEntity {
res := &PluginSearchEntity{
graphRepo: graphRepository,
chunkRepo: chunkRepository,
knowledgeRepo: knowledgeRepository,
}
eventManager.Register(res)
return res
}
// ActivationEvents returns the event types this plugin handles
func (p *PluginSearchEntity) ActivationEvents() []types.EventType {
return []types.EventType{types.ENTITY_SEARCH}
}
// OnEvent handles search events in the chat pipeline
func (p *PluginSearchEntity) OnEvent(ctx context.Context,
eventType types.EventType, chatManage *types.ChatManage, next func() *PluginError,
) *PluginError {
entity := chatManage.Entity
if len(entity) == 0 {
logger.Infof(ctx, "No entity found")
return next()
}
graph, err := p.graphRepo.SearchNode(ctx, types.NameSpace{KnowledgeBase: chatManage.KnowledgeBaseID}, entity)
if err != nil {
logger.Errorf(ctx, "Failed to search node, session_id: %s, error: %v", chatManage.SessionID, err)
return next()
}
chatManage.GraphResult = graph
logger.Infof(ctx, "search entity result count: %d", len(graph.Node))
// graphStr, _ := json.Marshal(graph)
// logger.Debugf(ctx, "search entity result: %s", string(graphStr))
chunkIDs := filterSeenChunk(ctx, graph, chatManage.SearchResult)
if len(chunkIDs) == 0 {
logger.Infof(ctx, "No new chunk found")
return next()
}
chunks, err := p.chunkRepo.ListChunksByID(ctx, ctx.Value(types.TenantIDContextKey).(uint), chunkIDs)
if err != nil {
logger.Errorf(ctx, "Failed to list chunks, session_id: %s, error: %v", chatManage.SessionID, err)
return next()
}
knowledgeIDs := []string{}
for _, chunk := range chunks {
knowledgeIDs = append(knowledgeIDs, chunk.KnowledgeID)
}
knowledges, err := p.knowledgeRepo.GetKnowledgeBatch(ctx, ctx.Value(types.TenantIDContextKey).(uint), knowledgeIDs)
if err != nil {
logger.Errorf(ctx, "Failed to list knowledge, session_id: %s, error: %v", chatManage.SessionID, err)
return next()
}
knowledgeMap := map[string]*types.Knowledge{}
for _, knowledge := range knowledges {
knowledgeMap[knowledge.ID] = knowledge
}
for _, chunk := range chunks {
searchResult := chunk2SearchResult(chunk, knowledgeMap[chunk.KnowledgeID])
chatManage.SearchResult = append(chatManage.SearchResult, searchResult)
}
// remove duplicate results
chatManage.SearchResult = removeDuplicateResults(chatManage.SearchResult)
if len(chatManage.SearchResult) == 0 {
logger.Infof(ctx, "No new search result, session_id: %s", chatManage.SessionID)
return ErrSearchNothing
}
logger.Infof(ctx, "search entity result count: %d, session_id: %s", len(chatManage.SearchResult), chatManage.SessionID)
return next()
}
func filterSeenChunk(ctx context.Context, graph *types.GraphData, searchResult []*types.SearchResult) []string {
seen := map[string]bool{}
for _, chunk := range searchResult {
seen[chunk.ID] = true
}
logger.Infof(ctx, "filterSeenChunk: seen count: %d", len(seen))
chunkIDs := []string{}
for _, node := range graph.Node {
for _, chunkID := range node.Chunks {
if seen[chunkID] {
continue
}
seen[chunkID] = true
chunkIDs = append(chunkIDs, chunkID)
}
}
logger.Infof(ctx, "filterSeenChunk: new chunkIDs count: %d", len(chunkIDs))
return chunkIDs
}
func chunk2SearchResult(chunk *types.Chunk, knowledge *types.Knowledge) *types.SearchResult {
return &types.SearchResult{
ID: chunk.ID,
Content: chunk.Content,
KnowledgeID: chunk.KnowledgeID,
ChunkIndex: chunk.ChunkIndex,
KnowledgeTitle: knowledge.Title,
StartAt: chunk.StartAt,
EndAt: chunk.EndAt,
Seq: chunk.ChunkIndex,
Score: 1.0,
MatchType: types.MatchTypeGraph,
Metadata: knowledge.GetMetadata(),
ChunkType: string(chunk.ChunkType),
ParentChunkID: chunk.ParentChunkID,
ImageInfo: chunk.ImageInfo,
KnowledgeFilename: knowledge.FileName,
KnowledgeSource: knowledge.Source,
}
}

View File

@@ -3,6 +3,7 @@ package service
import (
"context"
"errors"
"fmt"
"runtime"
"sync"
"time"
@@ -30,7 +31,7 @@ type EvaluationService struct {
knowledgeBaseService interfaces.KnowledgeBaseService // Service for knowledge base operations
knowledgeService interfaces.KnowledgeService // Service for knowledge operations
sessionService interfaces.SessionService // Service for chat sessions
testData *TestDataService // Service for test data
modelService interfaces.ModelService // Service for model operations
evaluationMemoryStorage *evaluationMemoryStorage // In-memory storage for evaluation tasks
}
@@ -41,7 +42,7 @@ func NewEvaluationService(
knowledgeBaseService interfaces.KnowledgeBaseService,
knowledgeService interfaces.KnowledgeService,
sessionService interfaces.SessionService,
testData *TestDataService,
modelService interfaces.ModelService,
) interfaces.EvaluationService {
evaluationMemoryStorage := newEvaluationMemoryStorage()
return &EvaluationService{
@@ -50,7 +51,7 @@ func NewEvaluationService(
knowledgeBaseService: knowledgeBaseService,
knowledgeService: knowledgeService,
sessionService: sessionService,
testData: testData,
modelService: modelService,
evaluationMemoryStorage: evaluationMemoryStorage,
}
}
@@ -144,11 +145,32 @@ func (e *EvaluationService) Evaluation(ctx context.Context,
if knowledgeBaseID == "" {
logger.Info(ctx, "No knowledge base ID provided, creating new knowledge base")
// Create new knowledge base with default evaluation settings
// 获取默认的嵌入模型和LLM模型
models, err := e.modelService.ListModels(ctx)
if err != nil {
logger.Errorf(ctx, "Failed to list models: %v", err)
return nil, err
}
var embeddingModelID, llmModelID string
for _, model := range models {
if model.Type == types.ModelTypeEmbedding {
embeddingModelID = model.ID
}
if model.Type == types.ModelTypeKnowledgeQA {
llmModelID = model.ID
}
}
if embeddingModelID == "" || llmModelID == "" {
return nil, fmt.Errorf("no default models found for evaluation")
}
kb, err := e.knowledgeBaseService.CreateKnowledgeBase(ctx, &types.KnowledgeBase{
Name: "evaluation",
Description: "evaluation",
EmbeddingModelID: e.testData.EmbedModel.GetModelID(),
SummaryModelID: e.testData.LLMModel.GetModelID(),
EmbeddingModelID: embeddingModelID,
SummaryModelID: llmModelID,
})
if err != nil {
logger.Errorf(ctx, "Failed to create knowledge base: %v", err)
@@ -186,12 +208,37 @@ func (e *EvaluationService) Evaluation(ctx context.Context,
}
if rerankModelID == "" {
rerankModelID = e.testData.RerankModel.GetModelID()
logger.Infof(ctx, "Using default rerank model: %s", rerankModelID)
// 获取默认的重排模型
models, err := e.modelService.ListModels(ctx)
if err == nil {
for _, model := range models {
if model.Type == types.ModelTypeRerank {
rerankModelID = model.ID
break
}
}
}
if rerankModelID == "" {
logger.Warnf(ctx, "No rerank model found, skipping rerank")
} else {
logger.Infof(ctx, "Using default rerank model: %s", rerankModelID)
}
}
if chatModelID == "" {
chatModelID = e.testData.LLMModel.GetModelID()
// 获取默认的LLM模型
models, err := e.modelService.ListModels(ctx)
if err == nil {
for _, model := range models {
if model.Type == types.ModelTypeKnowledgeQA {
chatModelID = model.ID
break
}
}
}
if chatModelID == "" {
return nil, fmt.Errorf("no default chat model found")
}
logger.Infof(ctx, "Using default chat model: %s", chatModelID)
}

View File

@@ -0,0 +1,137 @@
package service
import (
"context"
"encoding/json"
"fmt"
"os"
"strings"
chatpipline "github.com/Tencent/WeKnora/internal/application/service/chat_pipline"
"github.com/Tencent/WeKnora/internal/config"
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
"github.com/google/uuid"
"github.com/hibiken/asynq"
)
func NewChunkExtractTask(ctx context.Context, client *asynq.Client, tenantID uint, chunkID string, modelID string) error {
if strings.ToLower(os.Getenv("NEO4J_ENABLE")) != "true" {
logger.Debugf(ctx, "NOT SUPPORT RETRIEVE GRAPH")
return nil
}
payload, err := json.Marshal(types.ExtractChunkPayload{
TenantID: tenantID,
ChunkID: chunkID,
ModelID: modelID,
})
if err != nil {
return err
}
task := asynq.NewTask(types.TypeChunkExtract, payload, asynq.MaxRetry(3))
info, err := client.Enqueue(task)
if err != nil {
logger.Errorf(ctx, "failed to enqueue task: %v", err)
return fmt.Errorf("failed to enqueue task: %v", err)
}
logger.Infof(ctx, "enqueued task: id=%s queue=%s chunk=%s", info.ID, info.Queue, chunkID)
return nil
}
type ChunkExtractService struct {
template *types.PromptTemplateStructured
modelService interfaces.ModelService
knowledgeBaseRepo interfaces.KnowledgeBaseRepository
chunkRepo interfaces.ChunkRepository
graphEngine interfaces.RetrieveGraphRepository
}
func NewChunkExtractService(
config *config.Config,
modelService interfaces.ModelService,
knowledgeBaseRepo interfaces.KnowledgeBaseRepository,
chunkRepo interfaces.ChunkRepository,
graphEngine interfaces.RetrieveGraphRepository,
) interfaces.Extracter {
generator := chatpipline.NewQAPromptGenerator(chatpipline.NewFormater(), config.ExtractManager.ExtractGraph)
ctx := context.Background()
logger.Debugf(ctx, "chunk extract system prompt: %s", generator.System(ctx))
logger.Debugf(ctx, "chunk extract user prompt: %s", generator.User(ctx, "demo"))
return &ChunkExtractService{
template: config.ExtractManager.ExtractGraph,
modelService: modelService,
knowledgeBaseRepo: knowledgeBaseRepo,
chunkRepo: chunkRepo,
graphEngine: graphEngine,
}
}
func (s *ChunkExtractService) Extract(ctx context.Context, t *asynq.Task) error {
var p types.ExtractChunkPayload
if err := json.Unmarshal(t.Payload(), &p); err != nil {
logger.Errorf(ctx, "failed to unmarshal task payload: %v", err)
return err
}
ctx = logger.WithRequestID(ctx, uuid.New().String())
ctx = logger.WithField(ctx, "extract", p.ChunkID)
ctx = context.WithValue(ctx, types.TenantIDContextKey, p.TenantID)
chunk, err := s.chunkRepo.GetChunkByID(ctx, p.TenantID, p.ChunkID)
if err != nil {
logger.Errorf(ctx, "failed to get chunk: %v", err)
return err
}
kb, err := s.knowledgeBaseRepo.GetKnowledgeBaseByID(ctx, chunk.KnowledgeBaseID)
if err != nil {
logger.Errorf(ctx, "failed to get knowledge base: %v", err)
return err
}
if kb.ExtractConfig == nil {
logger.Warnf(ctx, "failed to get extract config")
return err
}
chatModel, err := s.modelService.GetChatModel(ctx, p.ModelID)
if err != nil {
logger.Errorf(ctx, "failed to get chat model: %v", err)
return err
}
template := &types.PromptTemplateStructured{
Description: s.template.Description,
Tags: kb.ExtractConfig.Tags,
Examples: []types.GraphData{
{
Text: kb.ExtractConfig.Text,
Node: kb.ExtractConfig.Nodes,
Relation: kb.ExtractConfig.Relations,
},
},
}
extractor := chatpipline.NewExtractor(chatModel, template)
graph, err := extractor.Extract(ctx, chunk.Content)
if err != nil {
return err
}
chunk, err = s.chunkRepo.GetChunkByID(ctx, p.TenantID, p.ChunkID)
if err != nil {
logger.Warnf(ctx, "graph ignore chunk %s: %v", p.ChunkID, err)
return nil
}
for _, node := range graph.Node {
node.Chunks = []string{chunk.ID}
}
if err = s.graphEngine.AddGraph(ctx,
types.NameSpace{KnowledgeBase: chunk.KnowledgeBaseID, Knowledge: chunk.KnowledgeID},
[]*types.GraphData{graph},
); err != nil {
logger.Errorf(ctx, "failed to add graph: %v", err)
return err
}
// gg, _ := json.Marshal(graph)
// logger.Infof(ctx, "extracted graph: %s", string(gg))
return nil
}

View File

@@ -23,8 +23,8 @@ type cosFileService struct {
}
// NewCosFileService creates a new COS file service instance
func NewCosFileService(appId, region, secretId, secretKey, cosPathPrefix string) (interfaces.FileService, error) {
bucketURL := fmt.Sprintf("https://%s.cos.%s.myqcloud.com", appId, region)
func NewCosFileService(bucketName, region, secretId, secretKey, cosPathPrefix string) (interfaces.FileService, error) {
bucketURL := fmt.Sprintf("https://%s.cos.%s.myqcloud.com/", bucketName, region)
u, err := url.Parse(bucketURL)
if err != nil {
return nil, fmt.Errorf("failed to parse bucketURL: %w", err)
@@ -59,7 +59,7 @@ func (s *cosFileService) SaveFile(ctx context.Context,
if err != nil {
return "", fmt.Errorf("failed to upload file to COS: %w", err)
}
return fmt.Sprintf("https://%s/%s", s.bucketURL, objectName), nil
return fmt.Sprintf("%s%s", s.bucketURL, objectName), nil
}
// GetFile retrieves a file from COS storage by its path URL

View File

@@ -25,9 +25,11 @@ import (
"github.com/Tencent/WeKnora/internal/tracing"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
secutils "github.com/Tencent/WeKnora/internal/utils"
"github.com/Tencent/WeKnora/services/docreader/src/client"
"github.com/Tencent/WeKnora/services/docreader/src/proto"
"github.com/google/uuid"
"github.com/hibiken/asynq"
"go.opentelemetry.io/otel/attribute"
"golang.org/x/sync/errgroup"
)
@@ -60,6 +62,8 @@ type knowledgeService struct {
chunkRepo interfaces.ChunkRepository
fileSvc interfaces.FileService
modelService interfaces.ModelService
task *asynq.Client
graphEngine interfaces.RetrieveGraphRepository
}
// NewKnowledgeService creates a new knowledge service instance
@@ -73,6 +77,8 @@ func NewKnowledgeService(
chunkRepo interfaces.ChunkRepository,
fileSvc interfaces.FileService,
modelService interfaces.ModelService,
task *asynq.Client,
graphEngine interfaces.RetrieveGraphRepository,
) (interfaces.KnowledgeService, error) {
return &knowledgeService{
config: config,
@@ -84,6 +90,8 @@ func NewKnowledgeService(
chunkRepo: chunkRepo,
fileSvc: fileSvc,
modelService: modelService,
task: task,
graphEngine: graphEngine,
}, nil
}
@@ -191,15 +199,22 @@ func (s *knowledgeService) CreateKnowledgeFromFile(ctx context.Context,
metadataJSON = types.JSON(metadataBytes)
}
// 验证文件名安全性
safeFilename, isValid := secutils.ValidateInput(file.Filename)
if !isValid {
logger.Errorf(ctx, "Invalid filename: %s", file.Filename)
return nil, werrors.NewValidationError("文件名包含非法字符")
}
// Create knowledge record
logger.Info(ctx, "Creating knowledge record")
knowledge := &types.Knowledge{
TenantID: tenantID,
KnowledgeBaseID: kbID,
Type: "file",
Title: file.Filename,
FileName: file.Filename,
FileType: getFileType(file.Filename),
Title: safeFilename,
FileName: safeFilename,
FileType: getFileType(safeFilename),
FileSize: file.Size,
FileHash: hash,
ParseStatus: "pending",
@@ -258,10 +273,10 @@ func (s *knowledgeService) CreateKnowledgeFromURL(ctx context.Context,
return nil, err
}
// Validate URL format
// Validate URL format and security
logger.Info(ctx, "Validating URL")
if !isValidURL(url) {
logger.Error(ctx, "Invalid URL format")
if !isValidURL(url) || !secutils.IsValidURL(url) {
logger.Error(ctx, "Invalid or unsafe URL format")
return nil, ErrInvalidURL
}
@@ -339,6 +354,17 @@ func (s *knowledgeService) CreateKnowledgeFromPassage(ctx context.Context,
logger.Info(ctx, "Start creating knowledge from passage")
logger.Infof(ctx, "Knowledge base ID: %s, passage count: %d", kbID, len(passage))
// 验证段落内容安全性
safePassages := make([]string, 0, len(passage))
for i, p := range passage {
safePassage, isValid := secutils.ValidateInput(p)
if !isValid {
logger.Errorf(ctx, "Invalid passage content at index %d", i)
return nil, werrors.NewValidationError(fmt.Sprintf("段落 %d 包含非法内容", i+1))
}
safePassages = append(safePassages, safePassage)
}
// Get knowledge base configuration
logger.Info(ctx, "Getting knowledge base configuration")
kb, err := s.kbService.GetKnowledgeBaseByID(ctx, kbID)
@@ -370,7 +396,7 @@ func (s *knowledgeService) CreateKnowledgeFromPassage(ctx context.Context,
// Process passages asynchronously
logger.Info(ctx, "Starting asynchronous passage processing")
go s.processDocumentFromPassage(ctx, kb, knowledge, passage)
go s.processDocumentFromPassage(ctx, kb, knowledge, safePassages)
logger.Infof(ctx, "Knowledge from passage created successfully, ID: %s", knowledge.ID)
return knowledge, nil
@@ -469,6 +495,16 @@ func (s *knowledgeService) DeleteKnowledge(ctx context.Context, id string) error
return nil
})
// Delete the knowledge graph
wg.Go(func() error {
namespace := types.NameSpace{KnowledgeBase: knowledge.KnowledgeBaseID, Knowledge: knowledge.ID}
if err := s.graphEngine.DelGraph(ctx, []types.NameSpace{namespace}); err != nil {
logger.GetLogger(ctx).WithField("error", err).Errorf("DeleteKnowledge delete knowledge graph failed")
return err
}
return nil
})
if err = wg.Wait(); err != nil {
return err
}
@@ -542,6 +578,19 @@ func (s *knowledgeService) DeleteKnowledgeList(ctx context.Context, ids []string
return nil
})
// Delete the knowledge graph
wg.Go(func() error {
namespaces := []types.NameSpace{}
for _, knowledge := range knowledgeList {
namespaces = append(namespaces, types.NameSpace{KnowledgeBase: knowledge.KnowledgeBaseID, Knowledge: knowledge.ID})
}
if err := s.graphEngine.DelGraph(ctx, namespaces); err != nil {
logger.GetLogger(ctx).WithField("error", err).Errorf("DeleteKnowledge delete knowledge graph failed")
return err
}
return nil
})
if err = wg.Wait(); err != nil {
return err
}
@@ -1139,6 +1188,15 @@ func (s *knowledgeService) processChunks(ctx context.Context,
}
logger.GetLogger(ctx).Infof("processChunks batch index successfully, with %d index", len(indexInfoList))
logger.Infof(ctx, "processChunks create relationship rag task")
for _, chunk := range textChunks {
err := NewChunkExtractTask(ctx, s.task, chunk.TenantID, chunk.ID, kb.SummaryModelID)
if err != nil {
logger.GetLogger(ctx).WithField("error", err).Errorf("processChunks create chunk extract task failed")
span.RecordError(err)
}
}
// Update knowledge status to completed
knowledge.ParseStatus = "completed"
knowledge.EnableStatus = "enabled"

View File

@@ -2,12 +2,11 @@ package service
import (
"context"
"encoding/json"
"errors"
"slices"
"time"
"encoding/json"
"github.com/Tencent/WeKnora/internal/application/service/retriever"
"github.com/Tencent/WeKnora/internal/common"
"github.com/Tencent/WeKnora/internal/logger"
@@ -376,8 +375,8 @@ func (s *knowledgeBaseService) HybridSearch(ctx context.Context,
// processSearchResults handles the processing of search results, optimizing database queries
func (s *knowledgeBaseService) processSearchResults(ctx context.Context,
chunks []*types.IndexWithScore) ([]*types.SearchResult, error) {
chunks []*types.IndexWithScore,
) ([]*types.SearchResult, error) {
if len(chunks) == 0 {
return nil, nil
}
@@ -527,8 +526,8 @@ func (s *knowledgeBaseService) collectRelatedChunkIDs(chunk *types.Chunk, proces
func (s *knowledgeBaseService) buildSearchResult(chunk *types.Chunk,
knowledge *types.Knowledge,
score float64,
matchType types.MatchType) *types.SearchResult {
matchType types.MatchType,
) *types.SearchResult {
return &types.SearchResult{
ID: chunk.ID,
Content: chunk.Content,

View File

@@ -18,7 +18,9 @@ import (
"github.com/Tencent/WeKnora/internal/types/interfaces"
)
var apiKeySecret = []byte(os.Getenv("TENANT_AES_KEY"))
var apiKeySecret = func() []byte {
return []byte(os.Getenv("TENANT_AES_KEY"))
}
// ListTenantsParams defines parameters for listing tenants with filtering and pagination
type ListTenantsParams struct {
@@ -221,7 +223,7 @@ func (r *tenantService) generateApiKey(tenantID uint) string {
binary.LittleEndian.PutUint64(idBytes, uint64(tenantID))
// 2. Encrypt tenant_id using AES-GCM
block, err := aes.NewCipher(apiKeySecret)
block, err := aes.NewCipher(apiKeySecret())
if err != nil {
panic("Failed to create AES cipher: " + err.Error())
}
@@ -267,7 +269,7 @@ func (r *tenantService) ExtractTenantIDFromAPIKey(apiKey string) (uint, error) {
nonce, ciphertext := encryptedData[:12], encryptedData[12:]
// 4. Decrypt
block, err := aes.NewCipher(apiKeySecret)
block, err := aes.NewCipher(apiKeySecret())
if err != nil {
return 0, errors.New("decryption error")
}

View File

@@ -1,444 +0,0 @@
// Package service 提供应用程序的核心业务逻辑服务层
// 此包包含了知识库管理、用户租户管理、模型服务等核心功能实现
package service
import (
"context"
"fmt"
"os"
"strconv"
"github.com/Tencent/WeKnora/internal/config"
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/models/chat"
"github.com/Tencent/WeKnora/internal/models/embedding"
"github.com/Tencent/WeKnora/internal/models/rerank"
"github.com/Tencent/WeKnora/internal/models/utils/ollama"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
)
// TestDataService 测试数据服务
// 负责初始化测试环境所需的数据,包括创建测试租户、测试知识库
// 以及配置必要的模型服务实例
type TestDataService struct {
config *config.Config // 应用程序配置
kbRepo interfaces.KnowledgeBaseRepository // 知识库存储库接口
tenantService interfaces.TenantService // 租户服务接口
ollamaService *ollama.OllamaService // Ollama模型服务
modelService interfaces.ModelService // 模型服务接口
EmbedModel embedding.Embedder // 嵌入模型实例
RerankModel rerank.Reranker // 重排模型实例
LLMModel chat.Chat // 大语言模型实例
}
// NewTestDataService 创建测试数据服务
// 注入所需的依赖服务和组件
func NewTestDataService(
config *config.Config,
kbRepo interfaces.KnowledgeBaseRepository,
tenantService interfaces.TenantService,
ollamaService *ollama.OllamaService,
modelService interfaces.ModelService,
) *TestDataService {
return &TestDataService{
config: config,
kbRepo: kbRepo,
tenantService: tenantService,
ollamaService: ollamaService,
modelService: modelService,
}
}
// initTenant 初始化测试租户
// 通过环境变量获取租户ID如果租户不存在则创建新租户否则更新现有租户
// 同时配置租户的检索引擎参数
func (s *TestDataService) initTenant(ctx context.Context) error {
logger.Info(ctx, "Start initializing test tenant")
// 从环境变量获取租户ID
tenantID := os.Getenv("INIT_TEST_TENANT_ID")
logger.Infof(ctx, "Test tenant ID from environment: %s", tenantID)
// 将字符串ID转换为uint64
tenantIDUint, err := strconv.ParseUint(tenantID, 10, 64)
if err != nil {
logger.Errorf(ctx, "Failed to parse tenant ID: %v", err)
return err
}
// 创建租户配置
tenantConfig := &types.Tenant{
Name: "Test Tenant",
Description: "Test Tenant for Testing",
RetrieverEngines: types.RetrieverEngines{
Engines: []types.RetrieverEngineParams{
{
RetrieverType: types.KeywordsRetrieverType,
RetrieverEngineType: types.PostgresRetrieverEngineType,
},
{
RetrieverType: types.VectorRetrieverType,
RetrieverEngineType: types.PostgresRetrieverEngineType,
},
},
},
}
// 获取或创建测试租户
logger.Infof(ctx, "Attempting to get tenant with ID: %d", tenantIDUint)
tenant, err := s.tenantService.GetTenantByID(ctx, uint(tenantIDUint))
if err != nil {
// 租户不存在,创建新租户
logger.Info(ctx, "Tenant not found, creating a new test tenant")
tenant, err = s.tenantService.CreateTenant(ctx, tenantConfig)
if err != nil {
logger.Errorf(ctx, "Failed to create tenant: %v", err)
return err
}
logger.Infof(ctx, "Created new test tenant with ID: %d", tenant.ID)
} else {
// 租户存在,更新检索引擎配置
logger.Info(ctx, "Test tenant found, updating retriever engines")
tenant.RetrieverEngines = tenantConfig.RetrieverEngines
tenant, err = s.tenantService.UpdateTenant(ctx, tenant)
if err != nil {
logger.Errorf(ctx, "Failed to update tenant: %v", err)
return err
}
logger.Info(ctx, "Test tenant updated successfully")
}
logger.Infof(ctx, "Test tenant configured - ID: %d, Name: %s, API Key: %s",
tenant.ID, tenant.Name, tenant.APIKey)
return nil
}
// initKnowledgeBase 初始化测试知识库
// 从环境变量获取知识库ID创建或更新知识库
// 配置知识库的分块策略、嵌入模型和摘要模型
func (s *TestDataService) initKnowledgeBase(ctx context.Context) error {
logger.Info(ctx, "Start initializing test knowledge base")
// 检查上下文中的租户ID
if ctx.Value(types.TenantIDContextKey).(uint) == 0 {
logger.Warn(ctx, "Tenant ID is 0, skipping knowledge base initialization")
return nil
}
// 从环境变量获取知识库ID
knowledgeBaseID := os.Getenv("INIT_TEST_KNOWLEDGE_BASE_ID")
logger.Infof(ctx, "Test knowledge base ID from environment: %s", knowledgeBaseID)
// 创建知识库配置
kbConfig := &types.KnowledgeBase{
ID: knowledgeBaseID,
Name: "Test Knowledge Base",
Description: "Knowledge Base for Testing",
TenantID: ctx.Value(types.TenantIDContextKey).(uint),
ChunkingConfig: types.ChunkingConfig{
ChunkSize: s.config.KnowledgeBase.ChunkSize,
ChunkOverlap: s.config.KnowledgeBase.ChunkOverlap,
Separators: s.config.KnowledgeBase.SplitMarkers,
EnableMultimodal: s.config.KnowledgeBase.ImageProcessing.EnableMultimodal,
},
EmbeddingModelID: s.EmbedModel.GetModelID(),
SummaryModelID: s.LLMModel.GetModelID(),
RerankModelID: s.RerankModel.GetModelID(),
}
// 初始化测试知识库
logger.Info(ctx, "Attempting to get existing knowledge base")
_, err := s.kbRepo.GetKnowledgeBaseByID(ctx, knowledgeBaseID)
if err != nil {
// 知识库不存在,创建新知识库
logger.Info(ctx, "Knowledge base not found, creating a new one")
logger.Infof(ctx, "Creating knowledge base with ID: %s, tenant ID: %d",
kbConfig.ID, kbConfig.TenantID)
if err := s.kbRepo.CreateKnowledgeBase(ctx, kbConfig); err != nil {
logger.Errorf(ctx, "Failed to create knowledge base: %v", err)
return err
}
logger.Info(ctx, "Knowledge base created successfully")
} else {
// 知识库存在,更新配置
logger.Info(ctx, "Knowledge base found, updating configuration")
logger.Infof(ctx, "Updating knowledge base with ID: %s", kbConfig.ID)
err = s.kbRepo.UpdateKnowledgeBase(ctx, kbConfig)
if err != nil {
logger.Errorf(ctx, "Failed to update knowledge base: %v", err)
return err
}
logger.Info(ctx, "Knowledge base updated successfully")
}
logger.Infof(ctx, "Test knowledge base configured - ID: %s, Name: %s", kbConfig.ID, kbConfig.Name)
return nil
}
// InitializeTestData 初始化测试数据
// 这是对外暴露的主要方法,负责协调所有测试数据的初始化过程
// 包括初始化租户、嵌入模型、重排模型、LLM模型和知识库
func (s *TestDataService) InitializeTestData(ctx context.Context) error {
logger.Info(ctx, "Start initializing test data")
// 从环境变量获取租户ID
tenantID := os.Getenv("INIT_TEST_TENANT_ID")
logger.Infof(ctx, "Test tenant ID from environment: %s", tenantID)
// 解析租户ID
tenantIDUint, err := strconv.ParseUint(tenantID, 10, 64)
if err != nil {
// 解析失败时使用默认值0
logger.Warn(ctx, "Failed to parse tenant ID, using default value 0")
tenantIDUint = 0
} else {
// 初始化租户
logger.Info(ctx, "Initializing tenant")
err = s.initTenant(ctx)
if err != nil {
logger.Errorf(ctx, "Failed to initialize tenant: %v", err)
return err
}
logger.Info(ctx, "Tenant initialized successfully")
}
// 创建带有租户ID的新上下文
newCtx := context.Background()
newCtx = context.WithValue(newCtx, types.TenantIDContextKey, uint(tenantIDUint))
logger.Infof(ctx, "Created new context with tenant ID: %d", tenantIDUint)
// 初始化模型
modelInitFuncs := []struct {
name string
fn func(context.Context) error
}{
{"embedding model", s.initEmbeddingModel},
{"rerank model", s.initRerankModel},
{"LLM model", s.initLLMModel},
}
for _, initFunc := range modelInitFuncs {
logger.Infof(ctx, "Initializing %s", initFunc.name)
if err := initFunc.fn(newCtx); err != nil {
logger.Errorf(ctx, "Failed to initialize %s: %v", initFunc.name, err)
return err
}
logger.Infof(ctx, "%s initialized successfully", initFunc.name)
}
// 初始化知识库
logger.Info(ctx, "Initializing knowledge base")
if err := s.initKnowledgeBase(newCtx); err != nil {
logger.Errorf(ctx, "Failed to initialize knowledge base: %v", err)
return err
}
logger.Info(ctx, "Knowledge base initialized successfully")
logger.Info(ctx, "Test data initialization completed")
return nil
}
// getEnvOrError 获取环境变量值,如果不存在则返回错误
func (s *TestDataService) getEnvOrError(name string) (string, error) {
value := os.Getenv(name)
if value == "" {
return "", fmt.Errorf("%s environment variable is not set", name)
}
return value, nil
}
// updateOrCreateModel 更新或创建模型
func (s *TestDataService) updateOrCreateModel(ctx context.Context, modelConfig *types.Model) error {
model, err := s.modelService.GetModelByID(ctx, modelConfig.ID)
if err != nil {
// 模型不存在,创建新模型
return s.modelService.CreateModel(ctx, modelConfig)
}
// 模型存在,更新属性
model.TenantID = modelConfig.TenantID
model.Name = modelConfig.Name
model.Source = modelConfig.Source
model.Type = modelConfig.Type
model.Parameters = modelConfig.Parameters
model.Status = modelConfig.Status
return s.modelService.UpdateModel(ctx, model)
}
// initEmbeddingModel 初始化嵌入模型
func (s *TestDataService) initEmbeddingModel(ctx context.Context) error {
// 从环境变量获取模型参数
modelName, err := s.getEnvOrError("INIT_EMBEDDING_MODEL_NAME")
if err != nil {
return err
}
dimensionStr := os.Getenv("INIT_EMBEDDING_MODEL_DIMENSION")
dimension, err := strconv.Atoi(dimensionStr)
if err != nil || dimension == 0 {
return fmt.Errorf("invalid embedding model dimension: %s", dimensionStr)
}
baseURL := os.Getenv("INIT_EMBEDDING_MODEL_BASE_URL")
apiKey := os.Getenv("INIT_EMBEDDING_MODEL_API_KEY")
// 确定模型来源
source := types.ModelSourceRemote
if baseURL == "" {
source = types.ModelSourceLocal
}
// 确定模型ID
modelID := os.Getenv("INIT_EMBEDDING_MODEL_ID")
if modelID == "" {
modelID = fmt.Sprintf("builtin:%s:%d", modelName, dimension)
}
// 创建嵌入模型实例
s.EmbedModel, err = embedding.NewEmbedder(embedding.Config{
Source: source,
BaseURL: baseURL,
ModelName: modelName,
APIKey: apiKey,
Dimensions: dimension,
ModelID: modelID,
})
if err != nil {
return fmt.Errorf("failed to create embedder: %w", err)
}
// 如果是本地模型使用Ollama拉取模型
if source == types.ModelSourceLocal && s.ollamaService != nil {
if err := s.ollamaService.PullModel(context.Background(), modelName); err != nil {
return fmt.Errorf("failed to pull embedding model: %w", err)
}
}
// 创建模型配置
modelConfig := &types.Model{
ID: modelID,
TenantID: ctx.Value(types.TenantIDContextKey).(uint),
Name: modelName,
Source: source,
Type: types.ModelTypeEmbedding,
Parameters: types.ModelParameters{
BaseURL: baseURL,
APIKey: apiKey,
EmbeddingParameters: types.EmbeddingParameters{
Dimension: dimension,
},
},
Status: "active",
}
// 更新或创建模型
return s.updateOrCreateModel(ctx, modelConfig)
}
// initRerankModel 初始化重排模型
func (s *TestDataService) initRerankModel(ctx context.Context) error {
// 从环境变量获取模型参数
modelName, err := s.getEnvOrError("INIT_RERANK_MODEL_NAME")
if err != nil {
logger.Warnf(ctx, "Skip Rerank Model: %v", err)
return nil
}
baseURL, err := s.getEnvOrError("INIT_RERANK_MODEL_BASE_URL")
if err != nil {
return err
}
apiKey := os.Getenv("INIT_RERANK_MODEL_API_KEY")
modelID := fmt.Sprintf("builtin:%s:rerank:%s", types.ModelSourceRemote, modelName)
// 创建重排模型实例
s.RerankModel, err = rerank.NewReranker(&rerank.RerankerConfig{
Source: types.ModelSourceRemote,
BaseURL: baseURL,
ModelName: modelName,
APIKey: apiKey,
ModelID: modelID,
})
if err != nil {
return fmt.Errorf("failed to create reranker: %w", err)
}
// 创建模型配置
modelConfig := &types.Model{
ID: modelID,
TenantID: ctx.Value(types.TenantIDContextKey).(uint),
Name: modelName,
Source: types.ModelSourceRemote,
Type: types.ModelTypeRerank,
Parameters: types.ModelParameters{
BaseURL: baseURL,
APIKey: apiKey,
},
Status: "active",
}
// 更新或创建模型
return s.updateOrCreateModel(ctx, modelConfig)
}
// initLLMModel 初始化大语言模型
func (s *TestDataService) initLLMModel(ctx context.Context) error {
// 从环境变量获取模型参数
modelName, err := s.getEnvOrError("INIT_LLM_MODEL_NAME")
if err != nil {
return err
}
baseURL := os.Getenv("INIT_LLM_MODEL_BASE_URL")
apiKey := os.Getenv("INIT_LLM_MODEL_API_KEY")
// 确定模型来源
source := types.ModelSourceRemote
if baseURL == "" {
source = types.ModelSourceLocal
}
// 确定模型ID
modelID := fmt.Sprintf("builtin:%s:llm:%s", source, modelName)
// 创建大语言模型实例
s.LLMModel, err = chat.NewChat(&chat.ChatConfig{
Source: source,
BaseURL: baseURL,
ModelName: modelName,
APIKey: apiKey,
ModelID: modelID,
})
if err != nil {
return fmt.Errorf("failed to create llm: %w", err)
}
// 如果是本地模型使用Ollama拉取模型
if source == types.ModelSourceLocal && s.ollamaService != nil {
if err := s.ollamaService.PullModel(context.Background(), modelName); err != nil {
return fmt.Errorf("failed to pull llm model: %w", err)
}
}
// 创建模型配置
modelConfig := &types.Model{
ID: modelID,
TenantID: ctx.Value(types.TenantIDContextKey).(uint),
Name: modelName,
Source: source,
Type: types.ModelTypeKnowledgeQA,
Parameters: types.ModelParameters{
BaseURL: baseURL,
APIKey: apiKey,
},
Status: "active",
}
// 更新或创建模型
return s.updateOrCreateModel(ctx, modelConfig)
}

View File

@@ -0,0 +1,449 @@
package service
import (
"context"
"errors"
"fmt"
"os"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
)
// JWT secret key - in production this should be from environment variable
var jwtSecret = []byte("your-secret-key")
// userService implements the UserService interface
type userService struct {
userRepo interfaces.UserRepository
tokenRepo interfaces.AuthTokenRepository
tenantService interfaces.TenantService
}
// NewUserService creates a new user service instance
func NewUserService(userRepo interfaces.UserRepository, tokenRepo interfaces.AuthTokenRepository, tenantService interfaces.TenantService) interfaces.UserService {
return &userService{
userRepo: userRepo,
tokenRepo: tokenRepo,
tenantService: tenantService,
}
}
var engine = map[string][]types.RetrieverEngineParams{
"postgres": {
{
RetrieverType: types.KeywordsRetrieverType,
RetrieverEngineType: types.PostgresRetrieverEngineType,
},
{
RetrieverType: types.VectorRetrieverType,
RetrieverEngineType: types.PostgresRetrieverEngineType,
},
},
"elasticsearch_v7": {
{
RetrieverType: types.KeywordsRetrieverType,
RetrieverEngineType: types.ElasticsearchRetrieverEngineType,
},
},
"elasticsearch_v8": {
{
RetrieverType: types.KeywordsRetrieverType,
RetrieverEngineType: types.ElasticsearchRetrieverEngineType,
},
{
RetrieverType: types.VectorRetrieverType,
RetrieverEngineType: types.ElasticsearchRetrieverEngineType,
},
},
}
// Register creates a new user account
func (s *userService) Register(ctx context.Context, req *types.RegisterRequest) (*types.User, error) {
logger.Info(ctx, "Start user registration")
// Validate input
if req.Username == "" || req.Email == "" || req.Password == "" {
return nil, errors.New("username, email and password are required")
}
// Check if user already exists
existingUser, _ := s.userRepo.GetUserByEmail(ctx, req.Email)
if existingUser != nil {
return nil, errors.New("user with this email already exists")
}
existingUser, _ = s.userRepo.GetUserByUsername(ctx, req.Username)
if existingUser != nil {
return nil, errors.New("user with this username already exists")
}
// Hash password
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
if err != nil {
logger.Errorf(ctx, "Failed to hash password: %v", err)
return nil, errors.New("failed to process password")
}
egs := []types.RetrieverEngineParams{}
for _, driver := range strings.Split(os.Getenv("RETRIEVE_DRIVER"), ",") {
if val, ok := engine[driver]; ok {
egs = append(egs, val...)
}
}
egs = uniqueRetrieverEngine(egs)
logger.Debugf(ctx, "user register retriever engines: %v", egs)
// Create default tenant for the user
tenant := &types.Tenant{
Name: fmt.Sprintf("%s's Workspace", req.Username),
Description: "Default workspace",
Status: "active",
RetrieverEngines: types.RetrieverEngines{Engines: egs},
}
createdTenant, err := s.tenantService.CreateTenant(ctx, tenant)
if err != nil {
logger.Errorf(ctx, "Failed to create tenant: %v", err)
return nil, errors.New("failed to create workspace")
}
// Create user
user := &types.User{
ID: uuid.New().String(),
Username: req.Username,
Email: req.Email,
PasswordHash: string(hashedPassword),
TenantID: createdTenant.ID,
IsActive: true,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
err = s.userRepo.CreateUser(ctx, user)
if err != nil {
logger.Errorf(ctx, "Failed to create user: %v", err)
return nil, errors.New("failed to create user")
}
logger.Infof(ctx, "User registered successfully: %s", user.Email)
return user, nil
}
// Login authenticates a user and returns tokens
func (s *userService) Login(ctx context.Context, req *types.LoginRequest) (*types.LoginResponse, error) {
logger.Infof(ctx, "Start user login for email: %s", req.Email)
// Get user by email
user, err := s.userRepo.GetUserByEmail(ctx, req.Email)
if err != nil {
logger.Errorf(ctx, "Failed to get user by email %s: %v", req.Email, err)
return &types.LoginResponse{
Success: false,
Message: "Invalid email or password",
}, nil
}
if user == nil {
logger.Warnf(ctx, "User not found for email: %s", req.Email)
return &types.LoginResponse{
Success: false,
Message: "Invalid email or password",
}, nil
}
logger.Infof(ctx, "Found user: ID=%s, Email=%s, IsActive=%t", user.ID, user.Email, user.IsActive)
// Check if user is active
if !user.IsActive {
logger.Warnf(ctx, "User account is disabled for email: %s", req.Email)
return &types.LoginResponse{
Success: false,
Message: "Account is disabled",
}, nil
}
// Verify password
logger.Infof(ctx, "Verifying password for user: %s", user.Email)
err = bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.Password))
if err != nil {
logger.Warnf(ctx, "Password verification failed for user %s: %v", user.Email, err)
return &types.LoginResponse{
Success: false,
Message: "Invalid email or password",
}, nil
}
logger.Infof(ctx, "Password verification successful for user: %s", user.Email)
// Generate tokens
logger.Infof(ctx, "Generating tokens for user: %s", user.Email)
accessToken, refreshToken, err := s.GenerateTokens(ctx, user)
if err != nil {
logger.Errorf(ctx, "Failed to generate tokens for user %s: %v", user.Email, err)
return &types.LoginResponse{
Success: false,
Message: "Login failed",
}, nil
}
logger.Infof(ctx, "Tokens generated successfully for user: %s", user.Email)
// Get tenant information
logger.Infof(ctx, "Getting tenant information for user %s, tenant ID: %s", user.Email, user.TenantID)
tenant, err := s.tenantService.GetTenantByID(ctx, user.TenantID)
if err != nil {
logger.Warnf(ctx, "Failed to get tenant info for user %s, tenant ID %s: %v", user.Email, user.TenantID, err)
} else {
logger.Infof(ctx, "Tenant information retrieved successfully for user: %s", user.Email)
}
logger.Infof(ctx, "User logged in successfully: %s", user.Email)
return &types.LoginResponse{
Success: true,
Message: "Login successful",
User: user,
Tenant: tenant,
Token: accessToken,
RefreshToken: refreshToken,
}, nil
}
// GetUserByID gets a user by ID
func (s *userService) GetUserByID(ctx context.Context, id string) (*types.User, error) {
return s.userRepo.GetUserByID(ctx, id)
}
// GetUserByEmail gets a user by email
func (s *userService) GetUserByEmail(ctx context.Context, email string) (*types.User, error) {
return s.userRepo.GetUserByEmail(ctx, email)
}
// GetUserByUsername gets a user by username
func (s *userService) GetUserByUsername(ctx context.Context, username string) (*types.User, error) {
return s.userRepo.GetUserByUsername(ctx, username)
}
// UpdateUser updates user information
func (s *userService) UpdateUser(ctx context.Context, user *types.User) error {
user.UpdatedAt = time.Now()
return s.userRepo.UpdateUser(ctx, user)
}
// DeleteUser deletes a user
func (s *userService) DeleteUser(ctx context.Context, id string) error {
return s.userRepo.DeleteUser(ctx, id)
}
// ChangePassword changes user password
func (s *userService) ChangePassword(ctx context.Context, userID string, oldPassword, newPassword string) error {
user, err := s.userRepo.GetUserByID(ctx, userID)
if err != nil {
return err
}
// Verify old password
err = bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(oldPassword))
if err != nil {
return errors.New("invalid old password")
}
// Hash new password
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost)
if err != nil {
return err
}
user.PasswordHash = string(hashedPassword)
user.UpdatedAt = time.Now()
return s.userRepo.UpdateUser(ctx, user)
}
// ValidatePassword validates user password
func (s *userService) ValidatePassword(ctx context.Context, userID string, password string) error {
user, err := s.userRepo.GetUserByID(ctx, userID)
if err != nil {
return err
}
return bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password))
}
// GenerateTokens generates access and refresh tokens for user
func (s *userService) GenerateTokens(ctx context.Context, user *types.User) (accessToken, refreshToken string, err error) {
// Generate access token (expires in 24 hours)
accessClaims := jwt.MapClaims{
"user_id": user.ID,
"email": user.Email,
"tenant_id": user.TenantID,
"exp": time.Now().Add(24 * time.Hour).Unix(),
"iat": time.Now().Unix(),
"type": "access",
}
accessTokenObj := jwt.NewWithClaims(jwt.SigningMethodHS256, accessClaims)
accessToken, err = accessTokenObj.SignedString(jwtSecret)
if err != nil {
return "", "", err
}
// Generate refresh token (expires in 7 days)
refreshClaims := jwt.MapClaims{
"user_id": user.ID,
"exp": time.Now().Add(7 * 24 * time.Hour).Unix(),
"iat": time.Now().Unix(),
"type": "refresh",
}
refreshTokenObj := jwt.NewWithClaims(jwt.SigningMethodHS256, refreshClaims)
refreshToken, err = refreshTokenObj.SignedString(jwtSecret)
if err != nil {
return "", "", err
}
// Store tokens in database
accessTokenRecord := &types.AuthToken{
ID: uuid.New().String(),
UserID: user.ID,
Token: accessToken,
TokenType: "access_token",
ExpiresAt: time.Now().Add(24 * time.Hour),
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
refreshTokenRecord := &types.AuthToken{
ID: uuid.New().String(),
UserID: user.ID,
Token: refreshToken,
TokenType: "refresh_token",
ExpiresAt: time.Now().Add(7 * 24 * time.Hour),
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
_ = s.tokenRepo.CreateToken(ctx, accessTokenRecord)
_ = s.tokenRepo.CreateToken(ctx, refreshTokenRecord)
return accessToken, refreshToken, nil
}
// ValidateToken validates an access token
func (s *userService) ValidateToken(ctx context.Context, tokenString string) (*types.User, error) {
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return jwtSecret, nil
})
if err != nil || !token.Valid {
return nil, errors.New("invalid token")
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return nil, errors.New("invalid token claims")
}
userID, ok := claims["user_id"].(string)
if !ok {
return nil, errors.New("invalid user ID in token")
}
// Check if token is revoked
tokenRecord, err := s.tokenRepo.GetTokenByValue(ctx, tokenString)
if err != nil || tokenRecord == nil || tokenRecord.IsRevoked {
return nil, errors.New("token is revoked")
}
return s.userRepo.GetUserByID(ctx, userID)
}
// RefreshToken refreshes access token using refresh token
func (s *userService) RefreshToken(ctx context.Context, refreshTokenString string) (accessToken, newRefreshToken string, err error) {
token, err := jwt.Parse(refreshTokenString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return jwtSecret, nil
})
if err != nil || !token.Valid {
return "", "", errors.New("invalid refresh token")
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return "", "", errors.New("invalid token claims")
}
tokenType, ok := claims["type"].(string)
if !ok || tokenType != "refresh" {
return "", "", errors.New("not a refresh token")
}
userID, ok := claims["user_id"].(string)
if !ok {
return "", "", errors.New("invalid user ID in token")
}
// Check if token is revoked
tokenRecord, err := s.tokenRepo.GetTokenByValue(ctx, refreshTokenString)
if err != nil || tokenRecord == nil || tokenRecord.IsRevoked {
return "", "", errors.New("refresh token is revoked")
}
// Get user
user, err := s.userRepo.GetUserByID(ctx, userID)
if err != nil {
return "", "", err
}
// Revoke old refresh token
tokenRecord.IsRevoked = true
_ = s.tokenRepo.UpdateToken(ctx, tokenRecord)
// Generate new tokens
return s.GenerateTokens(ctx, user)
}
// RevokeToken revokes a token
func (s *userService) RevokeToken(ctx context.Context, tokenString string) error {
tokenRecord, err := s.tokenRepo.GetTokenByValue(ctx, tokenString)
if err != nil {
return err
}
tokenRecord.IsRevoked = true
tokenRecord.UpdatedAt = time.Now()
return s.tokenRepo.UpdateToken(ctx, tokenRecord)
}
// GetCurrentUser gets current user from context
func (s *userService) GetCurrentUser(ctx context.Context) (*types.User, error) {
user, ok := ctx.Value("user").(*types.User)
if !ok {
return nil, errors.New("user not found in context")
}
return user, nil
}
func uniqueRetrieverEngine(engine []types.RetrieverEngineParams) []types.RetrieverEngineParams {
seen := make(map[types.RetrieverEngineParams]bool)
var result []types.RetrieverEngineParams
for _, v := range engine {
if !seen[v] {
seen[v] = true
result = append(result, v)
}
}
return result
}

View File

@@ -1,72 +0,0 @@
package common
import (
"log"
"github.com/Tencent/WeKnora/internal/config"
"github.com/hibiken/asynq"
)
// client is the global asyncq client instance
var client *asynq.Client
// InitAsyncq initializes the asyncq client with configuration
// It creates a new client and starts the server in a goroutine
func InitAsyncq(config *config.Config) error {
cfg := config.Asynq
client = asynq.NewClient(asynq.RedisClientOpt{
Addr: cfg.Addr,
Username: cfg.Username,
Password: cfg.Password,
ReadTimeout: cfg.ReadTimeout,
WriteTimeout: cfg.WriteTimeout,
})
go run(cfg)
return nil
}
// GetAsyncqClient returns the global asyncq client instance
func GetAsyncqClient() *asynq.Client {
return client
}
// handleFunc stores registered task handlers
var handleFunc = map[string]asynq.HandlerFunc{}
// RegisterHandlerFunc registers a handler function for a specific task type
func RegisterHandlerFunc(taskType string, handlerFunc asynq.HandlerFunc) {
handleFunc[taskType] = handlerFunc
}
// run starts the asyncq server with the given configuration
// It creates a new server, sets up handlers, and runs the server
func run(config *config.AsynqConfig) {
srv := asynq.NewServer(
asynq.RedisClientOpt{
Addr: config.Addr,
Username: config.Username,
Password: config.Password,
ReadTimeout: config.ReadTimeout,
WriteTimeout: config.WriteTimeout,
},
asynq.Config{
Concurrency: config.Concurrency,
Queues: map[string]int{
"critical": 6, // Highest priority queue
"default": 3, // Default priority queue
"low": 1, // Lowest priority queue
},
},
)
// Create a new mux and register all handlers
mux := asynq.NewServeMux()
for typ, handler := range handleFunc {
mux.HandleFunc(typ, handler)
}
// Start the server
if err := srv.Run(mux); err != nil {
log.Fatalf("could not run server: %v", err)
}
}

View File

@@ -7,6 +7,7 @@ import (
"strings"
"time"
"github.com/Tencent/WeKnora/internal/types"
"github.com/go-viper/mapstructure/v2"
"github.com/spf13/viper"
)
@@ -18,10 +19,10 @@ type Config struct {
KnowledgeBase *KnowledgeBaseConfig `yaml:"knowledge_base" json:"knowledge_base"`
Tenant *TenantConfig `yaml:"tenant" json:"tenant"`
Models []ModelConfig `yaml:"models" json:"models"`
Asynq *AsynqConfig `yaml:"asynq" json:"asynq"`
VectorDatabase *VectorDatabaseConfig `yaml:"vector_database" json:"vector_database"`
DocReader *DocReaderConfig `yaml:"docreader" json:"docreader"`
StreamManager *StreamManagerConfig `yaml:"stream_manager" json:"stream_manager"`
ExtractManager *ExtractManagerConfig `yaml:"extract" json:"extract"`
}
type DocReaderConfig struct {
@@ -109,15 +110,6 @@ type ModelConfig struct {
Parameters map[string]interface{} `yaml:"parameters" json:"parameters"`
}
type AsynqConfig struct {
Addr string `yaml:"addr" json:"addr"`
Username string `yaml:"username" json:"username"`
Password string `yaml:"password" json:"password"`
ReadTimeout time.Duration `yaml:"read_timeout" json:"read_timeout"`
WriteTimeout time.Duration `yaml:"write_timeout" json:"write_timeout"`
Concurrency int `yaml:"concurrency" json:"concurrency"`
}
// StreamManagerConfig 流管理器配置
type StreamManagerConfig struct {
Type string `yaml:"type" json:"type"` // 类型: "memory" 或 "redis"
@@ -134,6 +126,18 @@ type RedisConfig struct {
TTL time.Duration `yaml:"ttl" json:"ttl"` // 过期时间(小时)
}
// ExtractManagerConfig 抽取管理器配置
type ExtractManagerConfig struct {
ExtractGraph *types.PromptTemplateStructured `yaml:"extract_graph" json:"extract_graph"`
ExtractEntity *types.PromptTemplateStructured `yaml:"extract_entity" json:"extract_entity"`
FabriText *FebriText `yaml:"fabri_text" json:"fabri_text"`
}
type FebriText struct {
WithTag string `yaml:"with_tag" json:"with_tag"`
WithNoTag string `yaml:"with_no_tag" json:"with_no_tag"`
}
// LoadConfig 从配置文件加载配置
func LoadConfig() (*Config, error) {
// 设置配置文件名和路径

View File

@@ -14,6 +14,7 @@ import (
esv7 "github.com/elastic/go-elasticsearch/v7"
"github.com/elastic/go-elasticsearch/v8"
"github.com/neo4j/neo4j-go-driver/v6/neo4j"
"github.com/panjf2000/ants/v2"
"go.uber.org/dig"
"gorm.io/driver/postgres"
@@ -22,6 +23,7 @@ import (
"github.com/Tencent/WeKnora/internal/application/repository"
elasticsearchRepoV7 "github.com/Tencent/WeKnora/internal/application/repository/retriever/elasticsearch/v7"
elasticsearchRepoV8 "github.com/Tencent/WeKnora/internal/application/repository/retriever/elasticsearch/v8"
neo4jRepo "github.com/Tencent/WeKnora/internal/application/repository/retriever/neo4j"
postgresRepo "github.com/Tencent/WeKnora/internal/application/repository/retriever/postgres"
"github.com/Tencent/WeKnora/internal/application/service"
chatpipline "github.com/Tencent/WeKnora/internal/application/service/chat_pipline"
@@ -68,6 +70,7 @@ func BuildContainer(container *dig.Container) *dig.Container {
// External service clients
must(container.Provide(initDocReaderClient))
must(container.Provide(initOllamaService))
must(container.Provide(initNeo4jClient))
must(container.Provide(stream.NewStreamManager))
// Data repositories layer
@@ -78,6 +81,9 @@ func BuildContainer(container *dig.Container) *dig.Container {
must(container.Provide(repository.NewSessionRepository))
must(container.Provide(repository.NewMessageRepository))
must(container.Provide(repository.NewModelRepository))
must(container.Provide(repository.NewUserRepository))
must(container.Provide(repository.NewAuthTokenRepository))
must(container.Provide(neo4jRepo.NewNeo4jRepository))
// Business service layer
must(container.Provide(service.NewTenantService))
@@ -87,10 +93,11 @@ func BuildContainer(container *dig.Container) *dig.Container {
must(container.Provide(service.NewMessageService))
must(container.Provide(service.NewChunkService))
must(container.Provide(embedding.NewBatchEmbedder))
must(container.Provide(service.NewTestDataService))
must(container.Provide(service.NewModelService))
must(container.Provide(service.NewDatasetService))
must(container.Provide(service.NewEvaluationService))
must(container.Provide(service.NewUserService))
must(container.Provide(service.NewChunkExtractService))
// Chat pipeline components for processing chat requests
must(container.Provide(chatpipline.NewEventManager))
@@ -105,6 +112,8 @@ func BuildContainer(container *dig.Container) *dig.Container {
must(container.Invoke(chatpipline.NewPluginFilterTopK))
must(container.Invoke(chatpipline.NewPluginPreprocess))
must(container.Invoke(chatpipline.NewPluginRewrite))
must(container.Invoke(chatpipline.NewPluginExtractEntity))
must(container.Invoke(chatpipline.NewPluginSearchEntity))
// HTTP handlers layer
must(container.Provide(handler.NewTenantHandler))
@@ -113,13 +122,17 @@ func BuildContainer(container *dig.Container) *dig.Container {
must(container.Provide(handler.NewChunkHandler))
must(container.Provide(handler.NewSessionHandler))
must(container.Provide(handler.NewMessageHandler))
must(container.Provide(handler.NewTestDataHandler))
must(container.Provide(handler.NewModelHandler))
must(container.Provide(handler.NewEvaluationHandler))
must(container.Provide(handler.NewInitializationHandler))
must(container.Provide(handler.NewAuthHandler))
must(container.Provide(handler.NewSystemHandler))
// Router configuration
must(container.Provide(router.NewRouter))
must(container.Provide(router.NewAsyncqClient))
must(container.Provide(router.NewAsynqServer))
must(container.Invoke(router.RunAsynqServer))
return container
}
@@ -177,6 +190,16 @@ func initDatabase(cfg *config.Config) (*gorm.DB, error) {
return nil, err
}
// Auto-migrate database tables
err = db.AutoMigrate(
&types.User{},
&types.AuthToken{},
&types.KnowledgeBase{},
)
if err != nil {
return nil, fmt.Errorf("failed to auto-migrate database tables: %v", err)
}
// Get underlying SQL DB object
sqlDB, err := db.DB()
if err != nil {
@@ -216,7 +239,7 @@ func initFileService(cfg *config.Config) (interfaces.FileService, error) {
false,
)
case "cos":
if os.Getenv("COS_APP_ID") == "" ||
if os.Getenv("COS_BUCKET_NAME") == "" ||
os.Getenv("COS_REGION") == "" ||
os.Getenv("COS_SECRET_ID") == "" ||
os.Getenv("COS_SECRET_KEY") == "" ||
@@ -224,7 +247,7 @@ func initFileService(cfg *config.Config) (interfaces.FileService, error) {
return nil, fmt.Errorf("missing COS configuration")
}
return file.NewCosFileService(
os.Getenv("COS_APP_ID"),
os.Getenv("COS_BUCKET_NAME"),
os.Getenv("COS_REGION"),
os.Getenv("COS_SECRET_ID"),
os.Getenv("COS_SECRET_KEY"),
@@ -373,3 +396,23 @@ func initOllamaService() (*ollama.OllamaService, error) {
// Get Ollama service from existing factory function
return ollama.GetOllamaService()
}
func initNeo4jClient() (neo4j.Driver, error) {
if strings.ToLower(os.Getenv("NEO4J_ENABLE")) != "true" {
logger.Debugf(context.Background(), "NOT SUPPORT RETRIEVE GRAPH")
return nil, nil
}
uri := os.Getenv("NEO4J_URI")
username := os.Getenv("NEO4J_USERNAME")
password := os.Getenv("NEO4J_PASSWORD")
driver, err := neo4j.NewDriver(uri, neo4j.BasicAuth(username, password, ""))
if err != nil {
return nil, err
}
err = driver.VerifyAuthentication(context.Background(), nil)
if err != nil {
return nil, err
}
return driver, nil
}

343
internal/handler/auth.go Normal file
View File

@@ -0,0 +1,343 @@
package handler
import (
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/Tencent/WeKnora/internal/errors"
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
)
// AuthHandler implements HTTP request handlers for user authentication
// Provides functionality for user registration, login, logout, and token management
// through the REST API endpoints
type AuthHandler struct {
userService interfaces.UserService
tenantService interfaces.TenantService
}
// NewAuthHandler creates a new auth handler instance with the provided services
// Parameters:
// - userService: An implementation of the UserService interface for business logic
// - tenantService: An implementation of the TenantService interface for tenant management
//
// Returns a pointer to the newly created AuthHandler
func NewAuthHandler(userService interfaces.UserService, tenantService interfaces.TenantService) *AuthHandler {
return &AuthHandler{
userService: userService,
tenantService: tenantService,
}
}
// Register handles the HTTP request for user registration
// It deserializes the request body into a registration request object, validates it,
// calls the service to create the user, and returns the result
// Parameters:
// - c: Gin context for the HTTP request
func (h *AuthHandler) Register(c *gin.Context) {
ctx := c.Request.Context()
logger.Info(ctx, "Start user registration")
var req types.RegisterRequest
if err := c.ShouldBindJSON(&req); err != nil {
logger.Error(ctx, "Failed to parse registration request parameters", err)
appErr := errors.NewValidationError("Invalid registration parameters").WithDetails(err.Error())
c.Error(appErr)
return
}
// Validate required fields
if req.Username == "" || req.Email == "" || req.Password == "" {
logger.Error(ctx, "Missing required registration fields")
appErr := errors.NewValidationError("Username, email and password are required")
c.Error(appErr)
return
}
// Call service to register user
user, err := h.userService.Register(ctx, &req)
if err != nil {
logger.Errorf(ctx, "Failed to register user: %v", err)
appErr := errors.NewBadRequestError("Registration failed").WithDetails(err.Error())
c.Error(appErr)
return
}
// Return success response
response := &types.RegisterResponse{
Success: true,
Message: "Registration successful",
User: user,
}
logger.Infof(ctx, "User registered successfully: %s", user.Email)
c.JSON(http.StatusCreated, response)
}
// Login handles the HTTP request for user login
// It deserializes the request body into a login request object, validates it,
// calls the service to authenticate the user, and returns tokens
// Parameters:
// - c: Gin context for the HTTP request
func (h *AuthHandler) Login(c *gin.Context) {
ctx := c.Request.Context()
logger.Info(ctx, "Start user login")
var req types.LoginRequest
if err := c.ShouldBindJSON(&req); err != nil {
logger.Error(ctx, "Failed to parse login request parameters", err)
appErr := errors.NewValidationError("Invalid login parameters").WithDetails(err.Error())
c.Error(appErr)
return
}
// Validate required fields
if req.Email == "" || req.Password == "" {
logger.Error(ctx, "Missing required login fields")
appErr := errors.NewValidationError("Email and password are required")
c.Error(appErr)
return
}
// Call service to authenticate user
response, err := h.userService.Login(ctx, &req)
if err != nil {
logger.Errorf(ctx, "Failed to login user: %v", err)
appErr := errors.NewUnauthorizedError("Login failed").WithDetails(err.Error())
c.Error(appErr)
return
}
// Check if login was successful
if !response.Success {
logger.Warnf(ctx, "Login failed: %s", response.Message)
c.JSON(http.StatusUnauthorized, response)
return
}
// User is already in the correct format from service
logger.Infof(ctx, "User logged in successfully: %s", req.Email)
c.JSON(http.StatusOK, response)
}
// Logout handles the HTTP request for user logout
// It extracts the token from the Authorization header and revokes it
// Parameters:
// - c: Gin context for the HTTP request
func (h *AuthHandler) Logout(c *gin.Context) {
ctx := c.Request.Context()
logger.Info(ctx, "Start user logout")
// Extract token from Authorization header
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
logger.Error(ctx, "Missing Authorization header")
appErr := errors.NewValidationError("Authorization header is required")
c.Error(appErr)
return
}
// Parse Bearer token
tokenParts := strings.Split(authHeader, " ")
if len(tokenParts) != 2 || tokenParts[0] != "Bearer" {
logger.Error(ctx, "Invalid Authorization header format")
appErr := errors.NewValidationError("Invalid Authorization header format")
c.Error(appErr)
return
}
token := tokenParts[1]
// Revoke token
err := h.userService.RevokeToken(ctx, token)
if err != nil {
logger.Errorf(ctx, "Failed to revoke token: %v", err)
appErr := errors.NewInternalServerError("Logout failed").WithDetails(err.Error())
c.Error(appErr)
return
}
logger.Info(ctx, "User logged out successfully")
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "Logout successful",
})
}
// RefreshToken handles the HTTP request for refreshing access tokens
// It extracts the refresh token from the request body and generates new tokens
// Parameters:
// - c: Gin context for the HTTP request
func (h *AuthHandler) RefreshToken(c *gin.Context) {
ctx := c.Request.Context()
logger.Info(ctx, "Start token refresh")
var req struct {
RefreshToken string `json:"refreshToken" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
logger.Error(ctx, "Failed to parse refresh token request", err)
appErr := errors.NewValidationError("Invalid refresh token request").WithDetails(err.Error())
c.Error(appErr)
return
}
// Call service to refresh token
accessToken, newRefreshToken, err := h.userService.RefreshToken(ctx, req.RefreshToken)
if err != nil {
logger.Errorf(ctx, "Failed to refresh token: %v", err)
appErr := errors.NewUnauthorizedError("Token refresh failed").WithDetails(err.Error())
c.Error(appErr)
return
}
logger.Info(ctx, "Token refreshed successfully")
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "Token refreshed successfully",
"access_token": accessToken,
"refresh_token": newRefreshToken,
})
}
// GetCurrentUser handles the HTTP request for getting current user information
// It extracts the user from the context (set by auth middleware) and returns user info
// Parameters:
// - c: Gin context for the HTTP request
func (h *AuthHandler) GetCurrentUser(c *gin.Context) {
ctx := c.Request.Context()
logger.Debugf(ctx, "Get current user info")
// Get current user from service (which extracts from context)
user, err := h.userService.GetCurrentUser(ctx)
if err != nil {
logger.Errorf(ctx, "Failed to get current user: %v", err)
appErr := errors.NewUnauthorizedError("Failed to get user information").WithDetails(err.Error())
c.Error(appErr)
return
}
// Get tenant information
var tenant *types.Tenant
if user.TenantID > 0 {
tenant, err = h.tenantService.GetTenantByID(ctx, user.TenantID)
if err != nil {
logger.Warnf(ctx, "Failed to get tenant info for user %s, tenant ID %d: %v", user.Email, user.TenantID, err)
// Don't fail the request if tenant info is not available
} else {
logger.Debugf(ctx, "Retrieved tenant info for user %s: %s", user.Email, tenant.Name)
}
}
logger.Debugf(ctx, "Retrieved current user info: %s", user.Email)
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": gin.H{
"user": user.ToUserInfo(),
"tenant": tenant,
},
})
}
// ChangePassword handles the HTTP request for changing user password
// It extracts the current user and validates the old password before setting new one
// Parameters:
// - c: Gin context for the HTTP request
func (h *AuthHandler) ChangePassword(c *gin.Context) {
ctx := c.Request.Context()
logger.Info(ctx, "Start password change")
var req struct {
OldPassword string `json:"old_password" binding:"required"`
NewPassword string `json:"new_password" binding:"required,min=6"`
}
if err := c.ShouldBindJSON(&req); err != nil {
logger.Error(ctx, "Failed to parse password change request", err)
appErr := errors.NewValidationError("Invalid password change request").WithDetails(err.Error())
c.Error(appErr)
return
}
// Get current user
user, err := h.userService.GetCurrentUser(ctx)
if err != nil {
logger.Errorf(ctx, "Failed to get current user: %v", err)
appErr := errors.NewUnauthorizedError("Failed to get user information").WithDetails(err.Error())
c.Error(appErr)
return
}
// Change password
err = h.userService.ChangePassword(ctx, user.ID, req.OldPassword, req.NewPassword)
if err != nil {
logger.Errorf(ctx, "Failed to change password: %v", err)
appErr := errors.NewBadRequestError("Password change failed").WithDetails(err.Error())
c.Error(appErr)
return
}
logger.Infof(ctx, "Password changed successfully for user: %s", user.Email)
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "Password changed successfully",
})
}
// ValidateToken handles the HTTP request for validating access tokens
// It extracts the token from the Authorization header and validates it
// Parameters:
// - c: Gin context for the HTTP request
func (h *AuthHandler) ValidateToken(c *gin.Context) {
ctx := c.Request.Context()
logger.Info(ctx, "Start token validation")
// Extract token from Authorization header
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
logger.Error(ctx, "Missing Authorization header")
appErr := errors.NewValidationError("Authorization header is required")
c.Error(appErr)
return
}
// Parse Bearer token
tokenParts := strings.Split(authHeader, " ")
if len(tokenParts) != 2 || tokenParts[0] != "Bearer" {
logger.Error(ctx, "Invalid Authorization header format")
appErr := errors.NewValidationError("Invalid Authorization header format")
c.Error(appErr)
return
}
token := tokenParts[1]
// Validate token
user, err := h.userService.ValidateToken(ctx, token)
if err != nil {
logger.Errorf(ctx, "Failed to validate token: %v", err)
appErr := errors.NewUnauthorizedError("Token validation failed").WithDetails(err.Error())
c.Error(appErr)
return
}
logger.Infof(ctx, "Token validated successfully for user: %s", user.Email)
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "Token is valid",
"user": user.ToUserInfo(),
})
}

View File

@@ -8,6 +8,7 @@ import (
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
secutils "github.com/Tencent/WeKnora/internal/utils"
"github.com/gin-gonic/gin"
)
@@ -52,6 +53,13 @@ func (h *ChunkHandler) ListKnowledgeChunks(c *gin.Context) {
return
}
// 对 chunk 内容进行安全清理
for _, chunk := range result.Data.([]*types.Chunk) {
if chunk.Content != "" {
chunk.Content = secutils.SanitizeForDisplay(chunk.Content)
}
}
logger.Infof(
ctx, "Successfully retrieved knowledge chunks list, knowledge ID: %s, total: %d",
knowledgeID, result.Total,

File diff suppressed because it is too large Load Diff

View File

@@ -7,7 +7,6 @@ import (
"net/http"
"time"
"github.com/Tencent/WeKnora/internal/application/service"
"github.com/Tencent/WeKnora/internal/config"
"github.com/Tencent/WeKnora/internal/errors"
"github.com/Tencent/WeKnora/internal/logger"
@@ -22,7 +21,6 @@ type SessionHandler struct {
sessionService interfaces.SessionService // Service for managing sessions
streamManager interfaces.StreamManager // Manager for handling streaming responses
config *config.Config // Application configuration
testDataService *service.TestDataService // Service for test data (models, etc.)
knowledgebaseService interfaces.KnowledgeBaseService
}
@@ -32,7 +30,6 @@ func NewSessionHandler(
messageService interfaces.MessageService,
streamManager interfaces.StreamManager,
config *config.Config,
testDataService *service.TestDataService,
knowledgebaseService interfaces.KnowledgeBaseService,
) *SessionHandler {
return &SessionHandler{
@@ -40,7 +37,6 @@ func NewSessionHandler(
messageService: messageService,
streamManager: streamManager,
config: config,
testDataService: testDataService,
knowledgebaseService: knowledgebaseService,
}
}
@@ -87,7 +83,7 @@ type CreateSessionRequest struct {
func (h *SessionHandler) CreateSession(c *gin.Context) {
ctx := c.Request.Context()
logger.Info(ctx, "Start creating session")
// logger.Infof(ctx, "Start creating session, config: %+v", h.config.Conversation)
// Parse and validate the request body
var request CreateSessionRequest

View File

@@ -0,0 +1,49 @@
package handler
import (
"github.com/Tencent/WeKnora/internal/logger"
"github.com/gin-gonic/gin"
)
// SystemHandler handles system-related requests
type SystemHandler struct{}
// NewSystemHandler creates a new system handler
func NewSystemHandler() *SystemHandler {
return &SystemHandler{}
}
// GetSystemInfoResponse defines the response structure for system info
type GetSystemInfoResponse struct {
Version string `json:"version"`
CommitID string `json:"commit_id,omitempty"`
BuildTime string `json:"build_time,omitempty"`
GoVersion string `json:"go_version,omitempty"`
}
// 编译时注入的版本信息
var (
Version = "unknown"
CommitID = "unknown"
BuildTime = "unknown"
GoVersion = "unknown"
)
// GetSystemInfo gets system information including version and commit ID
func (h *SystemHandler) GetSystemInfo(c *gin.Context) {
ctx := logger.CloneContext(c.Request.Context())
response := GetSystemInfoResponse{
Version: Version,
CommitID: CommitID,
BuildTime: BuildTime,
GoVersion: GoVersion,
}
logger.Info(ctx, "System info retrieved successfully")
c.JSON(200, gin.H{
"code": 0,
"msg": "success",
"data": response,
})
}

View File

@@ -1,84 +0,0 @@
package handler
import (
"net/http"
"github.com/gin-gonic/gin"
"github.com/Tencent/WeKnora/internal/config"
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
)
// TestDataHandler handles HTTP requests related to test data operations
// Used for development and testing purposes to provide sample data
type TestDataHandler struct {
config *config.Config
kbService interfaces.KnowledgeBaseService
tenantService interfaces.TenantService
}
// NewTestDataHandler creates a new instance of the test data handler
// Parameters:
// - config: Application configuration instance
// - kbService: Knowledge base service for accessing knowledge base data
// - tenantService: Tenant service for accessing tenant data
//
// Returns a pointer to the new TestDataHandler instance
func NewTestDataHandler(
config *config.Config,
kbService interfaces.KnowledgeBaseService,
tenantService interfaces.TenantService,
) *TestDataHandler {
return &TestDataHandler{
config: config,
kbService: kbService,
tenantService: tenantService,
}
}
// GetTestData handles the HTTP request to retrieve test data for development purposes
// It returns predefined test tenant and knowledge base information
// This endpoint is only available in non-production environments
// Parameters:
// - c: Gin context for the HTTP request
func (h *TestDataHandler) GetTestData(c *gin.Context) {
ctx := c.Request.Context()
logger.Info(ctx, "Start retrieving test data")
tenantID := uint(types.InitDefaultTenantID)
logger.Debugf(ctx, "Test tenant ID environment variable: %d", tenantID)
// Retrieve the test tenant data
logger.Infof(ctx, "Retrieving test tenant, ID: %d", tenantID)
tenant, err := h.tenantService.GetTenantByID(ctx, tenantID)
if err != nil {
logger.ErrorWithFields(ctx, err, nil)
c.Error(err)
return
}
knowledgeBaseID := types.InitDefaultKnowledgeBaseID
logger.Debugf(ctx, "Test knowledge base ID environment variable: %s", knowledgeBaseID)
// Retrieve the test knowledge base data
logger.Infof(ctx, "Retrieving test knowledge base, ID: %s", knowledgeBaseID)
knowledgeBase, err := h.kbService.GetKnowledgeBaseByID(ctx, knowledgeBaseID)
if err != nil {
logger.ErrorWithFields(ctx, err, nil)
c.Error(err)
return
}
logger.Info(ctx, "Test data retrieved successfully")
// Return the test data in the response
c.JSON(http.StatusOK, gin.H{
"data": gin.H{
"tenant": tenant,
"knowledge_bases": []types.KnowledgeBase{*knowledgeBase},
},
"success": true,
})
}

View File

@@ -16,9 +16,10 @@ import (
// 无需认证的API列表
var noAuthAPI = map[string][]string{
"/api/v1/test-data": {"GET"},
"/api/v1/tenants": {"POST"},
"/api/v1/initialization/*": {"GET", "POST"},
"/health": {"GET"},
"/api/v1/auth/register": {"POST"},
"/api/v1/auth/login": {"POST"},
"/api/v1/auth/refresh": {"POST"},
}
// 检查请求是否在无需认证的API列表中
@@ -37,7 +38,7 @@ func isNoAuthAPI(path string, method string) bool {
}
// Auth 认证中间件
func Auth(tenantService interfaces.TenantService, cfg *config.Config) gin.HandlerFunc {
func Auth(tenantService interfaces.TenantService, userService interfaces.UserService, cfg *config.Config) gin.HandlerFunc {
return func(c *gin.Context) {
// ignore OPTIONS request
if c.Request.Method == "OPTIONS" {
@@ -51,53 +52,90 @@ func Auth(tenantService interfaces.TenantService, cfg *config.Config) gin.Handle
return
}
// Get API Key from request header
// 尝试JWT Token认证
authHeader := c.GetHeader("Authorization")
if authHeader != "" && strings.HasPrefix(authHeader, "Bearer ") {
token := strings.TrimPrefix(authHeader, "Bearer ")
user, err := userService.ValidateToken(c.Request.Context(), token)
if err == nil && user != nil {
// JWT Token认证成功
// 获取租户信息
tenant, err := tenantService.GetTenantByID(c.Request.Context(), user.TenantID)
if err != nil {
log.Printf("Error getting tenant by ID: %v, tenantID: %d, userID: %s", err, user.TenantID, user.ID)
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Unauthorized: invalid tenant",
})
c.Abort()
return
}
// 存储用户和租户信息到上下文
c.Set(types.TenantIDContextKey.String(), user.TenantID)
c.Set(types.TenantInfoContextKey.String(), tenant)
c.Set("user", user)
c.Request = c.Request.WithContext(
context.WithValue(
context.WithValue(
context.WithValue(c.Request.Context(), types.TenantIDContextKey, user.TenantID),
types.TenantInfoContextKey, tenant,
),
"user", user,
),
)
c.Next()
return
}
}
// 尝试X-API-Key认证兼容模式
apiKey := c.GetHeader("X-API-Key")
if apiKey == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
c.Abort()
if apiKey != "" {
// Get tenant information
tenantID, err := tenantService.ExtractTenantIDFromAPIKey(apiKey)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Unauthorized: invalid API key format",
})
c.Abort()
return
}
// Verify API key validity (matches the one in database)
t, err := tenantService.GetTenantByID(c.Request.Context(), tenantID)
if err != nil {
log.Printf("Error getting tenant by ID: %v, tenantID: %d, apiKey: %s", err, tenantID, apiKey)
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Unauthorized: invalid API key",
})
c.Abort()
return
}
if t == nil || t.APIKey != apiKey {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Unauthorized: invalid API key",
})
c.Abort()
return
}
// Store tenant ID in context
c.Set(types.TenantIDContextKey.String(), tenantID)
c.Set(types.TenantInfoContextKey.String(), t)
c.Request = c.Request.WithContext(
context.WithValue(
context.WithValue(c.Request.Context(), types.TenantIDContextKey, tenantID),
types.TenantInfoContextKey, t,
),
)
c.Next()
return
}
// Get tenant information
tenantID, err := tenantService.ExtractTenantIDFromAPIKey(apiKey)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Unauthorized: invalid API key format",
})
c.Abort()
return
}
// Verify API key validity (matches the one in database)
t, err := tenantService.GetTenantByID(c.Request.Context(), tenantID)
if err != nil {
log.Printf("Error getting tenant by ID: %v, tenantID: %d, apiKey: %s", err, tenantID, apiKey)
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Unauthorized: invalid API key",
})
c.Abort()
return
}
if t == nil || t.APIKey != apiKey {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Unauthorized: invalid API key",
})
c.Abort()
return
}
// Store tenant ID in context
c.Set(types.TenantIDContextKey.String(), tenantID)
c.Set(types.TenantInfoContextKey.String(), t)
c.Request = c.Request.WithContext(
context.WithValue(
context.WithValue(c.Request.Context(), types.TenantIDContextKey, tenantID),
types.TenantInfoContextKey, t,
),
)
c.Next()
// 没有提供任何认证信息
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized: missing authentication"})
c.Abort()
}
}

View File

@@ -85,6 +85,7 @@ func (c *RemoteAPIChat) buildChatCompletionRequest(messages []Message,
Messages: c.convertMessages(messages),
Stream: isStream,
}
thinking := false
// 添加可选参数
if opts != nil {
@@ -106,6 +107,13 @@ func (c *RemoteAPIChat) buildChatCompletionRequest(messages []Message,
if opts.PresencePenalty > 0 {
req.PresencePenalty = float32(opts.PresencePenalty)
}
if opts.Thinking != nil {
thinking = *opts.Thinking
}
}
req.ChatTemplateKwargs = map[string]interface{}{
"enable_thinking": thinking,
}
return req

View File

@@ -18,6 +18,14 @@ type RouterParams struct {
dig.In
Config *config.Config
UserService interfaces.UserService
KBService interfaces.KnowledgeBaseService
KnowledgeService interfaces.KnowledgeService
ChunkService interfaces.ChunkService
SessionService interfaces.SessionService
MessageService interfaces.MessageService
ModelService interfaces.ModelService
EvaluationService interfaces.EvaluationService
KBHandler *handler.KnowledgeBaseHandler
KnowledgeHandler *handler.KnowledgeHandler
TenantHandler *handler.TenantHandler
@@ -25,10 +33,11 @@ type RouterParams struct {
ChunkHandler *handler.ChunkHandler
SessionHandler *handler.SessionHandler
MessageHandler *handler.MessageHandler
TestDataHandler *handler.TestDataHandler
ModelHandler *handler.ModelHandler
EvaluationHandler *handler.EvaluationHandler
AuthHandler *handler.AuthHandler
InitializationHandler *handler.InitializationHandler
SystemHandler *handler.SystemHandler
}
// NewRouter 创建新的路由
@@ -50,7 +59,7 @@ func NewRouter(params RouterParams) *gin.Engine {
r.Use(middleware.Logger())
r.Use(middleware.Recovery())
r.Use(middleware.ErrorHandler())
r.Use(middleware.Auth(params.TenantService, params.Config))
r.Use(middleware.Auth(params.TenantService, params.UserService, params.Config))
// 添加OpenTelemetry追踪中间件
r.Use(middleware.TracingMiddleware())
@@ -60,31 +69,10 @@ func NewRouter(params RouterParams) *gin.Engine {
c.JSON(200, gin.H{"status": "ok"})
})
// 测试数据接口(不需要认证)
r.GET("/api/v1/test-data", params.TestDataHandler.GetTestData)
// 初始化接口(不需要认证)
r.GET("/api/v1/initialization/status", params.InitializationHandler.CheckStatus)
r.GET("/api/v1/initialization/config", params.InitializationHandler.GetCurrentConfig)
r.POST("/api/v1/initialization/initialize", params.InitializationHandler.Initialize)
// Ollama相关接口不需要认证
r.GET("/api/v1/initialization/ollama/status", params.InitializationHandler.CheckOllamaStatus)
r.GET("/api/v1/initialization/ollama/models", params.InitializationHandler.ListOllamaModels)
r.POST("/api/v1/initialization/ollama/models/check", params.InitializationHandler.CheckOllamaModels)
r.POST("/api/v1/initialization/ollama/models/download", params.InitializationHandler.DownloadOllamaModel)
r.GET("/api/v1/initialization/ollama/download/progress/:taskId", params.InitializationHandler.GetDownloadProgress)
r.GET("/api/v1/initialization/ollama/download/tasks", params.InitializationHandler.ListDownloadTasks)
// 远程API相关接口不需要认证
r.POST("/api/v1/initialization/remote/check", params.InitializationHandler.CheckRemoteModel)
r.POST("/api/v1/initialization/embedding/test", params.InitializationHandler.TestEmbeddingModel)
r.POST("/api/v1/initialization/rerank/check", params.InitializationHandler.CheckRerankModel)
r.POST("/api/v1/initialization/multimodal/test", params.InitializationHandler.TestMultimodalFunction)
// 需要认证的API路由
v1 := r.Group("/api/v1")
{
RegisterAuthRoutes(v1, params.AuthHandler)
RegisterTenantRoutes(v1, params.TenantHandler)
RegisterKnowledgeBaseRoutes(v1, params.KBHandler)
RegisterKnowledgeRoutes(v1, params.KnowledgeHandler)
@@ -94,6 +82,8 @@ func NewRouter(params RouterParams) *gin.Engine {
RegisterMessageRoutes(v1, params.MessageHandler)
RegisterModelRoutes(v1, params.ModelHandler)
RegisterEvaluationRoutes(v1, params.EvaluationHandler)
RegisterInitializationRoutes(v1, params.InitializationHandler)
RegisterSystemRoutes(v1, params.SystemHandler)
}
return r
@@ -247,3 +237,46 @@ func RegisterEvaluationRoutes(r *gin.RouterGroup, handler *handler.EvaluationHan
evaluationRoutes.GET("/", handler.GetEvaluationResult)
}
}
// RegisterAuthRoutes registers authentication routes
func RegisterAuthRoutes(r *gin.RouterGroup, handler *handler.AuthHandler) {
r.POST("/auth/register", handler.Register)
r.POST("/auth/login", handler.Login)
r.POST("/auth/refresh", handler.RefreshToken)
r.GET("/auth/validate", handler.ValidateToken)
r.POST("/auth/logout", handler.Logout)
r.GET("/auth/me", handler.GetCurrentUser)
r.POST("/auth/change-password", handler.ChangePassword)
}
func RegisterInitializationRoutes(r *gin.RouterGroup, handler *handler.InitializationHandler) {
// 初始化接口
r.GET("/initialization/config/:kbId", handler.GetCurrentConfigByKB)
r.POST("/initialization/initialize/:kbId", handler.InitializeByKB)
// Ollama相关接口
r.GET("/initialization/ollama/status", handler.CheckOllamaStatus)
r.GET("/initialization/ollama/models", handler.ListOllamaModels)
r.POST("/initialization/ollama/models/check", handler.CheckOllamaModels)
r.POST("/initialization/ollama/models/download", handler.DownloadOllamaModel)
r.GET("/initialization/ollama/download/progress/:taskId", handler.GetDownloadProgress)
r.GET("/initialization/ollama/download/tasks", handler.ListDownloadTasks)
// 远程API相关接口
r.POST("/initialization/remote/check", handler.CheckRemoteModel)
r.POST("/initialization/embedding/test", handler.TestEmbeddingModel)
r.POST("/initialization/rerank/check", handler.CheckRerankModel)
r.POST("/initialization/multimodal/test", handler.TestMultimodalFunction)
r.POST("/initialization/extract/text-relation", handler.ExtractTextRelations)
r.POST("/initialization/extract/fabri-tag", handler.FabriTag)
r.POST("/initialization/extract/fabri-text", handler.FabriText)
}
// RegisterSystemRoutes registers system information routes
func RegisterSystemRoutes(r *gin.RouterGroup, handler *handler.SystemHandler) {
systemRoutes := r.Group("/system")
{
systemRoutes.GET("/info", handler.GetSystemInfo)
}
}

66
internal/router/task.go Normal file
View File

@@ -0,0 +1,66 @@
package router
import (
"log"
"os"
"time"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
"github.com/hibiken/asynq"
"go.uber.org/dig"
)
type AsynqTaskParams struct {
dig.In
Server *asynq.Server
Extracter interfaces.Extracter
}
func getAsynqRedisClientOpt() *asynq.RedisClientOpt {
opt := &asynq.RedisClientOpt{
Addr: os.Getenv("REDIS_ADDR"),
Password: os.Getenv("REDIS_PASSWORD"),
ReadTimeout: 100 * time.Millisecond,
WriteTimeout: 200 * time.Millisecond,
DB: 0,
}
return opt
}
func NewAsyncqClient() *asynq.Client {
opt := getAsynqRedisClientOpt()
client := asynq.NewClient(opt)
return client
}
func NewAsynqServer() *asynq.Server {
opt := getAsynqRedisClientOpt()
srv := asynq.NewServer(
opt,
asynq.Config{
Queues: map[string]int{
"critical": 6, // Highest priority queue
"default": 3, // Default priority queue
"low": 1, // Lowest priority queue
},
},
)
return srv
}
func RunAsynqServer(params AsynqTaskParams) *asynq.ServeMux {
// Create a new mux and register all handlers
mux := asynq.NewServeMux()
mux.HandleFunc(types.TypeChunkExtract, params.Extracter.Extract)
go func() {
// Start the server
if err := params.Server.Run(mux); err != nil {
log.Fatalf("could not run server: %v", err)
}
}()
return mux
}

View File

@@ -28,6 +28,8 @@ type ChatManage struct {
SearchResult []*SearchResult `json:"-"` // Results from search phase
RerankResult []*SearchResult `json:"-"` // Results after reranking
MergeResult []*SearchResult `json:"-"` // Final merged results after all processing
Entity []string `json:"-"` // List of identified entities
GraphResult *GraphData `json:"-"` // Graph data from search phase
UserContent string `json:"-"` // Processed user content
ChatResponse *ChatResponse `json:"-"` // Final response from chat model
ResponseChan <-chan StreamResponse `json:"-"` // Channel for streaming responses
@@ -75,6 +77,7 @@ const (
PREPROCESS_QUERY EventType = "preprocess_query" // Query preprocessing stage
REWRITE_QUERY EventType = "rewrite_query" // Query rewriting for better retrieval
CHUNK_SEARCH EventType = "chunk_search" // Search for relevant chunks
ENTITY_SEARCH EventType = "entity_search" // Search for relevant entities
CHUNK_RERANK EventType = "chunk_rerank" // Rerank search results
CHUNK_MERGE EventType = "chunk_merge" // Merge similar chunks
INTO_CHAT_MESSAGE EventType = "into_chat_message" // Convert chunks into chat messages
@@ -104,6 +107,7 @@ var Pipline = map[string][]EventType{
REWRITE_QUERY,
PREPROCESS_QUERY,
CHUNK_SEARCH,
ENTITY_SEARCH,
CHUNK_RERANK,
CHUNK_MERGE,
FILTER_TOP_K,

View File

@@ -19,6 +19,7 @@ const (
MatchTypeHistory
MatchTypeParentChunk // 父Chunk匹配类型
MatchTypeRelationChunk // 关系Chunk匹配类型
MatchTypeGraph
)
// IndexInfo contains information about indexed content

View File

@@ -0,0 +1,51 @@
package types
const (
TypeChunkExtract = "chunk:extract"
)
type ExtractChunkPayload struct {
TenantID uint `json:"tenant_id"`
ChunkID string `json:"chunk_id"`
ModelID string `json:"model_id"`
}
type PromptTemplateStructured struct {
Description string `json:"description"`
Tags []string `json:"tags"`
Examples []GraphData `json:"examples"`
}
type GraphNode struct {
Name string `json:"name,omitempty"`
Chunks []string `json:"chunks,omitempty"`
Attributes []string `json:"attributes,omitempty"`
}
type GraphRelation struct {
Node1 string `json:"node1,omitempty"`
Node2 string `json:"node2,omitempty"`
Type string `json:"type,omitempty"`
}
type GraphData struct {
Text string `json:"text,omitempty"`
Node []*GraphNode `json:"node,omitempty"`
Relation []*GraphRelation `json:"relation,omitempty"`
}
type NameSpace struct {
KnowledgeBase string `json:"knowledge_base"`
Knowledge string `json:"knowledge"`
}
func (n NameSpace) Labels() []string {
res := make([]string, 0)
if n.KnowledgeBase != "" {
res = append(res, n.KnowledgeBase)
}
if n.Knowledge != "" {
res = append(res, n.Knowledge)
}
return res
}

View File

@@ -0,0 +1,11 @@
package interfaces
import (
"context"
"github.com/hibiken/asynq"
)
type Extracter interface {
Extract(ctx context.Context, t *asynq.Task) error
}

View File

@@ -0,0 +1,13 @@
package interfaces
import (
"context"
"github.com/Tencent/WeKnora/internal/types"
)
type RetrieveGraphRepository interface {
AddGraph(ctx context.Context, namespace types.NameSpace, graphs []*types.GraphData) error
DelGraph(ctx context.Context, namespace []types.NameSpace) error
SearchNode(ctx context.Context, namespace types.NameSpace, nodes []string) (*types.GraphData, error)
}

View File

@@ -0,0 +1,75 @@
package interfaces
import (
"context"
"github.com/Tencent/WeKnora/internal/types"
)
// UserService defines the user service interface
type UserService interface {
// Register creates a new user account
Register(ctx context.Context, req *types.RegisterRequest) (*types.User, error)
// Login authenticates a user and returns tokens
Login(ctx context.Context, req *types.LoginRequest) (*types.LoginResponse, error)
// GetUserByID gets a user by ID
GetUserByID(ctx context.Context, id string) (*types.User, error)
// GetUserByEmail gets a user by email
GetUserByEmail(ctx context.Context, email string) (*types.User, error)
// GetUserByUsername gets a user by username
GetUserByUsername(ctx context.Context, username string) (*types.User, error)
// UpdateUser updates user information
UpdateUser(ctx context.Context, user *types.User) error
// DeleteUser deletes a user
DeleteUser(ctx context.Context, id string) error
// ChangePassword changes user password
ChangePassword(ctx context.Context, userID string, oldPassword, newPassword string) error
// ValidatePassword validates user password
ValidatePassword(ctx context.Context, userID string, password string) error
// GenerateTokens generates access and refresh tokens for user
GenerateTokens(ctx context.Context, user *types.User) (accessToken, refreshToken string, err error)
// ValidateToken validates an access token
ValidateToken(ctx context.Context, token string) (*types.User, error)
// RefreshToken refreshes access token using refresh token
RefreshToken(ctx context.Context, refreshToken string) (accessToken, newRefreshToken string, err error)
// RevokeToken revokes a token
RevokeToken(ctx context.Context, token string) error
// GetCurrentUser gets current user from context
GetCurrentUser(ctx context.Context) (*types.User, error)
}
// UserRepository defines the user repository interface
type UserRepository interface {
// CreateUser creates a user
CreateUser(ctx context.Context, user *types.User) error
// GetUserByID gets a user by ID
GetUserByID(ctx context.Context, id string) (*types.User, error)
// GetUserByEmail gets a user by email
GetUserByEmail(ctx context.Context, email string) (*types.User, error)
// GetUserByUsername gets a user by username
GetUserByUsername(ctx context.Context, username string) (*types.User, error)
// UpdateUser updates a user
UpdateUser(ctx context.Context, user *types.User) error
// DeleteUser deletes a user
DeleteUser(ctx context.Context, id string) error
// ListUsers lists users with pagination
ListUsers(ctx context.Context, offset, limit int) ([]*types.User, error)
}
// AuthTokenRepository defines the auth token repository interface
type AuthTokenRepository interface {
// CreateToken creates an auth token
CreateToken(ctx context.Context, token *types.AuthToken) error
// GetTokenByValue gets a token by its value
GetTokenByValue(ctx context.Context, tokenValue string) (*types.AuthToken, error)
// GetTokensByUserID gets all tokens for a user
GetTokensByUserID(ctx context.Context, userID string) ([]*types.AuthToken, error)
// UpdateToken updates a token
UpdateToken(ctx context.Context, token *types.AuthToken) error
// DeleteToken deletes a token
DeleteToken(ctx context.Context, id string) error
// DeleteExpiredTokens deletes all expired tokens
DeleteExpiredTokens(ctx context.Context) error
// RevokeTokensByUserID revokes all tokens for a user
RevokeTokensByUserID(ctx context.Context, userID string) error
}

View File

@@ -38,6 +38,8 @@ type KnowledgeBase struct {
VLMConfig VLMConfig `yaml:"vlm_config" json:"vlm_config" gorm:"type:json"`
// Storage config
StorageConfig StorageConfig `yaml:"cos_config" json:"cos_config" gorm:"column:cos_config;type:json"`
// Extract config
ExtractConfig *ExtractConfig `yaml:"extract_config" json:"extract_config" gorm:"column:extract_config;type:json"`
// Creation time of the knowledge base
CreatedAt time.Time `yaml:"created_at" json:"created_at"`
// Last updated time of the knowledge base
@@ -167,3 +169,27 @@ func (c *VLMConfig) Scan(value interface{}) error {
}
return json.Unmarshal(b, c)
}
type ExtractConfig struct {
Text string `yaml:"text" json:"text"`
Tags []string `yaml:"tags" json:"tags"`
Nodes []*GraphNode `yaml:"nodes" json:"nodes"`
Relations []*GraphRelation `yaml:"relations" json:"relations"`
}
// Value implements the driver.Valuer interface, used to convert ExtractConfig to database value
func (e ExtractConfig) Value() (driver.Value, error) {
return json.Marshal(e)
}
// Scan implements the sql.Scanner interface, used to convert database value to ExtractConfig
func (e *ExtractConfig) Scan(value interface{}) error {
if value == nil {
return nil
}
b, ok := value.([]byte)
if !ok {
return nil
}
return json.Unmarshal(b, e)
}

114
internal/types/user.go Normal file
View File

@@ -0,0 +1,114 @@
package types
import (
"time"
"gorm.io/gorm"
)
// User represents a user in the system
type User struct {
// Unique identifier of the user
ID string `json:"id" gorm:"type:varchar(36);primaryKey"`
// Username of the user
Username string `json:"username" gorm:"type:varchar(100);uniqueIndex;not null"`
// Email address of the user
Email string `json:"email" gorm:"type:varchar(255);uniqueIndex;not null"`
// Hashed password of the user
PasswordHash string `json:"-" gorm:"type:varchar(255);not null"`
// Avatar URL of the user
Avatar string `json:"avatar" gorm:"type:varchar(500)"`
// Tenant ID that the user belongs to
TenantID uint `json:"tenant_id" gorm:"index"`
// Whether the user is active
IsActive bool `json:"is_active" gorm:"default:true"`
// Creation time of the user
CreatedAt time.Time `json:"created_at"`
// Last updated time of the user
UpdatedAt time.Time `json:"updated_at"`
// Deletion time of the user
DeletedAt gorm.DeletedAt `json:"deleted_at" gorm:"index"`
// Association relationship, not stored in the database
Tenant *Tenant `json:"tenant,omitempty" gorm:"foreignKey:TenantID"`
}
// AuthToken represents an authentication token
type AuthToken struct {
// Unique identifier of the token
ID string `json:"id" gorm:"type:varchar(36);primaryKey"`
// User ID that owns this token
UserID string `json:"user_id" gorm:"type:varchar(36);index;not null"`
// Token value (JWT or other format)
Token string `json:"token" gorm:"type:text;not null"`
// Token type (access_token, refresh_token)
TokenType string `json:"token_type" gorm:"type:varchar(50);not null"`
// Token expiration time
ExpiresAt time.Time `json:"expires_at"`
// Whether the token is revoked
IsRevoked bool `json:"is_revoked" gorm:"default:false"`
// Creation time of the token
CreatedAt time.Time `json:"created_at"`
// Last updated time of the token
UpdatedAt time.Time `json:"updated_at"`
// Association relationship
User *User `json:"user,omitempty" gorm:"foreignKey:UserID"`
}
// LoginRequest represents a login request
type LoginRequest struct {
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required,min=6"`
}
// RegisterRequest represents a registration request
type RegisterRequest struct {
Username string `json:"username" binding:"required,min=3,max=50"`
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required,min=6"`
}
// LoginResponse represents a login response
type LoginResponse struct {
Success bool `json:"success"`
Message string `json:"message,omitempty"`
User *User `json:"user,omitempty"`
Tenant *Tenant `json:"tenant,omitempty"`
Token string `json:"token,omitempty"`
RefreshToken string `json:"refresh_token,omitempty"`
}
// RegisterResponse represents a registration response
type RegisterResponse struct {
Success bool `json:"success"`
Message string `json:"message,omitempty"`
User *User `json:"user,omitempty"`
Tenant *Tenant `json:"tenant,omitempty"`
}
// UserInfo represents user information for API responses
type UserInfo struct {
ID string `json:"id"`
Username string `json:"username"`
Email string `json:"email"`
Avatar string `json:"avatar"`
TenantID uint `json:"tenant_id"`
IsActive bool `json:"is_active"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// ToUserInfo converts User to UserInfo (without sensitive data)
func (u *User) ToUserInfo() *UserInfo {
return &UserInfo{
ID: u.ID,
Username: u.Username,
Email: u.Email,
Avatar: u.Avatar,
TenantID: u.TenantID,
IsActive: u.IsActive,
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
}
}

171
internal/utils/security.go Normal file
View File

@@ -0,0 +1,171 @@
package utils
import (
"html"
"regexp"
"strings"
"unicode/utf8"
)
// XSS 防护相关正则表达式
var (
// 匹配潜在的 XSS 攻击模式
xssPatterns = []*regexp.Regexp{
regexp.MustCompile(`(?i)<script[^>]*>.*?</script>`),
regexp.MustCompile(`(?i)<iframe[^>]*>.*?</iframe>`),
regexp.MustCompile(`(?i)<object[^>]*>.*?</object>`),
regexp.MustCompile(`(?i)<embed[^>]*>.*?</embed>`),
regexp.MustCompile(`(?i)<embed[^>]*>`),
regexp.MustCompile(`(?i)<form[^>]*>.*?</form>`),
regexp.MustCompile(`(?i)<input[^>]*>`),
regexp.MustCompile(`(?i)<button[^>]*>.*?</button>`),
regexp.MustCompile(`(?i)javascript:`),
regexp.MustCompile(`(?i)vbscript:`),
regexp.MustCompile(`(?i)onload\s*=`),
regexp.MustCompile(`(?i)onerror\s*=`),
regexp.MustCompile(`(?i)onclick\s*=`),
regexp.MustCompile(`(?i)onmouseover\s*=`),
regexp.MustCompile(`(?i)onfocus\s*=`),
regexp.MustCompile(`(?i)onblur\s*=`),
}
)
// SanitizeHTML 清理 HTML 内容,防止 XSS 攻击
func SanitizeHTML(input string) string {
if input == "" {
return ""
}
// 检查输入长度
if len(input) > 10000 {
input = input[:10000]
}
// 检查是否包含潜在的 XSS 攻击
for _, pattern := range xssPatterns {
if pattern.MatchString(input) {
// 如果包含恶意内容,进行 HTML 转义
return html.EscapeString(input)
}
}
// 如果内容相对安全,返回原内容
return input
}
// EscapeHTML 转义 HTML 特殊字符
func EscapeHTML(input string) string {
if input == "" {
return ""
}
return html.EscapeString(input)
}
// ValidateInput 验证用户输入
func ValidateInput(input string) (string, bool) {
if input == "" {
return "", true
}
// 检查长度
if len(input) > 10000 {
return "", false
}
// 检查是否包含控制字符
for _, r := range input {
if r < 32 && r != 9 && r != 10 && r != 13 {
return "", false
}
}
// 检查 UTF-8 有效性
if !utf8.ValidString(input) {
return "", false
}
// 检查是否包含潜在的 XSS 攻击
for _, pattern := range xssPatterns {
if pattern.MatchString(input) {
return "", false
}
}
return strings.TrimSpace(input), true
}
// IsValidURL 验证 URL 是否安全
func IsValidURL(url string) bool {
if url == "" {
return false
}
// 检查长度
if len(url) > 2048 {
return false
}
// 检查协议
if !strings.HasPrefix(strings.ToLower(url), "http://") &&
!strings.HasPrefix(strings.ToLower(url), "https://") {
return false
}
// 检查是否包含恶意内容
for _, pattern := range xssPatterns {
if pattern.MatchString(url) {
return false
}
}
return true
}
// IsValidImageURL 验证图片 URL 是否安全
func IsValidImageURL(url string) bool {
if !IsValidURL(url) {
return false
}
// 检查是否为图片文件
imageExtensions := []string{".jpg", ".jpeg", ".png", ".gif", ".webp", ".svg", ".bmp", ".ico"}
lowerURL := strings.ToLower(url)
for _, ext := range imageExtensions {
if strings.Contains(lowerURL, ext) {
return true
}
}
return false
}
// CleanMarkdown 清理 Markdown 内容
func CleanMarkdown(input string) string {
if input == "" {
return ""
}
// 移除潜在的恶意脚本
cleaned := input
for _, pattern := range xssPatterns {
cleaned = pattern.ReplaceAllString(cleaned, "")
}
return cleaned
}
// SanitizeForDisplay 为显示清理内容
func SanitizeForDisplay(input string) string {
if input == "" {
return ""
}
// 首先清理 Markdown
cleaned := CleanMarkdown(input)
// 然后进行 HTML 转义
escaped := html.EscapeString(cleaned)
return escaped
}

View File

@@ -304,6 +304,19 @@ async def handle_list_tools() -> list[types.Tool]:
),
# Knowledge Management
types.Tool(
name="create_knowledge_from_file",
description="Create knowledge from a local file on the server filesystem",
inputSchema={
"type": "object",
"properties": {
"kb_id": {"type": "string", "description": "Knowledge base ID"},
"file_path": {"type": "string", "description": "Absolute path to the local file on the server"},
"enable_multimodel": {"type": "boolean", "description": "Enable multimodal processing", "default": True}
},
"required": ["kb_id", "file_path"]
}
),
types.Tool(
name="create_knowledge_from_url",
description="Create knowledge from URL",
@@ -537,6 +550,12 @@ async def handle_call_tool(
result = client.hybrid_search(args["kb_id"], args["query"], config)
# Knowledge Management
elif name == "create_knowledge_from_file":
result = client.create_knowledge_from_file(
args["kb_id"],
args["file_path"],
args.get("enable_multimodel", True)
)
elif name == "create_knowledge_from_url":
result = client.create_knowledge_from_url(
args["kb_id"],

View File

@@ -51,6 +51,7 @@ CREATE TABLE knowledge_bases (
vlm_model_id VARCHAR(64) NOT NULL,
cos_config JSON NOT NULL,
vlm_config JSON NOT NULL,
extract_config JSON NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
deleted_at TIMESTAMP NULL DEFAULT NULL

View File

@@ -62,6 +62,7 @@ CREATE TABLE IF NOT EXISTS knowledge_bases (
vlm_model_id VARCHAR(64) NOT NULL,
cos_config JSONB NOT NULL DEFAULT '{}',
vlm_config JSONB NOT NULL DEFAULT '{}',
extract_config JSONB NULL DEFAULT NULL,
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
deleted_at TIMESTAMP WITH TIME ZONE

View File

@@ -78,13 +78,49 @@ check_platform() {
log_info "检测系统平台信息..."
if [ "$(uname -m)" = "x86_64" ]; then
export PLATFORM="linux/amd64"
export TARGETARCH="amd64"
elif [ "$(uname -m)" = "aarch64" ] || [ "$(uname -m)" = "arm64" ]; then
export PLATFORM="linux/arm64"
export TARGETARCH="arm64"
else
log_warning "未识别的平台类型:$(uname -m),将使用默认平台 linux/amd64"
export PLATFORM="linux/amd64"
export TARGETARCH="amd64"
fi
log_info "当前平台:$PLATFORM"
log_info "当前架构:$TARGETARCH"
}
# 获取版本信息
get_version_info() {
# 从VERSION文件获取版本号
if [ -f "VERSION" ]; then
VERSION=$(cat VERSION | tr -d '\n\r')
else
VERSION="unknown"
fi
# 获取commit ID
if command -v git >/dev/null 2>&1; then
COMMIT_ID=$(git rev-parse --short HEAD 2>/dev/null || echo "unknown")
else
COMMIT_ID="unknown"
fi
# 获取构建时间
BUILD_TIME=$(date -u '+%Y-%m-%d %H:%M:%S UTC')
# 获取Go版本
if command -v go >/dev/null 2>&1; then
GO_VERSION=$(go version 2>/dev/null || echo "unknown")
else
GO_VERSION="unknown"
fi
log_info "版本信息: $VERSION"
log_info "Commit ID: $COMMIT_ID"
log_info "构建时间: $BUILD_TIME"
log_info "Go版本: $GO_VERSION"
}
# 构建应用镜像
@@ -93,11 +129,18 @@ build_app_image() {
cd "$PROJECT_ROOT"
# 获取版本信息
get_version_info
docker build \
--platform $PLATFORM \
--build-arg GOPRIVATE_ARG=${GOPRIVATE:-""} \
--build-arg GOPROXY_ARG=${GOPROXY:-"https://goproxy.cn,direct"} \
--build-arg GOSUMDB_ARG=${GOSUMDB:-"off"} \
--build-arg VERSION_ARG="$VERSION" \
--build-arg COMMIT_ID_ARG="$COMMIT_ID" \
--build-arg BUILD_TIME_ARG="$BUILD_TIME" \
--build-arg GO_VERSION_ARG="$GO_VERSION" \
-f docker/Dockerfile.app \
-t wechatopenai/weknora-app:latest \
.
@@ -120,6 +163,7 @@ build_docreader_image() {
docker build \
--platform $PLATFORM \
--build-arg PLATFORM=$PLATFORM \
--build-arg TARGETARCH=$TARGETARCH \
-f docker/Dockerfile.docreader \
-t wechatopenai/weknora-docreader:latest \
.

86
scripts/get_version.sh Executable file
View File

@@ -0,0 +1,86 @@
#!/bin/bash
# 统一的版本信息获取脚本
# 支持本地构建和CI构建环境
# 设置默认值
VERSION="unknown"
COMMIT_ID="unknown"
BUILD_TIME="unknown"
GO_VERSION="unknown"
# 获取版本号
if [ -f "VERSION" ]; then
VERSION=$(cat VERSION | tr -d '\n\r')
fi
# 获取commit ID
if [ -n "$GITHUB_SHA" ]; then
# GitHub Actions环境
COMMIT_ID="${GITHUB_SHA:0:7}"
elif command -v git >/dev/null 2>&1; then
# 本地环境
COMMIT_ID=$(git rev-parse --short HEAD 2>/dev/null || echo "unknown")
fi
# 获取构建时间
if [ -n "$GITHUB_ACTIONS" ]; then
# GitHub Actions环境使用标准时间格式
BUILD_TIME=$(date -u '+%Y-%m-%d %H:%M:%S UTC')
else
# 本地环境
BUILD_TIME=$(date -u '+%Y-%m-%d %H:%M:%S UTC')
fi
# 获取Go版本
if command -v go >/dev/null 2>&1; then
GO_VERSION=$(go version 2>/dev/null || echo "unknown")
fi
# 根据参数输出不同格式
case "${1:-env}" in
"env")
# 输出环境变量格式,对包含空格的值进行转义
echo "VERSION=$VERSION"
echo "COMMIT_ID=$COMMIT_ID"
echo "BUILD_TIME=\"$BUILD_TIME\""
echo "GO_VERSION=\"$GO_VERSION\""
;;
"json")
# 输出JSON格式
cat << EOF
{
"version": "$VERSION",
"commit_id": "$COMMIT_ID",
"build_time": "$BUILD_TIME",
"go_version": "$GO_VERSION"
}
EOF
;;
"docker-args")
# 输出Docker构建参数格式
echo "--build-arg VERSION_ARG=$VERSION"
echo "--build-arg COMMIT_ID_ARG=$COMMIT_ID"
echo "--build-arg BUILD_TIME_ARG=$BUILD_TIME"
echo "--build-arg GO_VERSION_ARG=$GO_VERSION"
;;
"ldflags")
# 输出Go ldflags格式
echo "-X 'github.com/Tencent/WeKnora/internal/handler.Version=$VERSION' -X 'github.com/Tencent/WeKnora/internal/handler.CommitID=$COMMIT_ID' -X 'github.com/Tencent/WeKnora/internal/handler.BuildTime=$BUILD_TIME' -X 'github.com/Tencent/WeKnora/internal/handler.GoVersion=$GO_VERSION'"
;;
"info")
# 输出信息格式
echo "版本信息: $VERSION"
echo "Commit ID: $COMMIT_ID"
echo "构建时间: $BUILD_TIME"
echo "Go版本: $GO_VERSION"
;;
*)
echo "用法: $0 [env|json|docker-args|ldflags|info]"
echo " env - 输出环境变量格式 (默认)"
echo " json - 输出JSON格式"
echo " docker-args - 输出Docker构建参数格式"
echo " ldflags - 输出Go ldflags格式"
echo " info - 输出信息格式"
exit 1
;;
esac

View File

@@ -18,8 +18,8 @@ SCRIPT_NAME=$(basename "$0")
# 显示帮助信息
show_help() {
echo -e "${GREEN}WeKnora 启动脚本 v${VERSION}${NC}"
echo -e "${GREEN}用法:${NC} $0 [选项]"
printf "%b\n" "${GREEN}WeKnora 启动脚本 v${VERSION}${NC}"
printf "%b\n" "${GREEN}用法:${NC} $0 [选项]"
echo "选项:"
echo " -h, --help 显示帮助信息"
echo " -o, --ollama 启动Ollama服务"
@@ -37,25 +37,25 @@ show_help() {
# 显示版本信息
show_version() {
echo -e "${GREEN}WeKnora 启动脚本 v${VERSION}${NC}"
printf "%b\n" "${GREEN}WeKnora 启动脚本 v${VERSION}${NC}"
exit 0
}
# 日志函数
log_info() {
echo -e "${BLUE}[INFO]${NC} $1"
printf "%b\n" "${BLUE}[INFO]${NC} $1"
}
log_warning() {
echo -e "${YELLOW}[WARNING]${NC} $1"
printf "%b\n" "${YELLOW}[WARNING]${NC} $1"
}
log_error() {
echo -e "${RED}[ERROR]${NC} $1"
printf "%b\n" "${RED}[ERROR]${NC} $1"
}
log_success() {
echo -e "${GREEN}[SUCCESS]${NC} $1"
printf "%b\n" "${GREEN}[SUCCESS]${NC} $1"
}
# 选择可用的 Docker Compose 命令(优先 docker compose其次 docker-compose
@@ -397,7 +397,7 @@ list_containers() {
cd "$PROJECT_ROOT"
# 列出所有容器
echo -e "${BLUE}当前正在运行的容器:${NC}"
printf "%b\n" "${BLUE}当前正在运行的容器:${NC}"
"$DOCKER_COMPOSE_BIN" $DOCKER_COMPOSE_SUBCMD ps --services | sort
return 0
@@ -703,20 +703,26 @@ else
if [ "$START_OLLAMA" = true ] && [ "$START_DOCKER" = true ]; then
if [ $OLLAMA_RESULT -eq 0 ] && [ $DOCKER_RESULT -eq 0 ]; then
log_success "所有服务启动完成,可通过以下地址访问:"
echo -e "${GREEN} - 前端界面: http://localhost:${FRONTEND_PORT:-80}${NC}"
echo -e "${GREEN} - API接口: http://localhost:${APP_PORT:-8080}${NC}"
echo -e "${GREEN} - Jaeger链路追踪: http://localhost:16686${NC}"
printf "%b\n" "${GREEN} - 前端界面: http://localhost:${FRONTEND_PORT:-80}${NC}"
printf "%b\n" "${GREEN} - API接口: http://localhost:${APP_PORT:-8080}${NC}"
printf "%b\n" "${GREEN} - Jaeger链路追踪: http://localhost:16686${NC}"
echo ""
log_info "正在持续输出容器日志(按 Ctrl+C 退出日志,容器不会停止)..."
"$DOCKER_COMPOSE_BIN" $DOCKER_COMPOSE_SUBCMD logs app docreader postgres --since=10s -f
else
log_error "部分服务启动失败,请检查日志并修复问题"
fi
elif [ "$START_OLLAMA" = true ] && [ $OLLAMA_RESULT -eq 0 ]; then
log_success "Ollama服务启动完成可通过以下地址访问:"
echo -e "${GREEN} - Ollama API: http://localhost:$OLLAMA_PORT${NC}"
printf "%b\n" "${GREEN} - Ollama API: http://localhost:$OLLAMA_PORT${NC}"
elif [ "$START_DOCKER" = true ] && [ $DOCKER_RESULT -eq 0 ]; then
log_success "Docker容器启动完成可通过以下地址访问:"
echo -e "${GREEN} - 前端界面: http://localhost:${FRONTEND_PORT:-80}${NC}"
echo -e "${GREEN} - API接口: http://localhost:${APP_PORT:-8080}${NC}"
echo -e "${GREEN} - Jaeger链路追踪: http://localhost:16686${NC}"
printf "%b\n" "${GREEN} - 前端界面: http://localhost:${FRONTEND_PORT:-80}${NC}"
printf "%b\n" "${GREEN} - API接口: http://localhost:${APP_PORT:-8080}${NC}"
printf "%b\n" "${GREEN} - Jaeger链路追踪: http://localhost:16686${NC}"
echo ""
log_info "正在持续输出容器日志(按 Ctrl+C 退出日志,容器不会停止)..."
"$DOCKER_COMPOSE_BIN" $DOCKER_COMPOSE_SUBCMD logs app docreader postgres --since=10s -f
fi
fi

View File

@@ -40,6 +40,31 @@ class PaddleOCRBackend(OCRBackend):
os.environ['CUDA_VISIBLE_DEVICES'] = ''
paddle.set_device('cpu')
# 尝试检测CPU是否支持AVX指令集
try:
import subprocess
import platform
# 检测CPU是否支持AVX
if platform.system() == "Linux":
try:
result = subprocess.run(['grep', '-o', 'avx', '/proc/cpuinfo'],
capture_output=True, text=True, timeout=5)
has_avx = 'avx' in result.stdout.lower()
if not has_avx:
logger.warning("CPU does not support AVX instructions, using compatibility mode")
# 进一步限制指令集使用
os.environ['FLAGS_use_avx2'] = '0'
os.environ['FLAGS_use_avx'] = '1'
except (subprocess.TimeoutExpired, FileNotFoundError, subprocess.SubprocessError):
logger.warning("Could not detect AVX support, using compatibility mode")
os.environ['FLAGS_use_avx2'] = '0'
os.environ['FLAGS_use_avx'] = '1'
except Exception as e:
logger.warning(f"Error detecting CPU capabilities: {e}, using compatibility mode")
os.environ['FLAGS_use_avx2'] = '0'
os.environ['FLAGS_use_avx'] = '1'
from paddleocr import PaddleOCR
# OCR configuration with text orientation classification enabled
ocr_config = {
@@ -67,6 +92,13 @@ class PaddleOCRBackend(OCRBackend):
except ImportError as e:
logger.error(f"Failed to import paddleocr: {str(e)}. Please install it with 'pip install paddleocr'")
except OSError as e:
if "Illegal instruction" in str(e) or "core dumped" in str(e):
logger.error(f"PaddlePaddle crashed due to CPU instruction set incompatibility: {str(e)}")
logger.error("This usually happens when the CPU doesn't support AVX instructions.")
logger.error("Please try installing a CPU-only version of PaddlePaddle or use a different OCR backend.")
else:
logger.error(f"Failed to initialize PaddleOCR due to OS error: {str(e)}")
except Exception as e:
logger.error(f"Failed to initialize PaddleOCR: {str(e)}")