mirror of
https://github.com/Tencent/WeKnora.git
synced 2025-11-25 11:29:31 +08:00
Compare commits
42 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ca704fa054 | ||
|
|
02b78a5908 | ||
|
|
de96a52d54 | ||
|
|
f24cd817cb | ||
|
|
4824e41361 | ||
|
|
bfd4fffbe3 | ||
|
|
a2902de6ce | ||
|
|
a5c3623a02 | ||
|
|
8f723b38fb | ||
|
|
7973128f4c | ||
|
|
8ed050b8ec | ||
|
|
4ccbd2a127 | ||
|
|
512910584b | ||
|
|
cd7e02e54a | ||
|
|
c9b1f43ed7 | ||
|
|
76fc64a807 | ||
|
|
947899ff10 | ||
|
|
5e0a99b127 | ||
|
|
b04566be32 | ||
|
|
0157eb25bd | ||
|
|
91e65d6445 | ||
|
|
c589a911dc | ||
|
|
66aec78960 | ||
|
|
76fbfdf8ac | ||
|
|
4137a63852 | ||
|
|
d28f805707 | ||
|
|
2e395864b9 | ||
|
|
4005aa3ded | ||
|
|
5e22f96d37 | ||
|
|
2237e1ee55 | ||
|
|
b11df52cfb | ||
|
|
c3744866fd | ||
|
|
c2d52a9374 | ||
|
|
81bd2e6c2c | ||
|
|
0908f9c487 | ||
|
|
1aac37d3fd | ||
|
|
cd249df8c8 | ||
|
|
092b30af3e | ||
|
|
74c121f7fb | ||
|
|
78088057fb | ||
|
|
bff0e742fa | ||
|
|
6598baab2e |
12
.env.example
12
.env.example
@@ -121,6 +121,18 @@ COS_ENABLE_OLD_DOMAIN=true
|
||||
# 如果解析网络连接使用Web代理,需要配置以下参数
|
||||
# WEB_PROXY=your_web_proxy
|
||||
|
||||
# Neo4j 开关
|
||||
# NEO4J_ENABLE=false
|
||||
|
||||
# Neo4j的访问地址
|
||||
# NEO4J_URI=neo4j://neo4j:7687
|
||||
|
||||
# Neo4j的用户名和密码
|
||||
# NEO4J_USERNAME=neo4j
|
||||
|
||||
# Neo4j的密码
|
||||
# NEO4J_PASSWORD=password
|
||||
|
||||
##############################################################
|
||||
|
||||
###### 注意: 以下配置不再生效,已在Web“配置初始化”阶段完成 #########
|
||||
|
||||
13
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
13
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
@@ -49,15 +49,7 @@ body:
|
||||
|
||||
请按照以下步骤收集相关日志:
|
||||
|
||||
**1. 应用模块日志:**
|
||||
```bash
|
||||
docker exec -it WeKnora-app tail -f /var/log/WeKnora.log
|
||||
```
|
||||
|
||||
**2. 文档解析模块日志:**
|
||||
```bash
|
||||
docker exec -it WeKnora-docreader tail -f /var/log/docreader.log
|
||||
```
|
||||
docker compose logs -f --tail=1000 app docreader postgres
|
||||
|
||||
请重现问题并收集相关日志,然后粘贴到下面的日志字段中。
|
||||
|
||||
@@ -68,8 +60,7 @@ body:
|
||||
description: 请按照上面的指南收集并粘贴相关日志
|
||||
placeholder: |
|
||||
请粘贴从以下命令收集的日志:
|
||||
- docker exec -it WeKnora-app tail -f /var/log/WeKnora.log
|
||||
- docker exec -it WeKnora-docreader tail -f /var/log/docreader.log
|
||||
docker compose logs -f --tail=1000 app docreader postgres
|
||||
render: shell
|
||||
|
||||
- type: input
|
||||
|
||||
8
.github/ISSUE_TEMPLATE/question.yml
vendored
8
.github/ISSUE_TEMPLATE/question.yml
vendored
@@ -68,14 +68,8 @@ body:
|
||||
|
||||
如果问题涉及错误或需要调试,请收集相关日志:
|
||||
|
||||
**应用模块日志:**
|
||||
```bash
|
||||
docker exec -it WeKnora-app tail -f /var/log/WeKnora.log
|
||||
```
|
||||
|
||||
**文档解析模块日志:**
|
||||
```bash
|
||||
docker exec -it WeKnora-docreader tail -f /var/log/docreader.log
|
||||
docker compose logs -f --tail=1000 app docreader postgres
|
||||
```
|
||||
|
||||
- type: textarea
|
||||
|
||||
212
.github/workflows/docker-image.yml
vendored
212
.github/workflows/docker-image.yml
vendored
@@ -1,6 +1,8 @@
|
||||
name: Build and Push Docker Image
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "v*"
|
||||
branches:
|
||||
- main
|
||||
|
||||
@@ -9,51 +11,201 @@ concurrency:
|
||||
cancel-in-progress: false
|
||||
|
||||
jobs:
|
||||
build-app:
|
||||
build-ui:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- service_name: ui
|
||||
file: frontend/Dockerfile
|
||||
context: ./frontend
|
||||
platform: linux/amd64,linux/arm64
|
||||
- service_name: app
|
||||
file: docker/Dockerfile.app
|
||||
context: .
|
||||
platform: linux/amd64,linux/arm64
|
||||
- service_name: docreader
|
||||
file: docker/Dockerfile.docreader
|
||||
context: .
|
||||
platform: linux/amd64,linux/arm64
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v2
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||
|
||||
- name: Read VERSION file
|
||||
run: echo "VERSION=$(cat VERSION)" >> $GITHUB_ENV
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ secrets.DOCKERHUB_USERNAME }}/weknora-ui
|
||||
|
||||
- name: Build ui Image
|
||||
uses: docker/build-push-action@v3
|
||||
with:
|
||||
push: true
|
||||
platforms: linux/amd64,linux/arm64
|
||||
file: frontend/Dockerfile
|
||||
context: ./frontend
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
cache-from: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/weknora-ui:cache
|
||||
cache-to: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/weknora-ui:cache,mode=max
|
||||
|
||||
build-docreader:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ secrets.DOCKERHUB_USERNAME }}/weknora-docreader
|
||||
|
||||
- name: Build docreader Image
|
||||
uses: docker/build-push-action@v3
|
||||
with:
|
||||
push: true
|
||||
platforms: linux/amd64,linux/arm64
|
||||
file: docker/Dockerfile.docreader
|
||||
context: .
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
cache-from: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/weknora-docreader:cache
|
||||
cache-to: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/weknora-docreader:cache,mode=max
|
||||
|
||||
build-app:
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- arch: amd64
|
||||
platform: linux/amd64
|
||||
runs: ubuntu-latest
|
||||
- arch: arm64
|
||||
platform: linux/arm64
|
||||
runs: ubuntu-24.04-arm
|
||||
runs-on: ${{ matrix.runs }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
id: setup-buildx
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ secrets.DOCKERHUB_USERNAME }}/weknora-app
|
||||
|
||||
- name: Prepare version info
|
||||
id: version
|
||||
run: |
|
||||
# 使用统一的版本管理脚本
|
||||
eval "$(./scripts/get_version.sh env)"
|
||||
echo "version=$VERSION" >> $GITHUB_OUTPUT
|
||||
echo "commit_id=$COMMIT_ID" >> $GITHUB_OUTPUT
|
||||
echo "build_time=$BUILD_TIME" >> $GITHUB_OUTPUT
|
||||
echo "go_version=$GO_VERSION" >> $GITHUB_OUTPUT
|
||||
|
||||
platform=${{ matrix.platform }}
|
||||
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
||||
# 显示版本信息
|
||||
./scripts/get_version.sh info
|
||||
|
||||
- name: Build ${{ matrix.service_name }} Image
|
||||
- name: Build Cache for Docker
|
||||
uses: actions/cache@v4
|
||||
id: cache
|
||||
with:
|
||||
path: go-pkg-mod
|
||||
key: ${{ env.PLATFORM_PAIR }}-go-build-cache-${{ hashFiles('**/go.sum') }}
|
||||
|
||||
- name: Inject go-build-cache
|
||||
uses: reproducible-containers/buildkit-cache-dance@v3
|
||||
with:
|
||||
builder: ${{ steps.setup-buildx.outputs.name }}
|
||||
cache-map: |
|
||||
{
|
||||
"go-pkg-mod": "/go/pkg/mod"
|
||||
}
|
||||
skip-extraction: ${{ steps.cache.outputs.cache-hit }}
|
||||
|
||||
- name: Build app Image
|
||||
id: build
|
||||
uses: docker/build-push-action@v3
|
||||
with:
|
||||
push: true
|
||||
platforms: ${{ matrix.platform }}
|
||||
file: ${{ matrix.file }}
|
||||
context: ${{ matrix.context }}
|
||||
tags: |
|
||||
${{ secrets.DOCKERHUB_USERNAME }}/weknora-${{ matrix.service_name }}:latest
|
||||
${{ secrets.DOCKERHUB_USERNAME }}/weknora-${{ matrix.service_name }}:${{ env.VERSION }}
|
||||
cache-from: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/weknora-${{ matrix.service_name }}:cache
|
||||
cache-to: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/weknora-${{ matrix.service_name }}:cache,mode=max
|
||||
file: docker/Dockerfile.app
|
||||
context: .
|
||||
build-args: |
|
||||
${{ format('VERSION_ARG={0}', steps.version.outputs.version) }}
|
||||
${{ format('COMMIT_ID_ARG={0}', steps.version.outputs.commit_id) }}
|
||||
${{ format('BUILD_TIME_ARG={0}', steps.version.outputs.build_time) }}
|
||||
${{ format('GO_VERSION_ARG={0}', steps.version.outputs.go_version) }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
tags: ${{ secrets.DOCKERHUB_USERNAME }}/weknora-app
|
||||
cache-from: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/weknora-app:cache-${{ env.PLATFORM_PAIR }}
|
||||
cache-to: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/weknora-app:cache-${{ env.PLATFORM_PAIR }},mode=max
|
||||
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
|
||||
|
||||
- name: Export digest
|
||||
run: |
|
||||
mkdir -p ${{ runner.temp }}/digests
|
||||
digest="${{ steps.build.outputs.digest }}"
|
||||
touch "${{ runner.temp }}/digests/${digest#sha256:}"
|
||||
|
||||
- name: Upload digest
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: digests-${{ env.PLATFORM_PAIR }}
|
||||
path: ${{ runner.temp }}/digests/*
|
||||
if-no-files-found: error
|
||||
retention-days: 1
|
||||
|
||||
merge:
|
||||
runs-on: ubuntu-latest
|
||||
needs:
|
||||
- build-app
|
||||
steps:
|
||||
- name: Download digests
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
path: ${{ runner.temp }}/digests
|
||||
pattern: digests-*
|
||||
merge-multiple: true
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ secrets.DOCKERHUB_USERNAME }}/weknora-app
|
||||
|
||||
- name: Create manifest list and push
|
||||
working-directory: ${{ runner.temp }}/digests
|
||||
run: |
|
||||
docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
|
||||
$(printf '${{ secrets.DOCKERHUB_USERNAME }}/weknora-app@sha256:%s ' *)
|
||||
|
||||
- name: Inspect image
|
||||
run: |
|
||||
docker buildx imagetools inspect ${{ secrets.DOCKERHUB_USERNAME }}/weknora-app:${{ steps.meta.outputs.version }}
|
||||
|
||||
86
CHANGELOG.md
86
CHANGELOG.md
@@ -2,6 +2,89 @@
|
||||
|
||||
All notable changes to this project will be documented in this file.
|
||||
|
||||
## [0.1.4] - 2025-09-17
|
||||
|
||||
### 🚀 Major Features
|
||||
- **NEW**: Multi-knowledgebases operation support
|
||||
- Added comprehensive multi-knowledgebase management functionality
|
||||
- Implemented multi-data source search engine configuration and optimization logic
|
||||
- Enhanced knowledge base switching and management in UI
|
||||
- **NEW**: Enhanced tenant information management
|
||||
- Added dedicated tenant information page
|
||||
- Improved user and tenant management capabilities
|
||||
|
||||
### 🎨 UI/UX Improvements
|
||||
- **REDESIGNED**: Settings page with improved layout and functionality
|
||||
- **ENHANCED**: Menu component with multi-knowledgebase support
|
||||
- **IMPROVED**: Initialization configuration page structure
|
||||
- **OPTIMIZED**: Login page and authentication flow
|
||||
|
||||
### 🔒 Security Fixes
|
||||
- **FIXED**: XSS attack vulnerabilities in thinking component
|
||||
- **FIXED**: Content Security Policy (CSP) errors
|
||||
- **ENHANCED**: Frontend security measures and input sanitization
|
||||
|
||||
### 🐛 Bug Fixes
|
||||
- **FIXED**: Login direct page navigation issues
|
||||
- **FIXED**: App LLM model check logic
|
||||
- **FIXED**: Version script functionality
|
||||
- **FIXED**: File download content errors
|
||||
- **IMPROVED**: Document content component display
|
||||
|
||||
### 🧹 Code Cleanup
|
||||
- **REMOVED**: Test data functionality and related APIs
|
||||
- **SIMPLIFIED**: Initialization configuration components
|
||||
- **CLEANED**: Redundant UI components and unused code
|
||||
|
||||
|
||||
## [0.1.3] - 2025-09-16
|
||||
|
||||
### 🔒 Security Features
|
||||
- **NEW**: Added login authentication functionality to enhance system security
|
||||
- Implemented user authentication and authorization mechanisms
|
||||
- Added session management and access control
|
||||
- Fixed XSS attack vulnerabilities in frontend components
|
||||
|
||||
### 📚 Documentation Updates
|
||||
- Added security notices in all README files (English, Chinese, Japanese)
|
||||
- Updated deployment recommendations emphasizing internal/private network deployment
|
||||
- Enhanced security guidelines to prevent information leakage risks
|
||||
- Fixed documentation spelling issues
|
||||
|
||||
### 🛡️ Security Improvements
|
||||
- Hide API keys in UI for security purposes
|
||||
- Enhanced input sanitization and XSS protection
|
||||
- Added comprehensive security utilities
|
||||
|
||||
### 🐛 Bug Fixes
|
||||
- Fixed OCR AVX support issues
|
||||
- Improved frontend health check dependencies
|
||||
- Enhanced Docker binary downloads for target architecture
|
||||
- Fixed COS file service initialization parameters and URL processing logic
|
||||
|
||||
### 🚀 Features & Enhancements
|
||||
- Improved application and docreader log output
|
||||
- Enhanced frontend routing and authentication flow
|
||||
- Added comprehensive user management system
|
||||
- Improved initialization configuration handling
|
||||
|
||||
### 🛡️ Security Recommendations
|
||||
- Deploy WeKnora services in internal/private network environments
|
||||
- Avoid direct exposure to public internet
|
||||
- Configure proper firewall rules and access controls
|
||||
- Regular updates for security patches and improvements
|
||||
|
||||
## [0.1.2] - 2025-09-10
|
||||
|
||||
- Fixed health check implementation for docreader service
|
||||
- Improved query handling for empty queries
|
||||
- Enhanced knowledge base column value update methods
|
||||
- Optimized logging throughout the application
|
||||
- Added process parsing documentation for markdown files
|
||||
- Fixed OCR model pre-fetching in Docker containers
|
||||
- Resolved image parser concurrency errors
|
||||
- Added support for modifying listening port configuration
|
||||
|
||||
## [0.1.0] - 2025-09-08
|
||||
|
||||
- Initial public release of WeKnora.
|
||||
@@ -14,4 +97,7 @@ All notable changes to this project will be documented in this file.
|
||||
- Docker Compose for quick startup and service orchestration.
|
||||
- MCP server support for integrating with MCP-compatible clients.
|
||||
|
||||
[0.1.4]: https://github.com/Tencent/WeKnora/tree/v0.1.4
|
||||
[0.1.3]: https://github.com/Tencent/WeKnora/tree/v0.1.3
|
||||
[0.1.2]: https://github.com/Tencent/WeKnora/tree/v0.1.2
|
||||
[0.1.0]: https://github.com/Tencent/WeKnora/tree/v0.1.0
|
||||
|
||||
17
Makefile
17
Makefile
@@ -85,7 +85,15 @@ clean:
|
||||
|
||||
# Build Docker image
|
||||
docker-build-app:
|
||||
docker build --platform $(PLATFORM) -f docker/Dockerfile.app -t $(DOCKER_IMAGE):$(DOCKER_TAG) .
|
||||
@echo "获取版本信息..."
|
||||
@eval $$(./scripts/get_version.sh env); \
|
||||
./scripts/get_version.sh info; \
|
||||
docker build --platform $(PLATFORM) \
|
||||
--build-arg VERSION_ARG="$$VERSION" \
|
||||
--build-arg COMMIT_ID_ARG="$$COMMIT_ID" \
|
||||
--build-arg BUILD_TIME_ARG="$$BUILD_TIME" \
|
||||
--build-arg GO_VERSION_ARG="$$GO_VERSION" \
|
||||
-f docker/Dockerfile.app -t $(DOCKER_IMAGE):$(DOCKER_TAG) .
|
||||
|
||||
# Build docreader Docker image
|
||||
docker-build-docreader:
|
||||
@@ -168,7 +176,12 @@ deps:
|
||||
|
||||
# Build for production
|
||||
build-prod:
|
||||
GOOS=linux go build -a -installsuffix cgo -ldflags="-w -s" -o $(BINARY_NAME) $(MAIN_PATH)
|
||||
VERSION=$${VERSION:-unknown}; \
|
||||
COMMIT_ID=$${COMMIT_ID:-unknown}; \
|
||||
BUILD_TIME=$${BUILD_TIME:-unknown}; \
|
||||
GO_VERSION=$${GO_VERSION:-unknown}; \
|
||||
LDFLAGS="-X 'github.com/Tencent/WeKnora/internal/handler.Version=$$VERSION' -X 'github.com/Tencent/WeKnora/internal/handler.CommitID=$$COMMIT_ID' -X 'github.com/Tencent/WeKnora/internal/handler.BuildTime=$$BUILD_TIME' -X 'github.com/Tencent/WeKnora/internal/handler.GoVersion=$$GO_VERSION'"; \
|
||||
go build -ldflags="-w -s $$LDFLAGS" -o $(BINARY_NAME) $(MAIN_PATH)
|
||||
|
||||
clean-db:
|
||||
@echo "Cleaning database..."
|
||||
|
||||
11
README.md
11
README.md
@@ -15,7 +15,7 @@
|
||||
<img src="https://img.shields.io/badge/License-MIT-ffffff?labelColor=d4eaf7&color=2e6cc4" alt="License">
|
||||
</a>
|
||||
<a href="./CHANGELOG.md">
|
||||
<img alt="Version" src="https://img.shields.io/badge/version-0.1.0-2e6cc4?labelColor=d4eaf7">
|
||||
<img alt="Version" src="https://img.shields.io/badge/version-0.1.3-2e6cc4?labelColor=d4eaf7">
|
||||
</a>
|
||||
</p>
|
||||
|
||||
@@ -41,6 +41,15 @@ It adopts a modular architecture that combines multimodal preprocessing, semanti
|
||||
|
||||
**Website:** https://weknora.weixin.qq.com
|
||||
|
||||
## 🔒 Security Notice
|
||||
|
||||
**Important:** Starting from v0.1.3, WeKnora includes login authentication functionality to enhance system security. For production deployments, we strongly recommend:
|
||||
|
||||
- Deploy WeKnora services in internal/private network environments rather than public internet
|
||||
- Avoid exposing the service directly to public networks to prevent potential information leakage
|
||||
- Configure proper firewall rules and access controls for your deployment environment
|
||||
- Regularly update to the latest version for security patches and improvements
|
||||
|
||||
## 🏗️ Architecture
|
||||
|
||||

|
||||
|
||||
11
README_CN.md
11
README_CN.md
@@ -15,7 +15,7 @@
|
||||
<img src="https://img.shields.io/badge/License-MIT-ffffff?labelColor=d4eaf7&color=2e6cc4" alt="License">
|
||||
</a>
|
||||
<a href="./CHANGELOG.md">
|
||||
<img alt="版本" src="https://img.shields.io/badge/version-0.1.0-2e6cc4?labelColor=d4eaf7">
|
||||
<img alt="版本" src="https://img.shields.io/badge/version-0.1.3-2e6cc4?labelColor=d4eaf7">
|
||||
</a>
|
||||
</p>
|
||||
|
||||
@@ -41,6 +41,15 @@
|
||||
|
||||
**官网:** https://weknora.weixin.qq.com
|
||||
|
||||
## 🔒 安全声明
|
||||
|
||||
**重要提示:** 从 v0.1.3 版本开始,WeKnora 提供了登录鉴权功能,以增强系统安全性。在生产环境部署时,我们强烈建议:
|
||||
|
||||
- 将 WeKnora 服务部署在内网/私有网络环境中,而非公网环境
|
||||
- 避免将服务直接暴露在公网上,以防止重要信息泄露风险
|
||||
- 为部署环境配置适当的防火墙规则和访问控制
|
||||
- 定期更新到最新版本以获取安全补丁和改进
|
||||
|
||||
## 🏗️ 架构设计
|
||||
|
||||

|
||||
|
||||
11
README_JA.md
11
README_JA.md
@@ -15,7 +15,7 @@
|
||||
<img src="https://img.shields.io/badge/License-MIT-ffffff?labelColor=d4eaf7&color=2e6cc4" alt="License">
|
||||
</a>
|
||||
<a href="./CHANGELOG.md">
|
||||
<img alt="バージョン" src="https://img.shields.io/badge/version-0.1.0-2e6cc4?labelColor=d4eaf7">
|
||||
<img alt="バージョン" src="https://img.shields.io/badge/version-0.1.3-2e6cc4?labelColor=d4eaf7">
|
||||
</a>
|
||||
</p>
|
||||
|
||||
@@ -41,6 +41,15 @@
|
||||
|
||||
**公式サイト:** https://weknora.weixin.qq.com
|
||||
|
||||
## 🔒 セキュリティ通知
|
||||
|
||||
**重要:** v0.1.3バージョンより、WeKnoraにはシステムセキュリティを強化するためのログイン認証機能が含まれています。本番環境でのデプロイメントにおいて、以下を強く推奨します:
|
||||
|
||||
- WeKnoraサービスはパブリックインターネットではなく、内部/プライベートネットワーク環境にデプロイしてください
|
||||
- 重要な情報漏洩を防ぐため、サービスを直接パブリックネットワークに公開することは避けてください
|
||||
- デプロイメント環境に適切なファイアウォールルールとアクセス制御を設定してください
|
||||
- セキュリティパッチと改善のため、定期的に最新バージョンに更新してください
|
||||
|
||||
## 🏗️ アーキテクチャ設計
|
||||
|
||||

|
||||
|
||||
@@ -74,6 +74,9 @@ type UpdateImageInfoRequest struct {
|
||||
// ErrDuplicateFile is returned when attempting to create a knowledge entry with a file that already exists
|
||||
var ErrDuplicateFile = errors.New("file already exists")
|
||||
|
||||
// ErrDuplicateURL is returned when attempting to create a knowledge entry with a URL that already exists
|
||||
var ErrDuplicateURL = errors.New("URL already exists")
|
||||
|
||||
// CreateKnowledgeFromFile creates a knowledge entry from a local file path
|
||||
func (c *Client) CreateKnowledgeFromFile(ctx context.Context,
|
||||
knowledgeBaseID string, filePath string, metadata map[string]string, enableMultimodel *bool,
|
||||
@@ -186,7 +189,12 @@ func (c *Client) CreateKnowledgeFromURL(ctx context.Context, knowledgeBaseID str
|
||||
}
|
||||
|
||||
var response KnowledgeResponse
|
||||
if err := parseResponse(resp, &response); err != nil {
|
||||
if resp.StatusCode == http.StatusConflict {
|
||||
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
return &response.Data, ErrDuplicateURL
|
||||
} else if err := parseResponse(resp, &response); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ conversation:
|
||||
keyword_threshold: 0.3
|
||||
embedding_top_k: 10
|
||||
vector_threshold: 0.5
|
||||
rerank_threshold: 0.7
|
||||
rerank_threshold: 0.5
|
||||
rerank_top_k: 5
|
||||
fallback_strategy: "fixed"
|
||||
fallback_response: "抱歉,我无法回答这个问题。"
|
||||
@@ -534,3 +534,69 @@ knowledge_base:
|
||||
split_markers: ["\n\n", "\n", "。"]
|
||||
image_processing:
|
||||
enable_multimodal: true
|
||||
|
||||
extract:
|
||||
extract_graph:
|
||||
description: |
|
||||
请基于给定文本,按以下步骤完成信息提取任务,确保逻辑清晰、信息完整准确:
|
||||
|
||||
## 一、实体提取与属性补充
|
||||
1. **提取核心实体**:通读文本,按逻辑顺序(如文本叙述顺序、实体关联紧密程度)提取所有与任务相关的核心实体。
|
||||
2. **补充实体详细属性**:针对每个提取的实体,全面补充其在文本中明确提及的详细属性,确保无关键属性遗漏。
|
||||
|
||||
## 二、关系提取与验证
|
||||
1. **明确关系类型**:仅从指定关系列表中选择对应类型,限定关系类型为: %s。
|
||||
2. **提取有效关系**:基于已提取的实体及属性,识别文本中真实存在的关系,确保关系符合文本事实、无虚假关联。
|
||||
3. **明确关系主体**:对每一组提取的关系,清晰标注两个关联主体,避免主体混淆。
|
||||
4. **补充关联属性**:若文本中存在与该关系直接相关的补充信息,需将该信息作为关系的关联属性补充,进一步完善关系信息。
|
||||
tags:
|
||||
- "作者"
|
||||
- "别名"
|
||||
examples:
|
||||
- text: |
|
||||
《红楼梦》,又名《石头记》,是清代作家曹雪芹创作的中国古典四大名著之一,被誉为中国封建社会的百科全书。该书前80回由曹雪芹所著,后40回一般认为是高鹗所续。
|
||||
小说以贾、史、王、薛四大家族的兴衰为背景,以贾宝玉、林黛玉和薛宝钗的爱情悲剧为主线,刻画了以贾宝玉和金陵十二钗为中心的正邪两赋、贤愚并出的高度复杂的人物群像。
|
||||
成书于乾隆年间(1743年前后),是中国文学史上现实主义的高峰,对后世影响深远。
|
||||
node:
|
||||
- name: "红楼梦"
|
||||
attributes:
|
||||
- "中国古典四大名著之一"
|
||||
- "又名《石头记》"
|
||||
- "被誉为中国封建社会的百科全书"
|
||||
- name: "石头记"
|
||||
attributes:
|
||||
- "《红楼梦》的别名"
|
||||
- name: "曹雪芹"
|
||||
attributes:
|
||||
- "清代作家"
|
||||
- "《红楼梦》前 80 回的作者"
|
||||
- name: "高鹗"
|
||||
attributes:
|
||||
- "一般认为是《红楼梦》后 40 回的续写者"
|
||||
relation:
|
||||
- node1: "红楼梦"
|
||||
node2: "曹雪芹"
|
||||
type: "作者"
|
||||
- node1: "红楼梦"
|
||||
node2: "高鹗"
|
||||
type: "作者"
|
||||
- node1: "红楼梦"
|
||||
node2: "石头记"
|
||||
type: "别名"
|
||||
extract_entity:
|
||||
description: |
|
||||
请基于用户给的问题,按以下步骤处理关键信息提取任务:
|
||||
1. 梳理逻辑关联:首先完整分析文本内容,明确其核心逻辑关系,并简要标注该核心逻辑类型;
|
||||
2. 提取关键实体:围绕梳理出的逻辑关系,精准提取文本中的关键信息并归类为明确实体,确保不遗漏核心信息、不添加冗余内容;
|
||||
3. 排序实体优先级:按实体与文本核心主题的关联紧密程度排序,优先呈现对理解文本主旨最重要的实体;
|
||||
examples:
|
||||
- text: "《红楼梦》,又名《石头记》,是清代作家曹雪芹创作的中国古典四大名著之一,被誉为中国封建社会的百科全书。"
|
||||
node:
|
||||
- name: "红楼梦"
|
||||
- name: "曹雪芹"
|
||||
- name: "中国古典四大名著"
|
||||
fabri_text:
|
||||
with_tag: |
|
||||
请随机生成一段文本,要求内容与 %s 等相关,字数在 [50-200] 之间,并且尽量包含一些与这些标签相关的专业术语或典型元素,使文本更具针对性和相关性。
|
||||
with_no_tag: |
|
||||
请随机生成一段文本,内容请自由发挥,字数在 [50-200] 之间。
|
||||
@@ -7,21 +7,27 @@ services:
|
||||
volumes:
|
||||
- data-files:/data/files
|
||||
- ./config/config.yaml:/app/config/config.yaml
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8080/health"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 60s
|
||||
environment:
|
||||
- COS_SECRET_ID=${COS_SECRET_ID}
|
||||
- COS_SECRET_KEY=${COS_SECRET_KEY}
|
||||
- COS_REGION=${COS_REGION}
|
||||
- COS_BUCKET_NAME=${COS_BUCKET_NAME}
|
||||
- COS_APP_ID=${COS_APP_ID}
|
||||
- COS_PATH_PREFIX=${COS_PATH_PREFIX}
|
||||
- COS_ENABLE_OLD_DOMAIN=${COS_ENABLE_OLD_DOMAIN}
|
||||
- GIN_MODE=${GIN_MODE}
|
||||
- COS_SECRET_ID=${COS_SECRET_ID:-}
|
||||
- COS_SECRET_KEY=${COS_SECRET_KEY:-}
|
||||
- COS_REGION=${COS_REGION:-}
|
||||
- COS_BUCKET_NAME=${COS_BUCKET_NAME:-}
|
||||
- COS_APP_ID=${COS_APP_ID:-}
|
||||
- COS_PATH_PREFIX=${COS_PATH_PREFIX:-}
|
||||
- COS_ENABLE_OLD_DOMAIN=${COS_ENABLE_OLD_DOMAIN:-}
|
||||
- GIN_MODE=${GIN_MODE:-}
|
||||
- DB_DRIVER=postgres
|
||||
- DB_HOST=postgres
|
||||
- DB_PORT=5432
|
||||
- DB_USER=${DB_USER}
|
||||
- DB_PASSWORD=${DB_PASSWORD}
|
||||
- DB_NAME=${DB_NAME}
|
||||
- DB_USER=${DB_USER:-}
|
||||
- DB_PASSWORD=${DB_PASSWORD:-}
|
||||
- DB_NAME=${DB_NAME:-}
|
||||
- TZ=Asia/Shanghai
|
||||
- OTEL_EXPORTER_OTLP_ENDPOINT=jaeger:4317
|
||||
- OTEL_SERVICE_NAME=WeKnora
|
||||
@@ -29,38 +35,42 @@ services:
|
||||
- OTEL_METRICS_EXPORTER=none
|
||||
- OTEL_LOGS_EXPORTER=none
|
||||
- OTEL_PROPAGATORS=tracecontext,baggage
|
||||
- RETRIEVE_DRIVER=${RETRIEVE_DRIVER}
|
||||
- ELASTICSEARCH_ADDR=${ELASTICSEARCH_ADDR}
|
||||
- ELASTICSEARCH_USERNAME=${ELASTICSEARCH_USERNAME}
|
||||
- ELASTICSEARCH_PASSWORD=${ELASTICSEARCH_PASSWORD}
|
||||
- ELASTICSEARCH_INDEX=${ELASTICSEARCH_INDEX}
|
||||
- RETRIEVE_DRIVER=${RETRIEVE_DRIVER:-}
|
||||
- ELASTICSEARCH_ADDR=${ELASTICSEARCH_ADDR:-}
|
||||
- ELASTICSEARCH_USERNAME=${ELASTICSEARCH_USERNAME:-}
|
||||
- ELASTICSEARCH_PASSWORD=${ELASTICSEARCH_PASSWORD:-}
|
||||
- ELASTICSEARCH_INDEX=${ELASTICSEARCH_INDEX:-}
|
||||
- DOCREADER_ADDR=docreader:50051
|
||||
- STORAGE_TYPE=${STORAGE_TYPE}
|
||||
- LOCAL_STORAGE_BASE_DIR=${LOCAL_STORAGE_BASE_DIR}
|
||||
- STORAGE_TYPE=${STORAGE_TYPE:-}
|
||||
- LOCAL_STORAGE_BASE_DIR=${LOCAL_STORAGE_BASE_DIR:-}
|
||||
- MINIO_ENDPOINT=minio:9000
|
||||
- MINIO_ACCESS_KEY_ID=${MINIO_ACCESS_KEY_ID:-minioadmin}
|
||||
- MINIO_SECRET_ACCESS_KEY=${MINIO_SECRET_ACCESS_KEY:-minioadmin}
|
||||
- MINIO_BUCKET_NAME=${MINIO_BUCKET_NAME}
|
||||
- MINIO_BUCKET_NAME=${MINIO_BUCKET_NAME:-}
|
||||
- OLLAMA_BASE_URL=${OLLAMA_BASE_URL:-http://host.docker.internal:11434}
|
||||
- STREAM_MANAGER_TYPE=${STREAM_MANAGER_TYPE}
|
||||
- STREAM_MANAGER_TYPE=${STREAM_MANAGER_TYPE:-}
|
||||
- REDIS_ADDR=redis:6379
|
||||
- REDIS_PASSWORD=${REDIS_PASSWORD}
|
||||
- REDIS_DB=${REDIS_DB}
|
||||
- REDIS_PREFIX=${REDIS_PREFIX}
|
||||
- ENABLE_GRAPH_RAG=${ENABLE_GRAPH_RAG}
|
||||
- TENANT_AES_KEY=${TENANT_AES_KEY}
|
||||
- REDIS_PASSWORD=${REDIS_PASSWORD:-}
|
||||
- REDIS_DB=${REDIS_DB:-}
|
||||
- REDIS_PREFIX=${REDIS_PREFIX:-}
|
||||
- ENABLE_GRAPH_RAG=${ENABLE_GRAPH_RAG:-}
|
||||
- NEO4J_ENABLE=${NEO4J_ENABLE:-}
|
||||
- NEO4J_URI=bolt://neo4j:7687
|
||||
- NEO4J_USERNAME=${NEO4J_USERNAME:-neo4j}
|
||||
- NEO4J_PASSWORD=${NEO4J_PASSWORD:-password}
|
||||
- TENANT_AES_KEY=${TENANT_AES_KEY:-}
|
||||
- CONCURRENCY_POOL_SIZE=${CONCURRENCY_POOL_SIZE:-5}
|
||||
- INIT_LLM_MODEL_NAME=${INIT_LLM_MODEL_NAME}
|
||||
- INIT_LLM_MODEL_BASE_URL=${INIT_LLM_MODEL_BASE_URL}
|
||||
- INIT_LLM_MODEL_API_KEY=${INIT_LLM_MODEL_API_KEY}
|
||||
- INIT_EMBEDDING_MODEL_NAME=${INIT_EMBEDDING_MODEL_NAME}
|
||||
- INIT_EMBEDDING_MODEL_BASE_URL=${INIT_EMBEDDING_MODEL_BASE_URL}
|
||||
- INIT_EMBEDDING_MODEL_API_KEY=${INIT_EMBEDDING_MODEL_API_KEY}
|
||||
- INIT_EMBEDDING_MODEL_DIMENSION=${INIT_EMBEDDING_MODEL_DIMENSION}
|
||||
- INIT_EMBEDDING_MODEL_ID=${INIT_EMBEDDING_MODEL_ID}
|
||||
- INIT_RERANK_MODEL_NAME=${INIT_RERANK_MODEL_NAME}
|
||||
- INIT_RERANK_MODEL_BASE_URL=${INIT_RERANK_MODEL_BASE_URL}
|
||||
- INIT_RERANK_MODEL_API_KEY=${INIT_RERANK_MODEL_API_KEY}
|
||||
- INIT_LLM_MODEL_NAME=${INIT_LLM_MODEL_NAME:-}
|
||||
- INIT_LLM_MODEL_BASE_URL=${INIT_LLM_MODEL_BASE_URL:-}
|
||||
- INIT_LLM_MODEL_API_KEY=${INIT_LLM_MODEL_API_KEY:-}
|
||||
- INIT_EMBEDDING_MODEL_NAME=${INIT_EMBEDDING_MODEL_NAME:-}
|
||||
- INIT_EMBEDDING_MODEL_BASE_URL=${INIT_EMBEDDING_MODEL_BASE_URL:-}
|
||||
- INIT_EMBEDDING_MODEL_API_KEY=${INIT_EMBEDDING_MODEL_API_KEY:-}
|
||||
- INIT_EMBEDDING_MODEL_DIMENSION=${INIT_EMBEDDING_MODEL_DIMENSION:-}
|
||||
- INIT_EMBEDDING_MODEL_ID=${INIT_EMBEDDING_MODEL_ID:-}
|
||||
- INIT_RERANK_MODEL_NAME=${INIT_RERANK_MODEL_NAME:-}
|
||||
- INIT_RERANK_MODEL_BASE_URL=${INIT_RERANK_MODEL_BASE_URL:-}
|
||||
- INIT_RERANK_MODEL_API_KEY=${INIT_RERANK_MODEL_API_KEY:-}
|
||||
depends_on:
|
||||
redis:
|
||||
condition: service_started
|
||||
@@ -70,6 +80,8 @@ services:
|
||||
condition: service_started
|
||||
docreader:
|
||||
condition: service_healthy
|
||||
neo4j:
|
||||
condition: service_started
|
||||
networks:
|
||||
- WeKnora-network
|
||||
restart: unless-stopped
|
||||
@@ -102,7 +114,8 @@ services:
|
||||
ports:
|
||||
- "${FRONTEND_PORT:-80}:80"
|
||||
depends_on:
|
||||
- app
|
||||
app:
|
||||
condition: service_healthy
|
||||
networks:
|
||||
- WeKnora-network
|
||||
restart: unless-stopped
|
||||
@@ -113,24 +126,24 @@ services:
|
||||
ports:
|
||||
- "${DOCREADER_PORT:-50051}:50051"
|
||||
environment:
|
||||
- COS_SECRET_ID=${COS_SECRET_ID}
|
||||
- COS_SECRET_KEY=${COS_SECRET_KEY}
|
||||
- COS_REGION=${COS_REGION}
|
||||
- COS_BUCKET_NAME=${COS_BUCKET_NAME}
|
||||
- COS_APP_ID=${COS_APP_ID}
|
||||
- COS_PATH_PREFIX=${COS_PATH_PREFIX}
|
||||
- COS_ENABLE_OLD_DOMAIN=${COS_ENABLE_OLD_DOMAIN}
|
||||
- VLM_MODEL_BASE_URL=${VLM_MODEL_BASE_URL}
|
||||
- VLM_MODEL_NAME=${VLM_MODEL_NAME}
|
||||
- VLM_MODEL_API_KEY=${VLM_MODEL_API_KEY}
|
||||
- STORAGE_TYPE=${STORAGE_TYPE}
|
||||
- COS_SECRET_ID=${COS_SECRET_ID:-}
|
||||
- COS_SECRET_KEY=${COS_SECRET_KEY:-}
|
||||
- COS_REGION=${COS_REGION:-}
|
||||
- COS_BUCKET_NAME=${COS_BUCKET_NAME:-}
|
||||
- COS_APP_ID=${COS_APP_ID:-}
|
||||
- COS_PATH_PREFIX=${COS_PATH_PREFIX:-}
|
||||
- COS_ENABLE_OLD_DOMAIN=${COS_ENABLE_OLD_DOMAIN:-}
|
||||
- VLM_MODEL_BASE_URL=${VLM_MODEL_BASE_URL:-}
|
||||
- VLM_MODEL_NAME=${VLM_MODEL_NAME:-}
|
||||
- VLM_MODEL_API_KEY=${VLM_MODEL_API_KEY:-}
|
||||
- STORAGE_TYPE=${STORAGE_TYPE:-}
|
||||
- MINIO_PUBLIC_ENDPOINT=http://localhost:${MINIO_PORT:-9000}
|
||||
- MINIO_ENDPOINT=minio:9000
|
||||
- MINIO_ACCESS_KEY_ID=${MINIO_ACCESS_KEY_ID:-minioadmin}
|
||||
- MINIO_SECRET_ACCESS_KEY=${MINIO_SECRET_ACCESS_KEY:-minioadmin}
|
||||
- MINIO_BUCKET_NAME=${MINIO_BUCKET_NAME}
|
||||
- MINIO_USE_SSL=${MINIO_USE_SSL}
|
||||
- WEB_PROXY=${WEB_PROXY}
|
||||
- MINIO_BUCKET_NAME=${MINIO_BUCKET_NAME:-}
|
||||
- MINIO_USE_SSL=${MINIO_USE_SSL:-}
|
||||
- WEB_PROXY=${WEB_PROXY:-}
|
||||
healthcheck:
|
||||
test: ["CMD", "grpc_health_probe", "-addr=:50051"]
|
||||
interval: 30s
|
||||
@@ -165,7 +178,7 @@ services:
|
||||
restart: unless-stopped
|
||||
# 修改的PostgreSQL配置
|
||||
postgres:
|
||||
image: paradedb/paradedb:latest
|
||||
image: paradedb/paradedb:v0.18.9-pg17
|
||||
container_name: WeKnora-postgres
|
||||
ports:
|
||||
- "${DB_PORT}:5432"
|
||||
@@ -202,6 +215,24 @@ services:
|
||||
networks:
|
||||
- WeKnora-network
|
||||
|
||||
neo4j:
|
||||
image: neo4j:latest
|
||||
container_name: WeKnora-neo4j
|
||||
volumes:
|
||||
- neo4j-data:/data
|
||||
environment:
|
||||
- NEO4J_AUTH=${NEO4J_USERNAME:-neo4j}/${NEO4J_PASSWORD:-password}
|
||||
- NEO4J_apoc_export_file_enabled=true
|
||||
- NEO4J_apoc_import_file_enabled=true
|
||||
- NEO4J_apoc_import_file_use__neo4j__config=true
|
||||
- NEO4JLABS_PLUGINS=["apoc"]
|
||||
ports:
|
||||
- "7474:7474"
|
||||
- "7687:7687"
|
||||
restart: always
|
||||
networks:
|
||||
- WeKnora-network
|
||||
|
||||
networks:
|
||||
WeKnora-network:
|
||||
driver: bridge
|
||||
@@ -212,3 +243,4 @@ volumes:
|
||||
jaeger_data:
|
||||
redis_data:
|
||||
minio_data:
|
||||
neo4j-data:
|
||||
|
||||
@@ -3,10 +3,6 @@ FROM golang:1.24-alpine AS builder
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install dependencies
|
||||
RUN sed -i 's/dl-cdn.alpinelinux.org/mirrors.tuna.tsinghua.edu.cn/g' /etc/apk/repositories && \
|
||||
apk add --no-cache git build-base
|
||||
|
||||
# 通过构建参数接收敏感信息
|
||||
ARG GOPRIVATE_ARG
|
||||
ARG GOPROXY_ARG
|
||||
@@ -17,19 +13,33 @@ ENV GOPRIVATE=${GOPRIVATE_ARG}
|
||||
ENV GOPROXY=${GOPROXY_ARG}
|
||||
ENV GOSUMDB=${GOSUMDB_ARG}
|
||||
|
||||
# Copy go mod and sum files
|
||||
COPY go.mod go.sum ./
|
||||
RUN go mod download
|
||||
# Install dependencies
|
||||
RUN sed -i 's/dl-cdn.alpinelinux.org/mirrors.tuna.tsinghua.edu.cn/g' /etc/apk/repositories && \
|
||||
apk add --no-cache git build-base
|
||||
|
||||
ENV CGO_ENABLED=1
|
||||
# Install migrate tool
|
||||
RUN go install -tags 'postgres' github.com/golang-migrate/migrate/v4/cmd/migrate@latest
|
||||
|
||||
# Copy source code
|
||||
# Copy go mod and sum files
|
||||
COPY go.mod go.sum ./
|
||||
RUN --mount=type=cache,target=/go/pkg/mod go mod download
|
||||
COPY . .
|
||||
|
||||
# Build the application
|
||||
RUN make build-prod
|
||||
# Get version and commit info for build injection
|
||||
ARG VERSION_ARG
|
||||
ARG COMMIT_ID_ARG
|
||||
ARG BUILD_TIME_ARG
|
||||
ARG GO_VERSION_ARG
|
||||
|
||||
# Set build-time variables
|
||||
ENV VERSION=${VERSION_ARG}
|
||||
ENV COMMIT_ID=${COMMIT_ID_ARG}
|
||||
ENV BUILD_TIME=${BUILD_TIME_ARG}
|
||||
ENV GO_VERSION=${GO_VERSION_ARG}
|
||||
|
||||
# Build the application with version info
|
||||
RUN --mount=type=cache,target=/go/pkg/mod make build-prod
|
||||
RUN --mount=type=cache,target=/go/pkg/mod cp -r /go/pkg/mod/github.com/yanyiwu/ /app/yanyiwu/
|
||||
|
||||
# Final stage
|
||||
FROM alpine:3.17
|
||||
@@ -39,36 +49,31 @@ WORKDIR /app
|
||||
# Install runtime dependencies
|
||||
RUN sed -i 's/dl-cdn.alpinelinux.org/mirrors.tuna.tsinghua.edu.cn/g' /etc/apk/repositories && \
|
||||
apk update && apk upgrade && \
|
||||
apk add --no-cache build-base postgresql-client mysql-client ca-certificates tzdata sed curl bash supervisor vim wget
|
||||
|
||||
# Copy the binary from the builder stage
|
||||
COPY --from=builder /app/WeKnora .
|
||||
COPY --from=builder /app/config ./config
|
||||
COPY --from=builder /app/scripts ./scripts
|
||||
COPY --from=builder /app/migrations ./migrations
|
||||
COPY --from=builder /app/dataset/samples ./dataset/samples
|
||||
|
||||
# Copy migrate tool from builder stage
|
||||
COPY --from=builder /go/bin/migrate /usr/local/bin/
|
||||
COPY --from=builder /go/pkg/mod/github.com/yanyiwu /go/pkg/mod/github.com/yanyiwu/
|
||||
|
||||
# Make scripts executable
|
||||
RUN chmod +x ./scripts/*.sh
|
||||
|
||||
# Setup supervisor configuration
|
||||
RUN mkdir -p /etc/supervisor.d/
|
||||
COPY docker/config/supervisord.conf /etc/supervisor.d/supervisord.conf
|
||||
|
||||
# Expose ports
|
||||
EXPOSE 8080
|
||||
|
||||
# Set environment variables
|
||||
ENV CGO_ENABLED=1
|
||||
apk add --no-cache build-base postgresql-client mysql-client ca-certificates tzdata sed curl bash vim wget
|
||||
|
||||
# Create a non-root user and switch to it
|
||||
RUN mkdir -p /data/files && \
|
||||
adduser -D -g '' appuser && \
|
||||
chown -R appuser:appuser /app /data/files
|
||||
|
||||
# Run supervisor instead of direct application start
|
||||
CMD ["supervisord", "-c", "/etc/supervisor.d/supervisord.conf"]
|
||||
# Copy migrate tool from builder stage
|
||||
COPY --from=builder /go/bin/migrate /usr/local/bin/
|
||||
COPY --from=builder /app/yanyiwu/ /go/pkg/mod/github.com/yanyiwu/
|
||||
|
||||
# Copy the binary from the builder stage
|
||||
COPY --from=builder /app/config ./config
|
||||
COPY --from=builder /app/scripts ./scripts
|
||||
COPY --from=builder /app/migrations ./migrations
|
||||
COPY --from=builder /app/dataset/samples ./dataset/samples
|
||||
COPY --from=builder /app/WeKnora .
|
||||
|
||||
# Make scripts executable
|
||||
RUN chmod +x ./scripts/*.sh
|
||||
|
||||
# Expose ports
|
||||
EXPOSE 8080
|
||||
|
||||
# Switch to non-root user and run the application directly
|
||||
USER appuser
|
||||
|
||||
CMD ["./WeKnora"]
|
||||
@@ -26,22 +26,31 @@ RUN apt-get update && apt-get install -y \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# 检查是否存在本地protoc安装包,如果存在则离线安装,否则在线安装,其他安装包按需求添加
|
||||
ARG TARGETARCH
|
||||
COPY packages/ /app/packages/
|
||||
RUN echo "检查本地protoc安装包..." && \
|
||||
if [ -f "/app/packages/protoc-3.19.4-linux-x86_64.zip" ]; then \
|
||||
# 根据目标架构选择正确的protoc包名
|
||||
case ${TARGETARCH} in \
|
||||
"amd64") PROTOC_ARCH="x86_64" ;; \
|
||||
"arm64") PROTOC_ARCH="aarch_64" ;; \
|
||||
"arm") PROTOC_ARCH="arm" ;; \
|
||||
*) echo "Unsupported architecture for protoc: ${TARGETARCH}" && exit 1 ;; \
|
||||
esac && \
|
||||
PROTOC_PACKAGE="protoc-3.19.4-linux-${PROTOC_ARCH}.zip" && \
|
||||
if [ -f "/app/packages/${PROTOC_PACKAGE}" ]; then \
|
||||
echo "发现本地protoc安装包,将进行离线安装"; \
|
||||
# 离线安装:使用本地包(精确路径避免歧义)
|
||||
cp /app/packages/protoc-*.zip /app/ && \
|
||||
unzip -o /app/protoc-*.zip -d /usr/local && \
|
||||
cp /app/packages/${PROTOC_PACKAGE} /app/ && \
|
||||
unzip -o /app/${PROTOC_PACKAGE} -d /usr/local && \
|
||||
chmod +x /usr/local/bin/protoc && \
|
||||
rm -f /app/protoc-*.zip; \
|
||||
rm -f /app/${PROTOC_PACKAGE}; \
|
||||
else \
|
||||
echo "未发现本地protoc安装包,将进行在线安装"; \
|
||||
# 在线安装:从网络下载
|
||||
curl -LO https://github.com/protocolbuffers/protobuf/releases/download/v3.19.4/protoc-3.19.4-linux-x86_64.zip && \
|
||||
unzip -o protoc-3.19.4-linux-x86_64.zip -d /usr/local && \
|
||||
curl -LO https://github.com/protocolbuffers/protobuf/releases/download/v3.19.4/${PROTOC_PACKAGE} && \
|
||||
unzip -o ${PROTOC_PACKAGE} -d /usr/local && \
|
||||
chmod +x /usr/local/bin/protoc && \
|
||||
rm -f protoc-3.19.4-linux-x86_64.zip; \
|
||||
rm -f ${PROTOC_PACKAGE}; \
|
||||
fi
|
||||
|
||||
# 复制依赖文件
|
||||
@@ -102,7 +111,6 @@ RUN apt-get update && apt-get install -y \
|
||||
libgl1 \
|
||||
libglib2.0-0 \
|
||||
antiword \
|
||||
supervisor \
|
||||
vim \
|
||||
tar \
|
||||
dpkg \
|
||||
@@ -118,30 +126,35 @@ RUN apt-get update && apt-get install -y \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# 安装 grpc_health_probe
|
||||
ARG TARGETARCH
|
||||
RUN GRPC_HEALTH_PROBE_VERSION=v0.4.24 && \
|
||||
wget -qO/bin/grpc_health_probe https://github.com/grpc-ecosystem/grpc-health-probe/releases/download/${GRPC_HEALTH_PROBE_VERSION}/grpc_health_probe-linux-amd64 && \
|
||||
# 根据目标架构选择正确的二进制文件
|
||||
case ${TARGETARCH} in \
|
||||
"amd64") ARCH="amd64" ;; \
|
||||
"arm64") ARCH="arm64" ;; \
|
||||
"arm") ARCH="arm" ;; \
|
||||
*) echo "Unsupported architecture: ${TARGETARCH}" && exit 1 ;; \
|
||||
esac && \
|
||||
wget -qO/bin/grpc_health_probe https://github.com/grpc-ecosystem/grpc-health-probe/releases/download/${GRPC_HEALTH_PROBE_VERSION}/grpc_health_probe-linux-${ARCH} && \
|
||||
chmod +x /bin/grpc_health_probe
|
||||
|
||||
# 从构建阶段复制已安装的依赖和生成的代码
|
||||
COPY --from=builder /usr/local/lib/python3.10/site-packages /usr/local/lib/python3.10/site-packages
|
||||
COPY --from=builder /usr/local/bin /usr/local/bin
|
||||
COPY --from=builder /root/.paddleocr /root/.paddleocr
|
||||
COPY --from=builder /app/src /app/src
|
||||
|
||||
# 安装 Playwright 浏览器
|
||||
RUN python -m playwright install webkit
|
||||
RUN python -m playwright install-deps webkit
|
||||
|
||||
COPY --from=builder /app/src /app/src
|
||||
|
||||
# 设置 Python 路径
|
||||
ENV PYTHONPATH=/app/src
|
||||
RUN cd /app/src && python -m download_deps
|
||||
|
||||
# 创建supervisor配置
|
||||
RUN mkdir -p /etc/supervisor/conf.d
|
||||
COPY services/docreader/supervisord.conf /etc/supervisor/conf.d/supervisord.conf
|
||||
|
||||
# 暴露 gRPC 端口
|
||||
EXPOSE 50051
|
||||
|
||||
# 使用supervisor启动服务
|
||||
CMD ["supervisord", "-c", "/etc/supervisor/conf.d/supervisord.conf"]
|
||||
# 直接运行 Python 服务(日志输出到 stdout/stderr)
|
||||
CMD ["python", "/app/src/server/server.py"]
|
||||
@@ -2,11 +2,7 @@
|
||||
|
||||
## 1. 如何查看日志?
|
||||
```bash
|
||||
# 查看 主服务 日志
|
||||
docker exec -it WeKnora-app tail -f /var/log/WeKnora.log
|
||||
|
||||
# 查看 文档解析模块 日志
|
||||
docker exec -it WeKnora-docreader tail -f /var/log/docreader.log
|
||||
docker compose logs -f app docreader postgres
|
||||
```
|
||||
|
||||
## 2. 如何启动和停止服务?
|
||||
|
||||
BIN
docs/images/pipeline.png
Normal file
BIN
docs/images/pipeline.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 504 KiB |
@@ -2,6 +2,16 @@ server {
|
||||
listen 80;
|
||||
server_name localhost;
|
||||
client_max_body_size 50M;
|
||||
|
||||
# 安全头配置
|
||||
add_header X-Frame-Options "SAMEORIGIN" always;
|
||||
add_header X-Content-Type-Options "nosniff" always;
|
||||
add_header X-XSS-Protection "1; mode=block" always;
|
||||
add_header Referrer-Policy "strict-origin-when-cross-origin" always;
|
||||
|
||||
# 错误日志配置
|
||||
error_log /var/log/nginx/error.log warn;
|
||||
access_log /var/log/nginx/access.log;
|
||||
|
||||
# 前端静态文件
|
||||
location / {
|
||||
@@ -18,6 +28,12 @@ server {
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
# 连接和重试配置
|
||||
proxy_connect_timeout 30s; # 连接超时时间
|
||||
proxy_next_upstream error timeout invalid_header http_500 http_502 http_503 http_504;
|
||||
proxy_next_upstream_tries 3; # 重试次数
|
||||
proxy_next_upstream_timeout 30s; # 重试超时时间
|
||||
|
||||
# SSE 相关配置
|
||||
proxy_http_version 1.1; # 使用 HTTP/1.1
|
||||
proxy_set_header Connection ""; # 禁用 Connection: close,保持连接打开
|
||||
|
||||
50
frontend/package-lock.json
generated
50
frontend/package-lock.json
generated
@@ -1,18 +1,21 @@
|
||||
{
|
||||
"name": "knowledage-base",
|
||||
"version": "0.0.0",
|
||||
"version": "0.1.3",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "knowledage-base",
|
||||
"version": "0.0.0",
|
||||
"version": "0.1.3",
|
||||
"dependencies": {
|
||||
"@microsoft/fetch-event-source": "^2.0.1",
|
||||
"@types/dompurify": "^3.0.5",
|
||||
"axios": "^1.8.4",
|
||||
"dompurify": "^3.2.6",
|
||||
"marked": "^5.1.2",
|
||||
"pagefind": "^1.1.1",
|
||||
"pinia": "^3.0.1",
|
||||
"tdesign-icons-vue-next": "^0.4.1",
|
||||
"tdesign-vue-next": "^1.11.5",
|
||||
"vue": "^3.5.13",
|
||||
"vue-router": "^4.5.0",
|
||||
@@ -1274,6 +1277,15 @@
|
||||
"dev": true,
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/@types/dompurify": {
|
||||
"version": "3.0.5",
|
||||
"resolved": "https://mirrors.tencent.com/npm/@types/dompurify/-/dompurify-3.0.5.tgz",
|
||||
"integrity": "sha512-1Wg0g3BtQF7sSb27fJQAKck1HECM6zV1EB66j8JH9i3LCjYabJa0FSdiSgsD5K/RbrsR0SiraKacLB+T8ZVYAg==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@types/trusted-types": "*"
|
||||
}
|
||||
},
|
||||
"node_modules/@types/eslint": {
|
||||
"version": "9.6.1",
|
||||
"resolved": "https://mirrors.tencent.com/npm/@types/eslint/-/eslint-9.6.1.tgz",
|
||||
@@ -1346,6 +1358,12 @@
|
||||
"resolved": "https://mirrors.tencent.com/npm/@types/tinycolor2/-/tinycolor2-1.4.6.tgz",
|
||||
"integrity": "sha512-iEN8J0BoMnsWBqjVbWH/c0G0Hh7O21lpR2/+PrvAVgWdzL7eexIFm4JN/Wn10PTcmNdtS6U67r499mlWMXOxNw=="
|
||||
},
|
||||
"node_modules/@types/trusted-types": {
|
||||
"version": "2.0.7",
|
||||
"resolved": "https://mirrors.tencent.com/npm/@types/trusted-types/-/trusted-types-2.0.7.tgz",
|
||||
"integrity": "sha512-ScaPdn1dQczgbl0QFTeTOmVHFULt394XJgOQNoyVhZ6r2vLnMLJfBPd53SB52T/3G36VI1/g2MZaX0cwDuXsfw==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/@types/validator": {
|
||||
"version": "13.15.2",
|
||||
"resolved": "https://mirrors.tencent.com/npm/@types/validator/-/validator-13.15.2.tgz",
|
||||
@@ -2121,6 +2139,15 @@
|
||||
"node": ">=0.4.0"
|
||||
}
|
||||
},
|
||||
"node_modules/dompurify": {
|
||||
"version": "3.2.6",
|
||||
"resolved": "https://mirrors.tencent.com/npm/dompurify/-/dompurify-3.2.6.tgz",
|
||||
"integrity": "sha512-/2GogDQlohXPZe6D6NOgQvXLPSYBqIWMnZ8zzOhn09REE4eyAzb+Hed3jhoM9OkuaJ8P6ZGTTVWQKAi8ieIzfQ==",
|
||||
"license": "(MPL-2.0 OR Apache-2.0)",
|
||||
"optionalDependencies": {
|
||||
"@types/trusted-types": "^2.0.7"
|
||||
}
|
||||
},
|
||||
"node_modules/dunder-proto": {
|
||||
"version": "1.0.1",
|
||||
"resolved": "https://mirrors.tencent.com/npm/dunder-proto/-/dunder-proto-1.0.1.tgz",
|
||||
@@ -3374,9 +3401,10 @@
|
||||
}
|
||||
},
|
||||
"node_modules/tdesign-icons-vue-next": {
|
||||
"version": "0.3.6",
|
||||
"resolved": "https://mirrors.tencent.com/npm/tdesign-icons-vue-next/-/tdesign-icons-vue-next-0.3.6.tgz",
|
||||
"integrity": "sha512-X9u90dBv8tPhfpguUyx+BzF8CU2ef2L4RXOO7MYOj1ufHCHwBXTF8L3GPfq6KZd/2u4vMLYAA8lGURn4PZZICw==",
|
||||
"version": "0.4.1",
|
||||
"resolved": "https://mirrors.tencent.com/npm/tdesign-icons-vue-next/-/tdesign-icons-vue-next-0.4.1.tgz",
|
||||
"integrity": "sha512-uDPuTLRORnGcTyVGNoentNaK4V+ZcBmhYwcY3KqDaQQ5rrPeLMxu0ZVmgOEf0JtF2QZiqAxY7vodNEiLUdoRKA==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@babel/runtime": "^7.16.3"
|
||||
},
|
||||
@@ -3410,6 +3438,18 @@
|
||||
"vue": ">=3.1.0"
|
||||
}
|
||||
},
|
||||
"node_modules/tdesign-vue-next/node_modules/tdesign-icons-vue-next": {
|
||||
"version": "0.3.7",
|
||||
"resolved": "https://mirrors.tencent.com/npm/tdesign-icons-vue-next/-/tdesign-icons-vue-next-0.3.7.tgz",
|
||||
"integrity": "sha512-Q5ebVty/TCqhBa0l/17kkhjC0pBAOGvn7C35MAt1xS+johKVM9QEDOy9R6XEl332AiwQ37MwqioczqjYC30ckw==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@babel/runtime": "^7.16.3"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"vue": "^3.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/terser": {
|
||||
"version": "5.43.1",
|
||||
"resolved": "https://mirrors.tencent.com/npm/terser/-/terser-5.43.1.tgz",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "knowledage-base",
|
||||
"version": "0.1.0",
|
||||
"version": "0.1.3",
|
||||
"private": true,
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
@@ -13,10 +13,13 @@
|
||||
},
|
||||
"dependencies": {
|
||||
"@microsoft/fetch-event-source": "^2.0.1",
|
||||
"@types/dompurify": "^3.0.5",
|
||||
"axios": "^1.8.4",
|
||||
"dompurify": "^3.2.6",
|
||||
"marked": "^5.1.2",
|
||||
"pagefind": "^1.1.1",
|
||||
"pinia": "^3.0.1",
|
||||
"tdesign-icons-vue-next": "^0.4.1",
|
||||
"tdesign-vue-next": "^1.11.5",
|
||||
"vue": "^3.5.13",
|
||||
"vue-router": "^4.5.0",
|
||||
|
||||
239
frontend/src/api/auth/index.ts
Normal file
239
frontend/src/api/auth/index.ts
Normal file
@@ -0,0 +1,239 @@
|
||||
import { post, get, put } from '@/utils/request'
|
||||
|
||||
// 用户登录接口
|
||||
export interface LoginRequest {
|
||||
email: string
|
||||
password: string
|
||||
}
|
||||
|
||||
export interface LoginResponse {
|
||||
success: boolean
|
||||
message?: string
|
||||
user?: {
|
||||
id: string
|
||||
username: string
|
||||
email: string
|
||||
avatar?: string
|
||||
tenant_id: number
|
||||
is_active: boolean
|
||||
created_at: string
|
||||
updated_at: string
|
||||
}
|
||||
tenant?: {
|
||||
id: number
|
||||
name: string
|
||||
description: string
|
||||
api_key: string
|
||||
status: string
|
||||
business: string
|
||||
storage_quota: number
|
||||
storage_used: number
|
||||
created_at: string
|
||||
updated_at: string
|
||||
}
|
||||
token?: string
|
||||
refresh_token?: string
|
||||
}
|
||||
|
||||
// 用户注册接口
|
||||
export interface RegisterRequest {
|
||||
username: string
|
||||
email: string
|
||||
password: string
|
||||
}
|
||||
|
||||
export interface RegisterResponse {
|
||||
success: boolean
|
||||
message?: string
|
||||
data?: {
|
||||
user: {
|
||||
id: string
|
||||
username: string
|
||||
email: string
|
||||
}
|
||||
tenant: {
|
||||
id: string
|
||||
name: string
|
||||
api_key: string
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 用户信息接口
|
||||
export interface UserInfo {
|
||||
id: string
|
||||
username: string
|
||||
email: string
|
||||
avatar?: string
|
||||
tenant_id: string
|
||||
created_at: string
|
||||
updated_at: string
|
||||
}
|
||||
|
||||
// 租户信息接口
|
||||
export interface TenantInfo {
|
||||
id: string
|
||||
name: string
|
||||
description?: string
|
||||
api_key: string
|
||||
status?: string
|
||||
business?: string
|
||||
owner_id: string
|
||||
storage_quota?: number
|
||||
storage_used?: number
|
||||
created_at: string
|
||||
updated_at: string
|
||||
knowledge_bases?: KnowledgeBaseInfo[]
|
||||
}
|
||||
|
||||
// 知识库信息接口
|
||||
export interface KnowledgeBaseInfo {
|
||||
id: string
|
||||
name: string
|
||||
description: string
|
||||
tenant_id: string
|
||||
created_at: string
|
||||
updated_at: string
|
||||
document_count?: number
|
||||
chunk_count?: number
|
||||
}
|
||||
|
||||
// 模型信息接口
|
||||
export interface ModelInfo {
|
||||
id: string
|
||||
name: string
|
||||
type: string
|
||||
source: string
|
||||
description?: string
|
||||
is_default?: boolean
|
||||
created_at: string
|
||||
updated_at: string
|
||||
}
|
||||
|
||||
/**
|
||||
* 用户登录
|
||||
*/
|
||||
export async function login(data: LoginRequest): Promise<LoginResponse> {
|
||||
try {
|
||||
const response = await post('/api/v1/auth/login', data)
|
||||
return response as unknown as LoginResponse
|
||||
} catch (error: any) {
|
||||
return {
|
||||
success: false,
|
||||
message: error.message || '登录失败'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 用户注册
|
||||
*/
|
||||
export async function register(data: RegisterRequest): Promise<RegisterResponse> {
|
||||
try {
|
||||
const response = await post('/api/v1/auth/register', data)
|
||||
return response as unknown as RegisterResponse
|
||||
} catch (error: any) {
|
||||
return {
|
||||
success: false,
|
||||
message: error.message || '注册失败'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取当前用户信息
|
||||
*/
|
||||
export async function getCurrentUser(): Promise<{ success: boolean; data?: { user: UserInfo; tenant: TenantInfo }; message?: string }> {
|
||||
try {
|
||||
const response = await get('/api/v1/auth/me')
|
||||
return response as unknown as { success: boolean; data?: { user: UserInfo; tenant: TenantInfo }; message?: string }
|
||||
} catch (error: any) {
|
||||
return {
|
||||
success: false,
|
||||
message: error.message || '获取用户信息失败'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取当前租户信息
|
||||
*/
|
||||
export async function getCurrentTenant(): Promise<{ success: boolean; data?: TenantInfo; message?: string }> {
|
||||
try {
|
||||
const response = await get('/api/v1/auth/tenant')
|
||||
return response as unknown as { success: boolean; data?: TenantInfo; message?: string }
|
||||
} catch (error: any) {
|
||||
return {
|
||||
success: false,
|
||||
message: error.message || '获取租户信息失败'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 刷新Token
|
||||
*/
|
||||
export async function refreshToken(refreshToken: string): Promise<{ success: boolean; data?: { token: string; refreshToken: string }; message?: string }> {
|
||||
try {
|
||||
const response: any = await post('/api/v1/auth/refresh', { refreshToken })
|
||||
if (response && response.success) {
|
||||
if (response.access_token || response.refresh_token) {
|
||||
return {
|
||||
success: true,
|
||||
data: {
|
||||
token: response.access_token,
|
||||
refreshToken: response.refresh_token,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 其他情况直接返回原始消息
|
||||
return {
|
||||
success: false,
|
||||
message: response?.message || '刷新Token失败'
|
||||
}
|
||||
} catch (error: any) {
|
||||
return {
|
||||
success: false,
|
||||
message: error.message || '刷新Token失败'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 用户登出
|
||||
*/
|
||||
export async function logout(): Promise<{ success: boolean; message?: string }> {
|
||||
try {
|
||||
await post('/api/v1/auth/logout', {})
|
||||
return {
|
||||
success: true
|
||||
}
|
||||
} catch (error: any) {
|
||||
return {
|
||||
success: false,
|
||||
message: error.message || '登出失败'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 验证Token有效性
|
||||
*/
|
||||
export async function validateToken(): Promise<{ success: boolean; valid?: boolean; message?: string }> {
|
||||
try {
|
||||
const response = await get('/api/v1/auth/validate')
|
||||
return response as unknown as { success: boolean; valid?: boolean; message?: string }
|
||||
} catch (error: any) {
|
||||
return {
|
||||
success: false,
|
||||
valid: false,
|
||||
message: error.message || 'Token验证失败'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,54 +1,24 @@
|
||||
import { get, post, put, del, postChat } from "../../utils/request";
|
||||
import { loadTestData } from "../test-data";
|
||||
|
||||
// 从localStorage获取设置
|
||||
function getSettings() {
|
||||
const settingsStr = localStorage.getItem("WeKnora_settings");
|
||||
if (settingsStr) {
|
||||
try {
|
||||
const settings = JSON.parse(settingsStr);
|
||||
if (settings.apiKey && settings.endpoint) {
|
||||
return settings;
|
||||
}
|
||||
} catch (e) {
|
||||
console.error("解析设置失败:", e);
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
// 根据是否有设置决定是否需要加载测试数据
|
||||
async function ensureConfigured() {
|
||||
const settings = getSettings();
|
||||
// 如果没有设置APIKey和Endpoint,则加载测试数据
|
||||
if (!settings) {
|
||||
await loadTestData();
|
||||
}
|
||||
}
|
||||
|
||||
export async function createSessions(data = {}) {
|
||||
await ensureConfigured();
|
||||
return post("/api/v1/sessions", data);
|
||||
}
|
||||
|
||||
export async function getSessionsList(page: number, page_size: number) {
|
||||
await ensureConfigured();
|
||||
return get(`/api/v1/sessions?page=${page}&page_size=${page_size}`);
|
||||
}
|
||||
|
||||
export async function generateSessionsTitle(session_id: string, data: any) {
|
||||
await ensureConfigured();
|
||||
return post(`/api/v1/sessions/${session_id}/generate_title`, data);
|
||||
}
|
||||
|
||||
export async function knowledgeChat(data: { session_id: string; query: string; }) {
|
||||
await ensureConfigured();
|
||||
return postChat(`/api/v1/knowledge-chat/${data.session_id}`, { query: data.query });
|
||||
}
|
||||
|
||||
export async function getMessageList(data: { session_id: string; limit: number, created_at: string }) {
|
||||
await ensureConfigured();
|
||||
|
||||
if (data.created_at) {
|
||||
return get(`/api/v1/messages/${data.session_id}/load?before_time=${encodeURIComponent(data.created_at)}&limit=${data.limit}`);
|
||||
} else {
|
||||
@@ -57,6 +27,5 @@ export async function getMessageList(data: { session_id: string; limit: number,
|
||||
}
|
||||
|
||||
export async function delSession(session_id: string) {
|
||||
await ensureConfigured();
|
||||
return del(`/api/v1/sessions/${session_id}`);
|
||||
}
|
||||
@@ -1,22 +1,8 @@
|
||||
import { fetchEventSource } from '@microsoft/fetch-event-source'
|
||||
import { ref, type Ref, onUnmounted, nextTick } from 'vue'
|
||||
import { generateRandomString } from '@/utils/index';
|
||||
import { getTestData } from '@/utils/request';
|
||||
import { loadTestData } from '@/api/test-data';
|
||||
|
||||
// 从localStorage获取设置
|
||||
function getSettings() {
|
||||
const settingsStr = localStorage.getItem("WeKnora_settings");
|
||||
if (settingsStr) {
|
||||
try {
|
||||
const settings = JSON.parse(settingsStr);
|
||||
return settings;
|
||||
} catch (e) {
|
||||
console.error("解析设置失败:", e);
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
interface StreamOptions {
|
||||
// 请求方法 (默认POST)
|
||||
@@ -49,26 +35,15 @@ export function useStream() {
|
||||
isStreaming.value = true;
|
||||
isLoading.value = true;
|
||||
|
||||
// 获取设置信息
|
||||
const settings = getSettings();
|
||||
let apiUrl = '';
|
||||
let apiKey = '';
|
||||
|
||||
// 如果有设置信息,优先使用设置信息
|
||||
if (settings && settings.endpoint && settings.apiKey) {
|
||||
apiUrl = settings.endpoint;
|
||||
apiKey = settings.apiKey;
|
||||
} else {
|
||||
// 否则加载测试数据
|
||||
await loadTestData();
|
||||
const testData = getTestData();
|
||||
if (!testData) {
|
||||
error.value = "测试数据未初始化,无法进行聊天";
|
||||
stopStream();
|
||||
return;
|
||||
}
|
||||
apiUrl = import.meta.env.VITE_IS_DOCKER ? "" : "http://localhost:8080";
|
||||
apiKey = testData.tenant.api_key;
|
||||
// 获取API配置
|
||||
const apiUrl = import.meta.env.VITE_IS_DOCKER ? "" : "http://localhost:8080";
|
||||
|
||||
// 获取JWT Token
|
||||
const token = localStorage.getItem('weknora_token');
|
||||
if (!token) {
|
||||
error.value = "未找到登录令牌,请重新登录";
|
||||
stopStream();
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
@@ -80,7 +55,7 @@ export function useStream() {
|
||||
method: params.method,
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
"X-API-Key": apiKey,
|
||||
"Authorization": `Bearer ${token}`,
|
||||
"X-Request-ID": `${generateRandomString(12)}`,
|
||||
},
|
||||
body:
|
||||
|
||||
@@ -19,6 +19,7 @@ export interface InitializationConfig {
|
||||
modelName: string;
|
||||
baseUrl: string;
|
||||
apiKey?: string;
|
||||
enabled: boolean;
|
||||
};
|
||||
multimodal: {
|
||||
enabled: boolean;
|
||||
@@ -49,6 +50,13 @@ export interface InitializationConfig {
|
||||
};
|
||||
// Frontend-only hint for storage selection UI
|
||||
storageType?: 'cos' | 'minio';
|
||||
nodeExtract: {
|
||||
enabled: boolean,
|
||||
text: string,
|
||||
tags: string[],
|
||||
nodes: Node[],
|
||||
relations: Relation[]
|
||||
}
|
||||
}
|
||||
|
||||
// 下载任务状态类型
|
||||
@@ -62,34 +70,18 @@ export interface DownloadTask {
|
||||
endTime?: string;
|
||||
}
|
||||
|
||||
// 系统初始化状态检查
|
||||
export function checkInitializationStatus(): Promise<{ initialized: boolean }> {
|
||||
// 根据知识库ID执行配置更新
|
||||
export function initializeSystemByKB(kbId: string, config: InitializationConfig): Promise<any> {
|
||||
return new Promise((resolve, reject) => {
|
||||
get('/api/v1/initialization/status')
|
||||
console.log('开始知识库配置更新...', kbId, config);
|
||||
post(`/api/v1/initialization/initialize/${kbId}`, config)
|
||||
.then((response: any) => {
|
||||
resolve(response.data || { initialized: false });
|
||||
})
|
||||
.catch((error: any) => {
|
||||
console.warn('检查初始化状态失败,假设需要初始化:', error);
|
||||
resolve({ initialized: false });
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// 执行系统初始化
|
||||
export function initializeSystem(config: InitializationConfig): Promise<any> {
|
||||
return new Promise((resolve, reject) => {
|
||||
console.log('开始系统初始化...', config);
|
||||
post('/api/v1/initialization/initialize', config)
|
||||
.then((response: any) => {
|
||||
console.log('系统初始化完成', response);
|
||||
// 设置本地初始化状态标记
|
||||
localStorage.setItem('system_initialized', 'true');
|
||||
console.log('知识库配置更新完成', response);
|
||||
resolve(response);
|
||||
})
|
||||
.catch((error: any) => {
|
||||
console.error('系统初始化失败:', error);
|
||||
reject(error);
|
||||
console.error('知识库配置更新失败:', error);
|
||||
reject(error.error || error);
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -178,15 +170,15 @@ export function listDownloadTasks(): Promise<DownloadTask[]> {
|
||||
});
|
||||
}
|
||||
|
||||
// 获取当前系统配置
|
||||
export function getCurrentConfig(): Promise<InitializationConfig & { hasFiles: boolean }> {
|
||||
|
||||
export function getCurrentConfigByKB(kbId: string): Promise<InitializationConfig & { hasFiles: boolean }> {
|
||||
return new Promise((resolve, reject) => {
|
||||
get('/api/v1/initialization/config')
|
||||
get(`/api/v1/initialization/config/${kbId}`)
|
||||
.then((response: any) => {
|
||||
resolve(response.data || {});
|
||||
})
|
||||
.catch((error: any) => {
|
||||
console.error('获取当前配置失败:', error);
|
||||
console.error('获取知识库配置失败:', error);
|
||||
reject(error);
|
||||
});
|
||||
});
|
||||
@@ -311,9 +303,17 @@ export function testMultimodalFunction(testData: {
|
||||
formData.append('chunk_overlap', testData.chunk_overlap.toString());
|
||||
formData.append('separators', JSON.stringify(testData.separators));
|
||||
|
||||
// 获取鉴权Token
|
||||
const token = localStorage.getItem('weknora_token');
|
||||
const headers: Record<string, string> = {};
|
||||
if (token) {
|
||||
headers['Authorization'] = `Bearer ${token}`;
|
||||
}
|
||||
|
||||
// 使用原生fetch因为需要发送FormData
|
||||
fetch('/api/v1/initialization/multimodal/test', {
|
||||
method: 'POST',
|
||||
headers,
|
||||
body: formData
|
||||
})
|
||||
.then(response => response.json())
|
||||
@@ -329,4 +329,93 @@ export function testMultimodalFunction(testData: {
|
||||
reject(error);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// 文本内容关系提取接口
|
||||
export interface TextRelationExtractionRequest {
|
||||
text: string;
|
||||
tags: string[];
|
||||
llmConfig: LLMConfig;
|
||||
}
|
||||
|
||||
export interface Node {
|
||||
name: string;
|
||||
attributes: string[];
|
||||
}
|
||||
|
||||
export interface Relation {
|
||||
node1: string;
|
||||
node2: string;
|
||||
type: string;
|
||||
}
|
||||
|
||||
export interface LLMConfig {
|
||||
source: 'local' | 'remote';
|
||||
modelName: string;
|
||||
baseUrl: string;
|
||||
apiKey: string;
|
||||
}
|
||||
|
||||
export interface TextRelationExtractionResponse {
|
||||
nodes: Node[];
|
||||
relations: Relation[];
|
||||
}
|
||||
|
||||
// 文本内容关系提取
|
||||
export function extractTextRelations(request: TextRelationExtractionRequest): Promise<TextRelationExtractionResponse> {
|
||||
return new Promise((resolve, reject) => {
|
||||
post('/api/v1/initialization/extract/text-relation', request)
|
||||
.then((response: any) => {
|
||||
resolve(response.data || { nodes: [], relations: [] });
|
||||
})
|
||||
.catch((error: any) => {
|
||||
console.error('文本内容关系提取失败:', error);
|
||||
reject(error);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
export interface FabriTextRequest {
|
||||
tags: string[];
|
||||
llmConfig: LLMConfig;
|
||||
}
|
||||
|
||||
export interface FabriTextResponse {
|
||||
text: string;
|
||||
}
|
||||
|
||||
// 文本内容生成
|
||||
export function fabriText(request: FabriTextRequest): Promise<FabriTextResponse> {
|
||||
return new Promise((resolve, reject) => {
|
||||
post('/api/v1/initialization/extract/fabri-text', request)
|
||||
.then((response: any) => {
|
||||
resolve(response.data || { text: '' });
|
||||
})
|
||||
.catch((error: any) => {
|
||||
console.error('文本内容生成失败:', error);
|
||||
reject(error);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
export interface FabriTagRequest {
|
||||
llmConfig: LLMConfig;
|
||||
}
|
||||
|
||||
export interface FabriTagResponse {
|
||||
tags: string[];
|
||||
}
|
||||
|
||||
// 文本内容生成
|
||||
export function fabriTag(request: FabriTagRequest): Promise<FabriTagResponse> {
|
||||
return new Promise((resolve, reject) => {
|
||||
post('/api/v1/initialization/extract/fabri-tag', request)
|
||||
.then((response: any) => {
|
||||
resolve(response.data || { tags: [] as string[] });
|
||||
})
|
||||
.catch((error: any) => {
|
||||
console.error('标签生成失败:', error);
|
||||
reject(error);
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -1,62 +1,55 @@
|
||||
import { get, post, put, del, postUpload, getDown, getTestData } from "../../utils/request";
|
||||
import { loadTestData } from "../test-data";
|
||||
import { get, post, put, del, postUpload, getDown } from "../../utils/request";
|
||||
|
||||
// 获取知识库ID(优先从设置中获取)
|
||||
async function getKnowledgeBaseID() {
|
||||
// 从localStorage获取设置中的知识库ID
|
||||
const settingsStr = localStorage.getItem("WeKnora_settings");
|
||||
let knowledgeBaseId = "";
|
||||
|
||||
if (settingsStr) {
|
||||
try {
|
||||
const settings = JSON.parse(settingsStr);
|
||||
if (settings.knowledgeBaseId) {
|
||||
return settings.knowledgeBaseId;
|
||||
}
|
||||
} catch (e) {
|
||||
console.error("解析设置失败:", e);
|
||||
}
|
||||
}
|
||||
|
||||
// 如果设置中没有知识库ID,则使用测试数据
|
||||
await loadTestData();
|
||||
|
||||
const testData = getTestData();
|
||||
if (!testData || testData.knowledge_bases.length === 0) {
|
||||
console.error("测试数据未初始化或不包含知识库");
|
||||
throw new Error("测试数据未初始化或不包含知识库");
|
||||
}
|
||||
return testData.knowledge_bases[0].id;
|
||||
// 知识库管理 API(列表、创建、获取、更新、删除、复制)
|
||||
export function listKnowledgeBases() {
|
||||
return get(`/api/v1/knowledge-bases`);
|
||||
}
|
||||
|
||||
export async function uploadKnowledgeBase(data = {}) {
|
||||
const kbId = await getKnowledgeBaseID();
|
||||
export function createKnowledgeBase(data: { name: string; description?: string; chunking_config?: any }) {
|
||||
return post(`/api/v1/knowledge-bases`, data);
|
||||
}
|
||||
|
||||
export function getKnowledgeBaseById(id: string) {
|
||||
return get(`/api/v1/knowledge-bases/${id}`);
|
||||
}
|
||||
|
||||
export function updateKnowledgeBase(id: string, data: { name: string; description?: string; config: any }) {
|
||||
return put(`/api/v1/knowledge-bases/${id}` , data);
|
||||
}
|
||||
|
||||
export function deleteKnowledgeBase(id: string) {
|
||||
return del(`/api/v1/knowledge-bases/${id}`);
|
||||
}
|
||||
|
||||
export function copyKnowledgeBase(data: { source_id: string; target_id?: string }) {
|
||||
return post(`/api/v1/knowledge-bases/copy`, data);
|
||||
}
|
||||
|
||||
// 知识文件 API(基于具体知识库)
|
||||
export function uploadKnowledgeFile(kbId: string, data = {}) {
|
||||
return postUpload(`/api/v1/knowledge-bases/${kbId}/knowledge/file`, data);
|
||||
}
|
||||
|
||||
export async function getKnowledgeBase({page, page_size}) {
|
||||
const kbId = await getKnowledgeBaseID();
|
||||
return get(
|
||||
`/api/v1/knowledge-bases/${kbId}/knowledge?page=${page}&page_size=${page_size}`
|
||||
);
|
||||
export function listKnowledgeFiles(kbId: string, { page, page_size }: { page: number; page_size: number }) {
|
||||
return get(`/api/v1/knowledge-bases/${kbId}/knowledge?page=${page}&page_size=${page_size}`);
|
||||
}
|
||||
|
||||
export function getKnowledgeDetails(id: any) {
|
||||
export function getKnowledgeDetails(id: string) {
|
||||
return get(`/api/v1/knowledge/${id}`);
|
||||
}
|
||||
|
||||
export function delKnowledgeDetails(id: any) {
|
||||
export function delKnowledgeDetails(id: string) {
|
||||
return del(`/api/v1/knowledge/${id}`);
|
||||
}
|
||||
|
||||
export function downKnowledgeDetails(id: any) {
|
||||
export function downKnowledgeDetails(id: string) {
|
||||
return getDown(`/api/v1/knowledge/${id}/download`);
|
||||
}
|
||||
|
||||
export function batchQueryKnowledge(ids: any) {
|
||||
return get(`/api/v1/knowledge/batch?${ids}`);
|
||||
export function batchQueryKnowledge(idsQueryString: string) {
|
||||
return get(`/api/v1/knowledge/batch?${idsQueryString}`);
|
||||
}
|
||||
|
||||
export function getKnowledgeDetailsCon(id: any, page) {
|
||||
export function getKnowledgeDetailsCon(id: string, page: number) {
|
||||
return get(`/api/v1/chunks/${id}?page=${page}&page_size=25`);
|
||||
}
|
||||
12
frontend/src/api/system/index.ts
Normal file
12
frontend/src/api/system/index.ts
Normal file
@@ -0,0 +1,12 @@
|
||||
import { get } from '@/utils/request'
|
||||
|
||||
export interface SystemInfo {
|
||||
version: string
|
||||
commit_id?: string
|
||||
build_time?: string
|
||||
go_version?: string
|
||||
}
|
||||
|
||||
export function getSystemInfo(): Promise<{ data: SystemInfo }> {
|
||||
return get('/api/v1/system/info')
|
||||
}
|
||||
@@ -1,55 +0,0 @@
|
||||
import { get, setTestData } from '../../utils/request';
|
||||
|
||||
export interface TestDataResponse {
|
||||
success: boolean;
|
||||
data: {
|
||||
tenant: {
|
||||
id: number;
|
||||
name: string;
|
||||
api_key: string;
|
||||
};
|
||||
knowledge_bases: Array<{
|
||||
id: string;
|
||||
name: string;
|
||||
description: string;
|
||||
}>;
|
||||
}
|
||||
}
|
||||
|
||||
// 是否已加载测试数据
|
||||
let isTestDataLoaded = false;
|
||||
|
||||
/**
|
||||
* 加载测试数据
|
||||
* 在API调用前调用此函数以确保测试数据已加载
|
||||
* @returns Promise<boolean> 是否成功加载
|
||||
*/
|
||||
export async function loadTestData(): Promise<boolean> {
|
||||
// 如果已经加载过,直接返回
|
||||
if (isTestDataLoaded) {
|
||||
return true;
|
||||
}
|
||||
|
||||
try {
|
||||
console.log('开始加载测试数据...');
|
||||
const response = await get('/api/v1/test-data');
|
||||
console.log('测试数据', response);
|
||||
|
||||
if (response && response.data) {
|
||||
// 设置测试数据
|
||||
setTestData({
|
||||
tenant: response.data.tenant,
|
||||
knowledge_bases: response.data.knowledge_bases
|
||||
});
|
||||
isTestDataLoaded = true;
|
||||
console.log('测试数据加载成功');
|
||||
return true;
|
||||
} else {
|
||||
console.warn('测试数据响应为空');
|
||||
return false;
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('加载测试数据失败:', error);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
6
frontend/src/assets/img/logout.svg
Normal file
6
frontend/src/assets/img/logout.svg
Normal file
@@ -0,0 +1,6 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24" fill="none">
|
||||
<path d="M10 3H6a2 2 0 0 0-2 2v14a2 2 0 0 0 2 2h4" stroke="#000" stroke-opacity="0.6" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M17 16l4-4-4-4" stroke="#000" stroke-opacity="0.6" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M21 12H10" stroke="#000" stroke-opacity="0.6" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
</svg>
|
||||
|
||||
|
After Width: | Height: | Size: 509 B |
4
frontend/src/assets/img/user-green.svg
Normal file
4
frontend/src/assets/img/user-green.svg
Normal file
@@ -0,0 +1,4 @@
|
||||
<svg width="20" height="20" viewBox="0 0 20 20" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<circle cx="10" cy="6" r="3" stroke="#07C05F" stroke-width="1.5" fill="none"/>
|
||||
<path d="M4 16c0-3.314 2.686-6 6-6s6 2.686 6 6" stroke="#07C05F" stroke-width="1.5" fill="none"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 284 B |
4
frontend/src/assets/img/user.svg
Normal file
4
frontend/src/assets/img/user.svg
Normal file
@@ -0,0 +1,4 @@
|
||||
<svg width="20" height="20" viewBox="0 0 20 20" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<circle cx="10" cy="6" r="3" stroke="currentColor" stroke-width="1.5" fill="none"/>
|
||||
<path d="M4 16c0-3.314 2.686-6 6-6s6 2.686 6 6" stroke="currentColor" stroke-width="1.5" fill="none"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 294 B |
@@ -4,6 +4,8 @@ import { onMounted, ref, nextTick, onUnmounted, onUpdated, watch } from "vue";
|
||||
import { downKnowledgeDetails } from "@/api/knowledge-base/index";
|
||||
import { MessagePlugin } from "tdesign-vue-next";
|
||||
import picturePreview from '@/components/picture-preview.vue';
|
||||
import { sanitizeHTML, safeMarkdownToHTML, createSafeImage, isValidImageURL } from '@/utils/security';
|
||||
|
||||
marked.use({
|
||||
mangle: false,
|
||||
headerIds: false,
|
||||
@@ -37,10 +39,16 @@ const checkImage = (url) => {
|
||||
});
|
||||
};
|
||||
renderer.image = function (href, title, text) {
|
||||
// 自定义HTML结构,图片展示带标题
|
||||
// 安全地处理图片链接
|
||||
if (!isValidImageURL(href)) {
|
||||
return `<p>无效的图片链接</p>`;
|
||||
}
|
||||
|
||||
// 使用安全的图片创建函数
|
||||
const safeImage = createSafeImage(href, text || '', title || '');
|
||||
return `<figure>
|
||||
<img class="markdown-image" src="${href}" alt="${title}" title="${text}">
|
||||
<figcaption style="text-align: left;">${text}</figcaption>
|
||||
${safeImage}
|
||||
<figcaption style="text-align: left;">${text || ''}</figcaption>
|
||||
</figure>`;
|
||||
};
|
||||
const props = defineProps(["visible", "details"]);
|
||||
@@ -66,14 +74,23 @@ watch(() => props.details.md, (newVal) => {
|
||||
deep: true
|
||||
})
|
||||
|
||||
// 处理 Markdown 中的图片
|
||||
// 安全地处理 Markdown 内容
|
||||
const processMarkdown = (markdownText) => {
|
||||
// 自定义渲染器处理图片
|
||||
if (!markdownText || typeof markdownText !== 'string') {
|
||||
return '';
|
||||
}
|
||||
|
||||
// 首先对 Markdown 内容进行安全处理
|
||||
const safeMarkdown = safeMarkdownToHTML(markdownText);
|
||||
|
||||
// 使用安全的渲染器
|
||||
marked.use({ renderer });
|
||||
let html = marked.parse(markdownText);
|
||||
const parser = new DOMParser();
|
||||
const doc = parser.parseFromString(html, 'text/html');
|
||||
return doc.body.innerHTML;
|
||||
let html = marked.parse(safeMarkdown);
|
||||
|
||||
// 使用 DOMPurify 进行最终的安全清理
|
||||
const sanitizedHTML = sanitizeHTML(html);
|
||||
|
||||
return sanitizedHTML;
|
||||
};
|
||||
const closePreImg = () => {
|
||||
reviewImg.value = false
|
||||
@@ -87,15 +104,19 @@ const downloadFile = () => {
|
||||
downKnowledgeDetails(props.details.id)
|
||||
.then((result) => {
|
||||
if (result) {
|
||||
if (url.value) {
|
||||
URL.revokeObjectURL(url.value);
|
||||
}
|
||||
url.value = URL.createObjectURL(result);
|
||||
down.value.click();
|
||||
// const link = document.createElement("a");
|
||||
// link.style.display = "none";
|
||||
// link.setAttribute("href", url);
|
||||
// link.setAttribute("download", props.details.title);
|
||||
// link.click();
|
||||
// document.body.removeChild(link);
|
||||
window.URL.revokeObjectURL(url);
|
||||
const link = document.createElement("a");
|
||||
link.style.display = "none";
|
||||
link.setAttribute("href", url.value);
|
||||
link.setAttribute("download", props.details.title);
|
||||
link.click();
|
||||
nextTick(() => {
|
||||
document.body.removeChild(link);
|
||||
URL.revokeObjectURL(url.value);
|
||||
})
|
||||
}
|
||||
})
|
||||
.catch((err) => {
|
||||
|
||||
@@ -1,68 +1,132 @@
|
||||
<template>
|
||||
<div class="aside_box">
|
||||
<div class="logo_box">
|
||||
<div class="logo_box" @click="router.push('/platform/knowledge-bases')" style="cursor: pointer;">
|
||||
<img class="logo" src="@/assets/img/weknora.png" alt="">
|
||||
</div>
|
||||
<div class="menu_box" v-for="(item, index) in menuArr" :key="index">
|
||||
<div @click="gotopage(item.path)"
|
||||
@mouseenter="mouseenteMenu(item.path)" @mouseleave="mouseleaveMenu(item.path)"
|
||||
:class="['menu_item', item.childrenPath && item.childrenPath == currentpath ? 'menu_item_c_active' : item.path == currentpath ? 'menu_item_active' : '']">
|
||||
<div class="menu_item-box">
|
||||
<div class="menu_icon">
|
||||
<img class="icon" :src="getImgSrc(item.icon == 'zhishiku' ? knowledgeIcon : item.icon == 'setting' ? settingIcon : prefixIcon)" alt="">
|
||||
|
||||
<!-- 上半部分:知识库和对话 -->
|
||||
<div class="menu_top">
|
||||
<div class="menu_box" :class="{ 'has-submenu': item.children }" v-for="(item, index) in topMenuItems" :key="index">
|
||||
<div @click="handleMenuClick(item.path)"
|
||||
@mouseenter="mouseenteMenu(item.path)" @mouseleave="mouseleaveMenu(item.path)"
|
||||
:class="['menu_item', item.childrenPath && item.childrenPath == currentpath ? 'menu_item_c_active' : isMenuItemActive(item.path) ? 'menu_item_active' : '']">
|
||||
<div class="menu_item-box">
|
||||
<div class="menu_icon">
|
||||
<img class="icon" :src="getImgSrc(item.icon == 'zhishiku' ? knowledgeIcon : item.icon == 'logout' ? logoutIcon : item.icon == 'tenant' ? tenantIcon : prefixIcon)" alt="">
|
||||
</div>
|
||||
<span class="menu_title" :title="item.path === 'knowledge-bases' && kbMenuItem ? kbMenuItem.title : item.title">{{ item.path === 'knowledge-bases' && kbMenuItem ? kbMenuItem.title : item.title }}</span>
|
||||
<!-- 知识库切换下拉箭头 -->
|
||||
<div v-if="item.path === 'knowledge-bases' && isInKnowledgeBase"
|
||||
class="kb-dropdown-icon"
|
||||
:class="{
|
||||
'rotate-180': showKbDropdown,
|
||||
'active': isMenuItemActive(item.path)
|
||||
}"
|
||||
@click.stop="toggleKbDropdown">
|
||||
<svg width="12" height="12" viewBox="0 0 12 12" fill="currentColor">
|
||||
<path d="M2.5 4.5L6 8L9.5 4.5H2.5Z"/>
|
||||
</svg>
|
||||
</div>
|
||||
</div>
|
||||
<span class="menu_title">{{ item.title }}</span>
|
||||
<!-- 知识库切换下拉菜单 -->
|
||||
<div v-if="item.path === 'knowledge-bases' && showKbDropdown && isInKnowledgeBase"
|
||||
class="kb-dropdown-menu">
|
||||
<div v-for="kb in initializedKnowledgeBases"
|
||||
:key="kb.id"
|
||||
class="kb-dropdown-item"
|
||||
:class="{ 'active': kb.name === currentKbName }"
|
||||
@click.stop="switchKnowledgeBase(kb.id)">
|
||||
{{ kb.name }}
|
||||
</div>
|
||||
</div>
|
||||
<t-popup overlayInnerClassName="upload-popup" class="placement top center" content="上传知识"
|
||||
placement="top" show-arrow destroy-on-close>
|
||||
<div class="upload-file-wrap" @click.stop="uploadFile" variant="outline"
|
||||
v-if="item.path === 'knowledge-bases' && $route.name === 'knowledgeBaseDetail'">
|
||||
<img class="upload-file-icon" :class="[item.path == currentpath ? 'active-upload' : '']"
|
||||
:src="getImgSrc(fileAddIcon)" alt="">
|
||||
</div>
|
||||
</t-popup>
|
||||
</div>
|
||||
<t-popup overlayInnerClassName="upload-popup" class="placement top center" content="上传知识"
|
||||
placement="top" show-arrow destroy-on-close>
|
||||
<div class="upload-file-wrap" @click="uploadFile" variant="outline"
|
||||
v-if="item.path == 'knowledgeBase'">
|
||||
<img class="upload-file-icon" :class="[item.path == currentpath ? 'active-upload' : '']"
|
||||
:src="getImgSrc(fileAddIcon)" alt="">
|
||||
</div>
|
||||
</t-popup>
|
||||
</div>
|
||||
<div ref="submenuscrollContainer" @scroll="handleScroll" class="submenu" v-if="item.children">
|
||||
<div class="submenu_item_p" v-for="(subitem, subindex) in item.children" :key="subindex"
|
||||
@click="gotopage(subitem.path)">
|
||||
<div :class="['submenu_item', currentSecondpath == subitem.path ? 'submenu_item_active' : '']"
|
||||
@mouseenter="mouseenteBotDownr(subindex)" @mouseleave="mouseleaveBotDown">
|
||||
<i v-if="currentSecondpath == subitem.path" class="dot"></i>
|
||||
<span class="submenu_title"
|
||||
:style="currentSecondpath == subitem.path ? 'margin-left:14px;max-width:160px;' : 'margin-left:18px;max-width:173px;'">
|
||||
{{ subitem.title }}
|
||||
</span>
|
||||
<t-popup v-model:visible="subitem.isMore" @overlay-click="delCard(subindex, subitem)"
|
||||
@visible-change="onVisibleChange" overlayClassName="del-menu-popup" trigger="click"
|
||||
destroy-on-close placement="top-left">
|
||||
<div v-if="(activeSubmenu == subindex) || (currentSecondpath == subitem.path) || subitem.isMore"
|
||||
@click.stop="openMore(subindex)" variant="outline" class="menu-more-wrap">
|
||||
<t-icon name="ellipsis" class="menu-more" />
|
||||
</div>
|
||||
<template #content>
|
||||
<span class="del_submenu">删除记录</span>
|
||||
</template>
|
||||
</t-popup>
|
||||
<div ref="submenuscrollContainer" @scroll="handleScroll" class="submenu" v-if="item.children">
|
||||
<div class="submenu_item_p" v-for="(subitem, subindex) in item.children" :key="subindex"
|
||||
@click="gotopage(subitem.path)">
|
||||
<div :class="['submenu_item', currentSecondpath == subitem.path ? 'submenu_item_active' : '']"
|
||||
@mouseenter="mouseenteBotDownr(subindex)" @mouseleave="mouseleaveBotDown">
|
||||
<i v-if="currentSecondpath == subitem.path" class="dot"></i>
|
||||
<span class="submenu_title"
|
||||
:style="currentSecondpath == subitem.path ? 'margin-left:14px;max-width:160px;' : 'margin-left:18px;max-width:173px;'">
|
||||
{{ subitem.title }}
|
||||
</span>
|
||||
<t-popup v-model:visible="subitem.isMore" @overlay-click="delCard(subindex, subitem)"
|
||||
@visible-change="onVisibleChange" overlayClassName="del-menu-popup" trigger="click"
|
||||
destroy-on-close placement="top-left">
|
||||
<div v-if="(activeSubmenu == subindex) || (currentSecondpath == subitem.path) || subitem.isMore"
|
||||
@click.stop="openMore(subindex)" variant="outline" class="menu-more-wrap">
|
||||
<t-icon name="ellipsis" class="menu-more" />
|
||||
</div>
|
||||
<template #content>
|
||||
<span class="del_submenu">删除记录</span>
|
||||
</template>
|
||||
</t-popup>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 下半部分:账户信息、系统设置、退出登录 -->
|
||||
<div class="menu_bottom">
|
||||
<div class="menu_box" v-for="(item, index) in bottomMenuItems" :key="'bottom-' + index">
|
||||
<div v-if="item.path === 'logout'">
|
||||
<t-popconfirm
|
||||
content="确定要退出登录吗?"
|
||||
@confirm="handleLogout"
|
||||
placement="top"
|
||||
:show-arrow="true"
|
||||
>
|
||||
<div @mouseenter="mouseenteMenu(item.path)" @mouseleave="mouseleaveMenu(item.path)"
|
||||
:class="['menu_item', 'logout-item']">
|
||||
<div class="menu_item-box">
|
||||
<div class="menu_icon">
|
||||
<img class="icon" :src="getImgSrc(logoutIcon)" alt="">
|
||||
</div>
|
||||
<span class="menu_title">{{ item.title }}</span>
|
||||
</div>
|
||||
</div>
|
||||
</t-popconfirm>
|
||||
</div>
|
||||
<div v-else @click="handleMenuClick(item.path)"
|
||||
@mouseenter="mouseenteMenu(item.path)" @mouseleave="mouseleaveMenu(item.path)"
|
||||
:class="['menu_item', item.childrenPath && item.childrenPath == currentpath ? 'menu_item_c_active' : (item.path == currentpath) ? 'menu_item_active' : '']">
|
||||
<div class="menu_item-box">
|
||||
<div class="menu_icon">
|
||||
<img class="icon" :src="getImgSrc(item.icon == 'zhishiku' ? knowledgeIcon : item.icon == 'tenant' ? tenantIcon : prefixIcon)" alt="">
|
||||
</div>
|
||||
<span class="menu_title">{{ item.path === 'knowledge-bases' && kbMenuItem ? kbMenuItem.title : item.title }}</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<input type="file" @change="upload" style="display: none" ref="uploadInput"
|
||||
accept=".pdf,.docx,.doc,.txt,.md,.jpg,.jpeg,.png" />
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup>
|
||||
<script setup lang="ts">
|
||||
import { storeToRefs } from 'pinia';
|
||||
import { onMounted, watch, computed, ref, reactive } from 'vue';
|
||||
import { onMounted, watch, computed, ref, reactive, nextTick } from 'vue';
|
||||
import { useRoute, useRouter } from 'vue-router';
|
||||
import { getSessionsList, delSession } from "@/api/chat/index";
|
||||
import { getKnowledgeBaseById, listKnowledgeBases, uploadKnowledgeFile } from '@/api/knowledge-base';
|
||||
import { kbFileTypeVerification } from '@/utils/index';
|
||||
import { useMenuStore } from '@/stores/menu';
|
||||
import useKnowledgeBase from '@/hooks/useKnowledgeBase';
|
||||
import { useAuthStore } from '@/stores/auth';
|
||||
import { MessagePlugin } from "tdesign-vue-next";
|
||||
let { requestMethod } = useKnowledgeBase()
|
||||
let uploadInput = ref();
|
||||
const usemenuStore = useMenuStore();
|
||||
const authStore = useAuthStore();
|
||||
const route = useRoute();
|
||||
const router = useRouter();
|
||||
const currentpath = ref('');
|
||||
@@ -74,39 +138,206 @@ const submenuscrollContainer = ref(null);
|
||||
// 计算总页数
|
||||
const totalPages = computed(() => Math.ceil(total.value / page_size.value));
|
||||
const hasMore = computed(() => currentPage.value < totalPages.value);
|
||||
type MenuItem = { title: string; icon: string; path: string; childrenPath?: string; children?: any[] };
|
||||
const { menuArr } = storeToRefs(usemenuStore);
|
||||
let activeSubmenu = ref(-1);
|
||||
let activeSubmenu = ref<number>(-1);
|
||||
|
||||
// 是否处于知识库详情页
|
||||
const isInKnowledgeBase = computed<boolean>(() => {
|
||||
return route.name === 'knowledgeBaseDetail' ||
|
||||
route.name === 'kbCreatChat' ||
|
||||
route.name === 'chat' ||
|
||||
route.name === 'knowledgeBaseSettings';
|
||||
});
|
||||
|
||||
// 统一的菜单项激活状态判断
|
||||
const isMenuItemActive = (itemPath: string): boolean => {
|
||||
const currentRoute = route.name;
|
||||
|
||||
switch (itemPath) {
|
||||
case 'knowledge-bases':
|
||||
return currentRoute === 'knowledgeBaseList' ||
|
||||
currentRoute === 'knowledgeBaseDetail' ||
|
||||
currentRoute === 'knowledgeBaseSettings';
|
||||
case 'creatChat':
|
||||
return currentRoute === 'kbCreatChat';
|
||||
case 'tenant':
|
||||
return currentRoute === 'tenant';
|
||||
default:
|
||||
return itemPath === currentpath.value;
|
||||
}
|
||||
};
|
||||
|
||||
// 统一的图标激活状态判断
|
||||
const getIconActiveState = (itemPath: string) => {
|
||||
const currentRoute = route.name;
|
||||
|
||||
return {
|
||||
isKbActive: itemPath === 'knowledge-bases' && (
|
||||
currentRoute === 'knowledgeBaseList' ||
|
||||
currentRoute === 'knowledgeBaseDetail' ||
|
||||
currentRoute === 'knowledgeBaseSettings'
|
||||
),
|
||||
isCreatChatActive: itemPath === 'creatChat' && currentRoute === 'kbCreatChat',
|
||||
isTenantActive: itemPath === 'tenant' && currentRoute === 'tenant',
|
||||
isChatActive: itemPath === 'chat' && currentRoute === 'chat'
|
||||
};
|
||||
};
|
||||
|
||||
// 分离上下两部分菜单
|
||||
const topMenuItems = computed<MenuItem[]>(() => {
|
||||
return (menuArr.value as unknown as MenuItem[]).filter((item: MenuItem) =>
|
||||
item.path === 'knowledge-bases' || (isInKnowledgeBase.value && item.path === 'creatChat')
|
||||
);
|
||||
});
|
||||
|
||||
const bottomMenuItems = computed<MenuItem[]>(() => {
|
||||
return (menuArr.value as unknown as MenuItem[]).filter((item: MenuItem) => {
|
||||
if (item.path === 'knowledge-bases' || item.path === 'creatChat') {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
});
|
||||
|
||||
// 当前知识库名称和列表
|
||||
const currentKbName = ref<string>('')
|
||||
const allKnowledgeBases = ref<Array<{ id: string; name: string; embedding_model_id?: string; summary_model_id?: string }>>([])
|
||||
const showKbDropdown = ref<boolean>(false)
|
||||
|
||||
// 过滤已初始化的知识库
|
||||
const initializedKnowledgeBases = computed(() => {
|
||||
return allKnowledgeBases.value.filter(kb =>
|
||||
kb.embedding_model_id && kb.embedding_model_id !== '' &&
|
||||
kb.summary_model_id && kb.summary_model_id !== ''
|
||||
)
|
||||
})
|
||||
|
||||
// 动态更新知识库菜单项标题
|
||||
const kbMenuItem = computed(() => {
|
||||
const kbItem = topMenuItems.value.find(item => item.path === 'knowledge-bases')
|
||||
if (kbItem && isInKnowledgeBase.value && currentKbName.value) {
|
||||
return { ...kbItem, title: currentKbName.value }
|
||||
}
|
||||
return kbItem
|
||||
})
|
||||
|
||||
const loading = ref(false)
|
||||
const uploadFile = () => {
|
||||
const uploadFile = async () => {
|
||||
// 获取当前知识库ID
|
||||
const currentKbId = await getCurrentKbId();
|
||||
|
||||
// 检查当前知识库的初始化状态
|
||||
if (currentKbId) {
|
||||
try {
|
||||
const kbResponse = await getKnowledgeBaseById(currentKbId);
|
||||
const kb = kbResponse.data;
|
||||
|
||||
// 检查知识库是否已初始化(有 EmbeddingModelID 和 SummaryModelID)
|
||||
if (!kb.embedding_model_id || kb.embedding_model_id === '' ||
|
||||
!kb.summary_model_id || kb.summary_model_id === '') {
|
||||
MessagePlugin.warning("该知识库尚未完成初始化配置,请先前往设置页面配置模型信息后再上传文件");
|
||||
return;
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('获取知识库信息失败:', error);
|
||||
MessagePlugin.error("获取知识库信息失败,无法上传文件");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
uploadInput.value.click()
|
||||
}
|
||||
const upload = (e) => {
|
||||
requestMethod(e.target.files[0], uploadInput)
|
||||
const upload = async (e: any) => {
|
||||
const file = e.target.files[0];
|
||||
if (!file) return;
|
||||
|
||||
// 文件类型验证
|
||||
if (kbFileTypeVerification(file)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// 获取当前知识库ID
|
||||
const currentKbId = (route.params as any)?.kbId as string;
|
||||
if (!currentKbId) {
|
||||
MessagePlugin.error("缺少知识库ID");
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const result = await uploadKnowledgeFile(currentKbId, { file });
|
||||
const responseData = result as any;
|
||||
console.log('上传API返回结果:', responseData);
|
||||
|
||||
// 如果没有抛出异常,就认为上传成功,先触发刷新事件
|
||||
console.log('文件上传完成,发送事件通知页面刷新,知识库ID:', currentKbId);
|
||||
window.dispatchEvent(new CustomEvent('knowledgeFileUploaded', {
|
||||
detail: { kbId: currentKbId }
|
||||
}));
|
||||
|
||||
// 然后处理UI消息
|
||||
// 判断上传是否成功 - 检查多种可能的成功标识
|
||||
const isSuccess = responseData.success || responseData.code === 200 || responseData.status === 'success' || (!responseData.error && responseData);
|
||||
|
||||
if (isSuccess) {
|
||||
MessagePlugin.info("上传成功!");
|
||||
} else {
|
||||
// 改进错误信息提取逻辑
|
||||
let errorMessage = "上传失败!";
|
||||
if (responseData.error && responseData.error.message) {
|
||||
errorMessage = responseData.error.message;
|
||||
} else if (responseData.message) {
|
||||
errorMessage = responseData.message;
|
||||
}
|
||||
if (responseData.code === 'duplicate_file' || (responseData.error && responseData.error.code === 'duplicate_file')) {
|
||||
errorMessage = "文件已存在";
|
||||
}
|
||||
MessagePlugin.error(errorMessage);
|
||||
}
|
||||
} catch (err: any) {
|
||||
let errorMessage = "上传失败!";
|
||||
if (err.code === 'duplicate_file') {
|
||||
errorMessage = "文件已存在";
|
||||
} else if (err.error && err.error.message) {
|
||||
errorMessage = err.error.message;
|
||||
} else if (err.message) {
|
||||
errorMessage = err.message;
|
||||
}
|
||||
MessagePlugin.error(errorMessage);
|
||||
} finally {
|
||||
uploadInput.value.value = "";
|
||||
}
|
||||
}
|
||||
const mouseenteBotDownr = (val) => {
|
||||
const mouseenteBotDownr = (val: number) => {
|
||||
activeSubmenu.value = val;
|
||||
}
|
||||
const mouseleaveBotDown = () => {
|
||||
activeSubmenu.value = -1;
|
||||
}
|
||||
const onVisibleChange = (e) => {
|
||||
const onVisibleChange = (_e: any) => {
|
||||
}
|
||||
|
||||
const delCard = (index, item) => {
|
||||
delSession(item.id).then(res => {
|
||||
if (res && res.success) {
|
||||
menuArr.value[1].children.splice(index, 1);
|
||||
const delCard = (index: number, item: any) => {
|
||||
delSession(item.id).then((res: any) => {
|
||||
if (res && (res as any).success) {
|
||||
(menuArr.value as any[])[1]?.children?.splice(index, 1);
|
||||
if (item.id == route.params.chatid) {
|
||||
router.push('/platform/creatChat');
|
||||
// 删除当前会话后,跳转到当前知识库的创建聊天页面
|
||||
const kbId = route.params.kbId;
|
||||
if (kbId) {
|
||||
router.push(`/platform/knowledge-bases/${kbId}/creatChat`);
|
||||
} else {
|
||||
router.push('/platform/knowledge-bases');
|
||||
}
|
||||
}
|
||||
} else {
|
||||
MessagePlugin.error("删除失败,请稍后再试!");
|
||||
}
|
||||
})
|
||||
}
|
||||
const debounce = (fn, delay) => {
|
||||
let timer
|
||||
return (...args) => {
|
||||
const debounce = (fn: (...args: any[]) => void, delay: number) => {
|
||||
let timer: ReturnType<typeof setTimeout>
|
||||
return (...args: any[]) => {
|
||||
clearTimeout(timer)
|
||||
timer = setTimeout(() => fn(...args), delay)
|
||||
}
|
||||
@@ -124,80 +355,221 @@ const checkScrollBottom = () => {
|
||||
}
|
||||
}
|
||||
const handleScroll = debounce(checkScrollBottom, 200)
|
||||
const getMessageList = () => {
|
||||
const getMessageList = async () => {
|
||||
// 仅在知识库内部显示对话列表
|
||||
if (!isInKnowledgeBase.value) {
|
||||
usemenuStore.clearMenuArr();
|
||||
currentKbName.value = '';
|
||||
return;
|
||||
}
|
||||
let kbId = (route.params as any)?.kbId as string
|
||||
// 新的路由格式:/platform/chat/:kbId/:chatid,直接从路由参数获取知识库ID
|
||||
if (!kbId) {
|
||||
usemenuStore.clearMenuArr();
|
||||
currentKbName.value = '';
|
||||
return;
|
||||
}
|
||||
|
||||
// 获取知识库名称和所有知识库列表
|
||||
try {
|
||||
const [kbRes, allKbRes]: any[] = await Promise.all([
|
||||
getKnowledgeBaseById(kbId),
|
||||
listKnowledgeBases()
|
||||
])
|
||||
if (kbRes?.data?.name) {
|
||||
currentKbName.value = kbRes.data.name
|
||||
}
|
||||
if (allKbRes?.data) {
|
||||
allKnowledgeBases.value = allKbRes.data
|
||||
}
|
||||
} catch {}
|
||||
|
||||
if (loading.value) return;
|
||||
loading.value = true;
|
||||
usemenuStore.clearMenuArr();
|
||||
getSessionsList(currentPage.value, page_size.value).then(res => {
|
||||
getSessionsList(currentPage.value, page_size.value).then((res: any) => {
|
||||
if (res.data && res.data.length) {
|
||||
res.data.forEach(item => {
|
||||
let obj = { title: item.title ? item.title : "新会话", path: `chat/${item.id}`, id: item.id, isMore: false, isNoTitle: item.title ? false : true }
|
||||
// 过滤出当前知识库的会话
|
||||
const filtered = res.data.filter((s: any) => s.knowledge_base_id === kbId)
|
||||
filtered.forEach((item: any) => {
|
||||
let obj = { title: item.title ? item.title : "新会话", path: `chat/${kbId}/${item.id}`, id: item.id, isMore: false, isNoTitle: item.title ? false : true }
|
||||
usemenuStore.updatemenuArr(obj)
|
||||
});
|
||||
loading.value = false;
|
||||
}
|
||||
if (res.total) {
|
||||
total.value = res.total;
|
||||
if ((res as any).total) {
|
||||
total.value = (res as any).total;
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
const openMore = (e) => { }
|
||||
const openMore = (_e: any) => { }
|
||||
onMounted(() => {
|
||||
currentpath.value = route.name;
|
||||
if (route.params.chatid) {
|
||||
currentSecondpath.value = `${route.name}/${route.params.chatid}`;
|
||||
const routeName = typeof route.name === 'string' ? route.name : (route.name ? String(route.name) : '')
|
||||
currentpath.value = routeName;
|
||||
if (route.params.chatid && route.params.kbId) {
|
||||
currentSecondpath.value = `chat/${route.params.kbId}/${route.params.chatid}`;
|
||||
}
|
||||
getMessageList();
|
||||
});
|
||||
|
||||
watch([() => route.name, () => route.params], (newvalue) => {
|
||||
currentpath.value = newvalue[0];
|
||||
if (newvalue[1].chatid) {
|
||||
currentSecondpath.value = `${newvalue[0]}/${newvalue[1].chatid}`;
|
||||
const nameStr = typeof newvalue[0] === 'string' ? (newvalue[0] as string) : (newvalue[0] ? String(newvalue[0]) : '')
|
||||
currentpath.value = nameStr;
|
||||
if (newvalue[1].chatid && newvalue[1].kbId) {
|
||||
currentSecondpath.value = `chat/${newvalue[1].kbId}/${newvalue[1].chatid}`;
|
||||
} else {
|
||||
currentSecondpath.value = "";
|
||||
}
|
||||
|
||||
// 路由变化时刷新对话列表(仅在知识库内部)
|
||||
getMessageList();
|
||||
// 路由变化时更新图标状态
|
||||
getIcon(nameStr);
|
||||
});
|
||||
let fileAddIcon = ref('file-add-green.svg');
|
||||
let knowledgeIcon = ref('zhishiku-green.svg');
|
||||
let prefixIcon = ref('prefixIcon.svg');
|
||||
let settingIcon = ref('setting.svg');
|
||||
let logoutIcon = ref('logout.svg');
|
||||
let tenantIcon = ref('user.svg'); // 使用专门的用户图标
|
||||
let pathPrefix = ref(route.name)
|
||||
const getIcon = (path) => {
|
||||
fileAddIcon.value = path == 'knowledgeBase' ? 'file-add-green.svg' : 'file-add.svg';
|
||||
knowledgeIcon.value = path == 'knowledgeBase' ? 'zhishiku-green.svg' : 'zhishiku.svg';
|
||||
prefixIcon.value = path == 'creatChat' ? 'prefixIcon-green.svg' : path == 'knowledgeBase' ? 'prefixIcon-grey.svg' : 'prefixIcon.svg';
|
||||
settingIcon.value = path == 'settings' ? 'setting-green.svg' : 'setting.svg';
|
||||
const getIcon = (path: string) => {
|
||||
// 根据当前路由状态更新所有图标
|
||||
const kbActiveState = getIconActiveState('knowledge-bases');
|
||||
const creatChatActiveState = getIconActiveState('creatChat');
|
||||
const tenantActiveState = getIconActiveState('tenant');
|
||||
|
||||
// 上传图标:只在知识库相关页面显示绿色
|
||||
fileAddIcon.value = kbActiveState.isKbActive ? 'file-add-green.svg' : 'file-add.svg';
|
||||
|
||||
// 知识库图标:只在知识库页面显示绿色
|
||||
knowledgeIcon.value = kbActiveState.isKbActive ? 'zhishiku-green.svg' : 'zhishiku.svg';
|
||||
|
||||
// 对话图标:只在对话创建页面显示绿色,在知识库页面显示灰色,其他情况显示默认
|
||||
prefixIcon.value = creatChatActiveState.isCreatChatActive ? 'prefixIcon-green.svg' :
|
||||
kbActiveState.isKbActive ? 'prefixIcon-grey.svg' :
|
||||
'prefixIcon.svg';
|
||||
|
||||
// 租户图标:只在租户页面显示绿色
|
||||
tenantIcon.value = tenantActiveState.isTenantActive ? 'user-green.svg' : 'user.svg';
|
||||
|
||||
// 退出图标:始终显示默认
|
||||
logoutIcon.value = 'logout.svg';
|
||||
}
|
||||
getIcon(route.name)
|
||||
const gotopage = (path) => {
|
||||
pathPrefix.value = path;
|
||||
// 如果是系统设置,跳转到初始化配置页面
|
||||
if (path === 'settings') {
|
||||
router.push('/initialization');
|
||||
getIcon(typeof route.name === 'string' ? route.name as string : (route.name ? String(route.name) : ''))
|
||||
const handleMenuClick = async (path: string) => {
|
||||
if (path === 'knowledge-bases') {
|
||||
// 知识库菜单项:如果在知识库内部,跳转到当前知识库文件页;否则跳转到知识库列表
|
||||
const kbId = await getCurrentKbId()
|
||||
if (kbId) {
|
||||
router.push(`/platform/knowledge-bases/${kbId}`)
|
||||
} else {
|
||||
router.push('/platform/knowledge-bases')
|
||||
}
|
||||
} else {
|
||||
router.push(`/platform/${path}`);
|
||||
gotopage(path)
|
||||
}
|
||||
}
|
||||
|
||||
// 处理退出登录确认
|
||||
const handleLogout = () => {
|
||||
gotopage('logout')
|
||||
}
|
||||
|
||||
const getCurrentKbId = async (): Promise<string | null> => {
|
||||
let kbId = (route.params as any)?.kbId as string
|
||||
// 新的路由格式:/platform/chat/:kbId/:chatid,直接从路由参数获取
|
||||
if (!kbId && route.name === 'chat' && (route.params as any)?.kbId) {
|
||||
kbId = (route.params as any).kbId
|
||||
}
|
||||
return kbId || null
|
||||
}
|
||||
|
||||
const gotopage = async (path: string) => {
|
||||
pathPrefix.value = path;
|
||||
// 处理退出登录
|
||||
if (path === 'logout') {
|
||||
authStore.logout();
|
||||
router.push('/login');
|
||||
return;
|
||||
} else {
|
||||
if (path === 'creatChat') {
|
||||
const kbId = await getCurrentKbId()
|
||||
if (kbId) {
|
||||
router.push(`/platform/knowledge-bases/${kbId}/creatChat`)
|
||||
} else {
|
||||
router.push(`/platform/knowledge-bases`)
|
||||
}
|
||||
} else {
|
||||
router.push(`/platform/${path}`);
|
||||
}
|
||||
}
|
||||
getIcon(path)
|
||||
}
|
||||
|
||||
const getImgSrc = (url) => {
|
||||
const getImgSrc = (url: string) => {
|
||||
return new URL(`/src/assets/img/${url}`, import.meta.url).href;
|
||||
}
|
||||
|
||||
const mouseenteMenu = (path) => {
|
||||
if (pathPrefix.value != 'knowledgeBase' && pathPrefix.value != 'creatChat' && path != 'knowledgeBase') {
|
||||
const mouseenteMenu = (path: string) => {
|
||||
if (pathPrefix.value != 'knowledge-bases' && pathPrefix.value != 'creatChat' && path != 'knowledge-bases') {
|
||||
prefixIcon.value = 'prefixIcon-grey.svg';
|
||||
}
|
||||
}
|
||||
const mouseleaveMenu = (path) => {
|
||||
if (pathPrefix.value != 'knowledgeBase' && pathPrefix.value != 'creatChat' && path != 'knowledgeBase') {
|
||||
getIcon(route.name)
|
||||
const mouseleaveMenu = (path: string) => {
|
||||
if (pathPrefix.value != 'knowledge-bases' && pathPrefix.value != 'creatChat' && path != 'knowledge-bases') {
|
||||
const nameStr = typeof route.name === 'string' ? route.name as string : (route.name ? String(route.name) : '')
|
||||
getIcon(nameStr)
|
||||
}
|
||||
}
|
||||
|
||||
// 知识库下拉相关方法
|
||||
const toggleKbDropdown = (event?: Event) => {
|
||||
if (event) {
|
||||
event.stopPropagation()
|
||||
}
|
||||
showKbDropdown.value = !showKbDropdown.value
|
||||
}
|
||||
|
||||
const switchKnowledgeBase = (kbId: string, event?: Event) => {
|
||||
if (event) {
|
||||
event.stopPropagation()
|
||||
}
|
||||
showKbDropdown.value = false
|
||||
const currentRoute = route.name
|
||||
|
||||
// 路由跳转
|
||||
if (currentRoute === 'knowledgeBaseDetail') {
|
||||
router.push(`/platform/knowledge-bases/${kbId}`)
|
||||
} else if (currentRoute === 'kbCreatChat') {
|
||||
router.push(`/platform/knowledge-bases/${kbId}/creatChat`)
|
||||
} else if (currentRoute === 'knowledgeBaseSettings') {
|
||||
router.push(`/platform/knowledge-bases/${kbId}/settings`)
|
||||
} else {
|
||||
router.push(`/platform/knowledge-bases/${kbId}`)
|
||||
}
|
||||
|
||||
// 刷新右侧内容 - 通过触发页面重新加载或发送事件
|
||||
nextTick(() => {
|
||||
// 发送全局事件通知页面刷新知识库内容
|
||||
window.dispatchEvent(new CustomEvent('knowledgeBaseChanged', {
|
||||
detail: { kbId }
|
||||
}))
|
||||
})
|
||||
}
|
||||
|
||||
// 点击外部关闭下拉菜单
|
||||
const handleClickOutside = () => {
|
||||
showKbDropdown.value = false
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
document.addEventListener('click', handleClickOutside)
|
||||
})
|
||||
|
||||
watch(() => route.params.kbId, () => {
|
||||
showKbDropdown.value = false
|
||||
})
|
||||
|
||||
</script>
|
||||
<style lang="less" scoped>
|
||||
.del_submenu {
|
||||
@@ -210,6 +582,10 @@ const mouseleaveMenu = (path) => {
|
||||
padding: 8px;
|
||||
background: #fff;
|
||||
box-sizing: border-box;
|
||||
height: 100vh;
|
||||
overflow: hidden;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
|
||||
.logo_box {
|
||||
height: 80px;
|
||||
@@ -239,9 +615,28 @@ const mouseleaveMenu = (path) => {
|
||||
line-height: 21.7px;
|
||||
}
|
||||
|
||||
.menu_top {
|
||||
flex: 1;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
overflow: hidden;
|
||||
min-height: 0;
|
||||
}
|
||||
|
||||
.menu_bottom {
|
||||
flex-shrink: 0;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
.menu_box {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
|
||||
&.has-submenu {
|
||||
flex: 1;
|
||||
min-height: 0;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -341,18 +736,21 @@ const mouseleaveMenu = (path) => {
|
||||
font-style: normal;
|
||||
font-weight: 600;
|
||||
line-height: 22px;
|
||||
overflow: hidden;
|
||||
white-space: nowrap;
|
||||
max-width: 120px;
|
||||
flex: 1;
|
||||
}
|
||||
|
||||
.submenu {
|
||||
font-family: "PingFang SC";
|
||||
font-size: 14px;
|
||||
font-style: normal;
|
||||
font-family: "PingFang SC";
|
||||
font-size: 14px;
|
||||
font-style: normal;
|
||||
overflow-y: scroll;
|
||||
overflow-y: auto;
|
||||
scrollbar-width: none;
|
||||
height: calc(98vh - 276px);
|
||||
flex: 1;
|
||||
min-height: 0;
|
||||
margin-left: 4px;
|
||||
}
|
||||
|
||||
.submenu_item_p {
|
||||
@@ -427,6 +825,92 @@ const mouseleaveMenu = (path) => {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* 知识库下拉菜单样式 */
|
||||
.kb-dropdown-icon {
|
||||
margin-left: auto;
|
||||
color: #666;
|
||||
transition: transform 0.3s ease, color 0.2s ease;
|
||||
cursor: pointer;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
width: 16px;
|
||||
height: 16px;
|
||||
|
||||
&.rotate-180 {
|
||||
transform: rotate(180deg);
|
||||
}
|
||||
|
||||
&:hover {
|
||||
color: #07c05f;
|
||||
}
|
||||
|
||||
&.active {
|
||||
color: #07c05f;
|
||||
}
|
||||
|
||||
&.active:hover {
|
||||
color: #05a04f;
|
||||
}
|
||||
|
||||
svg {
|
||||
width: 12px;
|
||||
height: 12px;
|
||||
transition: inherit;
|
||||
}
|
||||
}
|
||||
|
||||
.kb-dropdown-menu {
|
||||
position: absolute;
|
||||
top: 100%;
|
||||
left: 0;
|
||||
right: 0;
|
||||
background: #fff;
|
||||
border: 1px solid #e5e7eb;
|
||||
border-radius: 6px;
|
||||
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
|
||||
z-index: 1000;
|
||||
max-height: 200px;
|
||||
overflow-y: auto;
|
||||
}
|
||||
|
||||
.kb-dropdown-item {
|
||||
padding: 8px 16px;
|
||||
cursor: pointer;
|
||||
transition: background-color 0.2s ease;
|
||||
font-size: 14px;
|
||||
color: #333;
|
||||
|
||||
&:hover {
|
||||
background-color: #f5f5f5;
|
||||
}
|
||||
|
||||
&.active {
|
||||
background-color: #07c05f1a;
|
||||
color: #07c05f;
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
&:first-child {
|
||||
border-radius: 6px 6px 0 0;
|
||||
}
|
||||
|
||||
&:last-child {
|
||||
border-radius: 0 0 6px 6px;
|
||||
}
|
||||
}
|
||||
|
||||
.menu_item-box {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
width: 100%;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.menu_box {
|
||||
position: relative;
|
||||
}
|
||||
</style>
|
||||
<style lang="less">
|
||||
.upload-popup {
|
||||
@@ -456,4 +940,48 @@ const mouseleaveMenu = (path) => {
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// 退出登录确认框样式
|
||||
:deep(.t-popconfirm) {
|
||||
.t-popconfirm__content {
|
||||
background: #fff;
|
||||
border: 1px solid #e7e7e7;
|
||||
border-radius: 6px;
|
||||
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.15);
|
||||
padding: 12px 16px;
|
||||
font-size: 14px;
|
||||
color: #333;
|
||||
max-width: 200px;
|
||||
}
|
||||
|
||||
.t-popconfirm__arrow {
|
||||
border-bottom-color: #e7e7e7;
|
||||
}
|
||||
|
||||
.t-popconfirm__arrow::after {
|
||||
border-bottom-color: #fff;
|
||||
}
|
||||
|
||||
.t-popconfirm__buttons {
|
||||
margin-top: 8px;
|
||||
display: flex;
|
||||
justify-content: flex-end;
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
.t-button--variant-outline {
|
||||
border-color: #d9d9d9;
|
||||
color: #666;
|
||||
}
|
||||
|
||||
.t-button--theme-danger {
|
||||
background-color: #ff4d4f;
|
||||
border-color: #ff4d4f;
|
||||
}
|
||||
|
||||
.t-button--theme-danger:hover {
|
||||
background-color: #ff7875;
|
||||
border-color: #ff7875;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
@@ -1,52 +1,54 @@
|
||||
import { ref, reactive, onMounted } from "vue";
|
||||
import { ref, reactive } from "vue";
|
||||
import { storeToRefs } from "pinia";
|
||||
import { formatStringDate, kbFileTypeVerification } from "../utils/index";
|
||||
import { MessagePlugin } from "tdesign-vue-next";
|
||||
import {
|
||||
uploadKnowledgeBase,
|
||||
getKnowledgeBase,
|
||||
uploadKnowledgeFile,
|
||||
listKnowledgeFiles,
|
||||
getKnowledgeDetails,
|
||||
delKnowledgeDetails,
|
||||
getKnowledgeDetailsCon,
|
||||
} from "@/api/knowledge-base/index";
|
||||
import { knowledgeStore } from "@/stores/knowledge";
|
||||
import { useRoute } from 'vue-router';
|
||||
|
||||
const usemenuStore = knowledgeStore();
|
||||
export default function () {
|
||||
export default function (knowledgeBaseId?: string) {
|
||||
const route = useRoute();
|
||||
const { cardList, total } = storeToRefs(usemenuStore);
|
||||
let moreIndex = ref(-1);
|
||||
const details = reactive({
|
||||
title: "",
|
||||
time: "",
|
||||
md: [],
|
||||
md: [] as any[],
|
||||
id: "",
|
||||
total: 0
|
||||
});
|
||||
const getKnowled = (query = { page: 1, page_size: 35 }) => {
|
||||
getKnowledgeBase(query)
|
||||
const getKnowled = (query = { page: 1, page_size: 35 }, kbId?: string) => {
|
||||
const targetKbId = kbId || knowledgeBaseId;
|
||||
if (!targetKbId) return;
|
||||
|
||||
listKnowledgeFiles(targetKbId, query)
|
||||
.then((result: any) => {
|
||||
let { data, total: totalResult } = result;
|
||||
let cardList_ = data.map((item) => {
|
||||
item["file_name"] = item.file_name.substring(
|
||||
0,
|
||||
item.file_name.lastIndexOf(".")
|
||||
);
|
||||
return {
|
||||
...item,
|
||||
updated_at: formatStringDate(new Date(item.updated_at)),
|
||||
isMore: false,
|
||||
file_type: item.file_type.toLocaleUpperCase(),
|
||||
};
|
||||
});
|
||||
if (query.page == 1) {
|
||||
const { data, total: totalResult } = result;
|
||||
const cardList_ = data.map((item: any) => ({
|
||||
...item,
|
||||
file_name: item.file_name.substring(0, item.file_name.lastIndexOf(".")),
|
||||
updated_at: formatStringDate(new Date(item.updated_at)),
|
||||
isMore: false,
|
||||
file_type: item.file_type.toLocaleUpperCase(),
|
||||
}));
|
||||
|
||||
if (query.page === 1) {
|
||||
cardList.value = cardList_;
|
||||
} else {
|
||||
cardList.value.push(...cardList_);
|
||||
}
|
||||
total.value = totalResult;
|
||||
})
|
||||
.catch((err) => {});
|
||||
.catch(() => {});
|
||||
};
|
||||
const delKnowledge = (index: number, item) => {
|
||||
const delKnowledge = (index: number, item: any) => {
|
||||
cardList.value[index].isMore = false;
|
||||
moreIndex.value = -1;
|
||||
delKnowledgeDetails(item.id)
|
||||
@@ -58,7 +60,7 @@ export default function () {
|
||||
MessagePlugin.error("知识删除失败!");
|
||||
}
|
||||
})
|
||||
.catch((err) => {
|
||||
.catch(() => {
|
||||
MessagePlugin.error("知识删除失败!");
|
||||
});
|
||||
};
|
||||
@@ -70,56 +72,48 @@ export default function () {
|
||||
moreIndex.value = -1;
|
||||
}
|
||||
};
|
||||
const requestMethod = (file: any, uploadInput) => {
|
||||
if (file instanceof File && uploadInput) {
|
||||
if (kbFileTypeVerification(file)) {
|
||||
return;
|
||||
}
|
||||
uploadKnowledgeBase({ file })
|
||||
.then((result: any) => {
|
||||
if (result.success) {
|
||||
MessagePlugin.info("上传成功!");
|
||||
getKnowled();
|
||||
} else {
|
||||
// 改进错误信息提取逻辑
|
||||
let errorMessage = "上传失败!";
|
||||
|
||||
// 优先从 error 对象中获取错误信息
|
||||
if (result.error && result.error.message) {
|
||||
errorMessage = result.error.message;
|
||||
} else if (result.message) {
|
||||
errorMessage = result.message;
|
||||
}
|
||||
|
||||
// 检查错误码,如果是重复文件则显示特定提示
|
||||
if (result.code === 'duplicate_file' || (result.error && result.error.code === 'duplicate_file')) {
|
||||
errorMessage = "文件已存在";
|
||||
}
|
||||
|
||||
MessagePlugin.error(errorMessage);
|
||||
}
|
||||
uploadInput.value.value = "";
|
||||
})
|
||||
.catch((err: any) => {
|
||||
// 改进 catch 中的错误处理
|
||||
let errorMessage = "上传失败!";
|
||||
|
||||
if (err.code === 'duplicate_file') {
|
||||
errorMessage = "文件已存在";
|
||||
} else if (err.error && err.error.message) {
|
||||
errorMessage = err.error.message;
|
||||
} else if (err.message) {
|
||||
errorMessage = err.message;
|
||||
}
|
||||
|
||||
MessagePlugin.error(errorMessage);
|
||||
uploadInput.value.value = "";
|
||||
});
|
||||
} else {
|
||||
MessagePlugin.error("file文件类型错误!");
|
||||
const requestMethod = (file: any, uploadInput: any) => {
|
||||
if (!(file instanceof File) || !uploadInput) {
|
||||
MessagePlugin.error("文件类型错误!");
|
||||
return;
|
||||
}
|
||||
|
||||
if (kbFileTypeVerification(file)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// 获取当前知识库ID
|
||||
let currentKbId: string | undefined = (route.params as any)?.kbId as string;
|
||||
if (!currentKbId && typeof window !== 'undefined') {
|
||||
const match = window.location.pathname.match(/knowledge-bases\/([^/]+)/);
|
||||
if (match?.[1]) currentKbId = match[1];
|
||||
}
|
||||
if (!currentKbId) {
|
||||
currentKbId = knowledgeBaseId;
|
||||
}
|
||||
if (!currentKbId) {
|
||||
MessagePlugin.error("缺少知识库ID");
|
||||
return;
|
||||
}
|
||||
|
||||
uploadKnowledgeFile(currentKbId, { file })
|
||||
.then((result: any) => {
|
||||
if (result.success) {
|
||||
MessagePlugin.info("上传成功!");
|
||||
getKnowled({ page: 1, page_size: 35 }, currentKbId);
|
||||
} else {
|
||||
const errorMessage = result.error?.message || result.message || "上传失败!";
|
||||
MessagePlugin.error(result.code === 'duplicate_file' ? "文件已存在" : errorMessage);
|
||||
}
|
||||
uploadInput.value.value = "";
|
||||
})
|
||||
.catch((err: any) => {
|
||||
const errorMessage = err.error?.message || err.message || "上传失败!";
|
||||
MessagePlugin.error(err.code === 'duplicate_file' ? "文件已存在" : errorMessage);
|
||||
uploadInput.value.value = "";
|
||||
});
|
||||
};
|
||||
const getCardDetails = (item) => {
|
||||
const getCardDetails = (item: any) => {
|
||||
Object.assign(details, {
|
||||
title: "",
|
||||
time: "",
|
||||
@@ -129,7 +123,7 @@ export default function () {
|
||||
getKnowledgeDetails(item.id)
|
||||
.then((result: any) => {
|
||||
if (result.success && result.data) {
|
||||
let { data } = result;
|
||||
const { data } = result;
|
||||
Object.assign(details, {
|
||||
title: data.file_name,
|
||||
time: formatStringDate(new Date(data.updated_at)),
|
||||
@@ -137,15 +131,16 @@ export default function () {
|
||||
});
|
||||
}
|
||||
})
|
||||
.catch((err) => {});
|
||||
getfDetails(item.id, 1);
|
||||
.catch(() => {});
|
||||
getfDetails(item.id, 1);
|
||||
};
|
||||
const getfDetails = (id, page) => {
|
||||
|
||||
const getfDetails = (id: string, page: number) => {
|
||||
getKnowledgeDetailsCon(id, page)
|
||||
.then((result: any) => {
|
||||
if (result.success && result.data) {
|
||||
let { data, total: totalResult } = result;
|
||||
if (page == 1) {
|
||||
const { data, total: totalResult } = result;
|
||||
if (page === 1) {
|
||||
details.md = data;
|
||||
} else {
|
||||
details.md.push(...data);
|
||||
@@ -153,7 +148,7 @@ export default function () {
|
||||
details.total = totalResult;
|
||||
}
|
||||
})
|
||||
.catch((err) => {});
|
||||
.catch(() => {});
|
||||
};
|
||||
return {
|
||||
cardList,
|
||||
|
||||
@@ -1,89 +1,117 @@
|
||||
import { createRouter, createWebHistory } from 'vue-router'
|
||||
import { checkInitializationStatus } from '@/api/initialization'
|
||||
import { listKnowledgeBases } from '@/api/knowledge-base'
|
||||
import { useAuthStore } from '@/stores/auth'
|
||||
import { validateToken } from '@/api/auth'
|
||||
|
||||
const router = createRouter({
|
||||
history: createWebHistory(import.meta.env.BASE_URL),
|
||||
routes: [
|
||||
{
|
||||
path: "/",
|
||||
redirect: "/platform",
|
||||
redirect: "/platform/knowledge-bases",
|
||||
},
|
||||
{
|
||||
path: "/initialization",
|
||||
name: "initialization",
|
||||
component: () => import("../views/initialization/InitializationConfig.vue"),
|
||||
meta: { requiresInit: false } // 初始化页面不需要检查初始化状态
|
||||
path: "/login",
|
||||
name: "login",
|
||||
component: () => import("../views/auth/Login.vue"),
|
||||
meta: { requiresAuth: false, requiresInit: false }
|
||||
},
|
||||
{
|
||||
path: "/knowledgeBase",
|
||||
name: "home",
|
||||
component: () => import("../views/knowledge/KnowledgeBase.vue"),
|
||||
meta: { requiresInit: true }
|
||||
meta: { requiresInit: true, requiresAuth: true }
|
||||
},
|
||||
{
|
||||
path: "/platform",
|
||||
name: "Platform",
|
||||
redirect: "/platform/knowledgeBase",
|
||||
redirect: "/platform/knowledge-bases",
|
||||
component: () => import("../views/platform/index.vue"),
|
||||
meta: { requiresInit: true },
|
||||
meta: { requiresInit: true, requiresAuth: true },
|
||||
children: [
|
||||
{
|
||||
path: "knowledgeBase",
|
||||
name: "knowledgeBase",
|
||||
path: "tenant",
|
||||
name: "tenant",
|
||||
component: () => import("../views/tenant/TenantInfo.vue"),
|
||||
meta: { requiresInit: true, requiresAuth: true }
|
||||
},
|
||||
{
|
||||
path: "knowledge-bases",
|
||||
name: "knowledgeBaseList",
|
||||
component: () => import("../views/knowledge/KnowledgeBaseList.vue"),
|
||||
meta: { requiresInit: true, requiresAuth: true }
|
||||
},
|
||||
{
|
||||
path: "knowledge-bases/:kbId",
|
||||
name: "knowledgeBaseDetail",
|
||||
component: () => import("../views/knowledge/KnowledgeBase.vue"),
|
||||
meta: { requiresInit: true }
|
||||
meta: { requiresInit: true, requiresAuth: true }
|
||||
},
|
||||
{
|
||||
path: "creatChat",
|
||||
name: "creatChat",
|
||||
path: "knowledge-bases/:kbId/creatChat",
|
||||
name: "kbCreatChat",
|
||||
component: () => import("../views/creatChat/creatChat.vue"),
|
||||
meta: { requiresInit: true }
|
||||
meta: { requiresInit: true, requiresAuth: true }
|
||||
},
|
||||
{
|
||||
path: "chat/:chatid",
|
||||
path: "knowledge-bases/:kbId/settings",
|
||||
name: "knowledgeBaseSettings",
|
||||
component: () => import("../views/initialization/InitializationContent.vue"),
|
||||
props: { isKbSettings: true },
|
||||
meta: { requiresInit: true, requiresAuth: true }
|
||||
},
|
||||
{
|
||||
path: "chat/:kbId/:chatid",
|
||||
name: "chat",
|
||||
component: () => import("../views/chat/index.vue"),
|
||||
meta: { requiresInit: true }
|
||||
},
|
||||
{
|
||||
path: "settings",
|
||||
name: "settings",
|
||||
component: () => import("../views/settings/Settings.vue"),
|
||||
meta: { requiresInit: true }
|
||||
meta: { requiresInit: true, requiresAuth: true }
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
// 路由守卫:检查系统初始化状态
|
||||
// 路由守卫:检查认证状态和系统初始化状态
|
||||
router.beforeEach(async (to, from, next) => {
|
||||
// 如果访问的是初始化页面,直接放行
|
||||
if (to.meta.requiresInit === false) {
|
||||
next();
|
||||
return;
|
||||
}
|
||||
|
||||
1
|
||||
|
||||
try {
|
||||
// 检查系统是否已初始化
|
||||
const { initialized } = await checkInitializationStatus();
|
||||
|
||||
if (initialized) {
|
||||
// 系统已初始化,记录到本地存储并正常跳转
|
||||
localStorage.setItem('system_initialized', 'true');
|
||||
next();
|
||||
} else {
|
||||
// 系统未初始化,跳转到初始化页面
|
||||
console.log('系统未初始化,跳转到初始化页面');
|
||||
next('/initialization');
|
||||
const authStore = useAuthStore()
|
||||
|
||||
// 如果访问的是登录页面或初始化页面,直接放行
|
||||
if (to.meta.requiresAuth === false || to.meta.requiresInit === false) {
|
||||
// 如果已登录用户访问登录页面,重定向到知识库列表页面
|
||||
if (to.path === '/login' && authStore.isLoggedIn) {
|
||||
next('/platform/knowledge-bases')
|
||||
return
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('检查初始化状态失败:', error);
|
||||
// 如果检查失败,默认认为需要初始化
|
||||
next('/initialization');
|
||||
next()
|
||||
return
|
||||
}
|
||||
|
||||
// 检查用户认证状态
|
||||
if (to.meta.requiresAuth !== false) {
|
||||
if (!authStore.isLoggedIn) {
|
||||
// 未登录,跳转到登录页面
|
||||
next('/login')
|
||||
return
|
||||
}
|
||||
|
||||
// 验证Token有效性
|
||||
// try {
|
||||
// const { valid } = await validateToken()
|
||||
// if (!valid) {
|
||||
// // Token无效,清空认证信息并跳转到登录页面
|
||||
// authStore.logout()
|
||||
// next('/login')
|
||||
// return
|
||||
// }
|
||||
// } catch (error) {
|
||||
// console.error('Token验证失败:', error)
|
||||
// authStore.logout()
|
||||
// next('/login')
|
||||
// return
|
||||
// }
|
||||
}
|
||||
|
||||
next()
|
||||
});
|
||||
|
||||
export default router
|
||||
|
||||
169
frontend/src/stores/auth.ts
Normal file
169
frontend/src/stores/auth.ts
Normal file
@@ -0,0 +1,169 @@
|
||||
import { defineStore } from 'pinia'
|
||||
import { ref, computed } from 'vue'
|
||||
import type { UserInfo, TenantInfo, KnowledgeBaseInfo } from '@/api/auth'
|
||||
|
||||
export const useAuthStore = defineStore('auth', () => {
|
||||
// 状态
|
||||
const user = ref<UserInfo | null>(null)
|
||||
const tenant = ref<TenantInfo | null>(null)
|
||||
const token = ref<string>('')
|
||||
const refreshToken = ref<string>('')
|
||||
const knowledgeBases = ref<KnowledgeBaseInfo[]>([])
|
||||
const currentKnowledgeBase = ref<KnowledgeBaseInfo | null>(null)
|
||||
|
||||
// 计算属性
|
||||
const isLoggedIn = computed(() => {
|
||||
return !!token.value && !!user.value
|
||||
})
|
||||
|
||||
const hasValidTenant = computed(() => {
|
||||
return !!tenant.value && !!tenant.value.api_key
|
||||
})
|
||||
|
||||
const currentTenantId = computed(() => {
|
||||
return tenant.value?.id || ''
|
||||
})
|
||||
|
||||
const currentUserId = computed(() => {
|
||||
return user.value?.id || ''
|
||||
})
|
||||
|
||||
// 操作方法
|
||||
const setUser = (userData: UserInfo) => {
|
||||
user.value = userData
|
||||
// 保存到localStorage
|
||||
localStorage.setItem('weknora_user', JSON.stringify(userData))
|
||||
}
|
||||
|
||||
const setTenant = (tenantData: TenantInfo) => {
|
||||
tenant.value = tenantData
|
||||
// 保存到localStorage
|
||||
localStorage.setItem('weknora_tenant', JSON.stringify(tenantData))
|
||||
}
|
||||
|
||||
const setToken = (tokenValue: string) => {
|
||||
token.value = tokenValue
|
||||
localStorage.setItem('weknora_token', tokenValue)
|
||||
}
|
||||
|
||||
const setRefreshToken = (refreshTokenValue: string) => {
|
||||
refreshToken.value = refreshTokenValue
|
||||
localStorage.setItem('weknora_refresh_token', refreshTokenValue)
|
||||
}
|
||||
|
||||
const setKnowledgeBases = (kbList: KnowledgeBaseInfo[]) => {
|
||||
// 确保输入是数组
|
||||
knowledgeBases.value = Array.isArray(kbList) ? kbList : []
|
||||
localStorage.setItem('weknora_knowledge_bases', JSON.stringify(knowledgeBases.value))
|
||||
}
|
||||
|
||||
const setCurrentKnowledgeBase = (kb: KnowledgeBaseInfo | null) => {
|
||||
currentKnowledgeBase.value = kb
|
||||
if (kb) {
|
||||
localStorage.setItem('weknora_current_kb', JSON.stringify(kb))
|
||||
} else {
|
||||
localStorage.removeItem('weknora_current_kb')
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
const logout = () => {
|
||||
// 清空状态
|
||||
user.value = null
|
||||
tenant.value = null
|
||||
token.value = ''
|
||||
refreshToken.value = ''
|
||||
knowledgeBases.value = []
|
||||
currentKnowledgeBase.value = null
|
||||
|
||||
// 清空localStorage
|
||||
localStorage.removeItem('weknora_user')
|
||||
localStorage.removeItem('weknora_tenant')
|
||||
localStorage.removeItem('weknora_token')
|
||||
localStorage.removeItem('weknora_refresh_token')
|
||||
localStorage.removeItem('weknora_knowledge_bases')
|
||||
localStorage.removeItem('weknora_current_kb')
|
||||
|
||||
}
|
||||
|
||||
const initFromStorage = () => {
|
||||
// 从localStorage恢复状态
|
||||
const storedUser = localStorage.getItem('weknora_user')
|
||||
const storedTenant = localStorage.getItem('weknora_tenant')
|
||||
const storedToken = localStorage.getItem('weknora_token')
|
||||
const storedRefreshToken = localStorage.getItem('weknora_refresh_token')
|
||||
const storedKnowledgeBases = localStorage.getItem('weknora_knowledge_bases')
|
||||
const storedCurrentKb = localStorage.getItem('weknora_current_kb')
|
||||
|
||||
if (storedUser) {
|
||||
try {
|
||||
user.value = JSON.parse(storedUser)
|
||||
} catch (e) {
|
||||
console.error('解析用户信息失败:', e)
|
||||
}
|
||||
}
|
||||
|
||||
if (storedTenant) {
|
||||
try {
|
||||
tenant.value = JSON.parse(storedTenant)
|
||||
} catch (e) {
|
||||
console.error('解析租户信息失败:', e)
|
||||
}
|
||||
}
|
||||
|
||||
if (storedToken) {
|
||||
token.value = storedToken
|
||||
}
|
||||
|
||||
if (storedRefreshToken) {
|
||||
refreshToken.value = storedRefreshToken
|
||||
}
|
||||
|
||||
if (storedKnowledgeBases) {
|
||||
try {
|
||||
const parsed = JSON.parse(storedKnowledgeBases)
|
||||
knowledgeBases.value = Array.isArray(parsed) ? parsed : []
|
||||
} catch (e) {
|
||||
console.error('解析知识库列表失败:', e)
|
||||
knowledgeBases.value = []
|
||||
}
|
||||
}
|
||||
|
||||
if (storedCurrentKb) {
|
||||
try {
|
||||
currentKnowledgeBase.value = JSON.parse(storedCurrentKb)
|
||||
} catch (e) {
|
||||
console.error('解析当前知识库失败:', e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 初始化时从localStorage恢复状态
|
||||
initFromStorage()
|
||||
|
||||
return {
|
||||
// 状态
|
||||
user,
|
||||
tenant,
|
||||
token,
|
||||
refreshToken,
|
||||
knowledgeBases,
|
||||
currentKnowledgeBase,
|
||||
|
||||
// 计算属性
|
||||
isLoggedIn,
|
||||
hasValidTenant,
|
||||
currentTenantId,
|
||||
currentUserId,
|
||||
|
||||
// 方法
|
||||
setUser,
|
||||
setTenant,
|
||||
setToken,
|
||||
setRefreshToken,
|
||||
setKnowledgeBases,
|
||||
setCurrentKnowledgeBase,
|
||||
logout,
|
||||
initFromStorage
|
||||
}
|
||||
})
|
||||
@@ -4,8 +4,8 @@ import { defineStore } from "pinia";
|
||||
|
||||
export const knowledgeStore = defineStore("knowledge", {
|
||||
state: () => ({
|
||||
cardList: ref([]),
|
||||
total: ref(0),
|
||||
cardList: ref<any[]>([]),
|
||||
total: ref<number>(0),
|
||||
}),
|
||||
actions: {},
|
||||
});
|
||||
|
||||
@@ -5,7 +5,7 @@ import { defineStore } from 'pinia';
|
||||
export const useMenuStore = defineStore('menuStore', {
|
||||
state: () => ({
|
||||
menuArr: reactive([
|
||||
{ title: '知识库', icon: 'zhishiku', path: 'knowledgeBase' },
|
||||
{ title: '知识库', icon: 'zhishiku', path: 'knowledge-bases' },
|
||||
{
|
||||
title: '对话',
|
||||
icon: 'prefixIcon',
|
||||
@@ -13,7 +13,8 @@ export const useMenuStore = defineStore('menuStore', {
|
||||
childrenPath: 'chat',
|
||||
children: reactive<object[]>([]),
|
||||
},
|
||||
{ title: '系统设置', icon: 'setting', path: 'settings' }
|
||||
{ title: '系统信息', icon: 'tenant', path: 'tenant' },
|
||||
{ title: '退出登录', icon: 'logout', path: 'logout' }
|
||||
]),
|
||||
isFirstSession: false,
|
||||
firstQuery: ''
|
||||
@@ -30,7 +31,7 @@ export const useMenuStore = defineStore('menuStore', {
|
||||
this.menuArr[1].children?.unshift(item)
|
||||
},
|
||||
updatasessionTitle(session_id: string, title: string) {
|
||||
this.menuArr[1].children?.forEach(item => {
|
||||
this.menuArr[1].children?.forEach((item: any) => {
|
||||
if (item.id == session_id) {
|
||||
item.title = title;
|
||||
item.isNoTitle = false;
|
||||
|
||||
@@ -10,19 +10,19 @@ export function generateRandomString(length: number) {
|
||||
return result;
|
||||
}
|
||||
|
||||
export function formatStringDate(date) {
|
||||
export function formatStringDate(date: any) {
|
||||
let data = new Date(date);
|
||||
let year = data.getFullYear();
|
||||
let month = data.getMonth() + 1;
|
||||
let day = data.getDate();
|
||||
let hour = data.getHours();
|
||||
let minute = data.getMinutes();
|
||||
let second = data.getSeconds();
|
||||
let month = String(data.getMonth() + 1).padStart(2, '0');
|
||||
let day = String(data.getDate()).padStart(2, '0');
|
||||
let hour = String(data.getHours()).padStart(2, '0');
|
||||
let minute = String(data.getMinutes()).padStart(2, '0');
|
||||
let second = String(data.getSeconds()).padStart(2, '0');
|
||||
return (
|
||||
year + "-" + month + "-" + day + " " + hour + ":" + minute + ":" + second
|
||||
);
|
||||
}
|
||||
export function kbFileTypeVerification(file) {
|
||||
export function kbFileTypeVerification(file: any) {
|
||||
let validTypes = ["pdf", "txt", "md", "docx", "doc", "jpg", "jpeg", "png"];
|
||||
let type = file.name.substring(file.name.lastIndexOf(".") + 1);
|
||||
if (!validTypes.includes(type)) {
|
||||
|
||||
@@ -2,40 +2,9 @@
|
||||
import axios from "axios";
|
||||
import { generateRandomString } from "./index";
|
||||
|
||||
// 从localStorage获取设置
|
||||
function getSettings() {
|
||||
const settingsStr = localStorage.getItem("WeKnora_settings");
|
||||
if (settingsStr) {
|
||||
try {
|
||||
return JSON.parse(settingsStr);
|
||||
} catch (e) {
|
||||
console.error("解析设置失败:", e);
|
||||
}
|
||||
}
|
||||
return {
|
||||
endpoint: import.meta.env.VITE_IS_DOCKER ? "" : "http://localhost:8080",
|
||||
apiKey: "",
|
||||
knowledgeBaseId: "",
|
||||
};
|
||||
}
|
||||
// API基础URL
|
||||
const BASE_URL = import.meta.env.VITE_IS_DOCKER ? "" : "http://localhost:8080";
|
||||
|
||||
// API基础URL,优先使用设置中的endpoint
|
||||
const settings = getSettings();
|
||||
const BASE_URL = settings.endpoint;
|
||||
|
||||
// 测试数据
|
||||
let testData: {
|
||||
tenant: {
|
||||
id: number;
|
||||
name: string;
|
||||
api_key: string;
|
||||
};
|
||||
knowledge_bases: Array<{
|
||||
id: string;
|
||||
name: string;
|
||||
description: string;
|
||||
}>;
|
||||
} | null = null;
|
||||
|
||||
// 创建Axios实例
|
||||
const instance = axios.create({
|
||||
@@ -47,44 +16,41 @@ const instance = axios.create({
|
||||
},
|
||||
});
|
||||
|
||||
// 设置测试数据
|
||||
export function setTestData(data: typeof testData) {
|
||||
testData = data;
|
||||
if (data) {
|
||||
// 优先使用设置中的ApiKey,如果没有则使用测试数据中的
|
||||
const apiKey = settings.apiKey || (data?.tenant?.api_key || "");
|
||||
if (apiKey) {
|
||||
instance.defaults.headers["X-API-Key"] = apiKey;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 获取测试数据
|
||||
export function getTestData() {
|
||||
return testData;
|
||||
}
|
||||
|
||||
instance.interceptors.request.use(
|
||||
(config) => {
|
||||
// 每次请求前检查是否有更新的设置
|
||||
const currentSettings = getSettings();
|
||||
|
||||
// 更新BaseURL (如果有变化)
|
||||
if (currentSettings.endpoint && config.baseURL !== currentSettings.endpoint) {
|
||||
config.baseURL = currentSettings.endpoint;
|
||||
}
|
||||
|
||||
// 更新API Key (如果有)
|
||||
if (currentSettings.apiKey) {
|
||||
config.headers["X-API-Key"] = currentSettings.apiKey;
|
||||
// 添加JWT token认证
|
||||
const token = localStorage.getItem('weknora_token');
|
||||
if (token) {
|
||||
config.headers["Authorization"] = `Bearer ${token}`;
|
||||
}
|
||||
|
||||
config.headers["X-Request-ID"] = `${generateRandomString(12)}`;
|
||||
return config;
|
||||
},
|
||||
(error) => {}
|
||||
(error) => {
|
||||
return Promise.reject(error);
|
||||
}
|
||||
);
|
||||
|
||||
// Token刷新标志,防止多个请求同时刷新token
|
||||
let isRefreshing = false;
|
||||
let failedQueue: Array<{ resolve: Function; reject: Function }> = [];
|
||||
let hasRedirectedOn401 = false;
|
||||
|
||||
// 处理队列中的请求
|
||||
const processQueue = (error: any, token: string | null = null) => {
|
||||
failedQueue.forEach(({ resolve, reject }) => {
|
||||
if (error) {
|
||||
reject(error);
|
||||
} else {
|
||||
resolve(token);
|
||||
}
|
||||
});
|
||||
|
||||
failedQueue = [];
|
||||
};
|
||||
|
||||
instance.interceptors.response.use(
|
||||
(response) => {
|
||||
// 根据业务状态码处理逻辑
|
||||
@@ -95,12 +61,98 @@ instance.interceptors.response.use(
|
||||
return Promise.reject(data);
|
||||
}
|
||||
},
|
||||
(error: any) => {
|
||||
async (error: any) => {
|
||||
const originalRequest = error.config;
|
||||
|
||||
if (!error.response) {
|
||||
return Promise.reject({ message: "网络错误,请检查您的网络连接" });
|
||||
}
|
||||
const { data } = error.response;
|
||||
return Promise.reject(data);
|
||||
|
||||
// 如果是登录接口的401,直接返回错误以便页面展示toast,不做跳转
|
||||
if (error.response.status === 401 && originalRequest?.url?.includes('/auth/login')) {
|
||||
const { status, data } = error.response;
|
||||
return Promise.reject({ status, message: (typeof data === 'object' ? data?.message : data) || '用户名或密码错误' });
|
||||
}
|
||||
|
||||
// 如果是401错误且不是刷新token的请求,尝试刷新token
|
||||
if (error.response.status === 401 && !originalRequest._retry && !originalRequest.url?.includes('/auth/refresh')) {
|
||||
if (isRefreshing) {
|
||||
// 如果正在刷新token,将请求加入队列
|
||||
return new Promise((resolve, reject) => {
|
||||
failedQueue.push({ resolve, reject });
|
||||
}).then(token => {
|
||||
originalRequest.headers['Authorization'] = 'Bearer ' + token;
|
||||
return instance(originalRequest);
|
||||
}).catch(err => {
|
||||
return Promise.reject(err);
|
||||
});
|
||||
}
|
||||
|
||||
originalRequest._retry = true;
|
||||
isRefreshing = true;
|
||||
|
||||
const refreshToken = localStorage.getItem('weknora_refresh_token');
|
||||
|
||||
if (refreshToken) {
|
||||
try {
|
||||
// 动态导入refresh token API
|
||||
const { refreshToken: refreshTokenAPI } = await import('../api/auth/index');
|
||||
const response = await refreshTokenAPI(refreshToken);
|
||||
|
||||
if (response.success && response.data) {
|
||||
const { token, refreshToken: newRefreshToken } = response.data;
|
||||
|
||||
// 更新localStorage中的token
|
||||
localStorage.setItem('weknora_token', token);
|
||||
localStorage.setItem('weknora_refresh_token', newRefreshToken);
|
||||
|
||||
// 更新请求头
|
||||
originalRequest.headers['Authorization'] = 'Bearer ' + token;
|
||||
|
||||
// 处理队列中的请求
|
||||
processQueue(null, token);
|
||||
|
||||
return instance(originalRequest);
|
||||
} else {
|
||||
throw new Error(response.message || 'Token刷新失败');
|
||||
}
|
||||
} catch (refreshError) {
|
||||
// 刷新失败,清除所有token并跳转到登录页
|
||||
localStorage.removeItem('weknora_token');
|
||||
localStorage.removeItem('weknora_refresh_token');
|
||||
localStorage.removeItem('weknora_user');
|
||||
localStorage.removeItem('weknora_tenant');
|
||||
|
||||
processQueue(refreshError, null);
|
||||
|
||||
// 跳转到登录页
|
||||
if (!hasRedirectedOn401 && typeof window !== 'undefined') {
|
||||
hasRedirectedOn401 = true;
|
||||
window.location.href = '/login';
|
||||
}
|
||||
|
||||
return Promise.reject(refreshError);
|
||||
} finally {
|
||||
isRefreshing = false;
|
||||
}
|
||||
} else {
|
||||
// 没有refresh token,直接跳转到登录页
|
||||
localStorage.removeItem('weknora_token');
|
||||
localStorage.removeItem('weknora_user');
|
||||
localStorage.removeItem('weknora_tenant');
|
||||
|
||||
if (!hasRedirectedOn401 && typeof window !== 'undefined') {
|
||||
hasRedirectedOn401 = true;
|
||||
window.location.href = '/login';
|
||||
}
|
||||
|
||||
return Promise.reject({ message: '请重新登录' });
|
||||
}
|
||||
}
|
||||
|
||||
const { status, data } = error.response;
|
||||
// 将HTTP状态码一并抛出,方便上层判断401等场景
|
||||
return Promise.reject({ status, ...(typeof data === 'object' ? data : { message: data }) });
|
||||
}
|
||||
);
|
||||
|
||||
|
||||
207
frontend/src/utils/security.ts
Normal file
207
frontend/src/utils/security.ts
Normal file
@@ -0,0 +1,207 @@
|
||||
/**
|
||||
* 安全工具类 - 防止 XSS 攻击
|
||||
*/
|
||||
|
||||
import DOMPurify from 'dompurify';
|
||||
|
||||
// 配置 DOMPurify 的安全策略
|
||||
const DOMPurifyConfig = {
|
||||
// 允许的标签
|
||||
ALLOWED_TAGS: [
|
||||
'p', 'br', 'strong', 'em', 'u', 's', 'del', 'ins',
|
||||
'h1', 'h2', 'h3', 'h4', 'h5', 'h6',
|
||||
'ul', 'ol', 'li', 'blockquote', 'pre', 'code',
|
||||
'a', 'img', 'table', 'thead', 'tbody', 'tr', 'th', 'td',
|
||||
'div', 'span', 'figure', 'figcaption', 'think'
|
||||
],
|
||||
// 允许的属性
|
||||
ALLOWED_ATTR: [
|
||||
'href', 'title', 'alt', 'src', 'class', 'id', 'style',
|
||||
'target', 'rel', 'width', 'height'
|
||||
],
|
||||
// 允许的协议
|
||||
ALLOWED_URI_REGEXP: /^(?:(?:(?:f|ht)tps?|mailto|tel|callto|cid|xmpp):|[^a-z]|[a-z+.\-]+(?:[^a-z+.\-:]|$))/i,
|
||||
// 禁止的标签和属性
|
||||
FORBID_TAGS: ['script', 'object', 'embed', 'form', 'input', 'button'],
|
||||
FORBID_ATTR: ['onerror', 'onload', 'onclick', 'onmouseover', 'onfocus', 'onblur'],
|
||||
// 其他安全配置
|
||||
KEEP_CONTENT: true,
|
||||
RETURN_DOM: false,
|
||||
RETURN_DOM_FRAGMENT: false,
|
||||
RETURN_DOM_IMPORT: false,
|
||||
SANITIZE_DOM: true,
|
||||
SANITIZE_NAMED_PROPS: true,
|
||||
WHOLE_DOCUMENT: false,
|
||||
// 自定义钩子函数
|
||||
HOOKS: {
|
||||
// 在清理前处理
|
||||
beforeSanitizeElements: (currentNode: Element) => {
|
||||
// 移除所有 script 标签
|
||||
if (currentNode.tagName === 'SCRIPT') {
|
||||
currentNode.remove();
|
||||
return null;
|
||||
}
|
||||
// 移除所有事件处理器
|
||||
const eventAttrs = ['onclick', 'onload', 'onerror', 'onmouseover', 'onfocus', 'onblur'];
|
||||
eventAttrs.forEach(attr => {
|
||||
if (currentNode.hasAttribute(attr)) {
|
||||
currentNode.removeAttribute(attr);
|
||||
}
|
||||
});
|
||||
},
|
||||
// 在清理后处理
|
||||
afterSanitizeElements: (currentNode: Element) => {
|
||||
// 确保所有链接都有 rel="noopener noreferrer"
|
||||
if (currentNode.tagName === 'A') {
|
||||
const href = currentNode.getAttribute('href');
|
||||
if (href && href.startsWith('http')) {
|
||||
currentNode.setAttribute('rel', 'noopener noreferrer');
|
||||
currentNode.setAttribute('target', '_blank');
|
||||
}
|
||||
}
|
||||
// 确保所有图片都有 alt 属性
|
||||
if (currentNode.tagName === 'IMG') {
|
||||
if (!currentNode.getAttribute('alt')) {
|
||||
currentNode.setAttribute('alt', '');
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* 安全地清理 HTML 内容
|
||||
* @param html 需要清理的 HTML 字符串
|
||||
* @returns 清理后的安全 HTML 字符串
|
||||
*/
|
||||
export function sanitizeHTML(html: string): string {
|
||||
if (!html || typeof html !== 'string') {
|
||||
return '';
|
||||
}
|
||||
|
||||
try {
|
||||
return DOMPurify.sanitize(html, DOMPurifyConfig);
|
||||
} catch (error) {
|
||||
console.error('HTML sanitization failed:', error);
|
||||
// 如果清理失败,返回转义的纯文本
|
||||
return escapeHTML(html);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 转义 HTML 特殊字符
|
||||
* @param text 需要转义的文本
|
||||
* @returns 转义后的文本
|
||||
*/
|
||||
export function escapeHTML(text: string): string {
|
||||
if (!text || typeof text !== 'string') {
|
||||
return '';
|
||||
}
|
||||
|
||||
const map: { [key: string]: string } = {
|
||||
'&': '&',
|
||||
'<': '<',
|
||||
'>': '>',
|
||||
'"': '"',
|
||||
"'": ''',
|
||||
'/': '/',
|
||||
'`': '`',
|
||||
'=': '='
|
||||
};
|
||||
|
||||
return text.replace(/[&<>"'`=\/]/g, (s) => map[s]);
|
||||
}
|
||||
|
||||
/**
|
||||
* 验证 URL 是否安全
|
||||
* @param url 需要验证的 URL
|
||||
* @returns 是否为安全 URL
|
||||
*/
|
||||
export function isValidURL(url: string): boolean {
|
||||
if (!url || typeof url !== 'string') {
|
||||
return false;
|
||||
}
|
||||
|
||||
try {
|
||||
const urlObj = new URL(url);
|
||||
// 只允许 http 和 https 协议
|
||||
return ['http:', 'https:'].includes(urlObj.protocol);
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 安全地处理 Markdown 内容
|
||||
* @param markdown Markdown 文本
|
||||
* @returns 安全的 HTML 字符串
|
||||
*/
|
||||
export function safeMarkdownToHTML(markdown: string): string {
|
||||
if (!markdown || typeof markdown !== 'string') {
|
||||
return '';
|
||||
}
|
||||
|
||||
// 首先转义可能的 HTML 标签
|
||||
const escapedMarkdown = markdown
|
||||
.replace(/<script\b[^<]*(?:(?!<\/script>)<[^<]*)*<\/script>/gi, '')
|
||||
.replace(/<iframe\b[^<]*(?:(?!<\/iframe>)<[^<]*)*<\/iframe>/gi, '')
|
||||
.replace(/<object\b[^<]*(?:(?!<\/object>)<[^<]*)*<\/object>/gi, '')
|
||||
.replace(/<embed\b[^<]*(?:(?!<\/embed>)<[^<]*)*<\/embed>/gi, '');
|
||||
|
||||
return escapedMarkdown;
|
||||
}
|
||||
|
||||
/**
|
||||
* 清理用户输入
|
||||
* @param input 用户输入
|
||||
* @returns 清理后的安全输入
|
||||
*/
|
||||
export function sanitizeUserInput(input: string): string {
|
||||
if (!input || typeof input !== 'string') {
|
||||
return '';
|
||||
}
|
||||
|
||||
// 移除控制字符
|
||||
let cleaned = input.replace(/[\x00-\x1F\x7F-\x9F]/g, '');
|
||||
|
||||
// 限制长度
|
||||
if (cleaned.length > 10000) {
|
||||
cleaned = cleaned.substring(0, 10000);
|
||||
}
|
||||
|
||||
return cleaned.trim();
|
||||
}
|
||||
|
||||
/**
|
||||
* 验证图片 URL 是否安全
|
||||
* @param url 图片 URL
|
||||
* @returns 是否为安全的图片 URL
|
||||
*/
|
||||
export function isValidImageURL(url: string): boolean {
|
||||
if (!isValidURL(url)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 检查是否为图片文件
|
||||
const imageExtensions = /\.(jpg|jpeg|png|gif|webp|svg|bmp|ico)(\?.*)?$/i;
|
||||
return imageExtensions.test(url);
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建安全的图片元素
|
||||
* @param src 图片源
|
||||
* @param alt 替代文本
|
||||
* @param title 标题
|
||||
* @returns 安全的图片 HTML
|
||||
*/
|
||||
export function createSafeImage(src: string, alt: string = '', title: string = ''): string {
|
||||
if (!isValidImageURL(src)) {
|
||||
return '';
|
||||
}
|
||||
|
||||
const safeSrc = escapeHTML(src);
|
||||
const safeAlt = escapeHTML(alt);
|
||||
const safeTitle = escapeHTML(title);
|
||||
|
||||
return `<img src="${safeSrc}" alt="${safeAlt}" title="${safeTitle}" class="markdown-image" style="max-width: 100%; height: auto;">`;
|
||||
}
|
||||
553
frontend/src/views/auth/Login.vue
Normal file
553
frontend/src/views/auth/Login.vue
Normal file
@@ -0,0 +1,553 @@
|
||||
<template>
|
||||
<div class="login-container">
|
||||
<!-- 登录表单 -->
|
||||
<div class="login-card" v-if="!isRegisterMode">
|
||||
<!-- 系统Logo和标题 -->
|
||||
<div class="login-header">
|
||||
<div class="logo">
|
||||
<img src="@/assets/img/weknora.png" alt="WeKnora" class="logo-img" />
|
||||
</div>
|
||||
<p class="login-subtitle">基于大模型的文档理解与语义检索框架</p>
|
||||
</div>
|
||||
|
||||
<div class="login-form">
|
||||
<t-form
|
||||
ref="formRef"
|
||||
:data="formData"
|
||||
:rules="formRules"
|
||||
@submit="handleLogin"
|
||||
layout="vertical"
|
||||
>
|
||||
<t-form-item label="邮箱" name="email">
|
||||
<t-input
|
||||
v-model="formData.email"
|
||||
placeholder="请输入邮箱地址"
|
||||
type="email"
|
||||
size="large"
|
||||
:disabled="loading"
|
||||
/>
|
||||
</t-form-item>
|
||||
|
||||
<t-form-item label="密码" name="password">
|
||||
<t-input
|
||||
v-model="formData.password"
|
||||
placeholder="请输入密码(8-32位,包含字母和数字)"
|
||||
type="password"
|
||||
size="large"
|
||||
:disabled="loading"
|
||||
@keydown.enter="handleLogin"
|
||||
/>
|
||||
</t-form-item>
|
||||
|
||||
<t-button
|
||||
type="submit"
|
||||
theme="primary"
|
||||
size="large"
|
||||
block
|
||||
:loading="loading"
|
||||
class="login-button"
|
||||
>
|
||||
{{ loading ? '登录中...' : '登录' }}
|
||||
</t-button>
|
||||
</t-form>
|
||||
|
||||
<!-- 注册链接 -->
|
||||
<div class="register-link">
|
||||
<span>还没有账号?</span>
|
||||
<a href="#" @click.prevent="toggleMode" class="register-btn">
|
||||
立即注册
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 注册表单 -->
|
||||
<div class="register-card" v-if="isRegisterMode">
|
||||
<div class="login-header">
|
||||
<h1 class="login-title">创建账号</h1>
|
||||
<p class="login-subtitle">注册后系统将为您创建专属租户</p>
|
||||
</div>
|
||||
|
||||
<div class="login-form">
|
||||
<t-form
|
||||
ref="registerFormRef"
|
||||
:data="registerData"
|
||||
:rules="registerRules"
|
||||
@submit="handleRegister"
|
||||
layout="vertical"
|
||||
>
|
||||
<t-form-item label="用户名" name="username">
|
||||
<t-input
|
||||
v-model="registerData.username"
|
||||
placeholder="请输入用户名"
|
||||
size="large"
|
||||
:disabled="loading"
|
||||
/>
|
||||
</t-form-item>
|
||||
|
||||
<t-form-item label="邮箱" name="email">
|
||||
<t-input
|
||||
v-model="registerData.email"
|
||||
placeholder="请输入邮箱地址"
|
||||
type="email"
|
||||
size="large"
|
||||
:disabled="loading"
|
||||
/>
|
||||
</t-form-item>
|
||||
|
||||
<t-form-item label="密码" name="password">
|
||||
<t-input
|
||||
v-model="registerData.password"
|
||||
placeholder="请输入密码(8-32位,包含字母和数字)"
|
||||
type="password"
|
||||
size="large"
|
||||
:disabled="loading"
|
||||
/>
|
||||
</t-form-item>
|
||||
|
||||
<t-form-item label="确认密码" name="confirmPassword">
|
||||
<t-input
|
||||
v-model="registerData.confirmPassword"
|
||||
placeholder="请再次输入密码"
|
||||
type="password"
|
||||
size="large"
|
||||
:disabled="loading"
|
||||
@keydown.enter="handleRegister"
|
||||
/>
|
||||
</t-form-item>
|
||||
|
||||
<t-button
|
||||
type="submit"
|
||||
theme="primary"
|
||||
size="large"
|
||||
block
|
||||
:loading="loading"
|
||||
class="login-button"
|
||||
>
|
||||
{{ loading ? '注册中...' : '注册' }}
|
||||
</t-button>
|
||||
</t-form>
|
||||
|
||||
<!-- 返回登录 -->
|
||||
<div class="register-link">
|
||||
<span>已有账号?</span>
|
||||
<a href="#" @click.prevent="toggleMode" class="register-btn">
|
||||
返回登录
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, reactive, computed, nextTick, onMounted } from 'vue'
|
||||
import { useRouter } from 'vue-router'
|
||||
import { MessagePlugin } from 'tdesign-vue-next'
|
||||
import { login, register } from '@/api/auth'
|
||||
import { useAuthStore } from '@/stores/auth'
|
||||
|
||||
const router = useRouter()
|
||||
const authStore = useAuthStore()
|
||||
|
||||
// 表单引用
|
||||
const formRef = ref()
|
||||
const registerFormRef = ref()
|
||||
|
||||
// 状态管理
|
||||
const loading = ref(false)
|
||||
const isRegisterMode = ref(false)
|
||||
|
||||
|
||||
// 登录表单数据
|
||||
const formData = reactive<{[key: string]: any}>({
|
||||
email: '',
|
||||
password: '',
|
||||
})
|
||||
|
||||
// 注册表单数据
|
||||
const registerData = reactive<{[key: string]: any}>({
|
||||
username: '',
|
||||
email: '',
|
||||
password: '',
|
||||
confirmPassword: ''
|
||||
})
|
||||
|
||||
// 登录表单验证规则
|
||||
const formRules = {
|
||||
email: [
|
||||
{ required: true, message: '请输入邮箱地址', type: 'error' },
|
||||
{ email: true, message: '请输入正确的邮箱格式', type: 'error' }
|
||||
],
|
||||
password: [
|
||||
{ required: true, message: '请输入密码', type: 'error' },
|
||||
{ min: 8, message: '密码至少8位', type: 'error' },
|
||||
{ max: 32, message: '密码不能超过32位', type: 'error' },
|
||||
{ pattern: /[a-zA-Z]/, message: '密码必须包含字母', type: 'error' },
|
||||
{ pattern: /\d/, message: '密码必须包含数字', type: 'error' }
|
||||
]
|
||||
}
|
||||
|
||||
// 注册表单验证规则
|
||||
const registerRules = {
|
||||
username: [
|
||||
{ required: true, message: '请输入用户名', type: 'error' },
|
||||
{ min: 2, message: '用户名至少2位', type: 'error' },
|
||||
{ max: 20, message: '用户名不能超过20位', type: 'error' },
|
||||
{
|
||||
pattern: /^[a-zA-Z0-9_\u4e00-\u9fa5]+$/,
|
||||
message: '用户名只能包含字母、数字、下划线和中文',
|
||||
type: 'error'
|
||||
}
|
||||
],
|
||||
email: [
|
||||
{ required: true, message: '请输入邮箱地址', type: 'error' },
|
||||
{ email: true, message: '请输入正确的邮箱格式', type: 'error' }
|
||||
],
|
||||
password: [
|
||||
{ required: true, message: '请输入密码', type: 'error' },
|
||||
{ min: 8, message: '密码至少8位', type: 'error' },
|
||||
{ max: 32, message: '密码不能超过32位', type: 'error' },
|
||||
{ pattern: /[a-zA-Z]/, message: '密码必须包含字母', type: 'error' },
|
||||
{ pattern: /\d/, message: '密码必须包含数字', type: 'error' }
|
||||
],
|
||||
confirmPassword: [
|
||||
{ required: true, message: '请确认密码', type: 'error' },
|
||||
{
|
||||
validator: (val: string) => val === registerData.password,
|
||||
message: '两次输入的密码不一致',
|
||||
type: 'error'
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
// 切换登录/注册模式
|
||||
const toggleMode = () => {
|
||||
isRegisterMode.value = !isRegisterMode.value
|
||||
|
||||
Object.keys(registerData).forEach(key => {
|
||||
(registerData as any)[key] = ''
|
||||
})
|
||||
}
|
||||
|
||||
// 处理登录
|
||||
const handleLogin = async () => {
|
||||
try {
|
||||
const valid = await formRef.value?.validate()
|
||||
if (!valid) return
|
||||
|
||||
loading.value = true
|
||||
|
||||
const response = await login({
|
||||
email: formData.email,
|
||||
password: formData.password,
|
||||
})
|
||||
|
||||
if (response.success) {
|
||||
// 保存用户信息和token
|
||||
if (response.user && response.tenant && response.token) {
|
||||
authStore.setUser({
|
||||
id: response.user.id || '',
|
||||
username: response.user.username || '',
|
||||
email: response.user.email || '',
|
||||
avatar: response.user.avatar,
|
||||
tenant_id: String(response.tenant.id) || '',
|
||||
created_at: response.user.created_at || new Date().toISOString(),
|
||||
updated_at: response.user.updated_at || new Date().toISOString()
|
||||
})
|
||||
authStore.setToken(response.token)
|
||||
if (response.refresh_token) {
|
||||
authStore.setRefreshToken(response.refresh_token)
|
||||
}
|
||||
authStore.setTenant({
|
||||
id: String(response.tenant.id) || '',
|
||||
name: response.tenant.name || '',
|
||||
api_key: response.tenant.api_key || '',
|
||||
owner_id: response.user.id || '',
|
||||
created_at: response.tenant.created_at || new Date().toISOString(),
|
||||
updated_at: response.tenant.updated_at || new Date().toISOString()
|
||||
})
|
||||
}
|
||||
|
||||
MessagePlugin.success('登录成功!')
|
||||
|
||||
|
||||
// 等待状态更新完成后再跳转
|
||||
await nextTick()
|
||||
router.replace('/platform/knowledge-bases')
|
||||
} else {
|
||||
MessagePlugin.error(response.message || '登录失败,请检查邮箱或密码')
|
||||
}
|
||||
} catch (error: any) {
|
||||
console.error('登录错误:', error)
|
||||
MessagePlugin.error(error.message || '登录失败,请稍后重试')
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// 处理注册
|
||||
const handleRegister = async () => {
|
||||
try {
|
||||
const valid = await registerFormRef.value?.validate()
|
||||
if (!valid) return
|
||||
|
||||
loading.value = true
|
||||
|
||||
const response = await register({
|
||||
username: registerData.username,
|
||||
email: registerData.email,
|
||||
password: registerData.password
|
||||
})
|
||||
|
||||
if (response.success) {
|
||||
MessagePlugin.success('注册成功!系统已为您创建专属租户,请登录使用')
|
||||
|
||||
// 切换到登录模式并填入邮箱
|
||||
isRegisterMode.value = false
|
||||
formData.email = registerData.email
|
||||
|
||||
// 清空注册表单
|
||||
Object.keys(registerData).forEach(key => {
|
||||
(registerData as any)[key] = ''
|
||||
})
|
||||
} else {
|
||||
MessagePlugin.error(response.message || '注册失败')
|
||||
}
|
||||
} catch (error: any) {
|
||||
console.error('注册错误:', error)
|
||||
MessagePlugin.error(error.message || '注册失败,请稍后重试')
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// 处理忘记密码
|
||||
const handleForgotPassword = () => {
|
||||
MessagePlugin.info('忘记密码功能暂未开放,请联系管理员')
|
||||
}
|
||||
|
||||
// 检查是否已登录
|
||||
onMounted(() => {
|
||||
if (authStore.isLoggedIn) {
|
||||
router.replace('/platform/tenant/knowledge-bases')
|
||||
}
|
||||
})
|
||||
</script>
|
||||
|
||||
<style lang="less" scoped>
|
||||
.login-container {
|
||||
min-height: 100vh;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
|
||||
padding: 20px;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
.login-card,
|
||||
.register-card {
|
||||
width: 100%;
|
||||
max-width: 440px;
|
||||
background: #fff;
|
||||
border-radius: 14px;
|
||||
box-shadow: 0 10px 16px 0 #0000000f, 0 20px 24px -2px #0000001a;
|
||||
padding: 40px;
|
||||
box-sizing: border-box;
|
||||
animation: fadeInUp .28s ease-out both;
|
||||
}
|
||||
|
||||
.login-header {
|
||||
text-align: center;
|
||||
margin-bottom: 32px;
|
||||
|
||||
.logo {
|
||||
margin-bottom: 16px;
|
||||
|
||||
.logo-img {
|
||||
width: 180px;
|
||||
height: auto;
|
||||
border-radius: 12px;
|
||||
}
|
||||
}
|
||||
|
||||
.login-title {
|
||||
font-size: 28px;
|
||||
font-weight: 600;
|
||||
color: #000000e6;
|
||||
margin: 0 0 8px 0;
|
||||
font-family: "PingFang SC";
|
||||
}
|
||||
|
||||
.login-subtitle {
|
||||
font-size: 16px;
|
||||
color: #0000008c;
|
||||
margin: 0;
|
||||
font-family: "PingFang SC";
|
||||
}
|
||||
}
|
||||
|
||||
.login-form {
|
||||
:deep(.t-form-item__label) {
|
||||
font-size: 14px;
|
||||
color: #000000e6;
|
||||
font-weight: 500;
|
||||
margin-bottom: 8px;
|
||||
font-family: "PingFang SC";
|
||||
display: block;
|
||||
text-align: left;
|
||||
}
|
||||
|
||||
:deep(.t-input) {
|
||||
border: 1px solid #E7E7E7;
|
||||
border-radius: 8px;
|
||||
background: #fff;
|
||||
|
||||
&:focus-within {
|
||||
border-color: #07C05F;
|
||||
box-shadow: 0 0 0 2px rgba(7, 192, 95, 0.1);
|
||||
}
|
||||
|
||||
&:hover {
|
||||
border-color: #07C05F;
|
||||
}
|
||||
|
||||
.t-input__inner {
|
||||
border: none !important;
|
||||
box-shadow: none !important;
|
||||
outline: none !important;
|
||||
background: transparent;
|
||||
font-size: 16px;
|
||||
font-family: "PingFang SC";
|
||||
|
||||
&:focus {
|
||||
border: none !important;
|
||||
box-shadow: none !important;
|
||||
outline: none !important;
|
||||
}
|
||||
}
|
||||
|
||||
.t-input__wrap {
|
||||
border: none !important;
|
||||
box-shadow: none !important;
|
||||
}
|
||||
}
|
||||
|
||||
:deep(.t-form-item) {
|
||||
margin-bottom: 20px;
|
||||
|
||||
&:last-child {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
}
|
||||
|
||||
:deep(.t-form-item__control) {
|
||||
width: 100%;
|
||||
}
|
||||
}
|
||||
|
||||
.login-options {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
margin: 16px 0 24px 0;
|
||||
width: 100%;
|
||||
|
||||
:deep(.t-checkbox) {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
|
||||
.t-checkbox__input {
|
||||
margin-right: 8px;
|
||||
}
|
||||
}
|
||||
|
||||
:deep(.t-checkbox__label) {
|
||||
font-size: 14px;
|
||||
color: #00000099;
|
||||
font-family: "PingFang SC";
|
||||
line-height: 1.4;
|
||||
margin-left: 0;
|
||||
}
|
||||
|
||||
.forgot-password {
|
||||
font-size: 14px;
|
||||
color: #07C05F;
|
||||
text-decoration: none;
|
||||
font-family: "PingFang SC";
|
||||
line-height: 1.4;
|
||||
|
||||
&:hover {
|
||||
text-decoration: underline;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
.login-button {
|
||||
height: 48px;
|
||||
border-radius: 8px;
|
||||
font-size: 16px;
|
||||
font-weight: 500;
|
||||
font-family: "PingFang SC";
|
||||
margin: 16px 0 8px 0;
|
||||
|
||||
:deep(.t-button) {
|
||||
background-color: #07C05F;
|
||||
border-color: #07C05F;
|
||||
|
||||
&:hover {
|
||||
background-color: #06a855;
|
||||
border-color: #06a855;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
.register-link {
|
||||
text-align: center;
|
||||
font-size: 14px;
|
||||
color: #00000099;
|
||||
font-family: "PingFang SC";
|
||||
|
||||
.register-btn {
|
||||
color: #07C05F;
|
||||
text-decoration: none;
|
||||
margin-left: 4px;
|
||||
|
||||
&:hover {
|
||||
text-decoration: underline;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 响应式设计
|
||||
@media (max-width: 480px) {
|
||||
.login-container {
|
||||
padding: 16px;
|
||||
}
|
||||
|
||||
.login-card,
|
||||
.register-card {
|
||||
padding: 28px;
|
||||
}
|
||||
|
||||
.login-header {
|
||||
margin-bottom: 24px;
|
||||
|
||||
.login-title {
|
||||
font-size: 24px;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@keyframes fadeInUp {
|
||||
from {
|
||||
opacity: 0;
|
||||
transform: translate3d(0, 6px, 0);
|
||||
}
|
||||
to {
|
||||
opacity: 1;
|
||||
transform: translate3d(0, 0, 0);
|
||||
}
|
||||
}
|
||||
</style>
|
||||
@@ -23,6 +23,8 @@ import { marked } from 'marked';
|
||||
import docInfo from './docInfo.vue';
|
||||
import deepThink from './deepThink.vue';
|
||||
import picturePreview from '@/components/picture-preview.vue';
|
||||
import { sanitizeHTML, safeMarkdownToHTML, createSafeImage, isValidImageURL } from '@/utils/security';
|
||||
|
||||
marked.use({
|
||||
mangle: false,
|
||||
headerIds: false,
|
||||
@@ -89,36 +91,36 @@ const checkImage = (url) => {
|
||||
img.src = url;
|
||||
});
|
||||
};
|
||||
// 处理 Markdown 中的图片
|
||||
// 安全地处理 Markdown 内容
|
||||
const processMarkdown = (markdownText) => {
|
||||
// 自定义渲染器处理图片
|
||||
if (!markdownText || typeof markdownText !== 'string') {
|
||||
return '';
|
||||
}
|
||||
|
||||
// 首先对 Markdown 内容进行安全处理
|
||||
const safeMarkdown = safeMarkdownToHTML(markdownText);
|
||||
|
||||
// 自定义安全的渲染器处理图片
|
||||
const renderer = {
|
||||
image(href, title, text) {
|
||||
return `<img src="${href}" alt="${text}" title="${title || ''}" class="markdown-image" style="max-width: 708px;height: 230px;">`;
|
||||
// 验证图片 URL 是否安全
|
||||
if (!isValidImageURL(href)) {
|
||||
return `<p>无效的图片链接</p>`;
|
||||
}
|
||||
// 使用安全的图片创建函数
|
||||
return createSafeImage(href, text || '', title || '');
|
||||
}
|
||||
};
|
||||
|
||||
marked.use({ renderer });
|
||||
|
||||
// 第一次渲染
|
||||
let html = marked.parse(markdownText);
|
||||
// 安全地渲染 Markdown
|
||||
let html = marked.parse(safeMarkdown);
|
||||
|
||||
// 创建虚拟 DOM 来操作
|
||||
const parser = new DOMParser();
|
||||
const doc = parser.parseFromString(html, 'text/html');
|
||||
|
||||
// 检查所有图片
|
||||
// const images = doc.querySelectorAll('img');
|
||||
// images.forEach(async item => {
|
||||
// const isValid = await checkImage(item.src);
|
||||
// if (!isValid) {
|
||||
// item.remove();
|
||||
// }
|
||||
// });
|
||||
// if (props.isFirstEnter) {
|
||||
// emit('scroll-bottom')
|
||||
// }
|
||||
return doc.body.innerHTML;
|
||||
// 使用 DOMPurify 进行最终的安全清理
|
||||
const sanitizedHTML = sanitizeHTML(html);
|
||||
|
||||
return sanitizedHTML;
|
||||
};
|
||||
const handleImg = async (newVal) => {
|
||||
let index = newVal.lastIndexOf('![');
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
</div>
|
||||
</template>
|
||||
<div class="content">
|
||||
<span v-html="deepSession.thinkContent.replace(/\n/g, '<br/>')"></span>
|
||||
<span v-html="safeProcessThinkContent(deepSession.thinkContent)"></span>
|
||||
</div>
|
||||
</t-collapse-panel>
|
||||
|
||||
@@ -29,6 +29,7 @@
|
||||
</template>
|
||||
<script setup>
|
||||
import { onMounted, watch, computed, ref, reactive, defineProps } from 'vue';
|
||||
import { sanitizeHTML } from '@/utils/security';
|
||||
|
||||
const isFold = ref(true)
|
||||
const props = defineProps({
|
||||
@@ -51,8 +52,20 @@ const showHide = () => {
|
||||
}
|
||||
const handlePanelChange = (val) => {
|
||||
isFold.value = !val.length ? true : false;
|
||||
|
||||
}
|
||||
|
||||
// 安全地处理思考内容,防止XSS攻击
|
||||
const safeProcessThinkContent = (content) => {
|
||||
if (!content || typeof content !== 'string') return '';
|
||||
|
||||
// 先处理换行符
|
||||
const contentWithBreaks = content.replace(/\n/g, '<br/>');
|
||||
|
||||
// 使用DOMPurify进行安全清理,允许基本的文本格式化标签
|
||||
const cleanContent = sanitizeHTML(contentWithBreaks);
|
||||
|
||||
return cleanContent;
|
||||
};
|
||||
</script>
|
||||
<style lang="less" scoped>
|
||||
.deep-think {
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
trigger="click">
|
||||
<template #content>
|
||||
<div class="doc_content">
|
||||
<div v-html="item.content.replace(/\n/g, '<br/>')"></div>
|
||||
<div v-html="safeProcessContent(item.content)"></div>
|
||||
</div>
|
||||
</template>
|
||||
<span class="doc">
|
||||
@@ -28,6 +28,7 @@
|
||||
</template>
|
||||
<script setup>
|
||||
import { onMounted, defineProps, computed, ref, reactive } from "vue";
|
||||
import { sanitizeHTML } from '@/utils/security';
|
||||
const props = defineProps({
|
||||
// 必填项
|
||||
content: {
|
||||
@@ -44,6 +45,14 @@ const referBoxSwitch = () => {
|
||||
showReferBox.value = !showReferBox.value;
|
||||
};
|
||||
|
||||
// 安全地处理内容
|
||||
const safeProcessContent = (content) => {
|
||||
if (!content) return '';
|
||||
// 先进行安全清理,然后处理换行
|
||||
const sanitized = sanitizeHTML(content);
|
||||
return sanitized.replace(/\n/g, '<br/>');
|
||||
};
|
||||
|
||||
</script>
|
||||
<style lang="less" scoped>
|
||||
.refer {
|
||||
|
||||
@@ -38,6 +38,7 @@ const { output, onChunk, isStreaming, isLoading, error, startStream, stopStream
|
||||
const route = useRoute();
|
||||
const router = useRouter();
|
||||
const session_id = ref(route.params.chatid);
|
||||
const knowledge_base_id = ref(route.params.kbId);
|
||||
const created_at = ref('');
|
||||
const limit = ref(20);
|
||||
const messagesList = reactive([]);
|
||||
@@ -57,6 +58,7 @@ watch([() => route.params], (newvalue) => {
|
||||
}
|
||||
messagesList.splice(0);
|
||||
session_id.value = newvalue[0].chatid;
|
||||
knowledge_base_id.value = newvalue[0].kbId;
|
||||
checkmenuTitle(session_id.value)
|
||||
let data = {
|
||||
session_id: session_id.value,
|
||||
@@ -154,7 +156,14 @@ const sendMsg = async (value) => {
|
||||
loading.value = true;
|
||||
messagesList.push({ content: value, role: 'user' });
|
||||
scrollToBottom();
|
||||
await startStream({ session_id: session_id.value, query: value, method: 'POST', url: '/api/v1/knowledge-chat' });
|
||||
|
||||
await startStream({
|
||||
session_id: session_id.value,
|
||||
knowledge_base_id: knowledge_base_id.value,
|
||||
query: value,
|
||||
method: 'POST',
|
||||
url: '/api/v1/knowledge-chat'
|
||||
});
|
||||
}
|
||||
|
||||
// 处理流式数据
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
<template>
|
||||
<div v-show="cardList.length" class="dialogue-wrap">
|
||||
<div class="dialogue-wrap">
|
||||
<div class="dialogue-answers">
|
||||
<div class="dialogue-title">
|
||||
<span>基于知识库内容问答</span>
|
||||
@@ -7,7 +7,17 @@
|
||||
<InputField @send-msg="sendMsg"></InputField>
|
||||
</div>
|
||||
</div>
|
||||
<EmptyKnowledge v-show="!cardList.length"></EmptyKnowledge>
|
||||
|
||||
|
||||
<t-dialog v-model:visible="selectVisible" header="选择知识库" :confirmBtn="{ content: '开始对话', theme: 'primary' }" :onConfirm="confirmSelect" :onCancel="() => selectVisible = false">
|
||||
<t-form :data="{ kb: selectedKbId }">
|
||||
<t-form-item label="知识库">
|
||||
<t-select v-model="selectedKbId" :loading="kbLoading" placeholder="请选择知识库">
|
||||
<t-option v-for="kb in kbList" :key="kb.id" :value="kb.id" :label="kb.name" />
|
||||
</t-select>
|
||||
</t-form-item>
|
||||
</t-form>
|
||||
</t-dialog>
|
||||
</template>
|
||||
<script setup lang="ts">
|
||||
import { ref, onUnmounted, watch } from 'vue';
|
||||
@@ -17,55 +27,52 @@ import { getSessionsList, createSessions, generateSessionsTitle } from "@/api/ch
|
||||
import { useMenuStore } from '@/stores/menu';
|
||||
import { useRoute, useRouter } from 'vue-router';
|
||||
import useKnowledgeBase from '@/hooks/useKnowledgeBase';
|
||||
import { getTestData } from '@/utils/request';
|
||||
import { listKnowledgeBases } from '@/api/knowledge-base';
|
||||
|
||||
let { cardList } = useKnowledgeBase()
|
||||
const router = useRouter();
|
||||
const route = useRoute();
|
||||
const usemenuStore = useMenuStore();
|
||||
const sendMsg = (value: string) => {
|
||||
createNewSession(value);
|
||||
}
|
||||
|
||||
const selectVisible = ref(false)
|
||||
const selectedKbId = ref<string>('')
|
||||
const kbList = ref<Array<{ id: string; name: string }>>([])
|
||||
const kbLoading = ref(false)
|
||||
|
||||
const ensureKbId = async (): Promise<string | null> => {
|
||||
// 1) 优先使用当前路由上下文(如果来自某个知识库详情页)
|
||||
const routeKb = (route.params as any)?.kbId as string
|
||||
if (routeKb) return routeKb
|
||||
|
||||
|
||||
// 3) 弹窗选择知识库(从接口拉取)
|
||||
kbLoading.value = true
|
||||
try {
|
||||
const res: any = await listKnowledgeBases()
|
||||
kbList.value = res?.data || []
|
||||
if (kbList.value.length === 0) return null
|
||||
selectedKbId.value = kbList.value[0].id
|
||||
selectVisible.value = true
|
||||
return null
|
||||
} finally {
|
||||
kbLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function createNewSession(value: string) {
|
||||
// 从localStorage获取设置中的知识库ID
|
||||
const settingsStr = localStorage.getItem("WeKnora_settings");
|
||||
let knowledgeBaseId = "";
|
||||
|
||||
if (settingsStr) {
|
||||
try {
|
||||
const settings = JSON.parse(settingsStr);
|
||||
if (settings.knowledgeBaseId) {
|
||||
knowledgeBaseId = settings.knowledgeBaseId;
|
||||
createSessions({ knowledge_base_id: knowledgeBaseId }).then(res => {
|
||||
if (res.data && res.data.id) {
|
||||
getTitle(res.data.id, value);
|
||||
} else {
|
||||
// 错误处理
|
||||
console.error("创建会话失败");
|
||||
}
|
||||
}).catch(error => {
|
||||
console.error("创建会话出错:", error);
|
||||
});
|
||||
return;
|
||||
}
|
||||
} catch (e) {
|
||||
console.error("解析设置失败:", e);
|
||||
}
|
||||
}
|
||||
|
||||
// 如果设置中没有知识库ID,则使用测试数据
|
||||
const testData = getTestData();
|
||||
if (!testData || testData.knowledge_bases.length === 0) {
|
||||
console.error("测试数据未初始化或不包含知识库");
|
||||
return;
|
||||
let knowledgeBaseId = await ensureKbId()
|
||||
if (!knowledgeBaseId) {
|
||||
// 等待用户在弹窗中选择
|
||||
pendingValue.value = value
|
||||
return
|
||||
}
|
||||
|
||||
// 使用第一个知识库ID
|
||||
knowledgeBaseId = testData.knowledge_bases[0].id;
|
||||
|
||||
createSessions({ knowledge_base_id: knowledgeBaseId }).then(res => {
|
||||
createSessions({ knowledge_base_id: knowledgeBaseId }).then(async res => {
|
||||
if (res.data && res.data.id) {
|
||||
getTitle(res.data.id, value)
|
||||
await getTitle(res.data.id, value)
|
||||
} else {
|
||||
// 错误处理
|
||||
console.error("创建会话失败");
|
||||
@@ -75,12 +82,33 @@ async function createNewSession(value: string) {
|
||||
})
|
||||
}
|
||||
|
||||
const getTitle = (session_id: string, value: string) => {
|
||||
let obj = { title: '新会话', path: `chat/${session_id}`, id: session_id, isMore: false, isNoTitle: true }
|
||||
const pendingValue = ref<string>('')
|
||||
const confirmSelect = async () => {
|
||||
if (!selectedKbId.value) return
|
||||
const value = pendingValue.value
|
||||
pendingValue.value = ''
|
||||
selectVisible.value = false
|
||||
createSessions({ knowledge_base_id: selectedKbId.value }).then(async res => {
|
||||
if (res.data && res.data.id) {
|
||||
await getTitle(res.data.id, value, selectedKbId.value)
|
||||
} else {
|
||||
console.error('创建会话失败')
|
||||
}
|
||||
}).catch((e:any) => console.error('创建会话出错:', e))
|
||||
}
|
||||
|
||||
const getTitle = async (session_id: string, value: string, kbId?: string) => {
|
||||
const finalKbId = kbId || await ensureKbId();
|
||||
if (!finalKbId) {
|
||||
console.error('无法获取知识库ID');
|
||||
return;
|
||||
}
|
||||
|
||||
let obj = { title: '新会话', path: `chat/${finalKbId}/${session_id}`, id: session_id, isMore: false, isNoTitle: true }
|
||||
usemenuStore.updataMenuChildren(obj);
|
||||
usemenuStore.changeIsFirstSession(true);
|
||||
usemenuStore.changeFirstQuery(value);
|
||||
router.push(`/platform/chat/${session_id}`);
|
||||
router.push(`/platform/chat/${finalKbId}/${session_id}`);
|
||||
}
|
||||
|
||||
</script>
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,5 @@
|
||||
<script setup lang="ts">
|
||||
import { ref, onMounted, watch, reactive } from "vue";
|
||||
import { ref, onMounted, onUnmounted, watch, reactive, computed } from "vue";
|
||||
import DocContent from "@/components/doc-content.vue";
|
||||
import InputField from "@/components/Input-field.vue";
|
||||
import useKnowledgeBase from '@/hooks/useKnowledgeBase';
|
||||
@@ -7,17 +7,21 @@ import { useRoute, useRouter } from 'vue-router';
|
||||
import EmptyKnowledge from '@/components/empty-knowledge.vue';
|
||||
import { getSessionsList, createSessions, generateSessionsTitle } from "@/api/chat/index";
|
||||
import { useMenuStore } from '@/stores/menu';
|
||||
import { getTestData } from '@/utils/request';
|
||||
import { MessagePlugin } from 'tdesign-vue-next';
|
||||
const usemenuStore = useMenuStore();
|
||||
const router = useRouter();
|
||||
import {
|
||||
batchQueryKnowledge,
|
||||
listKnowledgeFiles,
|
||||
} from "@/api/knowledge-base/index";
|
||||
let { cardList, total, moreIndex, details, getKnowled, delKnowledge, openMore, onVisibleChange, getCardDetails, getfDetails } = useKnowledgeBase()
|
||||
import { formatStringDate } from "@/utils/index";
|
||||
const route = useRoute();
|
||||
const kbId = computed(() => (route.params as any).kbId as string || '');
|
||||
let { cardList, total, moreIndex, details, getKnowled, delKnowledge, openMore, onVisibleChange, getCardDetails, getfDetails } = useKnowledgeBase(kbId.value)
|
||||
let isCardDetails = ref(false);
|
||||
let timeout = null;
|
||||
let timeout: ReturnType<typeof setInterval> | null = null;
|
||||
let delDialog = ref(false)
|
||||
let knowledge = ref({})
|
||||
let knowledge = ref<KnowledgeCard>({ id: '', parse_status: '' })
|
||||
let knowledgeIndex = ref(-1)
|
||||
let knowledgeScroll = ref()
|
||||
let page = 1;
|
||||
@@ -29,29 +33,93 @@ const getPageSize = () => {
|
||||
pageSize = Math.max(35, itemsInView);
|
||||
}
|
||||
getPageSize()
|
||||
// 直接调用 API 获取知识库文件列表
|
||||
const loadKnowledgeFiles = async (kbIdValue: string) => {
|
||||
if (!kbIdValue) return;
|
||||
|
||||
try {
|
||||
const result = await listKnowledgeFiles(kbIdValue, { page: 1, page_size: pageSize });
|
||||
|
||||
// 由于响应拦截器已经返回了 data,所以 result 就是响应的 data 部分
|
||||
// 按照 useKnowledgeBase hook 中的方式处理
|
||||
const { data, total: totalResult } = result as any;
|
||||
|
||||
if (!data || !Array.isArray(data)) {
|
||||
console.error('Invalid data format. Expected array, got:', typeof data, data);
|
||||
return;
|
||||
}
|
||||
|
||||
const cardList_ = data.map((item: any) => {
|
||||
item["file_name"] = item.file_name.substring(
|
||||
0,
|
||||
item.file_name.lastIndexOf(".")
|
||||
);
|
||||
return {
|
||||
...item,
|
||||
updated_at: formatStringDate(new Date(item.updated_at)),
|
||||
isMore: false,
|
||||
file_type: item.file_type.toLocaleUpperCase(),
|
||||
};
|
||||
});
|
||||
|
||||
cardList.value = cardList_ as any[];
|
||||
total.value = totalResult;
|
||||
} catch (err) {
|
||||
console.error('Failed to load knowledge files:', err);
|
||||
}
|
||||
};
|
||||
|
||||
// 监听路由参数变化,重新获取知识库内容
|
||||
watch(() => kbId.value, (newKbId, oldKbId) => {
|
||||
if (newKbId && newKbId !== oldKbId) {
|
||||
loadKnowledgeFiles(newKbId);
|
||||
}
|
||||
}, { immediate: false });
|
||||
|
||||
// 监听文件上传事件
|
||||
const handleFileUploaded = (event: CustomEvent) => {
|
||||
const uploadedKbId = event.detail.kbId;
|
||||
console.log('接收到文件上传事件,上传的知识库ID:', uploadedKbId, '当前知识库ID:', kbId.value);
|
||||
if (uploadedKbId && uploadedKbId === kbId.value) {
|
||||
console.log('匹配当前知识库,开始刷新文件列表');
|
||||
// 如果上传的文件属于当前知识库,使用 loadKnowledgeFiles 刷新文件列表
|
||||
loadKnowledgeFiles(uploadedKbId);
|
||||
}
|
||||
};
|
||||
|
||||
onMounted(() => {
|
||||
getKnowled({ page: 1, page_size: pageSize });
|
||||
|
||||
// 监听文件上传事件
|
||||
window.addEventListener('knowledgeFileUploaded', handleFileUploaded as EventListener);
|
||||
});
|
||||
|
||||
onUnmounted(() => {
|
||||
window.removeEventListener('knowledgeFileUploaded', handleFileUploaded as EventListener);
|
||||
});
|
||||
watch(() => cardList.value, (newValue) => {
|
||||
let analyzeList = [];
|
||||
analyzeList = newValue.filter(item => {
|
||||
return item.parse_status == 'pending' || item.parse_status == 'processing';
|
||||
})
|
||||
clearInterval(timeout);
|
||||
timeout = null;
|
||||
if (timeout !== null) {
|
||||
clearInterval(timeout);
|
||||
timeout = null;
|
||||
}
|
||||
if (analyzeList.length) {
|
||||
updateStatus(analyzeList)
|
||||
}
|
||||
}, { deep: true })
|
||||
const updateStatus = (analyzeList) => {
|
||||
type KnowledgeCard = { id: string; parse_status: string; description?: string; file_name?: string; updated_at?: string; file_type?: string; isMore?: boolean };
|
||||
const updateStatus = (analyzeList: KnowledgeCard[]) => {
|
||||
let query = ``;
|
||||
for (let i = 0; i < analyzeList.length; i++) {
|
||||
query += `ids=${analyzeList[i].id}&`;
|
||||
}
|
||||
timeout = setInterval(() => {
|
||||
batchQueryKnowledge(query).then((result) => {
|
||||
batchQueryKnowledge(query).then((result: any) => {
|
||||
if (result.success && result.data) {
|
||||
result.data.forEach(item => {
|
||||
(result.data as KnowledgeCard[]).forEach((item: KnowledgeCard) => {
|
||||
if (item.parse_status == 'failed' || item.parse_status == 'completed') {
|
||||
let index = cardList.value.findIndex(card => card.id == item.id);
|
||||
if (index != -1) {
|
||||
@@ -70,12 +138,12 @@ const updateStatus = (analyzeList) => {
|
||||
const closeDoc = () => {
|
||||
isCardDetails.value = false;
|
||||
};
|
||||
const openCardDetails = (item) => {
|
||||
const openCardDetails = (item: KnowledgeCard) => {
|
||||
isCardDetails.value = true;
|
||||
getCardDetails(item);
|
||||
};
|
||||
|
||||
const delCard = (index, item) => {
|
||||
const delCard = (index: number, item: KnowledgeCard) => {
|
||||
knowledgeIndex.value = index;
|
||||
knowledge.value = item;
|
||||
delDialog.value = true;
|
||||
@@ -94,7 +162,7 @@ const handleScroll = () => {
|
||||
}
|
||||
}
|
||||
};
|
||||
const getDoc = (page) => {
|
||||
const getDoc = (page: number) => {
|
||||
getfDetails(details.id, page)
|
||||
};
|
||||
|
||||
@@ -108,51 +176,36 @@ const sendMsg = (value: string) => {
|
||||
};
|
||||
|
||||
const getTitle = (session_id: string, value: string) => {
|
||||
let obj = { title: '新会话', path: `chat/${session_id}`, id: session_id, isMore: false, isNoTitle: true };
|
||||
let obj = { title: '新会话', path: `chat/${kbId.value}/${session_id}`, id: session_id, isMore: false, isNoTitle: true };
|
||||
usemenuStore.updataMenuChildren(obj);
|
||||
usemenuStore.changeIsFirstSession(true);
|
||||
usemenuStore.changeFirstQuery(value);
|
||||
router.push(`/platform/chat/${session_id}`);
|
||||
router.push(`/platform/chat/${kbId.value}/${session_id}`);
|
||||
};
|
||||
|
||||
async function createNewSession(value: string): Promise<void> {
|
||||
// 从localStorage获取设置中的知识库ID
|
||||
const settingsStr = localStorage.getItem("WeKnora_settings");
|
||||
let knowledgeBaseId = "";
|
||||
// 优先使用当前页面的知识库ID
|
||||
let sessionKbId = kbId.value;
|
||||
|
||||
if (settingsStr) {
|
||||
try {
|
||||
const settings = JSON.parse(settingsStr);
|
||||
if (settings.knowledgeBaseId) {
|
||||
knowledgeBaseId = settings.knowledgeBaseId;
|
||||
createSessions({ knowledge_base_id: knowledgeBaseId }).then(res => {
|
||||
if (res.data && res.data.id) {
|
||||
getTitle(res.data.id, value);
|
||||
} else {
|
||||
// 错误处理
|
||||
console.error("创建会话失败");
|
||||
}
|
||||
}).catch(error => {
|
||||
console.error("创建会话出错:", error);
|
||||
});
|
||||
return;
|
||||
// 如果当前页面没有知识库ID,尝试从localStorage获取设置中的知识库ID
|
||||
if (!sessionKbId) {
|
||||
const settingsStr = localStorage.getItem("WeKnora_settings");
|
||||
if (settingsStr) {
|
||||
try {
|
||||
const settings = JSON.parse(settingsStr);
|
||||
sessionKbId = settings.knowledgeBaseId;
|
||||
} catch (e) {
|
||||
console.error("解析设置失败:", e);
|
||||
}
|
||||
} catch (e) {
|
||||
console.error("解析设置失败:", e);
|
||||
}
|
||||
}
|
||||
|
||||
// 如果设置中没有知识库ID,则使用测试数据
|
||||
const testData = getTestData();
|
||||
if (!testData || !testData.knowledge_bases || testData.knowledge_bases.length === 0) {
|
||||
console.error("测试数据未初始化或不包含知识库");
|
||||
if (!sessionKbId) {
|
||||
MessagePlugin.warning("请先选择一个知识库");
|
||||
return;
|
||||
}
|
||||
|
||||
// 使用第一个知识库ID
|
||||
knowledgeBaseId = testData.knowledge_bases[0].id;
|
||||
|
||||
createSessions({ knowledge_base_id: knowledgeBaseId }).then(res => {
|
||||
|
||||
createSessions({ knowledge_base_id: sessionKbId }).then(res => {
|
||||
if (res.data && res.data.id) {
|
||||
getTitle(res.data.id, value);
|
||||
} else {
|
||||
|
||||
234
frontend/src/views/knowledge/KnowledgeBaseList.vue
Normal file
234
frontend/src/views/knowledge/KnowledgeBaseList.vue
Normal file
@@ -0,0 +1,234 @@
|
||||
<template>
|
||||
<div class="kb-list-container">
|
||||
<div class="header">
|
||||
<h2>知识库</h2>
|
||||
<t-button theme="primary" @click="openCreate">新建知识库</t-button>
|
||||
</div>
|
||||
|
||||
<!-- 未初始化知识库提示 -->
|
||||
<div v-if="hasUninitializedKbs" class="warning-banner">
|
||||
<t-icon name="info-circle" size="16px" />
|
||||
<span>部分知识库尚未初始化,需要先在设置中配置模型信息才能添加知识文档</span>
|
||||
</div>
|
||||
<t-table :data="kbs" :columns="columns" row-key="id" size="medium" hover>
|
||||
<template #status="{ row }">
|
||||
<div class="status-cell">
|
||||
<t-tag
|
||||
:theme="isInitialized(row) ? 'success' : 'warning'"
|
||||
size="small"
|
||||
>
|
||||
{{ isInitialized(row) ? '已初始化' : '未初始化' }}
|
||||
</t-tag>
|
||||
<t-tooltip
|
||||
v-if="!isInitialized(row)"
|
||||
content="需要先在设置中配置模型信息才能添加知识"
|
||||
placement="top"
|
||||
>
|
||||
<span class="warning-icon">⚠</span>
|
||||
</t-tooltip>
|
||||
</div>
|
||||
</template>
|
||||
<template #description="{ row }">
|
||||
<div class="description-text">{{ row.description || '暂无描述' }}</div>
|
||||
</template>
|
||||
<template #op="{ row }">
|
||||
<t-space size="small">
|
||||
<t-button
|
||||
size="small"
|
||||
@click="goDetail(row.id)"
|
||||
:disabled="!isInitialized(row)"
|
||||
:theme="isInitialized(row) ? 'primary' : 'default'"
|
||||
:variant="isInitialized(row) ? 'base' : 'outline'"
|
||||
:title="!isInitialized(row) ? '请先在设置中配置模型信息' : ''"
|
||||
>
|
||||
文档
|
||||
</t-button>
|
||||
<t-button size="small" variant="outline" @click="goSettings(row.id)">设置</t-button>
|
||||
<t-popconfirm content="确认删除该知识库?" @confirm="remove(row.id)">
|
||||
<t-button size="small" theme="danger" variant="text">删除</t-button>
|
||||
</t-popconfirm>
|
||||
</t-space>
|
||||
</template>
|
||||
</t-table>
|
||||
|
||||
<t-dialog v-model:visible="createVisible" header="新建知识库" :footer="false">
|
||||
<t-form :data="createForm" @submit="create">
|
||||
<t-form-item label="名称" name="name" :rules="[{ required: true, message: '请输入名称' }]">
|
||||
<t-input v-model="createForm.name" />
|
||||
</t-form-item>
|
||||
<t-form-item label="描述" name="description">
|
||||
<t-textarea v-model="createForm.description" />
|
||||
</t-form-item>
|
||||
<t-form-item>
|
||||
<t-space>
|
||||
<t-button theme="primary" type="submit" :loading="creating">创建</t-button>
|
||||
<t-button variant="outline" @click="createVisible = false">取消</t-button>
|
||||
</t-space>
|
||||
</t-form-item>
|
||||
</t-form>
|
||||
</t-dialog>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { onMounted, reactive, ref, computed } from 'vue'
|
||||
import { useRouter } from 'vue-router'
|
||||
import { MessagePlugin } from 'tdesign-vue-next'
|
||||
import { listKnowledgeBases, createKnowledgeBase, deleteKnowledgeBase } from '@/api/knowledge-base'
|
||||
import { formatStringDate } from '@/utils/index'
|
||||
|
||||
const router = useRouter()
|
||||
|
||||
interface KB {
|
||||
id: string;
|
||||
name: string;
|
||||
description?: string;
|
||||
updated_at?: string;
|
||||
embedding_model_id?: string;
|
||||
summary_model_id?: string;
|
||||
}
|
||||
const kbs = ref<KB[]>([])
|
||||
const loading = ref(false)
|
||||
|
||||
const columns = [
|
||||
{ colKey: 'name', title: '名称' },
|
||||
{ colKey: 'description', title: '描述', cell: 'description', width: 300 },
|
||||
{ colKey: 'status', title: '状态', cell: 'status', width: 100 },
|
||||
{ colKey: 'updated_at', title: '更新时间' },
|
||||
{ colKey: 'op', title: '操作', cell: 'op', width: 220 },
|
||||
]
|
||||
|
||||
const fetchList = () => {
|
||||
loading.value = true
|
||||
listKnowledgeBases().then((res: any) => {
|
||||
const data = res.data || []
|
||||
// 格式化时间
|
||||
kbs.value = data.map((kb: KB) => ({
|
||||
...kb,
|
||||
updated_at: kb.updated_at ? formatStringDate(new Date(kb.updated_at)) : ''
|
||||
}))
|
||||
}).finally(() => loading.value = false)
|
||||
}
|
||||
|
||||
onMounted(fetchList)
|
||||
|
||||
const createVisible = ref(false)
|
||||
const creating = ref(false)
|
||||
const createForm = reactive({ name: '', description: '' })
|
||||
const openCreate = () => {
|
||||
createForm.name = ''
|
||||
createForm.description = ''
|
||||
createVisible.value = true
|
||||
}
|
||||
const create = () => {
|
||||
if (!createForm.name) return
|
||||
creating.value = true
|
||||
const chunking_config = {
|
||||
chunk_size: 512,
|
||||
chunk_overlap: 100,
|
||||
separators: ['.', '?', '!', '。', '?', '!'],
|
||||
enable_multimodal: false
|
||||
}
|
||||
createKnowledgeBase({ name: createForm.name, description: createForm.description, chunking_config }).then((res: any) => {
|
||||
if (res.success) {
|
||||
MessagePlugin.success('创建成功')
|
||||
createVisible.value = false
|
||||
fetchList()
|
||||
} else {
|
||||
MessagePlugin.error(res.message || '创建失败')
|
||||
}
|
||||
}).catch((e: any) => {
|
||||
MessagePlugin.error(e?.message || '创建失败')
|
||||
}).finally(() => creating.value = false)
|
||||
}
|
||||
|
||||
const remove = (id: string) => {
|
||||
deleteKnowledgeBase(id).then((res: any) => {
|
||||
if (res.success) {
|
||||
MessagePlugin.success('已删除')
|
||||
fetchList()
|
||||
} else {
|
||||
MessagePlugin.error(res.message || '删除失败')
|
||||
}
|
||||
}).catch((e: any) => MessagePlugin.error(e?.message || '删除失败'))
|
||||
}
|
||||
|
||||
const isInitialized = (kb: KB) => {
|
||||
return !!(kb.embedding_model_id && kb.embedding_model_id !== '' &&
|
||||
kb.summary_model_id && kb.summary_model_id !== '')
|
||||
}
|
||||
|
||||
// 计算是否有未初始化的知识库
|
||||
const hasUninitializedKbs = computed(() => {
|
||||
return kbs.value.some(kb => !isInitialized(kb))
|
||||
})
|
||||
|
||||
const goDetail = (id: string) => {
|
||||
router.push(`/platform/knowledge-bases/${id}`)
|
||||
}
|
||||
const goSettings = (id: string) => {
|
||||
router.push(`/platform/knowledge-bases/${id}/settings`)
|
||||
}
|
||||
</script>
|
||||
|
||||
<style scoped lang="less">
|
||||
.kb-list-container {
|
||||
padding: 20px;
|
||||
background: #fff;
|
||||
margin: 0 20px 0 20px;
|
||||
height: calc(100vh);
|
||||
overflow-y: auto;
|
||||
box-sizing: border-box;
|
||||
flex: 1;
|
||||
}
|
||||
.header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
margin-bottom: 16px;
|
||||
h2 { margin: 0; font-size: 20px; font-weight: 600; }
|
||||
}
|
||||
|
||||
.warning-banner {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
padding: 12px 16px;
|
||||
margin-bottom: 16px;
|
||||
background: #fff7e6;
|
||||
border: 1px solid #ffd591;
|
||||
border-radius: 6px;
|
||||
color: #d46b08;
|
||||
font-size: 14px;
|
||||
|
||||
.t-icon {
|
||||
color: #d46b08;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
}
|
||||
|
||||
.status-cell {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
|
||||
.warning-icon {
|
||||
color: #ff8800;
|
||||
cursor: pointer;
|
||||
font-size: 16px;
|
||||
font-weight: bold;
|
||||
transition: color 0.2s;
|
||||
|
||||
&:hover {
|
||||
color: #d46b08;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
.description-cell {
|
||||
.description-text {
|
||||
color: #000000e6;
|
||||
font-size: 14px;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
@@ -1,5 +1,5 @@
|
||||
<template>
|
||||
<div class="main" ref="dropzone" @dragover="dragover" @drop="drop" @dragstart="dragstart">
|
||||
<div class="main" ref="dropzone">
|
||||
<Menu></Menu>
|
||||
<RouterView />
|
||||
<div class="upload-mask" v-show="ismask">
|
||||
@@ -10,42 +10,104 @@
|
||||
</template>
|
||||
<script setup lang="ts">
|
||||
import Menu from '@/components/menu.vue'
|
||||
import { ref } from 'vue';
|
||||
import { useRouter } from 'vue-router'
|
||||
import { storeToRefs } from "pinia";
|
||||
import { knowledgeStore } from "@/stores/knowledge";
|
||||
const usemenuStore = knowledgeStore();
|
||||
import { ref, onMounted, onUnmounted } from 'vue';
|
||||
import { useRoute } from 'vue-router'
|
||||
import useKnowledgeBase from '@/hooks/useKnowledgeBase'
|
||||
import UploadMask from '@/components/upload-mask.vue'
|
||||
import { getKnowledgeBaseById } from '@/api/knowledge-base/index'
|
||||
import { MessagePlugin } from 'tdesign-vue-next'
|
||||
|
||||
let { requestMethod } = useKnowledgeBase()
|
||||
const router = useRouter();
|
||||
const route = useRoute();
|
||||
let ismask = ref(false)
|
||||
let dropzone = ref();
|
||||
let uploadInput = ref();
|
||||
const dragover = (event) => {
|
||||
event.preventDefault();
|
||||
ismask.value = true;
|
||||
if (((window.innerWidth - event.clientX) < 50) || ((window.innerHeight - event.clientY) < 50) || event.clientX < 50 || event.clientY < 50) {
|
||||
ismask.value = false
|
||||
}
|
||||
|
||||
// 获取当前知识库ID
|
||||
const getCurrentKbId = (): string | null => {
|
||||
return (route.params as any)?.kbId as string || null
|
||||
}
|
||||
const drop = (event) => {
|
||||
event.preventDefault();
|
||||
ismask.value = false
|
||||
const DataTransferItemList = event.dataTransfer.items;
|
||||
for (const dataTransferItem of DataTransferItemList) {
|
||||
const fileEntry = dataTransferItem.webkitGetAsEntry();
|
||||
if (fileEntry) {
|
||||
fileEntry.file((file: file) => {
|
||||
requestMethod(file, uploadInput)
|
||||
router.push('/platform/knowledgeBase?upload=true')
|
||||
})
|
||||
|
||||
// 检查知识库初始化状态
|
||||
const checkKnowledgeBaseInitialization = async (): Promise<boolean> => {
|
||||
const currentKbId = getCurrentKbId();
|
||||
|
||||
if (!currentKbId) {
|
||||
MessagePlugin.error("缺少知识库ID");
|
||||
return false;
|
||||
}
|
||||
|
||||
try {
|
||||
const kbResponse = await getKnowledgeBaseById(currentKbId);
|
||||
const kb = kbResponse.data;
|
||||
|
||||
if (!kb.embedding_model_id || !kb.summary_model_id) {
|
||||
MessagePlugin.warning("该知识库尚未完成初始化配置,请先前往设置页面配置模型信息后再上传文件");
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
} catch (error) {
|
||||
MessagePlugin.error("获取知识库信息失败,无法上传文件");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
const dragstart = (event) => {
|
||||
|
||||
|
||||
// 全局拖拽事件处理
|
||||
const handleGlobalDragEnter = (event: DragEvent) => {
|
||||
event.preventDefault();
|
||||
if (event.dataTransfer) {
|
||||
event.dataTransfer.effectAllowed = 'all';
|
||||
}
|
||||
ismask.value = true;
|
||||
}
|
||||
|
||||
const handleGlobalDragOver = (event: DragEvent) => {
|
||||
event.preventDefault();
|
||||
if (event.dataTransfer) {
|
||||
event.dataTransfer.dropEffect = 'copy';
|
||||
}
|
||||
ismask.value = true;
|
||||
}
|
||||
|
||||
const handleGlobalDrop = async (event: DragEvent) => {
|
||||
event.preventDefault();
|
||||
ismask.value = false;
|
||||
|
||||
const DataTransferFiles = event.dataTransfer?.files ? Array.from(event.dataTransfer.files) : [];
|
||||
const DataTransferItemList = event.dataTransfer?.items ? Array.from(event.dataTransfer.items) : [];
|
||||
|
||||
const isInitialized = await checkKnowledgeBaseInitialization();
|
||||
if (!isInitialized) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (DataTransferFiles.length > 0) {
|
||||
DataTransferFiles.forEach(file => requestMethod(file, uploadInput));
|
||||
} else if (DataTransferItemList.length > 0) {
|
||||
DataTransferItemList.forEach(dataTransferItem => {
|
||||
const fileEntry = dataTransferItem.webkitGetAsEntry() as FileSystemFileEntry | null;
|
||||
if (fileEntry) {
|
||||
fileEntry.file((file: File) => requestMethod(file, uploadInput));
|
||||
}
|
||||
});
|
||||
} else {
|
||||
MessagePlugin.warning('请拖拽文件而不是文本或链接');
|
||||
}
|
||||
}
|
||||
|
||||
// 组件挂载时添加全局事件监听器
|
||||
onMounted(() => {
|
||||
document.addEventListener('dragenter', handleGlobalDragEnter, true);
|
||||
document.addEventListener('dragover', handleGlobalDragOver, true);
|
||||
document.addEventListener('drop', handleGlobalDrop, true);
|
||||
});
|
||||
|
||||
// 组件卸载时移除全局事件监听器
|
||||
onUnmounted(() => {
|
||||
document.removeEventListener('dragenter', handleGlobalDragEnter, true);
|
||||
document.removeEventListener('dragover', handleGlobalDragOver, true);
|
||||
document.removeEventListener('drop', handleGlobalDrop, true);
|
||||
});
|
||||
</script>
|
||||
<style lang="less">
|
||||
.main {
|
||||
|
||||
211
frontend/src/views/settings/SystemSettings.vue
Normal file
211
frontend/src/views/settings/SystemSettings.vue
Normal file
@@ -0,0 +1,211 @@
|
||||
<template>
|
||||
<div class="system-settings-container">
|
||||
<!-- 页面标题区域 -->
|
||||
<div class="settings-header">
|
||||
<h2>{{ isKbSettings ? '知识库设置' : '系统设置' }}</h2>
|
||||
<p class="settings-subtitle">{{ isKbSettings ? '配置该知识库的模型与文档切分参数' : '管理和更新系统模型与服务配置' }}</p>
|
||||
</div>
|
||||
|
||||
<!-- 配置内容 -->
|
||||
<div class="settings-content">
|
||||
<!-- 系统设置:使用初始化配置 -->
|
||||
<InitializationContent v-if="!isKbSettings" />
|
||||
<!-- 知识库设置:基础信息与文档切分配置 -->
|
||||
<div v-else>
|
||||
<t-form :data="kbForm" @submit="saveKb">
|
||||
<div class="config-section">
|
||||
<h3><span class="section-icon">⚙️</span>基础信息</h3>
|
||||
<t-form-item label="名称" name="name" :rules="[{ required: true, message: '请输入名称' }]">
|
||||
<t-input v-model="kbForm.name" />
|
||||
</t-form-item>
|
||||
<t-form-item label="描述" name="description">
|
||||
<t-textarea v-model="kbForm.description" />
|
||||
</t-form-item>
|
||||
</div>
|
||||
<div class="config-section">
|
||||
<h3><span class="section-icon">📄</span>文档切分</h3>
|
||||
<t-row :gutter="16">
|
||||
<t-col :span="6">
|
||||
<t-form-item label="Chunk Size" name="chunkSize">
|
||||
<t-input-number v-model="kbForm.config.chunking_config.chunk_size" :min="1" />
|
||||
</t-form-item>
|
||||
</t-col>
|
||||
<t-col :span="6">
|
||||
<t-form-item label="Chunk Overlap" name="chunkOverlap">
|
||||
<t-input-number v-model="kbForm.config.chunking_config.chunk_overlap" :min="0" />
|
||||
</t-form-item>
|
||||
</t-col>
|
||||
</t-row>
|
||||
</div>
|
||||
<div class="submit-section">
|
||||
<t-button theme="primary" type="submit" :loading="saving">保存</t-button>
|
||||
</div>
|
||||
</t-form>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { defineAsyncComponent, onMounted, reactive, ref } from 'vue'
|
||||
import { useRoute, useRouter } from 'vue-router'
|
||||
import { MessagePlugin } from 'tdesign-vue-next'
|
||||
import { getKnowledgeBaseById, updateKnowledgeBase } from '@/api/knowledge-base'
|
||||
|
||||
// 异步加载初始化配置组件
|
||||
const InitializationContent = defineAsyncComponent(() => import('../initialization/InitializationContent.vue'))
|
||||
|
||||
const route = useRoute()
|
||||
const router = useRouter()
|
||||
const isKbSettings = ref<boolean>(false)
|
||||
|
||||
interface KbForm {
|
||||
name: string
|
||||
description?: string
|
||||
config: { chunking_config: { chunk_size: number; chunk_overlap: number } }
|
||||
}
|
||||
const kbForm = reactive<KbForm>({
|
||||
name: '',
|
||||
description: '',
|
||||
config: { chunking_config: { chunk_size: 512, chunk_overlap: 64 } }
|
||||
})
|
||||
const saving = ref(false)
|
||||
|
||||
const loadKb = () => {
|
||||
const kbId = (route.params as any).kbId as string
|
||||
if (!kbId) return
|
||||
getKnowledgeBaseById(kbId).then((res: any) => {
|
||||
if (res?.data) {
|
||||
kbForm.name = res.data.name
|
||||
kbForm.description = res.data.description
|
||||
const cc = res.data.chunking_config || {}
|
||||
kbForm.config.chunking_config.chunk_size = cc.chunk_size ?? 512
|
||||
kbForm.config.chunking_config.chunk_overlap = cc.chunk_overlap ?? 64
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
isKbSettings.value = route.name === 'knowledgeBaseSettings'
|
||||
if (isKbSettings.value) loadKb()
|
||||
})
|
||||
|
||||
const saveKb = () => {
|
||||
const kbId = (route.params as any).kbId as string
|
||||
if (!kbId) return
|
||||
saving.value = true
|
||||
updateKnowledgeBase(kbId, { name: kbForm.name, description: kbForm.description, config: { chunking_config: { chunk_size: kbForm.config.chunking_config.chunk_size, chunk_overlap: kbForm.config.chunking_config.chunk_overlap, separators: [], enable_multimodal: false }, image_processing_config: { model_id: '' } } })
|
||||
.then((res: any) => {
|
||||
if (res.success) {
|
||||
MessagePlugin.success('保存成功')
|
||||
} else {
|
||||
MessagePlugin.error(res.message || '保存失败')
|
||||
}
|
||||
})
|
||||
.catch((e: any) => MessagePlugin.error(e?.message || '保存失败'))
|
||||
.finally(() => saving.value = false)
|
||||
}
|
||||
</script>
|
||||
|
||||
<style lang="less" scoped>
|
||||
.system-settings-container {
|
||||
padding: 20px;
|
||||
background-color: #fff;
|
||||
margin: 0 20px 0 20px;
|
||||
height: calc(100vh);
|
||||
overflow-y: auto;
|
||||
box-sizing: border-box;
|
||||
flex: 1;
|
||||
}
|
||||
|
||||
.settings-header {
|
||||
margin-bottom: 20px;
|
||||
border-bottom: 1px solid #f0f0f0;
|
||||
padding-bottom: 16px;
|
||||
|
||||
h2 {
|
||||
font-size: 20px;
|
||||
font-weight: 600;
|
||||
color: #000000;
|
||||
margin: 0 0 8px 0;
|
||||
}
|
||||
|
||||
.settings-subtitle {
|
||||
font-size: 14px;
|
||||
color: #666666;
|
||||
margin: 0;
|
||||
}
|
||||
}
|
||||
|
||||
.settings-content {
|
||||
margin-top: 0;
|
||||
}
|
||||
|
||||
/* 响应式设计 */
|
||||
@media (max-width: 768px) {
|
||||
.system-settings-container {
|
||||
padding: 16px;
|
||||
margin: 10px;
|
||||
height: calc(100vh - 20px);
|
||||
}
|
||||
|
||||
.settings-header h2 {
|
||||
font-size: 18px;
|
||||
}
|
||||
}
|
||||
|
||||
/* 覆盖TDesign组件样式,与账户信息页面保持一致 */
|
||||
:deep(.t-card) {
|
||||
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1);
|
||||
border: 1px solid #e5e7eb;
|
||||
}
|
||||
|
||||
/* 调整InitializationContent内部样式,使每个配置区域显示为独立卡片 */
|
||||
:deep(.config-section) {
|
||||
background: #fff;
|
||||
border: 1px solid #e5e7eb;
|
||||
border-radius: 8px;
|
||||
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1);
|
||||
padding: 20px;
|
||||
margin-bottom: 20px;
|
||||
|
||||
&:last-child {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
h3 {
|
||||
font-size: 16px;
|
||||
font-weight: 600;
|
||||
color: #000000;
|
||||
margin: 0 0 16px 0;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
padding: 0;
|
||||
background: none;
|
||||
border-left: none;
|
||||
border-radius: 0;
|
||||
border-bottom: 1px solid #f0f0f0;
|
||||
padding-bottom: 12px;
|
||||
|
||||
.section-icon {
|
||||
margin-right: 8px;
|
||||
color: #07c05f;
|
||||
font-size: 18px;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
:deep(.ollama-summary-card) {
|
||||
background: #fff;
|
||||
border: 1px solid #e5e7eb;
|
||||
border-radius: 8px;
|
||||
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1);
|
||||
padding: 20px;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
|
||||
:deep(.submit-section) {
|
||||
margin-top: 20px;
|
||||
text-align: center;
|
||||
}
|
||||
</style>
|
||||
535
frontend/src/views/tenant/TenantInfo.vue
Normal file
535
frontend/src/views/tenant/TenantInfo.vue
Normal file
@@ -0,0 +1,535 @@
|
||||
<template>
|
||||
<div class="tenant-info-container">
|
||||
<div class="tenant-header">
|
||||
<h2>系统信息</h2>
|
||||
<p class="tenant-subtitle">查看系统版本信息和用户账户配置</p>
|
||||
</div>
|
||||
|
||||
<div class="tenant-content" v-if="!loading && !error">
|
||||
<!-- 系统信息卡片 -->
|
||||
<t-card class="info-card" :bordered="false">
|
||||
<template #header>
|
||||
<div class="card-title">系统信息</div>
|
||||
</template>
|
||||
<div class="info-content">
|
||||
<t-descriptions :column="1" layout="vertical">
|
||||
<t-descriptions-item label="版本号">
|
||||
{{ systemInfo?.version || '未知' }}
|
||||
<span v-if="systemInfo?.commit_id" class="commit-info">
|
||||
({{ systemInfo.commit_id }})
|
||||
</span>
|
||||
</t-descriptions-item>
|
||||
<t-descriptions-item label="构建时间" v-if="systemInfo?.build_time">
|
||||
{{ systemInfo.build_time }}
|
||||
</t-descriptions-item>
|
||||
<t-descriptions-item label="Go版本" v-if="systemInfo?.go_version">
|
||||
{{ systemInfo.go_version }}
|
||||
</t-descriptions-item>
|
||||
</t-descriptions>
|
||||
</div>
|
||||
</t-card>
|
||||
|
||||
<!-- 用户信息卡片 -->
|
||||
<t-card class="info-card" :bordered="false">
|
||||
<template #header>
|
||||
<div class="card-title">用户信息</div>
|
||||
</template>
|
||||
<div class="info-content">
|
||||
<t-descriptions :column="1" layout="vertical">
|
||||
<t-descriptions-item label="用户 ID">
|
||||
{{ userInfo?.id }}
|
||||
</t-descriptions-item>
|
||||
<t-descriptions-item label="用户名">
|
||||
{{ userInfo?.username }}
|
||||
</t-descriptions-item>
|
||||
<t-descriptions-item label="邮箱">
|
||||
{{ userInfo?.email }}
|
||||
</t-descriptions-item>
|
||||
<t-descriptions-item label="创建时间">
|
||||
{{ formatDate(userInfo?.created_at) }}
|
||||
</t-descriptions-item>
|
||||
</t-descriptions>
|
||||
</div>
|
||||
</t-card>
|
||||
|
||||
<!-- 租户信息卡片 -->
|
||||
<t-card class="info-card" :bordered="false">
|
||||
<template #header>
|
||||
<div class="card-title">租户信息</div>
|
||||
</template>
|
||||
<div class="info-content">
|
||||
<t-descriptions :column="1" layout="vertical">
|
||||
<t-descriptions-item label="租户 ID">
|
||||
{{ tenantInfo?.id }}
|
||||
</t-descriptions-item>
|
||||
<t-descriptions-item label="租户名称">
|
||||
{{ tenantInfo?.name }}
|
||||
</t-descriptions-item>
|
||||
<t-descriptions-item label="描述">
|
||||
{{ tenantInfo?.description || '暂无描述' }}
|
||||
</t-descriptions-item>
|
||||
<t-descriptions-item label="业务">
|
||||
{{ tenantInfo?.business || '暂无' }}
|
||||
</t-descriptions-item>
|
||||
<t-descriptions-item label="状态">
|
||||
<t-tag
|
||||
:theme="getStatusTheme(tenantInfo?.status)"
|
||||
variant="light"
|
||||
>
|
||||
{{ getStatusText(tenantInfo?.status) }}
|
||||
</t-tag>
|
||||
</t-descriptions-item>
|
||||
<t-descriptions-item label="创建时间">
|
||||
{{ formatDate(tenantInfo?.created_at) }}
|
||||
</t-descriptions-item>
|
||||
</t-descriptions>
|
||||
</div>
|
||||
</t-card>
|
||||
|
||||
<!-- API Key 卡片 -->
|
||||
<t-card class="info-card" :bordered="false">
|
||||
<template #header>
|
||||
<div class="card-header-with-actions">
|
||||
<div class="card-title">API Key</div>
|
||||
</div>
|
||||
</template>
|
||||
<div class="api-key-content">
|
||||
<t-input
|
||||
v-model="displayApiKey"
|
||||
readonly
|
||||
class="api-key-input"
|
||||
:type="showApiKey ? 'text' : 'password'"
|
||||
/>
|
||||
<t-alert theme="warning" :close="false" class="api-warning">
|
||||
<template #icon>
|
||||
<t-icon name="error-circle" />
|
||||
</template>
|
||||
请妥善保管您的 API Key,不要在公共场所或代码仓库中暴露
|
||||
</t-alert>
|
||||
</div>
|
||||
</t-card>
|
||||
|
||||
<!-- 存储信息卡片 -->
|
||||
<t-card
|
||||
class="info-card"
|
||||
:bordered="false"
|
||||
v-if="tenantInfo?.storage_quota !== undefined"
|
||||
>
|
||||
<template #header>
|
||||
<div class="card-title">存储信息</div>
|
||||
</template>
|
||||
<div class="storage-content">
|
||||
<t-descriptions :column="1" layout="vertical">
|
||||
<t-descriptions-item label="存储配额">
|
||||
{{ formatBytes(tenantInfo.storage_quota) }}
|
||||
</t-descriptions-item>
|
||||
<t-descriptions-item label="已使用">
|
||||
{{ formatBytes(tenantInfo.storage_used || 0) }}
|
||||
</t-descriptions-item>
|
||||
<t-descriptions-item label="使用率">
|
||||
<div class="usage-info">
|
||||
<span class="usage-text">{{ getUsagePercentage() }}%</span>
|
||||
<t-progress
|
||||
:percentage="getUsagePercentage()"
|
||||
:show-info="false"
|
||||
size="medium"
|
||||
:theme="getUsagePercentage() > 80 ? 'warning' : 'success'"
|
||||
/>
|
||||
</div>
|
||||
</t-descriptions-item>
|
||||
</t-descriptions>
|
||||
</div>
|
||||
</t-card>
|
||||
|
||||
<!-- API 开发文档卡片 -->
|
||||
<t-card class="info-card" :bordered="false">
|
||||
<template #header>
|
||||
<div class="card-title">API 开发文档</div>
|
||||
</template>
|
||||
<div class="doc-content">
|
||||
<p class="doc-description">使用您的 API Key 开始开发,查看完整的 API 文档和示例代码。</p>
|
||||
<t-space class="doc-actions">
|
||||
<t-button
|
||||
theme="primary"
|
||||
@click="openApiDoc"
|
||||
>
|
||||
<template #icon>
|
||||
<t-icon name="link" />
|
||||
</template>
|
||||
查看 API 文档
|
||||
</t-button>
|
||||
</t-space>
|
||||
|
||||
</div>
|
||||
</t-card>
|
||||
</div>
|
||||
|
||||
<!-- 加载状态 -->
|
||||
<div v-if="loading" class="loading-container">
|
||||
<t-loading size="large" />
|
||||
<p class="loading-text">正在加载账户信息...</p>
|
||||
</div>
|
||||
|
||||
<!-- 错误状态 -->
|
||||
<div v-if="error" class="error-container">
|
||||
<t-result theme="error" title="加载失败" :description="error">
|
||||
<template #extra>
|
||||
<t-button theme="primary" @click="loadTenantInfo">重试</t-button>
|
||||
</template>
|
||||
</t-result>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, onMounted, computed } from 'vue'
|
||||
import { getCurrentUser, type TenantInfo, type UserInfo } from '@/api/auth'
|
||||
import { getSystemInfo, type SystemInfo } from '@/api/system'
|
||||
|
||||
// 响应式数据
|
||||
const tenantInfo = ref<TenantInfo | null>(null)
|
||||
const userInfo = ref<UserInfo | null>(null)
|
||||
const systemInfo = ref<SystemInfo | null>(null)
|
||||
const loading = ref(true)
|
||||
const error = ref('')
|
||||
const showApiKey = ref(false)
|
||||
const showApiExample = ref(false)
|
||||
|
||||
// API 基础 URL
|
||||
const apiBaseUrl = window.location.origin
|
||||
|
||||
// 计算属性
|
||||
const displayApiKey = computed(() => {
|
||||
if (!tenantInfo.value?.api_key) return ''
|
||||
return tenantInfo.value.api_key
|
||||
})
|
||||
|
||||
// API示例代码
|
||||
const apiExampleCode = computed(() => {
|
||||
return `curl -X GET "${apiBaseUrl}/api/v1/tenants/${tenantInfo.value?.id}" \\
|
||||
-H "Content-Type: application/json" \\
|
||||
-H "X-API-Key: ${tenantInfo.value?.api_key}"`
|
||||
})
|
||||
|
||||
// 方法
|
||||
const loadTenantInfo = async () => {
|
||||
try {
|
||||
loading.value = true
|
||||
error.value = ''
|
||||
|
||||
// 并行获取用户信息和系统信息
|
||||
const [userResponse, systemResponse] = await Promise.all([
|
||||
getCurrentUser(),
|
||||
getSystemInfo().catch(() => ({ data: null })) // 系统信息获取失败不影响页面显示
|
||||
])
|
||||
|
||||
if (userResponse.success && userResponse.data) {
|
||||
userInfo.value = userResponse.data.user
|
||||
tenantInfo.value = userResponse.data.tenant
|
||||
} else {
|
||||
error.value = userResponse.message || '获取用户信息失败'
|
||||
}
|
||||
|
||||
if (systemResponse.data) {
|
||||
systemInfo.value = systemResponse.data
|
||||
}
|
||||
} catch (err: any) {
|
||||
error.value = err.message || '网络错误,请稍后重试'
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
const toggleApiKeyVisibility = () => {
|
||||
showApiKey.value = !showApiKey.value
|
||||
}
|
||||
|
||||
const copyApiKey = async () => {
|
||||
if (!tenantInfo.value?.api_key) return
|
||||
|
||||
try {
|
||||
await navigator.clipboard.writeText(tenantInfo.value.api_key)
|
||||
// 使用TDesign的消息组件
|
||||
import('tdesign-vue-next').then(({ MessagePlugin }) => {
|
||||
MessagePlugin.success('API Key 已复制到剪贴板')
|
||||
})
|
||||
} catch (err) {
|
||||
// 降级到传统方式
|
||||
const textArea = document.createElement('textarea')
|
||||
textArea.value = tenantInfo.value.api_key
|
||||
document.body.appendChild(textArea)
|
||||
textArea.select()
|
||||
document.execCommand('copy')
|
||||
document.body.removeChild(textArea)
|
||||
import('tdesign-vue-next').then(({ MessagePlugin }) => {
|
||||
MessagePlugin.success('API Key 已复制到剪贴板')
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const openApiDoc = () => {
|
||||
window.open('https://github.com/Tencent/WeKnora/blob/main/docs/API.md', '_blank')
|
||||
}
|
||||
|
||||
const getStatusText = (status: string | undefined) => {
|
||||
switch (status) {
|
||||
case 'active':
|
||||
return '活跃'
|
||||
case 'inactive':
|
||||
return '未激活'
|
||||
case 'suspended':
|
||||
return '已暂停'
|
||||
default:
|
||||
return '未知'
|
||||
}
|
||||
}
|
||||
|
||||
const getStatusTheme = (status: string | undefined) => {
|
||||
switch (status) {
|
||||
case 'active':
|
||||
return 'success'
|
||||
case 'inactive':
|
||||
return 'warning'
|
||||
case 'suspended':
|
||||
return 'danger'
|
||||
default:
|
||||
return 'default'
|
||||
}
|
||||
}
|
||||
|
||||
const formatDate = (dateStr: string | undefined) => {
|
||||
if (!dateStr) return '未知'
|
||||
|
||||
try {
|
||||
const date = new Date(dateStr)
|
||||
return date.toLocaleString('zh-CN', {
|
||||
year: 'numeric',
|
||||
month: '2-digit',
|
||||
day: '2-digit',
|
||||
hour: '2-digit',
|
||||
minute: '2-digit',
|
||||
second: '2-digit'
|
||||
})
|
||||
} catch {
|
||||
return '格式错误'
|
||||
}
|
||||
}
|
||||
|
||||
const formatBytes = (bytes: number) => {
|
||||
if (bytes === 0) return '0 B'
|
||||
|
||||
const k = 1024
|
||||
const sizes = ['B', 'KB', 'MB', 'GB', 'TB']
|
||||
const i = Math.floor(Math.log(bytes) / Math.log(k))
|
||||
|
||||
return parseFloat((bytes / Math.pow(k, i)).toFixed(2)) + ' ' + sizes[i]
|
||||
}
|
||||
|
||||
const getUsagePercentage = () => {
|
||||
if (!tenantInfo.value?.storage_quota || tenantInfo.value.storage_quota === 0) {
|
||||
return 0
|
||||
}
|
||||
|
||||
const used = tenantInfo.value.storage_used || 0
|
||||
const percentage = (used / tenantInfo.value.storage_quota) * 100
|
||||
return Math.min(Math.round(percentage * 100) / 100, 100) // 保留两位小数,最大100%
|
||||
}
|
||||
|
||||
// 生命周期
|
||||
onMounted(() => {
|
||||
loadTenantInfo()
|
||||
})
|
||||
</script>
|
||||
|
||||
<style lang="less" scoped>
|
||||
.tenant-info-container {
|
||||
padding: 20px;
|
||||
background-color: #fff;
|
||||
margin: 0 20px 0 20px;
|
||||
height: calc(100vh);
|
||||
overflow-y: auto;
|
||||
box-sizing: border-box;
|
||||
flex: 1;
|
||||
}
|
||||
|
||||
.tenant-header {
|
||||
margin-bottom: 20px;
|
||||
border-bottom: 1px solid #f0f0f0;
|
||||
padding-bottom: 16px;
|
||||
|
||||
h2 {
|
||||
font-size: 20px;
|
||||
font-weight: 600;
|
||||
color: #000000;
|
||||
margin: 0 0 8px 0;
|
||||
}
|
||||
|
||||
.tenant-subtitle {
|
||||
font-size: 14px;
|
||||
color: #666666;
|
||||
margin: 0;
|
||||
}
|
||||
}
|
||||
|
||||
.tenant-content {
|
||||
display: grid;
|
||||
gap: 20px;
|
||||
grid-template-columns: 1fr;
|
||||
}
|
||||
|
||||
.info-card {
|
||||
margin-bottom: 20px;
|
||||
|
||||
.card-title {
|
||||
font-size: 16px;
|
||||
font-weight: 600;
|
||||
color: #07C05F;
|
||||
}
|
||||
|
||||
.card-header-with-actions {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
}
|
||||
}
|
||||
|
||||
.info-content,
|
||||
.api-key-content,
|
||||
.storage-content,
|
||||
.doc-content {
|
||||
margin-top: 0;
|
||||
}
|
||||
|
||||
.api-key-input {
|
||||
margin-bottom: 16px;
|
||||
font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', monospace;
|
||||
}
|
||||
|
||||
.api-warning {
|
||||
margin-top: 16px;
|
||||
}
|
||||
|
||||
.usage-info {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 8px;
|
||||
|
||||
.usage-text {
|
||||
font-weight: 500;
|
||||
color: #000000;
|
||||
}
|
||||
}
|
||||
|
||||
.doc-description {
|
||||
margin-bottom: 16px;
|
||||
color: #666666;
|
||||
font-size: 14px;
|
||||
}
|
||||
|
||||
.doc-actions {
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
|
||||
.api-example {
|
||||
margin-top: 20px;
|
||||
padding: 16px;
|
||||
background-color: #f8f9fa;
|
||||
border-radius: 6px;
|
||||
|
||||
.example-header h4 {
|
||||
margin: 0 0 16px 0;
|
||||
font-size: 16px;
|
||||
font-weight: 600;
|
||||
color: #000000;
|
||||
}
|
||||
|
||||
.code-textarea {
|
||||
margin-bottom: 16px;
|
||||
font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', monospace;
|
||||
}
|
||||
|
||||
.example-note {
|
||||
margin-top: 16px;
|
||||
}
|
||||
}
|
||||
|
||||
.loading-container {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
padding: 60px 24px;
|
||||
text-align: center;
|
||||
|
||||
.loading-text {
|
||||
margin-top: 16px;
|
||||
color: #666666;
|
||||
font-size: 14px;
|
||||
}
|
||||
}
|
||||
|
||||
.error-container {
|
||||
padding: 40px;
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
/* 响应式设计 */
|
||||
@media (max-width: 768px) {
|
||||
.tenant-info-container {
|
||||
padding: 16px;
|
||||
margin: 10px;
|
||||
height: calc(100vh - 20px);
|
||||
}
|
||||
|
||||
.tenant-header h2 {
|
||||
font-size: 18px;
|
||||
}
|
||||
|
||||
.card-header-with-actions {
|
||||
flex-direction: column;
|
||||
align-items: flex-start !important;
|
||||
gap: 12px;
|
||||
}
|
||||
|
||||
.commit-info {
|
||||
color: #666;
|
||||
font-size: 12px;
|
||||
margin-left: 8px;
|
||||
}
|
||||
|
||||
.doc-actions {
|
||||
:deep(.t-space) {
|
||||
flex-direction: column;
|
||||
width: 100%;
|
||||
|
||||
.t-button {
|
||||
width: 100%;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* 覆盖TDesign组件样式 */
|
||||
:deep(.t-card) {
|
||||
border: 1px solid #e5e7eb;
|
||||
}
|
||||
|
||||
:deep(.t-descriptions-item__label) {
|
||||
font-weight: 500;
|
||||
color: #374151;
|
||||
}
|
||||
|
||||
:deep(.t-descriptions-item__content) {
|
||||
color: #000000;
|
||||
}
|
||||
|
||||
:deep(.t-input__inner) {
|
||||
font-family: inherit;
|
||||
}
|
||||
|
||||
:deep(.code-textarea .t-textarea__inner) {
|
||||
font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', monospace;
|
||||
font-size: 13px;
|
||||
line-height: 1.4;
|
||||
}
|
||||
</style>
|
||||
20
go.mod
20
go.mod
@@ -10,14 +10,16 @@ require (
|
||||
github.com/gin-contrib/cors v1.7.5
|
||||
github.com/gin-gonic/gin v1.10.0
|
||||
github.com/go-viper/mapstructure/v2 v2.2.1
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/hibiken/asynq v0.25.1
|
||||
github.com/minio/minio-go/v7 v7.0.90
|
||||
github.com/neo4j/neo4j-go-driver/v6 v6.0.0-alpha.1
|
||||
github.com/ollama/ollama v0.11.4
|
||||
github.com/panjf2000/ants/v2 v2.11.2
|
||||
github.com/parquet-go/parquet-go v0.25.0
|
||||
github.com/pgvector/pgvector-go v0.3.0
|
||||
github.com/redis/go-redis/v9 v9.7.3
|
||||
github.com/redis/go-redis/v9 v9.14.0
|
||||
github.com/sashabaranov/go-openai v1.40.5
|
||||
github.com/sirupsen/logrus v1.9.3
|
||||
github.com/spf13/viper v1.20.1
|
||||
@@ -31,9 +33,10 @@ require (
|
||||
go.opentelemetry.io/otel/sdk v1.37.0
|
||||
go.opentelemetry.io/otel/trace v1.37.0
|
||||
go.uber.org/dig v1.18.1
|
||||
golang.org/x/sync v0.15.0
|
||||
golang.org/x/crypto v0.42.0
|
||||
golang.org/x/sync v0.17.0
|
||||
google.golang.org/grpc v1.73.0
|
||||
google.golang.org/protobuf v1.36.6
|
||||
google.golang.org/protobuf v1.36.9
|
||||
gorm.io/driver/postgres v1.5.11
|
||||
gorm.io/gorm v1.25.12
|
||||
)
|
||||
@@ -90,7 +93,7 @@ require (
|
||||
github.com/sagikazarmark/locafero v0.7.0 // indirect
|
||||
github.com/sourcegraph/conc v0.3.0 // indirect
|
||||
github.com/spf13/afero v1.12.0 // indirect
|
||||
github.com/spf13/cast v1.7.1 // indirect
|
||||
github.com/spf13/cast v1.10.0 // indirect
|
||||
github.com/spf13/pflag v1.0.6 // indirect
|
||||
github.com/subosito/gotenv v1.6.0 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
@@ -100,11 +103,10 @@ require (
|
||||
go.opentelemetry.io/proto/otlp v1.7.0 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
golang.org/x/arch v0.15.0 // indirect
|
||||
golang.org/x/crypto v0.39.0 // indirect
|
||||
golang.org/x/net v0.41.0 // indirect
|
||||
golang.org/x/sys v0.33.0 // indirect
|
||||
golang.org/x/text v0.26.0 // indirect
|
||||
golang.org/x/time v0.11.0 // indirect
|
||||
golang.org/x/net v0.43.0 // indirect
|
||||
golang.org/x/sys v0.36.0 // indirect
|
||||
golang.org/x/text v0.29.0 // indirect
|
||||
golang.org/x/time v0.13.0 // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
|
||||
44
go.sum
44
go.sum
@@ -70,6 +70,8 @@ github.com/go-viper/mapstructure/v2 v2.2.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlnd
|
||||
github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
|
||||
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
@@ -139,6 +141,8 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/mozillazg/go-httpheader v0.2.1 h1:geV7TrjbL8KXSyvghnFm+NyTux/hxwueTSrwhe88TQQ=
|
||||
github.com/mozillazg/go-httpheader v0.2.1/go.mod h1:jJ8xECTlalr6ValeXYdOF8fFUISeBAdw6E61aqQma60=
|
||||
github.com/neo4j/neo4j-go-driver/v6 v6.0.0-alpha.1 h1:nV3ZdYJTi73jel0mm3dpWumNY3i3nwyo25y69SPGwyg=
|
||||
github.com/neo4j/neo4j-go-driver/v6 v6.0.0-alpha.1/go.mod h1:hzSTfNfM31p1uRSzL1F/BAYOgaiTarE6OAQBajfsm+I=
|
||||
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
|
||||
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
|
||||
github.com/ollama/ollama v0.11.4 h1:6xLYLEPTKtw6N20qQecyEL/rrBktPO4o5U05cnvkSmI=
|
||||
@@ -155,8 +159,8 @@ github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ
|
||||
github.com/pierrec/lz4/v4 v4.1.21/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM=
|
||||
github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA=
|
||||
github.com/redis/go-redis/v9 v9.14.0 h1:u4tNCjXOyzfgeLN+vAZaW1xUooqWDqVEsZN0U01jfAE=
|
||||
github.com/redis/go-redis/v9 v9.14.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
||||
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||
@@ -176,8 +180,8 @@ github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9yS
|
||||
github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0=
|
||||
github.com/spf13/afero v1.12.0 h1:UcOPyRBYczmFn6yvphxkn9ZEOY65cpwGKb5mL36mrqs=
|
||||
github.com/spf13/afero v1.12.0/go.mod h1:ZTlWwG4/ahT8W7T0WQ5uYmjI9duaLQGy3Q2OAl4sk/4=
|
||||
github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y=
|
||||
github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
|
||||
github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY=
|
||||
github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo=
|
||||
github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o=
|
||||
github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/spf13/viper v1.20.1 h1:ZMi+z/lvLyPSCoNtFCpqjy0S4kPbirhpTMwl8BkW9X4=
|
||||
@@ -253,22 +257,22 @@ go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
|
||||
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||
golang.org/x/arch v0.15.0 h1:QtOrQd0bTUnhNVNndMpLHNWrDmYzZ2KDqSrEymqInZw=
|
||||
golang.org/x/arch v0.15.0/go.mod h1:JmwW7aLIoRUKgaTzhkiEFxvcEiQGyOg9BMonBJUS7EE=
|
||||
golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM=
|
||||
golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U=
|
||||
golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw=
|
||||
golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA=
|
||||
golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8=
|
||||
golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI=
|
||||
golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8=
|
||||
golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE=
|
||||
golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg=
|
||||
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
|
||||
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
|
||||
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/term v0.32.0 h1:DR4lr0TjUs3epypdhTOkMmuF5CDFJ/8pOnbzMZPQ7bg=
|
||||
golang.org/x/term v0.32.0/go.mod h1:uZG1FhGx848Sqfsq4/DlJr3xGGsYMu/L5GW4abiaEPQ=
|
||||
golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M=
|
||||
golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA=
|
||||
golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0=
|
||||
golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
|
||||
golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k=
|
||||
golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.35.0 h1:bZBVKBudEyhRcajGcNc3jIfWPqV4y/Kt2XcoigOWtDQ=
|
||||
golang.org/x/term v0.35.0/go.mod h1:TPGtkTLesOwf2DE8CgVYiZinHAOuy5AYUYT1lENIZnA=
|
||||
golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk=
|
||||
golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4=
|
||||
golang.org/x/time v0.13.0 h1:eUlYslOIt32DgYD6utsuUeHs4d7AsEYLuIAdg7FlYgI=
|
||||
golang.org/x/time v0.13.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822 h1:oWVWY3NzT7KJppx2UKhKmzPq4SRe0LdCijVRwvGeikY=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822/go.mod h1:h3c4v36UTKzUiuaOKQ6gr3S+0hovBtUrXzTG/i3+XEc=
|
||||
@@ -276,8 +280,8 @@ google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 h1:
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A=
|
||||
google.golang.org/grpc v1.73.0 h1:VIWSmpI2MegBtTuFt5/JWy2oXxtjJ/e89Z70ImfD2ok=
|
||||
google.golang.org/grpc v1.73.0/go.mod h1:50sbHOUqWoCQGI8V2HQLJM0B+LMlIUjNSZmow7EVBQc=
|
||||
google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY=
|
||||
google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY=
|
||||
google.golang.org/protobuf v1.36.9 h1:w2gp2mA27hUeUzj9Ex9FBjsBm40zfaDtEWow293U7Iw=
|
||||
google.golang.org/protobuf v1.36.9/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
|
||||
231
internal/application/repository/retriever/neo4j/repository.go
Normal file
231
internal/application/repository/retriever/neo4j/repository.go
Normal file
@@ -0,0 +1,231 @@
|
||||
package neo4j
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/logger"
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
||||
"github.com/neo4j/neo4j-go-driver/v6/neo4j"
|
||||
)
|
||||
|
||||
type Neo4jRepository struct {
|
||||
driver neo4j.Driver
|
||||
nodePrefix string
|
||||
}
|
||||
|
||||
func NewNeo4jRepository(driver neo4j.Driver) interfaces.RetrieveGraphRepository {
|
||||
return &Neo4jRepository{driver: driver, nodePrefix: "ENTITY"}
|
||||
}
|
||||
|
||||
func _remove_hyphen(s string) string {
|
||||
return strings.ReplaceAll(s, "-", "_")
|
||||
}
|
||||
|
||||
func (n *Neo4jRepository) Labels(namespace types.NameSpace) []string {
|
||||
res := make([]string, 0)
|
||||
for _, label := range namespace.Labels() {
|
||||
res = append(res, n.nodePrefix+_remove_hyphen(label))
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func (n *Neo4jRepository) Label(namespace types.NameSpace) string {
|
||||
labels := n.Labels(namespace)
|
||||
return strings.Join(labels, ":")
|
||||
}
|
||||
|
||||
// AddGraph implements interfaces.RetrieveGraphRepository.
|
||||
func (n *Neo4jRepository) AddGraph(ctx context.Context, namespace types.NameSpace, graphs []*types.GraphData) error {
|
||||
if n.driver == nil {
|
||||
logger.Warnf(ctx, "NOT SUPPORT RETRIEVE GRAPH")
|
||||
return nil
|
||||
}
|
||||
for _, graph := range graphs {
|
||||
if err := n.addGraph(ctx, namespace, graph); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *Neo4jRepository) addGraph(ctx context.Context, namespace types.NameSpace, graph *types.GraphData) error {
|
||||
session := n.driver.NewSession(ctx, neo4j.SessionConfig{AccessMode: neo4j.AccessModeWrite})
|
||||
defer session.Close(ctx)
|
||||
|
||||
_, err := session.ExecuteWrite(ctx, func(tx neo4j.ManagedTransaction) (interface{}, error) {
|
||||
node_import_query := `
|
||||
UNWIND $data AS row
|
||||
CALL apoc.merge.node(row.labels, {name: row.name, kg: row.knowledge_id}, row.props, {}) YIELD node
|
||||
SET node.chunks = apoc.coll.union(node.chunks, row.chunks)
|
||||
RETURN distinct 'done' AS result
|
||||
`
|
||||
nodeData := []map[string]interface{}{}
|
||||
for _, node := range graph.Node {
|
||||
nodeData = append(nodeData, map[string]interface{}{
|
||||
"name": node.Name,
|
||||
"knowledge_id": namespace.Knowledge,
|
||||
"props": map[string][]string{"attributes": node.Attributes},
|
||||
"chunks": node.Chunks,
|
||||
"labels": n.Labels(namespace),
|
||||
})
|
||||
}
|
||||
if _, err := tx.Run(ctx, node_import_query, map[string]interface{}{"data": nodeData}); err != nil {
|
||||
return nil, fmt.Errorf("failed to create nodes: %v", err)
|
||||
}
|
||||
|
||||
rel_import_query := `
|
||||
UNWIND $data AS row
|
||||
CALL apoc.merge.node(row.source_labels, {name: row.source, kg: row.knowledge_id}, {}, {}) YIELD node as source
|
||||
CALL apoc.merge.node(row.target_labels, {name: row.target, kg: row.knowledge_id}, {}, {}) YIELD node as target
|
||||
CALL apoc.merge.relationship(source, row.type, {}, row.attributes, target) YIELD rel
|
||||
RETURN distinct 'done'
|
||||
`
|
||||
relData := []map[string]interface{}{}
|
||||
for _, rel := range graph.Relation {
|
||||
relData = append(relData, map[string]interface{}{
|
||||
"source": rel.Node1,
|
||||
"target": rel.Node2,
|
||||
"knowledge_id": namespace.Knowledge,
|
||||
"type": rel.Type,
|
||||
"source_labels": n.Labels(namespace),
|
||||
"target_labels": n.Labels(namespace),
|
||||
})
|
||||
}
|
||||
if _, err := tx.Run(ctx, rel_import_query, map[string]interface{}{"data": relData}); err != nil {
|
||||
return nil, fmt.Errorf("failed to create relationships: %v", err)
|
||||
}
|
||||
return nil, nil
|
||||
})
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "failed to add graph: %v", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DelGraph implements interfaces.RetrieveGraphRepository.
|
||||
func (n *Neo4jRepository) DelGraph(ctx context.Context, namespaces []types.NameSpace) error {
|
||||
if n.driver == nil {
|
||||
logger.Warnf(ctx, "NOT SUPPORT RETRIEVE GRAPH")
|
||||
return nil
|
||||
}
|
||||
session := n.driver.NewSession(ctx, neo4j.SessionConfig{AccessMode: neo4j.AccessModeWrite})
|
||||
defer session.Close(ctx)
|
||||
|
||||
result, err := session.ExecuteWrite(ctx, func(tx neo4j.ManagedTransaction) (interface{}, error) {
|
||||
for _, namespace := range namespaces {
|
||||
labelExpr := n.Label(namespace)
|
||||
|
||||
deleteRelsQuery := `
|
||||
CALL apoc.periodic.iterate(
|
||||
"MATCH (n:` + labelExpr + ` {kg: $knowledge_id})-[r]-(m:` + labelExpr + ` {kg: $knowledge_id}) RETURN r",
|
||||
"DELETE r",
|
||||
{batchSize: 1000, parallel: true, params: {knowledge_id: $knowledge_id}}
|
||||
) YIELD batches, total
|
||||
RETURN total
|
||||
`
|
||||
if _, err := tx.Run(ctx, deleteRelsQuery, map[string]interface{}{"knowledge_id": namespace.Knowledge}); err != nil {
|
||||
return nil, fmt.Errorf("failed to delete relationships: %v", err)
|
||||
}
|
||||
|
||||
deleteNodesQuery := `
|
||||
CALL apoc.periodic.iterate(
|
||||
"MATCH (n:` + labelExpr + ` {kg: $knowledge_id}) RETURN n",
|
||||
"DELETE n",
|
||||
{batchSize: 1000, parallel: true, params: {knowledge_id: $knowledge_id}}
|
||||
) YIELD batches, total
|
||||
RETURN total
|
||||
`
|
||||
if _, err := tx.Run(ctx, deleteNodesQuery, map[string]interface{}{"knowledge_id": namespace.Knowledge}); err != nil {
|
||||
return nil, fmt.Errorf("failed to delete nodes: %v", err)
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
logger.Infof(ctx, "delete graph result: %v", result)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *Neo4jRepository) SearchNode(ctx context.Context, namespace types.NameSpace, nodes []string) (*types.GraphData, error) {
|
||||
if n.driver == nil {
|
||||
logger.Warnf(ctx, "NOT SUPPORT RETRIEVE GRAPH")
|
||||
return nil, nil
|
||||
}
|
||||
session := n.driver.NewSession(ctx, neo4j.SessionConfig{AccessMode: neo4j.AccessModeRead})
|
||||
defer session.Close(ctx)
|
||||
|
||||
result, err := session.ExecuteRead(ctx, func(tx neo4j.ManagedTransaction) (interface{}, error) {
|
||||
labelExpr := n.Label(namespace)
|
||||
query := `
|
||||
MATCH (n:` + labelExpr + `)-[r]-(m:` + labelExpr + `)
|
||||
WHERE ANY(nodeText IN $nodes WHERE n.name CONTAINS nodeText)
|
||||
RETURN n, r, m
|
||||
`
|
||||
params := map[string]interface{}{"nodes": nodes}
|
||||
result, err := tx.Run(ctx, query, params)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to run query: %v", err)
|
||||
}
|
||||
|
||||
graphData := &types.GraphData{}
|
||||
nodeSeen := make(map[string]bool)
|
||||
for result.Next(ctx) {
|
||||
record := result.Record()
|
||||
node, _ := record.Get("n")
|
||||
rel, _ := record.Get("r")
|
||||
targetNode, _ := record.Get("m")
|
||||
|
||||
nodeData := node.(neo4j.Node)
|
||||
targetNodeData := targetNode.(neo4j.Node)
|
||||
|
||||
// Convert node to types.Node
|
||||
for _, n := range []neo4j.Node{nodeData, targetNodeData} {
|
||||
nameStr := n.Props["name"].(string)
|
||||
if _, ok := nodeSeen[nameStr]; !ok {
|
||||
nodeSeen[nameStr] = true
|
||||
graphData.Node = append(graphData.Node, &types.GraphNode{
|
||||
Name: nameStr,
|
||||
Chunks: listI2listS(n.Props["chunks"].([]interface{})),
|
||||
Attributes: listI2listS(n.Props["attributes"].([]interface{})),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Convert relationship to types.Relation
|
||||
relData := rel.(neo4j.Relationship)
|
||||
graphData.Relation = append(graphData.Relation, &types.GraphRelation{
|
||||
Node1: nodeData.Props["name"].(string),
|
||||
Node2: targetNodeData.Props["name"].(string),
|
||||
Type: relData.Type,
|
||||
})
|
||||
}
|
||||
return graphData, nil
|
||||
})
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "search node failed: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
return result.(*types.GraphData), nil
|
||||
}
|
||||
|
||||
func listI2listS(list []any) []string {
|
||||
result := make([]string, len(list))
|
||||
for i, v := range list {
|
||||
result[i] = fmt.Sprintf("%v", v)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func mapI2mapS(prop map[string]any) map[string]string {
|
||||
attributes := make(map[string]string)
|
||||
for k, v := range prop {
|
||||
attributes[k] = fmt.Sprintf("%v", v)
|
||||
}
|
||||
return attributes
|
||||
}
|
||||
154
internal/application/repository/user.go
Normal file
154
internal/application/repository/user.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/logger"
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
ErrUserAlreadyExists = errors.New("user already exists")
|
||||
ErrTokenNotFound = errors.New("token not found")
|
||||
)
|
||||
|
||||
// userRepository implements user repository interface
|
||||
type userRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewUserRepository creates a new user repository
|
||||
func NewUserRepository(db *gorm.DB) interfaces.UserRepository {
|
||||
return &userRepository{db: db}
|
||||
}
|
||||
|
||||
// CreateUser creates a user
|
||||
func (r *userRepository) CreateUser(ctx context.Context, user *types.User) error {
|
||||
logger.Infof(ctx, "Creating user in database: %s", user.Email)
|
||||
return r.db.WithContext(ctx).Create(user).Error
|
||||
}
|
||||
|
||||
// GetUserByID gets a user by ID
|
||||
func (r *userRepository) GetUserByID(ctx context.Context, id string) (*types.User, error) {
|
||||
var user types.User
|
||||
if err := r.db.WithContext(ctx).Where("id = ?", id).First(&user).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// GetUserByEmail gets a user by email
|
||||
func (r *userRepository) GetUserByEmail(ctx context.Context, email string) (*types.User, error) {
|
||||
var user types.User
|
||||
if err := r.db.WithContext(ctx).Where("email = ?", email).First(&user).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// GetUserByUsername gets a user by username
|
||||
func (r *userRepository) GetUserByUsername(ctx context.Context, username string) (*types.User, error) {
|
||||
var user types.User
|
||||
if err := r.db.WithContext(ctx).Where("username = ?", username).First(&user).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// UpdateUser updates a user
|
||||
func (r *userRepository) UpdateUser(ctx context.Context, user *types.User) error {
|
||||
return r.db.WithContext(ctx).Save(user).Error
|
||||
}
|
||||
|
||||
// DeleteUser deletes a user
|
||||
func (r *userRepository) DeleteUser(ctx context.Context, id string) error {
|
||||
return r.db.WithContext(ctx).Where("id = ?", id).Delete(&types.User{}).Error
|
||||
}
|
||||
|
||||
// ListUsers lists users with pagination
|
||||
func (r *userRepository) ListUsers(ctx context.Context, offset, limit int) ([]*types.User, error) {
|
||||
var users []*types.User
|
||||
query := r.db.WithContext(ctx).Order("created_at DESC")
|
||||
|
||||
if limit > 0 {
|
||||
query = query.Limit(limit)
|
||||
}
|
||||
|
||||
if offset > 0 {
|
||||
query = query.Offset(offset)
|
||||
}
|
||||
|
||||
if err := query.Find(&users).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// authTokenRepository implements auth token repository interface
|
||||
type authTokenRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewAuthTokenRepository creates a new auth token repository
|
||||
func NewAuthTokenRepository(db *gorm.DB) interfaces.AuthTokenRepository {
|
||||
return &authTokenRepository{db: db}
|
||||
}
|
||||
|
||||
// CreateToken creates an auth token
|
||||
func (r *authTokenRepository) CreateToken(ctx context.Context, token *types.AuthToken) error {
|
||||
return r.db.WithContext(ctx).Create(token).Error
|
||||
}
|
||||
|
||||
// GetTokenByValue gets a token by its value
|
||||
func (r *authTokenRepository) GetTokenByValue(ctx context.Context, tokenValue string) (*types.AuthToken, error) {
|
||||
var token types.AuthToken
|
||||
if err := r.db.WithContext(ctx).Where("token = ?", tokenValue).First(&token).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrTokenNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// GetTokensByUserID gets all tokens for a user
|
||||
func (r *authTokenRepository) GetTokensByUserID(ctx context.Context, userID string) ([]*types.AuthToken, error) {
|
||||
var tokens []*types.AuthToken
|
||||
if err := r.db.WithContext(ctx).Where("user_id = ?", userID).Find(&tokens).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
// UpdateToken updates a token
|
||||
func (r *authTokenRepository) UpdateToken(ctx context.Context, token *types.AuthToken) error {
|
||||
return r.db.WithContext(ctx).Save(token).Error
|
||||
}
|
||||
|
||||
// DeleteToken deletes a token
|
||||
func (r *authTokenRepository) DeleteToken(ctx context.Context, id string) error {
|
||||
return r.db.WithContext(ctx).Where("id = ?", id).Delete(&types.AuthToken{}).Error
|
||||
}
|
||||
|
||||
// DeleteExpiredTokens deletes all expired tokens
|
||||
func (r *authTokenRepository) DeleteExpiredTokens(ctx context.Context) error {
|
||||
return r.db.WithContext(ctx).Where("expires_at < NOW()").Delete(&types.AuthToken{}).Error
|
||||
}
|
||||
|
||||
// RevokeTokensByUserID revokes all tokens for a user
|
||||
func (r *authTokenRepository) RevokeTokensByUserID(ctx context.Context, userID string) error {
|
||||
return r.db.WithContext(ctx).Model(&types.AuthToken{}).Where("user_id = ?", userID).Update("is_revoked", true).Error
|
||||
}
|
||||
499
internal/application/service/chat_pipline/extract_entity.go
Normal file
499
internal/application/service/chat_pipline/extract_entity.go
Normal file
@@ -0,0 +1,499 @@
|
||||
package chatpipline
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/config"
|
||||
"github.com/Tencent/WeKnora/internal/logger"
|
||||
"github.com/Tencent/WeKnora/internal/models/chat"
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
||||
)
|
||||
|
||||
// PluginExtractEntity is a plugin for extracting entities from user queries
|
||||
// It uses historical dialog context and large language models to identify key entities in the user's original query
|
||||
type PluginExtractEntity struct {
|
||||
modelService interfaces.ModelService // Model service for calling large language models
|
||||
template *types.PromptTemplateStructured // Template for generating prompts
|
||||
knowledgeBaseRepo interfaces.KnowledgeBaseRepository
|
||||
}
|
||||
|
||||
// NewPluginRewrite creates a new query rewriting plugin instance
|
||||
// Also registers the plugin with the event manager
|
||||
func NewPluginExtractEntity(
|
||||
eventManager *EventManager,
|
||||
modelService interfaces.ModelService,
|
||||
knowledgeBaseRepo interfaces.KnowledgeBaseRepository,
|
||||
config *config.Config,
|
||||
) *PluginExtractEntity {
|
||||
res := &PluginExtractEntity{
|
||||
modelService: modelService,
|
||||
template: config.ExtractManager.ExtractEntity,
|
||||
knowledgeBaseRepo: knowledgeBaseRepo,
|
||||
}
|
||||
eventManager.Register(res)
|
||||
return res
|
||||
}
|
||||
|
||||
// ActivationEvents returns the list of event types this plugin responds to
|
||||
// This plugin only responds to REWRITE_QUERY events
|
||||
func (p *PluginExtractEntity) ActivationEvents() []types.EventType {
|
||||
return []types.EventType{types.REWRITE_QUERY}
|
||||
}
|
||||
|
||||
// OnEvent processes triggered events
|
||||
// When receiving a REWRITE_QUERY event, it rewrites the user query using conversation history and the language model
|
||||
func (p *PluginExtractEntity) OnEvent(ctx context.Context,
|
||||
eventType types.EventType, chatManage *types.ChatManage, next func() *PluginError,
|
||||
) *PluginError {
|
||||
if strings.ToLower(os.Getenv("NEO4J_ENABLE")) != "true" {
|
||||
logger.Debugf(ctx, "skipping extract entity, neo4j is disabled")
|
||||
return next()
|
||||
}
|
||||
|
||||
query := chatManage.Query
|
||||
|
||||
model, err := p.modelService.GetChatModel(ctx, chatManage.ChatModelID)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "Failed to get model, session_id: %s, error: %v", chatManage.SessionID, err)
|
||||
return next()
|
||||
}
|
||||
|
||||
kb, err := p.knowledgeBaseRepo.GetKnowledgeBaseByID(ctx, chatManage.KnowledgeBaseID)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "failed to get knowledge base: %v", err)
|
||||
return next()
|
||||
}
|
||||
if kb.ExtractConfig == nil {
|
||||
logger.Warnf(ctx, "failed to get extract config")
|
||||
return next()
|
||||
}
|
||||
|
||||
template := &types.PromptTemplateStructured{
|
||||
Description: p.template.Description,
|
||||
Examples: p.template.Examples,
|
||||
}
|
||||
extractor := NewExtractor(model, template)
|
||||
graph, err := extractor.Extract(ctx, query)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "Failed to extract entities, session_id: %s, error: %v", chatManage.SessionID, err)
|
||||
return next()
|
||||
}
|
||||
nodes := []string{}
|
||||
for _, node := range graph.Node {
|
||||
nodes = append(nodes, node.Name)
|
||||
}
|
||||
logger.Debugf(ctx, "extracted node: %v", nodes)
|
||||
chatManage.Entity = nodes
|
||||
return next()
|
||||
}
|
||||
|
||||
type Extractor struct {
|
||||
chat chat.Chat
|
||||
formater *Formater
|
||||
template *types.PromptTemplateStructured
|
||||
chatOpt *chat.ChatOptions
|
||||
}
|
||||
|
||||
func NewExtractor(
|
||||
chatModel chat.Chat,
|
||||
template *types.PromptTemplateStructured,
|
||||
) Extractor {
|
||||
think := false
|
||||
return Extractor{
|
||||
chat: chatModel,
|
||||
formater: NewFormater(),
|
||||
template: template,
|
||||
chatOpt: &chat.ChatOptions{
|
||||
Temperature: 0.3,
|
||||
MaxTokens: 4096,
|
||||
Thinking: &think,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Extractor) Extract(ctx context.Context, content string) (*types.GraphData, error) {
|
||||
generator := NewQAPromptGenerator(e.formater, e.template)
|
||||
|
||||
// logger.Debugf(ctx, "chat system: %s", generator.System(ctx))
|
||||
// logger.Debugf(ctx, "chat user: %s", generator.User(ctx, content))
|
||||
|
||||
chatResponse, err := e.chat.Chat(ctx, generator.Render(ctx, content), e.chatOpt)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "failed to chat: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
graph, err := e.formater.ParseGraph(ctx, chatResponse.Content)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "failed to parse graph: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
// e.RemoveUnknownRelation(ctx, graph)
|
||||
return graph, nil
|
||||
}
|
||||
|
||||
func (e *Extractor) RemoveUnknownRelation(ctx context.Context, graph *types.GraphData) {
|
||||
relationType := make(map[string]bool)
|
||||
for _, tag := range e.template.Tags {
|
||||
relationType[tag] = true
|
||||
}
|
||||
|
||||
relationNew := make([]*types.GraphRelation, 0)
|
||||
for _, relation := range graph.Relation {
|
||||
if _, ok := relationType[relation.Type]; ok {
|
||||
relationNew = append(relationNew, relation)
|
||||
} else {
|
||||
logger.Infof(ctx, "Unknown relation type %s with %v, ignore it", relation.Type, e.template.Tags)
|
||||
}
|
||||
}
|
||||
graph.Relation = relationNew
|
||||
}
|
||||
|
||||
type QAPromptGenerator struct {
|
||||
Formater *Formater
|
||||
Template *types.PromptTemplateStructured
|
||||
ExamplesHeading string
|
||||
QuestionHeading string
|
||||
QuestionPrefix string
|
||||
AnswerPrefix string
|
||||
}
|
||||
|
||||
func NewQAPromptGenerator(formater *Formater, template *types.PromptTemplateStructured) *QAPromptGenerator {
|
||||
return &QAPromptGenerator{
|
||||
Formater: formater,
|
||||
Template: template,
|
||||
ExamplesHeading: "# Examples",
|
||||
QuestionHeading: "# Question",
|
||||
QuestionPrefix: "Q: ",
|
||||
AnswerPrefix: "A: ",
|
||||
}
|
||||
}
|
||||
|
||||
func (qa *QAPromptGenerator) System(ctx context.Context) string {
|
||||
promptLines := []string{}
|
||||
|
||||
if len(qa.Template.Tags) == 0 {
|
||||
promptLines = append(promptLines, qa.Template.Description)
|
||||
} else {
|
||||
tags, _ := json.Marshal(qa.Template.Tags)
|
||||
promptLines = append(promptLines, fmt.Sprintf(qa.Template.Description, string(tags)))
|
||||
}
|
||||
if len(qa.Template.Examples) > 0 {
|
||||
promptLines = append(promptLines, qa.ExamplesHeading)
|
||||
for _, example := range qa.Template.Examples {
|
||||
// Question
|
||||
promptLines = append(promptLines, fmt.Sprintf("%s%s", qa.QuestionPrefix, strings.TrimSpace(example.Text)))
|
||||
|
||||
// Answer
|
||||
answer, err := qa.Formater.formatExtraction(example.Node, example.Relation)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
promptLines = append(promptLines, fmt.Sprintf("%s%s", qa.AnswerPrefix, answer))
|
||||
|
||||
// new line
|
||||
promptLines = append(promptLines, "")
|
||||
}
|
||||
}
|
||||
return strings.Join(promptLines, "\n")
|
||||
}
|
||||
|
||||
func (qa *QAPromptGenerator) User(ctx context.Context, question string) string {
|
||||
promptLines := []string{}
|
||||
promptLines = append(promptLines, qa.QuestionHeading)
|
||||
promptLines = append(promptLines, fmt.Sprintf("%s%s", qa.QuestionPrefix, question))
|
||||
promptLines = append(promptLines, qa.AnswerPrefix)
|
||||
return strings.Join(promptLines, "\n")
|
||||
}
|
||||
|
||||
func (qa *QAPromptGenerator) Render(ctx context.Context, question string) []chat.Message {
|
||||
return []chat.Message{
|
||||
{
|
||||
Role: "system",
|
||||
Content: qa.System(ctx),
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: qa.User(ctx, question),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type FormatType string
|
||||
|
||||
const (
|
||||
FormatTypeJSON FormatType = "json"
|
||||
FormatTypeYAML FormatType = "yaml"
|
||||
)
|
||||
|
||||
const (
|
||||
_FENCE_START = "```"
|
||||
_LANGUAGE_TAG = `(?P<lang>[A-Za-z0-9_+-]+)?`
|
||||
_FENCE_NEWLINE = `(?:\s*\n)?`
|
||||
_FENCE_BODY = `(?P<body>[\s\S]*?)`
|
||||
_FENCE_END = "```"
|
||||
)
|
||||
|
||||
var _FENCE_RE = regexp.MustCompile(
|
||||
_FENCE_START + _LANGUAGE_TAG + _FENCE_NEWLINE + _FENCE_BODY + _FENCE_END,
|
||||
)
|
||||
|
||||
type Formater struct {
|
||||
attributeSuffix string
|
||||
formatType FormatType
|
||||
useFences bool
|
||||
nodePrefix string
|
||||
|
||||
relationSource string
|
||||
relationTarget string
|
||||
relationPrefix string
|
||||
}
|
||||
|
||||
func NewFormater() *Formater {
|
||||
return &Formater{
|
||||
attributeSuffix: "_attributes",
|
||||
formatType: FormatTypeJSON,
|
||||
useFences: true,
|
||||
nodePrefix: "entity",
|
||||
relationSource: "entity1",
|
||||
relationTarget: "entity2",
|
||||
relationPrefix: "relation",
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Formater) formatExtraction(nodes []*types.GraphNode, relations []*types.GraphRelation) (string, error) {
|
||||
items := make([]map[string]interface{}, 0)
|
||||
for _, node := range nodes {
|
||||
item := map[string]interface{}{
|
||||
f.nodePrefix: node.Name,
|
||||
}
|
||||
if len(node.Attributes) > 0 {
|
||||
item[fmt.Sprintf("%s%s", f.nodePrefix, f.attributeSuffix)] = node.Attributes
|
||||
}
|
||||
items = append(items, item)
|
||||
}
|
||||
for _, relation := range relations {
|
||||
item := map[string]interface{}{
|
||||
f.relationSource: relation.Node1,
|
||||
f.relationTarget: relation.Node2,
|
||||
f.relationPrefix: relation.Type,
|
||||
}
|
||||
items = append(items, item)
|
||||
}
|
||||
formatted := ""
|
||||
switch f.formatType {
|
||||
default:
|
||||
formattedBytes, err := json.MarshalIndent(items, "", " ")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
formatted = string(formattedBytes)
|
||||
}
|
||||
if f.useFences {
|
||||
formatted = f.addFences(formatted)
|
||||
}
|
||||
return formatted, nil
|
||||
}
|
||||
|
||||
func (f *Formater) parseOutput(ctx context.Context, text string) ([]map[string]interface{}, error) {
|
||||
if text == "" {
|
||||
return nil, errors.New("Empty or invalid input string.")
|
||||
}
|
||||
content := f.extractContent(ctx, text)
|
||||
// logger.Debugf(ctx, "Extracted content: %s", content)
|
||||
if content == "" {
|
||||
return nil, errors.New("Empty or invalid input string.")
|
||||
}
|
||||
|
||||
var parsed interface{}
|
||||
var err error
|
||||
if f.formatType == FormatTypeJSON {
|
||||
err = json.Unmarshal([]byte(content), &parsed)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to parse %s content: %s", strings.ToUpper(string(f.formatType)), err.Error())
|
||||
}
|
||||
if parsed == nil {
|
||||
return nil, fmt.Errorf("Content must be a list of extractions or a dict.")
|
||||
}
|
||||
|
||||
var items []interface{}
|
||||
if parsedMap, ok := parsed.(map[string]interface{}); ok {
|
||||
items = []interface{}{parsedMap}
|
||||
} else if parsedList, ok := parsed.([]interface{}); ok {
|
||||
items = parsedList
|
||||
} else {
|
||||
return nil, fmt.Errorf("Expected list or dict, got %T", parsed)
|
||||
}
|
||||
|
||||
itemsList := make([]map[string]interface{}, 0)
|
||||
for _, item := range items {
|
||||
if itemMap, ok := item.(map[string]interface{}); ok {
|
||||
itemsList = append(itemsList, itemMap)
|
||||
} else {
|
||||
return nil, fmt.Errorf("Each item in the sequence must be a mapping.")
|
||||
}
|
||||
}
|
||||
return itemsList, nil
|
||||
}
|
||||
|
||||
func (f *Formater) ParseGraph(ctx context.Context, text string) (*types.GraphData, error) {
|
||||
matchData, err := f.parseOutput(ctx, text)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(matchData) == 0 {
|
||||
logger.Debugf(ctx, "Received empty extraction data.")
|
||||
return &types.GraphData{}, nil
|
||||
}
|
||||
// mm, _ := json.Marshal(matchData)
|
||||
// logger.Debugf(ctx, "Parsed graph data: %s", string(mm))
|
||||
|
||||
var nodes []*types.GraphNode
|
||||
var relations []*types.GraphRelation
|
||||
|
||||
for _, group := range matchData {
|
||||
switch {
|
||||
case group[f.nodePrefix] != nil:
|
||||
attributes := make([]string, 0)
|
||||
attributesKey := f.nodePrefix + f.attributeSuffix
|
||||
if attr, ok := group[attributesKey].([]interface{}); ok {
|
||||
for _, v := range attr {
|
||||
attributes = append(attributes, fmt.Sprintf("%v", v))
|
||||
}
|
||||
}
|
||||
nodes = append(nodes, &types.GraphNode{
|
||||
Name: fmt.Sprintf("%v", group[f.nodePrefix]),
|
||||
Attributes: attributes,
|
||||
})
|
||||
case group[f.relationSource] != nil && group[f.relationTarget] != nil:
|
||||
relations = append(relations, &types.GraphRelation{
|
||||
Node1: fmt.Sprintf("%v", group[f.relationSource]),
|
||||
Node2: fmt.Sprintf("%v", group[f.relationTarget]),
|
||||
Type: fmt.Sprintf("%v", group[f.relationPrefix]),
|
||||
})
|
||||
default:
|
||||
logger.Warnf(ctx, "Unsupported graph group: %v", group)
|
||||
continue
|
||||
}
|
||||
}
|
||||
graph := &types.GraphData{
|
||||
Node: nodes,
|
||||
Relation: relations,
|
||||
}
|
||||
f.rebuildGraph(ctx, graph)
|
||||
return graph, nil
|
||||
}
|
||||
|
||||
func (f *Formater) rebuildGraph(ctx context.Context, graph *types.GraphData) {
|
||||
nodeMap := make(map[string]*types.GraphNode)
|
||||
nodes := make([]*types.GraphNode, 0, len(graph.Node))
|
||||
for _, node := range graph.Node {
|
||||
if prenode, ok := nodeMap[node.Name]; ok {
|
||||
logger.Infof(ctx, "Duplicate node ID: %s, merge attribute", node.Name)
|
||||
// 修复panic:检查Attributes是否为nil
|
||||
if node.Attributes == nil {
|
||||
node.Attributes = make([]string, 0)
|
||||
}
|
||||
if prenode.Attributes != nil {
|
||||
for _, attr := range prenode.Attributes {
|
||||
node.Attributes = append(node.Attributes, attr)
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
nodeMap[node.Name] = node
|
||||
nodes = append(nodes, node)
|
||||
}
|
||||
|
||||
relations := make([]*types.GraphRelation, 0, len(graph.Relation))
|
||||
for _, relation := range graph.Relation {
|
||||
if relation.Node1 == relation.Node2 {
|
||||
logger.Infof(ctx, "Duplicate relation, ignore it")
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := nodeMap[relation.Node1]; !ok {
|
||||
node := &types.GraphNode{Name: relation.Node1}
|
||||
nodes = append(nodes, node)
|
||||
nodeMap[relation.Node1] = node
|
||||
logger.Infof(ctx, "Add unknown source node ID: %s", relation.Node1)
|
||||
}
|
||||
if _, ok := nodeMap[relation.Node2]; !ok {
|
||||
node := &types.GraphNode{Name: relation.Node2}
|
||||
nodes = append(nodes, node)
|
||||
nodeMap[relation.Node2] = node
|
||||
logger.Infof(ctx, "Add unknown target node ID: %s", relation.Node2)
|
||||
}
|
||||
|
||||
relations = append(relations, relation)
|
||||
}
|
||||
*graph = types.GraphData{
|
||||
Node: nodes,
|
||||
Relation: relations,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Formater) extractContent(ctx context.Context, text string) string {
|
||||
if !f.useFences {
|
||||
return strings.TrimSpace(text)
|
||||
}
|
||||
validTags := map[FormatType]map[string]struct{}{
|
||||
FormatTypeYAML: {"yaml": {}, "yml": {}},
|
||||
FormatTypeJSON: {"json": {}},
|
||||
}
|
||||
matches := _FENCE_RE.FindAllStringSubmatch(text, -1)
|
||||
var candidates []string
|
||||
for _, match := range matches {
|
||||
lang := match[1]
|
||||
body := match[2]
|
||||
if f.isValidLanguageTag(lang, validTags) {
|
||||
candidates = append(candidates, body)
|
||||
}
|
||||
}
|
||||
switch {
|
||||
case len(candidates) == 1:
|
||||
return strings.TrimSpace(candidates[0])
|
||||
|
||||
case len(candidates) > 1:
|
||||
logger.Warnf(ctx, "multiple candidates found: %d", len(candidates))
|
||||
return strings.TrimSpace(candidates[0])
|
||||
|
||||
case len(matches) == 1:
|
||||
logger.Debugf(ctx, "no candidate found, use first match without language tag: %s", matches[0][1])
|
||||
return strings.TrimSpace(matches[0][2])
|
||||
|
||||
case len(matches) > 1:
|
||||
logger.Warnf(ctx, "multiple matches found: %d", len(matches))
|
||||
return strings.TrimSpace(matches[0][2])
|
||||
|
||||
default:
|
||||
logger.Warnf(ctx, "no match found")
|
||||
return strings.TrimSpace(text)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Formater) addFences(content string) string {
|
||||
content = strings.TrimSpace(content)
|
||||
return fmt.Sprintf("```%s\n%s\n```", f.formatType, content)
|
||||
}
|
||||
|
||||
func (f *Formater) isValidLanguageTag(lang string, validTags map[FormatType]map[string]struct{}) bool {
|
||||
if lang == "" {
|
||||
return true
|
||||
}
|
||||
tag := strings.TrimSpace(strings.ToLower(lang))
|
||||
validSet, ok := validTags[f.formatType]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
_, exists := validSet[tag]
|
||||
return exists
|
||||
}
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/logger"
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
secutils "github.com/Tencent/WeKnora/internal/utils"
|
||||
)
|
||||
|
||||
// PluginIntoChatMessage handles the transformation of search results into chat messages
|
||||
@@ -50,9 +51,16 @@ func (p *PluginIntoChatMessage) OnEvent(ctx context.Context,
|
||||
weekdayName := []string{"星期日", "星期一", "星期二", "星期三", "星期四", "星期五", "星期六"}
|
||||
var userContent bytes.Buffer
|
||||
|
||||
// 验证用户查询的安全性
|
||||
safeQuery, isValid := secutils.ValidateInput(chatManage.Query)
|
||||
if !isValid {
|
||||
logger.Errorf(ctx, "Invalid user query: %s", chatManage.Query)
|
||||
return ErrTemplateExecute.WithError(fmt.Errorf("用户查询包含非法内容"))
|
||||
}
|
||||
|
||||
// Execute template with context data
|
||||
err = tmpl.Execute(&userContent, map[string]interface{}{
|
||||
"Query": chatManage.Query, // User's original query
|
||||
"Query": safeQuery, // User's original query
|
||||
"Contexts": passages, // Extracted passages from search results
|
||||
"CurrentTime": time.Now().Format("2006-01-02 15:04:05"), // Formatted current time
|
||||
"CurrentWeek": weekdayName[time.Now().Weekday()], // Current weekday in Chinese
|
||||
|
||||
@@ -77,6 +77,7 @@ func (p *PluginSearch) OnEvent(ctx context.Context,
|
||||
}
|
||||
chatManage.SearchResult = append(chatManage.SearchResult, searchResults...)
|
||||
}
|
||||
|
||||
// remove duplicate results
|
||||
chatManage.SearchResult = removeDuplicateResults(chatManage.SearchResult)
|
||||
|
||||
|
||||
136
internal/application/service/chat_pipline/search_entity.go
Normal file
136
internal/application/service/chat_pipline/search_entity.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package chatpipline
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/logger"
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
||||
)
|
||||
|
||||
// PluginSearch implements search functionality for chat pipeline
|
||||
type PluginSearchEntity struct {
|
||||
graphRepo interfaces.RetrieveGraphRepository
|
||||
chunkRepo interfaces.ChunkRepository
|
||||
knowledgeRepo interfaces.KnowledgeRepository
|
||||
}
|
||||
|
||||
func NewPluginSearchEntity(
|
||||
eventManager *EventManager,
|
||||
graphRepository interfaces.RetrieveGraphRepository,
|
||||
chunkRepository interfaces.ChunkRepository,
|
||||
knowledgeRepository interfaces.KnowledgeRepository,
|
||||
) *PluginSearchEntity {
|
||||
res := &PluginSearchEntity{
|
||||
graphRepo: graphRepository,
|
||||
chunkRepo: chunkRepository,
|
||||
knowledgeRepo: knowledgeRepository,
|
||||
}
|
||||
eventManager.Register(res)
|
||||
return res
|
||||
}
|
||||
|
||||
// ActivationEvents returns the event types this plugin handles
|
||||
func (p *PluginSearchEntity) ActivationEvents() []types.EventType {
|
||||
return []types.EventType{types.ENTITY_SEARCH}
|
||||
}
|
||||
|
||||
// OnEvent handles search events in the chat pipeline
|
||||
func (p *PluginSearchEntity) OnEvent(ctx context.Context,
|
||||
eventType types.EventType, chatManage *types.ChatManage, next func() *PluginError,
|
||||
) *PluginError {
|
||||
entity := chatManage.Entity
|
||||
if len(entity) == 0 {
|
||||
logger.Infof(ctx, "No entity found")
|
||||
return next()
|
||||
}
|
||||
|
||||
graph, err := p.graphRepo.SearchNode(ctx, types.NameSpace{KnowledgeBase: chatManage.KnowledgeBaseID}, entity)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "Failed to search node, session_id: %s, error: %v", chatManage.SessionID, err)
|
||||
return next()
|
||||
}
|
||||
chatManage.GraphResult = graph
|
||||
logger.Infof(ctx, "search entity result count: %d", len(graph.Node))
|
||||
// graphStr, _ := json.Marshal(graph)
|
||||
// logger.Debugf(ctx, "search entity result: %s", string(graphStr))
|
||||
|
||||
chunkIDs := filterSeenChunk(ctx, graph, chatManage.SearchResult)
|
||||
if len(chunkIDs) == 0 {
|
||||
logger.Infof(ctx, "No new chunk found")
|
||||
return next()
|
||||
}
|
||||
chunks, err := p.chunkRepo.ListChunksByID(ctx, ctx.Value(types.TenantIDContextKey).(uint), chunkIDs)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "Failed to list chunks, session_id: %s, error: %v", chatManage.SessionID, err)
|
||||
return next()
|
||||
}
|
||||
knowledgeIDs := []string{}
|
||||
for _, chunk := range chunks {
|
||||
knowledgeIDs = append(knowledgeIDs, chunk.KnowledgeID)
|
||||
}
|
||||
knowledges, err := p.knowledgeRepo.GetKnowledgeBatch(ctx, ctx.Value(types.TenantIDContextKey).(uint), knowledgeIDs)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "Failed to list knowledge, session_id: %s, error: %v", chatManage.SessionID, err)
|
||||
return next()
|
||||
}
|
||||
|
||||
knowledgeMap := map[string]*types.Knowledge{}
|
||||
for _, knowledge := range knowledges {
|
||||
knowledgeMap[knowledge.ID] = knowledge
|
||||
}
|
||||
for _, chunk := range chunks {
|
||||
searchResult := chunk2SearchResult(chunk, knowledgeMap[chunk.KnowledgeID])
|
||||
chatManage.SearchResult = append(chatManage.SearchResult, searchResult)
|
||||
}
|
||||
// remove duplicate results
|
||||
chatManage.SearchResult = removeDuplicateResults(chatManage.SearchResult)
|
||||
if len(chatManage.SearchResult) == 0 {
|
||||
logger.Infof(ctx, "No new search result, session_id: %s", chatManage.SessionID)
|
||||
return ErrSearchNothing
|
||||
}
|
||||
logger.Infof(ctx, "search entity result count: %d, session_id: %s", len(chatManage.SearchResult), chatManage.SessionID)
|
||||
return next()
|
||||
}
|
||||
|
||||
func filterSeenChunk(ctx context.Context, graph *types.GraphData, searchResult []*types.SearchResult) []string {
|
||||
seen := map[string]bool{}
|
||||
for _, chunk := range searchResult {
|
||||
seen[chunk.ID] = true
|
||||
}
|
||||
logger.Infof(ctx, "filterSeenChunk: seen count: %d", len(seen))
|
||||
|
||||
chunkIDs := []string{}
|
||||
for _, node := range graph.Node {
|
||||
for _, chunkID := range node.Chunks {
|
||||
if seen[chunkID] {
|
||||
continue
|
||||
}
|
||||
seen[chunkID] = true
|
||||
chunkIDs = append(chunkIDs, chunkID)
|
||||
}
|
||||
}
|
||||
logger.Infof(ctx, "filterSeenChunk: new chunkIDs count: %d", len(chunkIDs))
|
||||
return chunkIDs
|
||||
}
|
||||
|
||||
func chunk2SearchResult(chunk *types.Chunk, knowledge *types.Knowledge) *types.SearchResult {
|
||||
return &types.SearchResult{
|
||||
ID: chunk.ID,
|
||||
Content: chunk.Content,
|
||||
KnowledgeID: chunk.KnowledgeID,
|
||||
ChunkIndex: chunk.ChunkIndex,
|
||||
KnowledgeTitle: knowledge.Title,
|
||||
StartAt: chunk.StartAt,
|
||||
EndAt: chunk.EndAt,
|
||||
Seq: chunk.ChunkIndex,
|
||||
Score: 1.0,
|
||||
MatchType: types.MatchTypeGraph,
|
||||
Metadata: knowledge.GetMetadata(),
|
||||
ChunkType: string(chunk.ChunkType),
|
||||
ParentChunkID: chunk.ParentChunkID,
|
||||
ImageInfo: chunk.ImageInfo,
|
||||
KnowledgeFilename: knowledge.FileName,
|
||||
KnowledgeSource: knowledge.Source,
|
||||
}
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -30,7 +31,7 @@ type EvaluationService struct {
|
||||
knowledgeBaseService interfaces.KnowledgeBaseService // Service for knowledge base operations
|
||||
knowledgeService interfaces.KnowledgeService // Service for knowledge operations
|
||||
sessionService interfaces.SessionService // Service for chat sessions
|
||||
testData *TestDataService // Service for test data
|
||||
modelService interfaces.ModelService // Service for model operations
|
||||
|
||||
evaluationMemoryStorage *evaluationMemoryStorage // In-memory storage for evaluation tasks
|
||||
}
|
||||
@@ -41,7 +42,7 @@ func NewEvaluationService(
|
||||
knowledgeBaseService interfaces.KnowledgeBaseService,
|
||||
knowledgeService interfaces.KnowledgeService,
|
||||
sessionService interfaces.SessionService,
|
||||
testData *TestDataService,
|
||||
modelService interfaces.ModelService,
|
||||
) interfaces.EvaluationService {
|
||||
evaluationMemoryStorage := newEvaluationMemoryStorage()
|
||||
return &EvaluationService{
|
||||
@@ -50,7 +51,7 @@ func NewEvaluationService(
|
||||
knowledgeBaseService: knowledgeBaseService,
|
||||
knowledgeService: knowledgeService,
|
||||
sessionService: sessionService,
|
||||
testData: testData,
|
||||
modelService: modelService,
|
||||
evaluationMemoryStorage: evaluationMemoryStorage,
|
||||
}
|
||||
}
|
||||
@@ -144,11 +145,32 @@ func (e *EvaluationService) Evaluation(ctx context.Context,
|
||||
if knowledgeBaseID == "" {
|
||||
logger.Info(ctx, "No knowledge base ID provided, creating new knowledge base")
|
||||
// Create new knowledge base with default evaluation settings
|
||||
// 获取默认的嵌入模型和LLM模型
|
||||
models, err := e.modelService.ListModels(ctx)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "Failed to list models: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var embeddingModelID, llmModelID string
|
||||
for _, model := range models {
|
||||
if model.Type == types.ModelTypeEmbedding {
|
||||
embeddingModelID = model.ID
|
||||
}
|
||||
if model.Type == types.ModelTypeKnowledgeQA {
|
||||
llmModelID = model.ID
|
||||
}
|
||||
}
|
||||
|
||||
if embeddingModelID == "" || llmModelID == "" {
|
||||
return nil, fmt.Errorf("no default models found for evaluation")
|
||||
}
|
||||
|
||||
kb, err := e.knowledgeBaseService.CreateKnowledgeBase(ctx, &types.KnowledgeBase{
|
||||
Name: "evaluation",
|
||||
Description: "evaluation",
|
||||
EmbeddingModelID: e.testData.EmbedModel.GetModelID(),
|
||||
SummaryModelID: e.testData.LLMModel.GetModelID(),
|
||||
EmbeddingModelID: embeddingModelID,
|
||||
SummaryModelID: llmModelID,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "Failed to create knowledge base: %v", err)
|
||||
@@ -186,12 +208,37 @@ func (e *EvaluationService) Evaluation(ctx context.Context,
|
||||
}
|
||||
|
||||
if rerankModelID == "" {
|
||||
rerankModelID = e.testData.RerankModel.GetModelID()
|
||||
logger.Infof(ctx, "Using default rerank model: %s", rerankModelID)
|
||||
// 获取默认的重排模型
|
||||
models, err := e.modelService.ListModels(ctx)
|
||||
if err == nil {
|
||||
for _, model := range models {
|
||||
if model.Type == types.ModelTypeRerank {
|
||||
rerankModelID = model.ID
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if rerankModelID == "" {
|
||||
logger.Warnf(ctx, "No rerank model found, skipping rerank")
|
||||
} else {
|
||||
logger.Infof(ctx, "Using default rerank model: %s", rerankModelID)
|
||||
}
|
||||
}
|
||||
|
||||
if chatModelID == "" {
|
||||
chatModelID = e.testData.LLMModel.GetModelID()
|
||||
// 获取默认的LLM模型
|
||||
models, err := e.modelService.ListModels(ctx)
|
||||
if err == nil {
|
||||
for _, model := range models {
|
||||
if model.Type == types.ModelTypeKnowledgeQA {
|
||||
chatModelID = model.ID
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if chatModelID == "" {
|
||||
return nil, fmt.Errorf("no default chat model found")
|
||||
}
|
||||
logger.Infof(ctx, "Using default chat model: %s", chatModelID)
|
||||
}
|
||||
|
||||
|
||||
137
internal/application/service/extract.go
Normal file
137
internal/application/service/extract.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
chatpipline "github.com/Tencent/WeKnora/internal/application/service/chat_pipline"
|
||||
"github.com/Tencent/WeKnora/internal/config"
|
||||
"github.com/Tencent/WeKnora/internal/logger"
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
||||
"github.com/google/uuid"
|
||||
"github.com/hibiken/asynq"
|
||||
)
|
||||
|
||||
func NewChunkExtractTask(ctx context.Context, client *asynq.Client, tenantID uint, chunkID string, modelID string) error {
|
||||
if strings.ToLower(os.Getenv("NEO4J_ENABLE")) != "true" {
|
||||
logger.Debugf(ctx, "NOT SUPPORT RETRIEVE GRAPH")
|
||||
return nil
|
||||
}
|
||||
payload, err := json.Marshal(types.ExtractChunkPayload{
|
||||
TenantID: tenantID,
|
||||
ChunkID: chunkID,
|
||||
ModelID: modelID,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
task := asynq.NewTask(types.TypeChunkExtract, payload, asynq.MaxRetry(3))
|
||||
info, err := client.Enqueue(task)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "failed to enqueue task: %v", err)
|
||||
return fmt.Errorf("failed to enqueue task: %v", err)
|
||||
}
|
||||
logger.Infof(ctx, "enqueued task: id=%s queue=%s chunk=%s", info.ID, info.Queue, chunkID)
|
||||
return nil
|
||||
}
|
||||
|
||||
type ChunkExtractService struct {
|
||||
template *types.PromptTemplateStructured
|
||||
modelService interfaces.ModelService
|
||||
knowledgeBaseRepo interfaces.KnowledgeBaseRepository
|
||||
chunkRepo interfaces.ChunkRepository
|
||||
graphEngine interfaces.RetrieveGraphRepository
|
||||
}
|
||||
|
||||
func NewChunkExtractService(
|
||||
config *config.Config,
|
||||
modelService interfaces.ModelService,
|
||||
knowledgeBaseRepo interfaces.KnowledgeBaseRepository,
|
||||
chunkRepo interfaces.ChunkRepository,
|
||||
graphEngine interfaces.RetrieveGraphRepository,
|
||||
) interfaces.Extracter {
|
||||
generator := chatpipline.NewQAPromptGenerator(chatpipline.NewFormater(), config.ExtractManager.ExtractGraph)
|
||||
ctx := context.Background()
|
||||
logger.Debugf(ctx, "chunk extract system prompt: %s", generator.System(ctx))
|
||||
logger.Debugf(ctx, "chunk extract user prompt: %s", generator.User(ctx, "demo"))
|
||||
return &ChunkExtractService{
|
||||
template: config.ExtractManager.ExtractGraph,
|
||||
modelService: modelService,
|
||||
knowledgeBaseRepo: knowledgeBaseRepo,
|
||||
chunkRepo: chunkRepo,
|
||||
graphEngine: graphEngine,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ChunkExtractService) Extract(ctx context.Context, t *asynq.Task) error {
|
||||
var p types.ExtractChunkPayload
|
||||
if err := json.Unmarshal(t.Payload(), &p); err != nil {
|
||||
logger.Errorf(ctx, "failed to unmarshal task payload: %v", err)
|
||||
return err
|
||||
}
|
||||
ctx = logger.WithRequestID(ctx, uuid.New().String())
|
||||
ctx = logger.WithField(ctx, "extract", p.ChunkID)
|
||||
ctx = context.WithValue(ctx, types.TenantIDContextKey, p.TenantID)
|
||||
|
||||
chunk, err := s.chunkRepo.GetChunkByID(ctx, p.TenantID, p.ChunkID)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "failed to get chunk: %v", err)
|
||||
return err
|
||||
}
|
||||
kb, err := s.knowledgeBaseRepo.GetKnowledgeBaseByID(ctx, chunk.KnowledgeBaseID)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "failed to get knowledge base: %v", err)
|
||||
return err
|
||||
}
|
||||
if kb.ExtractConfig == nil {
|
||||
logger.Warnf(ctx, "failed to get extract config")
|
||||
return err
|
||||
}
|
||||
|
||||
chatModel, err := s.modelService.GetChatModel(ctx, p.ModelID)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "failed to get chat model: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
template := &types.PromptTemplateStructured{
|
||||
Description: s.template.Description,
|
||||
Tags: kb.ExtractConfig.Tags,
|
||||
Examples: []types.GraphData{
|
||||
{
|
||||
Text: kb.ExtractConfig.Text,
|
||||
Node: kb.ExtractConfig.Nodes,
|
||||
Relation: kb.ExtractConfig.Relations,
|
||||
},
|
||||
},
|
||||
}
|
||||
extractor := chatpipline.NewExtractor(chatModel, template)
|
||||
graph, err := extractor.Extract(ctx, chunk.Content)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
chunk, err = s.chunkRepo.GetChunkByID(ctx, p.TenantID, p.ChunkID)
|
||||
if err != nil {
|
||||
logger.Warnf(ctx, "graph ignore chunk %s: %v", p.ChunkID, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, node := range graph.Node {
|
||||
node.Chunks = []string{chunk.ID}
|
||||
}
|
||||
if err = s.graphEngine.AddGraph(ctx,
|
||||
types.NameSpace{KnowledgeBase: chunk.KnowledgeBaseID, Knowledge: chunk.KnowledgeID},
|
||||
[]*types.GraphData{graph},
|
||||
); err != nil {
|
||||
logger.Errorf(ctx, "failed to add graph: %v", err)
|
||||
return err
|
||||
}
|
||||
// gg, _ := json.Marshal(graph)
|
||||
// logger.Infof(ctx, "extracted graph: %s", string(gg))
|
||||
return nil
|
||||
}
|
||||
@@ -23,8 +23,8 @@ type cosFileService struct {
|
||||
}
|
||||
|
||||
// NewCosFileService creates a new COS file service instance
|
||||
func NewCosFileService(appId, region, secretId, secretKey, cosPathPrefix string) (interfaces.FileService, error) {
|
||||
bucketURL := fmt.Sprintf("https://%s.cos.%s.myqcloud.com", appId, region)
|
||||
func NewCosFileService(bucketName, region, secretId, secretKey, cosPathPrefix string) (interfaces.FileService, error) {
|
||||
bucketURL := fmt.Sprintf("https://%s.cos.%s.myqcloud.com/", bucketName, region)
|
||||
u, err := url.Parse(bucketURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse bucketURL: %w", err)
|
||||
@@ -59,7 +59,7 @@ func (s *cosFileService) SaveFile(ctx context.Context,
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to upload file to COS: %w", err)
|
||||
}
|
||||
return fmt.Sprintf("https://%s/%s", s.bucketURL, objectName), nil
|
||||
return fmt.Sprintf("%s%s", s.bucketURL, objectName), nil
|
||||
}
|
||||
|
||||
// GetFile retrieves a file from COS storage by its path URL
|
||||
|
||||
@@ -25,9 +25,11 @@ import (
|
||||
"github.com/Tencent/WeKnora/internal/tracing"
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
||||
secutils "github.com/Tencent/WeKnora/internal/utils"
|
||||
"github.com/Tencent/WeKnora/services/docreader/src/client"
|
||||
"github.com/Tencent/WeKnora/services/docreader/src/proto"
|
||||
"github.com/google/uuid"
|
||||
"github.com/hibiken/asynq"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
@@ -60,6 +62,8 @@ type knowledgeService struct {
|
||||
chunkRepo interfaces.ChunkRepository
|
||||
fileSvc interfaces.FileService
|
||||
modelService interfaces.ModelService
|
||||
task *asynq.Client
|
||||
graphEngine interfaces.RetrieveGraphRepository
|
||||
}
|
||||
|
||||
// NewKnowledgeService creates a new knowledge service instance
|
||||
@@ -73,6 +77,8 @@ func NewKnowledgeService(
|
||||
chunkRepo interfaces.ChunkRepository,
|
||||
fileSvc interfaces.FileService,
|
||||
modelService interfaces.ModelService,
|
||||
task *asynq.Client,
|
||||
graphEngine interfaces.RetrieveGraphRepository,
|
||||
) (interfaces.KnowledgeService, error) {
|
||||
return &knowledgeService{
|
||||
config: config,
|
||||
@@ -84,6 +90,8 @@ func NewKnowledgeService(
|
||||
chunkRepo: chunkRepo,
|
||||
fileSvc: fileSvc,
|
||||
modelService: modelService,
|
||||
task: task,
|
||||
graphEngine: graphEngine,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -191,15 +199,22 @@ func (s *knowledgeService) CreateKnowledgeFromFile(ctx context.Context,
|
||||
metadataJSON = types.JSON(metadataBytes)
|
||||
}
|
||||
|
||||
// 验证文件名安全性
|
||||
safeFilename, isValid := secutils.ValidateInput(file.Filename)
|
||||
if !isValid {
|
||||
logger.Errorf(ctx, "Invalid filename: %s", file.Filename)
|
||||
return nil, werrors.NewValidationError("文件名包含非法字符")
|
||||
}
|
||||
|
||||
// Create knowledge record
|
||||
logger.Info(ctx, "Creating knowledge record")
|
||||
knowledge := &types.Knowledge{
|
||||
TenantID: tenantID,
|
||||
KnowledgeBaseID: kbID,
|
||||
Type: "file",
|
||||
Title: file.Filename,
|
||||
FileName: file.Filename,
|
||||
FileType: getFileType(file.Filename),
|
||||
Title: safeFilename,
|
||||
FileName: safeFilename,
|
||||
FileType: getFileType(safeFilename),
|
||||
FileSize: file.Size,
|
||||
FileHash: hash,
|
||||
ParseStatus: "pending",
|
||||
@@ -258,10 +273,10 @@ func (s *knowledgeService) CreateKnowledgeFromURL(ctx context.Context,
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Validate URL format
|
||||
// Validate URL format and security
|
||||
logger.Info(ctx, "Validating URL")
|
||||
if !isValidURL(url) {
|
||||
logger.Error(ctx, "Invalid URL format")
|
||||
if !isValidURL(url) || !secutils.IsValidURL(url) {
|
||||
logger.Error(ctx, "Invalid or unsafe URL format")
|
||||
return nil, ErrInvalidURL
|
||||
}
|
||||
|
||||
@@ -339,6 +354,17 @@ func (s *knowledgeService) CreateKnowledgeFromPassage(ctx context.Context,
|
||||
logger.Info(ctx, "Start creating knowledge from passage")
|
||||
logger.Infof(ctx, "Knowledge base ID: %s, passage count: %d", kbID, len(passage))
|
||||
|
||||
// 验证段落内容安全性
|
||||
safePassages := make([]string, 0, len(passage))
|
||||
for i, p := range passage {
|
||||
safePassage, isValid := secutils.ValidateInput(p)
|
||||
if !isValid {
|
||||
logger.Errorf(ctx, "Invalid passage content at index %d", i)
|
||||
return nil, werrors.NewValidationError(fmt.Sprintf("段落 %d 包含非法内容", i+1))
|
||||
}
|
||||
safePassages = append(safePassages, safePassage)
|
||||
}
|
||||
|
||||
// Get knowledge base configuration
|
||||
logger.Info(ctx, "Getting knowledge base configuration")
|
||||
kb, err := s.kbService.GetKnowledgeBaseByID(ctx, kbID)
|
||||
@@ -370,7 +396,7 @@ func (s *knowledgeService) CreateKnowledgeFromPassage(ctx context.Context,
|
||||
|
||||
// Process passages asynchronously
|
||||
logger.Info(ctx, "Starting asynchronous passage processing")
|
||||
go s.processDocumentFromPassage(ctx, kb, knowledge, passage)
|
||||
go s.processDocumentFromPassage(ctx, kb, knowledge, safePassages)
|
||||
|
||||
logger.Infof(ctx, "Knowledge from passage created successfully, ID: %s", knowledge.ID)
|
||||
return knowledge, nil
|
||||
@@ -469,6 +495,16 @@ func (s *knowledgeService) DeleteKnowledge(ctx context.Context, id string) error
|
||||
return nil
|
||||
})
|
||||
|
||||
// Delete the knowledge graph
|
||||
wg.Go(func() error {
|
||||
namespace := types.NameSpace{KnowledgeBase: knowledge.KnowledgeBaseID, Knowledge: knowledge.ID}
|
||||
if err := s.graphEngine.DelGraph(ctx, []types.NameSpace{namespace}); err != nil {
|
||||
logger.GetLogger(ctx).WithField("error", err).Errorf("DeleteKnowledge delete knowledge graph failed")
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err = wg.Wait(); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -542,6 +578,19 @@ func (s *knowledgeService) DeleteKnowledgeList(ctx context.Context, ids []string
|
||||
return nil
|
||||
})
|
||||
|
||||
// Delete the knowledge graph
|
||||
wg.Go(func() error {
|
||||
namespaces := []types.NameSpace{}
|
||||
for _, knowledge := range knowledgeList {
|
||||
namespaces = append(namespaces, types.NameSpace{KnowledgeBase: knowledge.KnowledgeBaseID, Knowledge: knowledge.ID})
|
||||
}
|
||||
if err := s.graphEngine.DelGraph(ctx, namespaces); err != nil {
|
||||
logger.GetLogger(ctx).WithField("error", err).Errorf("DeleteKnowledge delete knowledge graph failed")
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err = wg.Wait(); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1139,6 +1188,15 @@ func (s *knowledgeService) processChunks(ctx context.Context,
|
||||
}
|
||||
logger.GetLogger(ctx).Infof("processChunks batch index successfully, with %d index", len(indexInfoList))
|
||||
|
||||
logger.Infof(ctx, "processChunks create relationship rag task")
|
||||
for _, chunk := range textChunks {
|
||||
err := NewChunkExtractTask(ctx, s.task, chunk.TenantID, chunk.ID, kb.SummaryModelID)
|
||||
if err != nil {
|
||||
logger.GetLogger(ctx).WithField("error", err).Errorf("processChunks create chunk extract task failed")
|
||||
span.RecordError(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Update knowledge status to completed
|
||||
knowledge.ParseStatus = "completed"
|
||||
knowledge.EnableStatus = "enabled"
|
||||
|
||||
@@ -2,12 +2,11 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"encoding/json"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/application/service/retriever"
|
||||
"github.com/Tencent/WeKnora/internal/common"
|
||||
"github.com/Tencent/WeKnora/internal/logger"
|
||||
@@ -376,8 +375,8 @@ func (s *knowledgeBaseService) HybridSearch(ctx context.Context,
|
||||
|
||||
// processSearchResults handles the processing of search results, optimizing database queries
|
||||
func (s *knowledgeBaseService) processSearchResults(ctx context.Context,
|
||||
chunks []*types.IndexWithScore) ([]*types.SearchResult, error) {
|
||||
|
||||
chunks []*types.IndexWithScore,
|
||||
) ([]*types.SearchResult, error) {
|
||||
if len(chunks) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -527,8 +526,8 @@ func (s *knowledgeBaseService) collectRelatedChunkIDs(chunk *types.Chunk, proces
|
||||
func (s *knowledgeBaseService) buildSearchResult(chunk *types.Chunk,
|
||||
knowledge *types.Knowledge,
|
||||
score float64,
|
||||
matchType types.MatchType) *types.SearchResult {
|
||||
|
||||
matchType types.MatchType,
|
||||
) *types.SearchResult {
|
||||
return &types.SearchResult{
|
||||
ID: chunk.ID,
|
||||
Content: chunk.Content,
|
||||
|
||||
@@ -18,7 +18,9 @@ import (
|
||||
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
||||
)
|
||||
|
||||
var apiKeySecret = []byte(os.Getenv("TENANT_AES_KEY"))
|
||||
var apiKeySecret = func() []byte {
|
||||
return []byte(os.Getenv("TENANT_AES_KEY"))
|
||||
}
|
||||
|
||||
// ListTenantsParams defines parameters for listing tenants with filtering and pagination
|
||||
type ListTenantsParams struct {
|
||||
@@ -221,7 +223,7 @@ func (r *tenantService) generateApiKey(tenantID uint) string {
|
||||
binary.LittleEndian.PutUint64(idBytes, uint64(tenantID))
|
||||
|
||||
// 2. Encrypt tenant_id using AES-GCM
|
||||
block, err := aes.NewCipher(apiKeySecret)
|
||||
block, err := aes.NewCipher(apiKeySecret())
|
||||
if err != nil {
|
||||
panic("Failed to create AES cipher: " + err.Error())
|
||||
}
|
||||
@@ -267,7 +269,7 @@ func (r *tenantService) ExtractTenantIDFromAPIKey(apiKey string) (uint, error) {
|
||||
nonce, ciphertext := encryptedData[:12], encryptedData[12:]
|
||||
|
||||
// 4. Decrypt
|
||||
block, err := aes.NewCipher(apiKeySecret)
|
||||
block, err := aes.NewCipher(apiKeySecret())
|
||||
if err != nil {
|
||||
return 0, errors.New("decryption error")
|
||||
}
|
||||
|
||||
@@ -1,444 +0,0 @@
|
||||
// Package service 提供应用程序的核心业务逻辑服务层
|
||||
// 此包包含了知识库管理、用户租户管理、模型服务等核心功能实现
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/config"
|
||||
"github.com/Tencent/WeKnora/internal/logger"
|
||||
"github.com/Tencent/WeKnora/internal/models/chat"
|
||||
"github.com/Tencent/WeKnora/internal/models/embedding"
|
||||
"github.com/Tencent/WeKnora/internal/models/rerank"
|
||||
"github.com/Tencent/WeKnora/internal/models/utils/ollama"
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
||||
)
|
||||
|
||||
// TestDataService 测试数据服务
|
||||
// 负责初始化测试环境所需的数据,包括创建测试租户、测试知识库
|
||||
// 以及配置必要的模型服务实例
|
||||
type TestDataService struct {
|
||||
config *config.Config // 应用程序配置
|
||||
kbRepo interfaces.KnowledgeBaseRepository // 知识库存储库接口
|
||||
tenantService interfaces.TenantService // 租户服务接口
|
||||
ollamaService *ollama.OllamaService // Ollama模型服务
|
||||
modelService interfaces.ModelService // 模型服务接口
|
||||
EmbedModel embedding.Embedder // 嵌入模型实例
|
||||
RerankModel rerank.Reranker // 重排模型实例
|
||||
LLMModel chat.Chat // 大语言模型实例
|
||||
}
|
||||
|
||||
// NewTestDataService 创建测试数据服务
|
||||
// 注入所需的依赖服务和组件
|
||||
func NewTestDataService(
|
||||
config *config.Config,
|
||||
kbRepo interfaces.KnowledgeBaseRepository,
|
||||
tenantService interfaces.TenantService,
|
||||
ollamaService *ollama.OllamaService,
|
||||
modelService interfaces.ModelService,
|
||||
) *TestDataService {
|
||||
return &TestDataService{
|
||||
config: config,
|
||||
kbRepo: kbRepo,
|
||||
tenantService: tenantService,
|
||||
ollamaService: ollamaService,
|
||||
modelService: modelService,
|
||||
}
|
||||
}
|
||||
|
||||
// initTenant 初始化测试租户
|
||||
// 通过环境变量获取租户ID,如果租户不存在则创建新租户,否则更新现有租户
|
||||
// 同时配置租户的检索引擎参数
|
||||
func (s *TestDataService) initTenant(ctx context.Context) error {
|
||||
logger.Info(ctx, "Start initializing test tenant")
|
||||
|
||||
// 从环境变量获取租户ID
|
||||
tenantID := os.Getenv("INIT_TEST_TENANT_ID")
|
||||
logger.Infof(ctx, "Test tenant ID from environment: %s", tenantID)
|
||||
|
||||
// 将字符串ID转换为uint64
|
||||
tenantIDUint, err := strconv.ParseUint(tenantID, 10, 64)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "Failed to parse tenant ID: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// 创建租户配置
|
||||
tenantConfig := &types.Tenant{
|
||||
Name: "Test Tenant",
|
||||
Description: "Test Tenant for Testing",
|
||||
RetrieverEngines: types.RetrieverEngines{
|
||||
Engines: []types.RetrieverEngineParams{
|
||||
{
|
||||
RetrieverType: types.KeywordsRetrieverType,
|
||||
RetrieverEngineType: types.PostgresRetrieverEngineType,
|
||||
},
|
||||
{
|
||||
RetrieverType: types.VectorRetrieverType,
|
||||
RetrieverEngineType: types.PostgresRetrieverEngineType,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// 获取或创建测试租户
|
||||
logger.Infof(ctx, "Attempting to get tenant with ID: %d", tenantIDUint)
|
||||
tenant, err := s.tenantService.GetTenantByID(ctx, uint(tenantIDUint))
|
||||
if err != nil {
|
||||
// 租户不存在,创建新租户
|
||||
logger.Info(ctx, "Tenant not found, creating a new test tenant")
|
||||
tenant, err = s.tenantService.CreateTenant(ctx, tenantConfig)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "Failed to create tenant: %v", err)
|
||||
return err
|
||||
}
|
||||
logger.Infof(ctx, "Created new test tenant with ID: %d", tenant.ID)
|
||||
} else {
|
||||
// 租户存在,更新检索引擎配置
|
||||
logger.Info(ctx, "Test tenant found, updating retriever engines")
|
||||
tenant.RetrieverEngines = tenantConfig.RetrieverEngines
|
||||
tenant, err = s.tenantService.UpdateTenant(ctx, tenant)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "Failed to update tenant: %v", err)
|
||||
return err
|
||||
}
|
||||
logger.Info(ctx, "Test tenant updated successfully")
|
||||
}
|
||||
|
||||
logger.Infof(ctx, "Test tenant configured - ID: %d, Name: %s, API Key: %s",
|
||||
tenant.ID, tenant.Name, tenant.APIKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
// initKnowledgeBase 初始化测试知识库
|
||||
// 从环境变量获取知识库ID,创建或更新知识库
|
||||
// 配置知识库的分块策略、嵌入模型和摘要模型
|
||||
func (s *TestDataService) initKnowledgeBase(ctx context.Context) error {
|
||||
logger.Info(ctx, "Start initializing test knowledge base")
|
||||
|
||||
// 检查上下文中的租户ID
|
||||
if ctx.Value(types.TenantIDContextKey).(uint) == 0 {
|
||||
logger.Warn(ctx, "Tenant ID is 0, skipping knowledge base initialization")
|
||||
return nil
|
||||
}
|
||||
|
||||
// 从环境变量获取知识库ID
|
||||
knowledgeBaseID := os.Getenv("INIT_TEST_KNOWLEDGE_BASE_ID")
|
||||
logger.Infof(ctx, "Test knowledge base ID from environment: %s", knowledgeBaseID)
|
||||
|
||||
// 创建知识库配置
|
||||
kbConfig := &types.KnowledgeBase{
|
||||
ID: knowledgeBaseID,
|
||||
Name: "Test Knowledge Base",
|
||||
Description: "Knowledge Base for Testing",
|
||||
TenantID: ctx.Value(types.TenantIDContextKey).(uint),
|
||||
ChunkingConfig: types.ChunkingConfig{
|
||||
ChunkSize: s.config.KnowledgeBase.ChunkSize,
|
||||
ChunkOverlap: s.config.KnowledgeBase.ChunkOverlap,
|
||||
Separators: s.config.KnowledgeBase.SplitMarkers,
|
||||
EnableMultimodal: s.config.KnowledgeBase.ImageProcessing.EnableMultimodal,
|
||||
},
|
||||
EmbeddingModelID: s.EmbedModel.GetModelID(),
|
||||
SummaryModelID: s.LLMModel.GetModelID(),
|
||||
RerankModelID: s.RerankModel.GetModelID(),
|
||||
}
|
||||
|
||||
// 初始化测试知识库
|
||||
logger.Info(ctx, "Attempting to get existing knowledge base")
|
||||
_, err := s.kbRepo.GetKnowledgeBaseByID(ctx, knowledgeBaseID)
|
||||
if err != nil {
|
||||
// 知识库不存在,创建新知识库
|
||||
logger.Info(ctx, "Knowledge base not found, creating a new one")
|
||||
logger.Infof(ctx, "Creating knowledge base with ID: %s, tenant ID: %d",
|
||||
kbConfig.ID, kbConfig.TenantID)
|
||||
|
||||
if err := s.kbRepo.CreateKnowledgeBase(ctx, kbConfig); err != nil {
|
||||
logger.Errorf(ctx, "Failed to create knowledge base: %v", err)
|
||||
return err
|
||||
}
|
||||
logger.Info(ctx, "Knowledge base created successfully")
|
||||
} else {
|
||||
// 知识库存在,更新配置
|
||||
logger.Info(ctx, "Knowledge base found, updating configuration")
|
||||
logger.Infof(ctx, "Updating knowledge base with ID: %s", kbConfig.ID)
|
||||
|
||||
err = s.kbRepo.UpdateKnowledgeBase(ctx, kbConfig)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "Failed to update knowledge base: %v", err)
|
||||
return err
|
||||
}
|
||||
logger.Info(ctx, "Knowledge base updated successfully")
|
||||
}
|
||||
|
||||
logger.Infof(ctx, "Test knowledge base configured - ID: %s, Name: %s", kbConfig.ID, kbConfig.Name)
|
||||
return nil
|
||||
}
|
||||
|
||||
// InitializeTestData 初始化测试数据
|
||||
// 这是对外暴露的主要方法,负责协调所有测试数据的初始化过程
|
||||
// 包括初始化租户、嵌入模型、重排模型、LLM模型和知识库
|
||||
func (s *TestDataService) InitializeTestData(ctx context.Context) error {
|
||||
logger.Info(ctx, "Start initializing test data")
|
||||
|
||||
// 从环境变量获取租户ID
|
||||
tenantID := os.Getenv("INIT_TEST_TENANT_ID")
|
||||
logger.Infof(ctx, "Test tenant ID from environment: %s", tenantID)
|
||||
|
||||
// 解析租户ID
|
||||
tenantIDUint, err := strconv.ParseUint(tenantID, 10, 64)
|
||||
if err != nil {
|
||||
// 解析失败时使用默认值0
|
||||
logger.Warn(ctx, "Failed to parse tenant ID, using default value 0")
|
||||
tenantIDUint = 0
|
||||
} else {
|
||||
// 初始化租户
|
||||
logger.Info(ctx, "Initializing tenant")
|
||||
err = s.initTenant(ctx)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "Failed to initialize tenant: %v", err)
|
||||
return err
|
||||
}
|
||||
logger.Info(ctx, "Tenant initialized successfully")
|
||||
}
|
||||
|
||||
// 创建带有租户ID的新上下文
|
||||
newCtx := context.Background()
|
||||
newCtx = context.WithValue(newCtx, types.TenantIDContextKey, uint(tenantIDUint))
|
||||
logger.Infof(ctx, "Created new context with tenant ID: %d", tenantIDUint)
|
||||
|
||||
// 初始化模型
|
||||
modelInitFuncs := []struct {
|
||||
name string
|
||||
fn func(context.Context) error
|
||||
}{
|
||||
{"embedding model", s.initEmbeddingModel},
|
||||
{"rerank model", s.initRerankModel},
|
||||
{"LLM model", s.initLLMModel},
|
||||
}
|
||||
|
||||
for _, initFunc := range modelInitFuncs {
|
||||
logger.Infof(ctx, "Initializing %s", initFunc.name)
|
||||
if err := initFunc.fn(newCtx); err != nil {
|
||||
logger.Errorf(ctx, "Failed to initialize %s: %v", initFunc.name, err)
|
||||
return err
|
||||
}
|
||||
logger.Infof(ctx, "%s initialized successfully", initFunc.name)
|
||||
}
|
||||
|
||||
// 初始化知识库
|
||||
logger.Info(ctx, "Initializing knowledge base")
|
||||
if err := s.initKnowledgeBase(newCtx); err != nil {
|
||||
logger.Errorf(ctx, "Failed to initialize knowledge base: %v", err)
|
||||
return err
|
||||
}
|
||||
logger.Info(ctx, "Knowledge base initialized successfully")
|
||||
|
||||
logger.Info(ctx, "Test data initialization completed")
|
||||
return nil
|
||||
}
|
||||
|
||||
// getEnvOrError 获取环境变量值,如果不存在则返回错误
|
||||
func (s *TestDataService) getEnvOrError(name string) (string, error) {
|
||||
value := os.Getenv(name)
|
||||
if value == "" {
|
||||
return "", fmt.Errorf("%s environment variable is not set", name)
|
||||
}
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// updateOrCreateModel 更新或创建模型
|
||||
func (s *TestDataService) updateOrCreateModel(ctx context.Context, modelConfig *types.Model) error {
|
||||
model, err := s.modelService.GetModelByID(ctx, modelConfig.ID)
|
||||
if err != nil {
|
||||
// 模型不存在,创建新模型
|
||||
return s.modelService.CreateModel(ctx, modelConfig)
|
||||
}
|
||||
|
||||
// 模型存在,更新属性
|
||||
model.TenantID = modelConfig.TenantID
|
||||
model.Name = modelConfig.Name
|
||||
model.Source = modelConfig.Source
|
||||
model.Type = modelConfig.Type
|
||||
model.Parameters = modelConfig.Parameters
|
||||
model.Status = modelConfig.Status
|
||||
|
||||
return s.modelService.UpdateModel(ctx, model)
|
||||
}
|
||||
|
||||
// initEmbeddingModel 初始化嵌入模型
|
||||
func (s *TestDataService) initEmbeddingModel(ctx context.Context) error {
|
||||
// 从环境变量获取模型参数
|
||||
modelName, err := s.getEnvOrError("INIT_EMBEDDING_MODEL_NAME")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dimensionStr := os.Getenv("INIT_EMBEDDING_MODEL_DIMENSION")
|
||||
dimension, err := strconv.Atoi(dimensionStr)
|
||||
if err != nil || dimension == 0 {
|
||||
return fmt.Errorf("invalid embedding model dimension: %s", dimensionStr)
|
||||
}
|
||||
|
||||
baseURL := os.Getenv("INIT_EMBEDDING_MODEL_BASE_URL")
|
||||
apiKey := os.Getenv("INIT_EMBEDDING_MODEL_API_KEY")
|
||||
|
||||
// 确定模型来源
|
||||
source := types.ModelSourceRemote
|
||||
if baseURL == "" {
|
||||
source = types.ModelSourceLocal
|
||||
}
|
||||
|
||||
// 确定模型ID
|
||||
modelID := os.Getenv("INIT_EMBEDDING_MODEL_ID")
|
||||
if modelID == "" {
|
||||
modelID = fmt.Sprintf("builtin:%s:%d", modelName, dimension)
|
||||
}
|
||||
|
||||
// 创建嵌入模型实例
|
||||
s.EmbedModel, err = embedding.NewEmbedder(embedding.Config{
|
||||
Source: source,
|
||||
BaseURL: baseURL,
|
||||
ModelName: modelName,
|
||||
APIKey: apiKey,
|
||||
Dimensions: dimension,
|
||||
ModelID: modelID,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create embedder: %w", err)
|
||||
}
|
||||
|
||||
// 如果是本地模型,使用Ollama拉取模型
|
||||
if source == types.ModelSourceLocal && s.ollamaService != nil {
|
||||
if err := s.ollamaService.PullModel(context.Background(), modelName); err != nil {
|
||||
return fmt.Errorf("failed to pull embedding model: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 创建模型配置
|
||||
modelConfig := &types.Model{
|
||||
ID: modelID,
|
||||
TenantID: ctx.Value(types.TenantIDContextKey).(uint),
|
||||
Name: modelName,
|
||||
Source: source,
|
||||
Type: types.ModelTypeEmbedding,
|
||||
Parameters: types.ModelParameters{
|
||||
BaseURL: baseURL,
|
||||
APIKey: apiKey,
|
||||
EmbeddingParameters: types.EmbeddingParameters{
|
||||
Dimension: dimension,
|
||||
},
|
||||
},
|
||||
Status: "active",
|
||||
}
|
||||
|
||||
// 更新或创建模型
|
||||
return s.updateOrCreateModel(ctx, modelConfig)
|
||||
}
|
||||
|
||||
// initRerankModel 初始化重排模型
|
||||
func (s *TestDataService) initRerankModel(ctx context.Context) error {
|
||||
// 从环境变量获取模型参数
|
||||
modelName, err := s.getEnvOrError("INIT_RERANK_MODEL_NAME")
|
||||
if err != nil {
|
||||
logger.Warnf(ctx, "Skip Rerank Model: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
baseURL, err := s.getEnvOrError("INIT_RERANK_MODEL_BASE_URL")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
apiKey := os.Getenv("INIT_RERANK_MODEL_API_KEY")
|
||||
modelID := fmt.Sprintf("builtin:%s:rerank:%s", types.ModelSourceRemote, modelName)
|
||||
|
||||
// 创建重排模型实例
|
||||
s.RerankModel, err = rerank.NewReranker(&rerank.RerankerConfig{
|
||||
Source: types.ModelSourceRemote,
|
||||
BaseURL: baseURL,
|
||||
ModelName: modelName,
|
||||
APIKey: apiKey,
|
||||
ModelID: modelID,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create reranker: %w", err)
|
||||
}
|
||||
|
||||
// 创建模型配置
|
||||
modelConfig := &types.Model{
|
||||
ID: modelID,
|
||||
TenantID: ctx.Value(types.TenantIDContextKey).(uint),
|
||||
Name: modelName,
|
||||
Source: types.ModelSourceRemote,
|
||||
Type: types.ModelTypeRerank,
|
||||
Parameters: types.ModelParameters{
|
||||
BaseURL: baseURL,
|
||||
APIKey: apiKey,
|
||||
},
|
||||
Status: "active",
|
||||
}
|
||||
|
||||
// 更新或创建模型
|
||||
return s.updateOrCreateModel(ctx, modelConfig)
|
||||
}
|
||||
|
||||
// initLLMModel 初始化大语言模型
|
||||
func (s *TestDataService) initLLMModel(ctx context.Context) error {
|
||||
// 从环境变量获取模型参数
|
||||
modelName, err := s.getEnvOrError("INIT_LLM_MODEL_NAME")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
baseURL := os.Getenv("INIT_LLM_MODEL_BASE_URL")
|
||||
apiKey := os.Getenv("INIT_LLM_MODEL_API_KEY")
|
||||
|
||||
// 确定模型来源
|
||||
source := types.ModelSourceRemote
|
||||
if baseURL == "" {
|
||||
source = types.ModelSourceLocal
|
||||
}
|
||||
|
||||
// 确定模型ID
|
||||
modelID := fmt.Sprintf("builtin:%s:llm:%s", source, modelName)
|
||||
|
||||
// 创建大语言模型实例
|
||||
s.LLMModel, err = chat.NewChat(&chat.ChatConfig{
|
||||
Source: source,
|
||||
BaseURL: baseURL,
|
||||
ModelName: modelName,
|
||||
APIKey: apiKey,
|
||||
ModelID: modelID,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create llm: %w", err)
|
||||
}
|
||||
|
||||
// 如果是本地模型,使用Ollama拉取模型
|
||||
if source == types.ModelSourceLocal && s.ollamaService != nil {
|
||||
if err := s.ollamaService.PullModel(context.Background(), modelName); err != nil {
|
||||
return fmt.Errorf("failed to pull llm model: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 创建模型配置
|
||||
modelConfig := &types.Model{
|
||||
ID: modelID,
|
||||
TenantID: ctx.Value(types.TenantIDContextKey).(uint),
|
||||
Name: modelName,
|
||||
Source: source,
|
||||
Type: types.ModelTypeKnowledgeQA,
|
||||
Parameters: types.ModelParameters{
|
||||
BaseURL: baseURL,
|
||||
APIKey: apiKey,
|
||||
},
|
||||
Status: "active",
|
||||
}
|
||||
|
||||
// 更新或创建模型
|
||||
return s.updateOrCreateModel(ctx, modelConfig)
|
||||
}
|
||||
449
internal/application/service/user.go
Normal file
449
internal/application/service/user.go
Normal file
@@ -0,0 +1,449 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/logger"
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
||||
)
|
||||
|
||||
// JWT secret key - in production this should be from environment variable
|
||||
var jwtSecret = []byte("your-secret-key")
|
||||
|
||||
// userService implements the UserService interface
|
||||
type userService struct {
|
||||
userRepo interfaces.UserRepository
|
||||
tokenRepo interfaces.AuthTokenRepository
|
||||
tenantService interfaces.TenantService
|
||||
}
|
||||
|
||||
// NewUserService creates a new user service instance
|
||||
func NewUserService(userRepo interfaces.UserRepository, tokenRepo interfaces.AuthTokenRepository, tenantService interfaces.TenantService) interfaces.UserService {
|
||||
return &userService{
|
||||
userRepo: userRepo,
|
||||
tokenRepo: tokenRepo,
|
||||
tenantService: tenantService,
|
||||
}
|
||||
}
|
||||
|
||||
var engine = map[string][]types.RetrieverEngineParams{
|
||||
"postgres": {
|
||||
{
|
||||
RetrieverType: types.KeywordsRetrieverType,
|
||||
RetrieverEngineType: types.PostgresRetrieverEngineType,
|
||||
},
|
||||
{
|
||||
RetrieverType: types.VectorRetrieverType,
|
||||
RetrieverEngineType: types.PostgresRetrieverEngineType,
|
||||
},
|
||||
},
|
||||
"elasticsearch_v7": {
|
||||
{
|
||||
RetrieverType: types.KeywordsRetrieverType,
|
||||
RetrieverEngineType: types.ElasticsearchRetrieverEngineType,
|
||||
},
|
||||
},
|
||||
"elasticsearch_v8": {
|
||||
{
|
||||
RetrieverType: types.KeywordsRetrieverType,
|
||||
RetrieverEngineType: types.ElasticsearchRetrieverEngineType,
|
||||
},
|
||||
{
|
||||
RetrieverType: types.VectorRetrieverType,
|
||||
RetrieverEngineType: types.ElasticsearchRetrieverEngineType,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Register creates a new user account
|
||||
func (s *userService) Register(ctx context.Context, req *types.RegisterRequest) (*types.User, error) {
|
||||
logger.Info(ctx, "Start user registration")
|
||||
|
||||
// Validate input
|
||||
if req.Username == "" || req.Email == "" || req.Password == "" {
|
||||
return nil, errors.New("username, email and password are required")
|
||||
}
|
||||
|
||||
// Check if user already exists
|
||||
existingUser, _ := s.userRepo.GetUserByEmail(ctx, req.Email)
|
||||
if existingUser != nil {
|
||||
return nil, errors.New("user with this email already exists")
|
||||
}
|
||||
|
||||
existingUser, _ = s.userRepo.GetUserByUsername(ctx, req.Username)
|
||||
if existingUser != nil {
|
||||
return nil, errors.New("user with this username already exists")
|
||||
}
|
||||
|
||||
// Hash password
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "Failed to hash password: %v", err)
|
||||
return nil, errors.New("failed to process password")
|
||||
}
|
||||
|
||||
egs := []types.RetrieverEngineParams{}
|
||||
for _, driver := range strings.Split(os.Getenv("RETRIEVE_DRIVER"), ",") {
|
||||
if val, ok := engine[driver]; ok {
|
||||
egs = append(egs, val...)
|
||||
}
|
||||
}
|
||||
egs = uniqueRetrieverEngine(egs)
|
||||
logger.Debugf(ctx, "user register retriever engines: %v", egs)
|
||||
|
||||
// Create default tenant for the user
|
||||
tenant := &types.Tenant{
|
||||
Name: fmt.Sprintf("%s's Workspace", req.Username),
|
||||
Description: "Default workspace",
|
||||
Status: "active",
|
||||
RetrieverEngines: types.RetrieverEngines{Engines: egs},
|
||||
}
|
||||
|
||||
createdTenant, err := s.tenantService.CreateTenant(ctx, tenant)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "Failed to create tenant: %v", err)
|
||||
return nil, errors.New("failed to create workspace")
|
||||
}
|
||||
|
||||
// Create user
|
||||
user := &types.User{
|
||||
ID: uuid.New().String(),
|
||||
Username: req.Username,
|
||||
Email: req.Email,
|
||||
PasswordHash: string(hashedPassword),
|
||||
TenantID: createdTenant.ID,
|
||||
IsActive: true,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
err = s.userRepo.CreateUser(ctx, user)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "Failed to create user: %v", err)
|
||||
return nil, errors.New("failed to create user")
|
||||
}
|
||||
|
||||
logger.Infof(ctx, "User registered successfully: %s", user.Email)
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// Login authenticates a user and returns tokens
|
||||
func (s *userService) Login(ctx context.Context, req *types.LoginRequest) (*types.LoginResponse, error) {
|
||||
logger.Infof(ctx, "Start user login for email: %s", req.Email)
|
||||
|
||||
// Get user by email
|
||||
user, err := s.userRepo.GetUserByEmail(ctx, req.Email)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "Failed to get user by email %s: %v", req.Email, err)
|
||||
return &types.LoginResponse{
|
||||
Success: false,
|
||||
Message: "Invalid email or password",
|
||||
}, nil
|
||||
}
|
||||
if user == nil {
|
||||
logger.Warnf(ctx, "User not found for email: %s", req.Email)
|
||||
return &types.LoginResponse{
|
||||
Success: false,
|
||||
Message: "Invalid email or password",
|
||||
}, nil
|
||||
}
|
||||
|
||||
logger.Infof(ctx, "Found user: ID=%s, Email=%s, IsActive=%t", user.ID, user.Email, user.IsActive)
|
||||
|
||||
// Check if user is active
|
||||
if !user.IsActive {
|
||||
logger.Warnf(ctx, "User account is disabled for email: %s", req.Email)
|
||||
return &types.LoginResponse{
|
||||
Success: false,
|
||||
Message: "Account is disabled",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Verify password
|
||||
logger.Infof(ctx, "Verifying password for user: %s", user.Email)
|
||||
err = bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.Password))
|
||||
if err != nil {
|
||||
logger.Warnf(ctx, "Password verification failed for user %s: %v", user.Email, err)
|
||||
return &types.LoginResponse{
|
||||
Success: false,
|
||||
Message: "Invalid email or password",
|
||||
}, nil
|
||||
}
|
||||
logger.Infof(ctx, "Password verification successful for user: %s", user.Email)
|
||||
|
||||
// Generate tokens
|
||||
logger.Infof(ctx, "Generating tokens for user: %s", user.Email)
|
||||
accessToken, refreshToken, err := s.GenerateTokens(ctx, user)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "Failed to generate tokens for user %s: %v", user.Email, err)
|
||||
return &types.LoginResponse{
|
||||
Success: false,
|
||||
Message: "Login failed",
|
||||
}, nil
|
||||
}
|
||||
logger.Infof(ctx, "Tokens generated successfully for user: %s", user.Email)
|
||||
|
||||
// Get tenant information
|
||||
logger.Infof(ctx, "Getting tenant information for user %s, tenant ID: %s", user.Email, user.TenantID)
|
||||
tenant, err := s.tenantService.GetTenantByID(ctx, user.TenantID)
|
||||
if err != nil {
|
||||
logger.Warnf(ctx, "Failed to get tenant info for user %s, tenant ID %s: %v", user.Email, user.TenantID, err)
|
||||
} else {
|
||||
logger.Infof(ctx, "Tenant information retrieved successfully for user: %s", user.Email)
|
||||
}
|
||||
|
||||
logger.Infof(ctx, "User logged in successfully: %s", user.Email)
|
||||
return &types.LoginResponse{
|
||||
Success: true,
|
||||
Message: "Login successful",
|
||||
User: user,
|
||||
Tenant: tenant,
|
||||
Token: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetUserByID gets a user by ID
|
||||
func (s *userService) GetUserByID(ctx context.Context, id string) (*types.User, error) {
|
||||
return s.userRepo.GetUserByID(ctx, id)
|
||||
}
|
||||
|
||||
// GetUserByEmail gets a user by email
|
||||
func (s *userService) GetUserByEmail(ctx context.Context, email string) (*types.User, error) {
|
||||
return s.userRepo.GetUserByEmail(ctx, email)
|
||||
}
|
||||
|
||||
// GetUserByUsername gets a user by username
|
||||
func (s *userService) GetUserByUsername(ctx context.Context, username string) (*types.User, error) {
|
||||
return s.userRepo.GetUserByUsername(ctx, username)
|
||||
}
|
||||
|
||||
// UpdateUser updates user information
|
||||
func (s *userService) UpdateUser(ctx context.Context, user *types.User) error {
|
||||
user.UpdatedAt = time.Now()
|
||||
return s.userRepo.UpdateUser(ctx, user)
|
||||
}
|
||||
|
||||
// DeleteUser deletes a user
|
||||
func (s *userService) DeleteUser(ctx context.Context, id string) error {
|
||||
return s.userRepo.DeleteUser(ctx, id)
|
||||
}
|
||||
|
||||
// ChangePassword changes user password
|
||||
func (s *userService) ChangePassword(ctx context.Context, userID string, oldPassword, newPassword string) error {
|
||||
user, err := s.userRepo.GetUserByID(ctx, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Verify old password
|
||||
err = bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(oldPassword))
|
||||
if err != nil {
|
||||
return errors.New("invalid old password")
|
||||
}
|
||||
|
||||
// Hash new password
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user.PasswordHash = string(hashedPassword)
|
||||
user.UpdatedAt = time.Now()
|
||||
|
||||
return s.userRepo.UpdateUser(ctx, user)
|
||||
}
|
||||
|
||||
// ValidatePassword validates user password
|
||||
func (s *userService) ValidatePassword(ctx context.Context, userID string, password string) error {
|
||||
user, err := s.userRepo.GetUserByID(ctx, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password))
|
||||
}
|
||||
|
||||
// GenerateTokens generates access and refresh tokens for user
|
||||
func (s *userService) GenerateTokens(ctx context.Context, user *types.User) (accessToken, refreshToken string, err error) {
|
||||
// Generate access token (expires in 24 hours)
|
||||
accessClaims := jwt.MapClaims{
|
||||
"user_id": user.ID,
|
||||
"email": user.Email,
|
||||
"tenant_id": user.TenantID,
|
||||
"exp": time.Now().Add(24 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"type": "access",
|
||||
}
|
||||
|
||||
accessTokenObj := jwt.NewWithClaims(jwt.SigningMethodHS256, accessClaims)
|
||||
accessToken, err = accessTokenObj.SignedString(jwtSecret)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
// Generate refresh token (expires in 7 days)
|
||||
refreshClaims := jwt.MapClaims{
|
||||
"user_id": user.ID,
|
||||
"exp": time.Now().Add(7 * 24 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"type": "refresh",
|
||||
}
|
||||
|
||||
refreshTokenObj := jwt.NewWithClaims(jwt.SigningMethodHS256, refreshClaims)
|
||||
refreshToken, err = refreshTokenObj.SignedString(jwtSecret)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
// Store tokens in database
|
||||
accessTokenRecord := &types.AuthToken{
|
||||
ID: uuid.New().String(),
|
||||
UserID: user.ID,
|
||||
Token: accessToken,
|
||||
TokenType: "access_token",
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
refreshTokenRecord := &types.AuthToken{
|
||||
ID: uuid.New().String(),
|
||||
UserID: user.ID,
|
||||
Token: refreshToken,
|
||||
TokenType: "refresh_token",
|
||||
ExpiresAt: time.Now().Add(7 * 24 * time.Hour),
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
_ = s.tokenRepo.CreateToken(ctx, accessTokenRecord)
|
||||
_ = s.tokenRepo.CreateToken(ctx, refreshTokenRecord)
|
||||
|
||||
return accessToken, refreshToken, nil
|
||||
}
|
||||
|
||||
// ValidateToken validates an access token
|
||||
func (s *userService) ValidateToken(ctx context.Context, tokenString string) (*types.User, error) {
|
||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return jwtSecret, nil
|
||||
})
|
||||
|
||||
if err != nil || !token.Valid {
|
||||
return nil, errors.New("invalid token")
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
return nil, errors.New("invalid token claims")
|
||||
}
|
||||
|
||||
userID, ok := claims["user_id"].(string)
|
||||
if !ok {
|
||||
return nil, errors.New("invalid user ID in token")
|
||||
}
|
||||
|
||||
// Check if token is revoked
|
||||
tokenRecord, err := s.tokenRepo.GetTokenByValue(ctx, tokenString)
|
||||
if err != nil || tokenRecord == nil || tokenRecord.IsRevoked {
|
||||
return nil, errors.New("token is revoked")
|
||||
}
|
||||
|
||||
return s.userRepo.GetUserByID(ctx, userID)
|
||||
}
|
||||
|
||||
// RefreshToken refreshes access token using refresh token
|
||||
func (s *userService) RefreshToken(ctx context.Context, refreshTokenString string) (accessToken, newRefreshToken string, err error) {
|
||||
token, err := jwt.Parse(refreshTokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return jwtSecret, nil
|
||||
})
|
||||
|
||||
if err != nil || !token.Valid {
|
||||
return "", "", errors.New("invalid refresh token")
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
return "", "", errors.New("invalid token claims")
|
||||
}
|
||||
|
||||
tokenType, ok := claims["type"].(string)
|
||||
if !ok || tokenType != "refresh" {
|
||||
return "", "", errors.New("not a refresh token")
|
||||
}
|
||||
|
||||
userID, ok := claims["user_id"].(string)
|
||||
if !ok {
|
||||
return "", "", errors.New("invalid user ID in token")
|
||||
}
|
||||
|
||||
// Check if token is revoked
|
||||
tokenRecord, err := s.tokenRepo.GetTokenByValue(ctx, refreshTokenString)
|
||||
if err != nil || tokenRecord == nil || tokenRecord.IsRevoked {
|
||||
return "", "", errors.New("refresh token is revoked")
|
||||
}
|
||||
|
||||
// Get user
|
||||
user, err := s.userRepo.GetUserByID(ctx, userID)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
// Revoke old refresh token
|
||||
tokenRecord.IsRevoked = true
|
||||
_ = s.tokenRepo.UpdateToken(ctx, tokenRecord)
|
||||
|
||||
// Generate new tokens
|
||||
return s.GenerateTokens(ctx, user)
|
||||
}
|
||||
|
||||
// RevokeToken revokes a token
|
||||
func (s *userService) RevokeToken(ctx context.Context, tokenString string) error {
|
||||
tokenRecord, err := s.tokenRepo.GetTokenByValue(ctx, tokenString)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tokenRecord.IsRevoked = true
|
||||
tokenRecord.UpdatedAt = time.Now()
|
||||
|
||||
return s.tokenRepo.UpdateToken(ctx, tokenRecord)
|
||||
}
|
||||
|
||||
// GetCurrentUser gets current user from context
|
||||
func (s *userService) GetCurrentUser(ctx context.Context) (*types.User, error) {
|
||||
user, ok := ctx.Value("user").(*types.User)
|
||||
if !ok {
|
||||
return nil, errors.New("user not found in context")
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func uniqueRetrieverEngine(engine []types.RetrieverEngineParams) []types.RetrieverEngineParams {
|
||||
seen := make(map[types.RetrieverEngineParams]bool)
|
||||
var result []types.RetrieverEngineParams
|
||||
for _, v := range engine {
|
||||
if !seen[v] {
|
||||
seen[v] = true
|
||||
result = append(result, v)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
@@ -1,72 +0,0 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/config"
|
||||
"github.com/hibiken/asynq"
|
||||
)
|
||||
|
||||
// client is the global asyncq client instance
|
||||
var client *asynq.Client
|
||||
|
||||
// InitAsyncq initializes the asyncq client with configuration
|
||||
// It creates a new client and starts the server in a goroutine
|
||||
func InitAsyncq(config *config.Config) error {
|
||||
cfg := config.Asynq
|
||||
client = asynq.NewClient(asynq.RedisClientOpt{
|
||||
Addr: cfg.Addr,
|
||||
Username: cfg.Username,
|
||||
Password: cfg.Password,
|
||||
ReadTimeout: cfg.ReadTimeout,
|
||||
WriteTimeout: cfg.WriteTimeout,
|
||||
})
|
||||
go run(cfg)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAsyncqClient returns the global asyncq client instance
|
||||
func GetAsyncqClient() *asynq.Client {
|
||||
return client
|
||||
}
|
||||
|
||||
// handleFunc stores registered task handlers
|
||||
var handleFunc = map[string]asynq.HandlerFunc{}
|
||||
|
||||
// RegisterHandlerFunc registers a handler function for a specific task type
|
||||
func RegisterHandlerFunc(taskType string, handlerFunc asynq.HandlerFunc) {
|
||||
handleFunc[taskType] = handlerFunc
|
||||
}
|
||||
|
||||
// run starts the asyncq server with the given configuration
|
||||
// It creates a new server, sets up handlers, and runs the server
|
||||
func run(config *config.AsynqConfig) {
|
||||
srv := asynq.NewServer(
|
||||
asynq.RedisClientOpt{
|
||||
Addr: config.Addr,
|
||||
Username: config.Username,
|
||||
Password: config.Password,
|
||||
ReadTimeout: config.ReadTimeout,
|
||||
WriteTimeout: config.WriteTimeout,
|
||||
},
|
||||
asynq.Config{
|
||||
Concurrency: config.Concurrency,
|
||||
Queues: map[string]int{
|
||||
"critical": 6, // Highest priority queue
|
||||
"default": 3, // Default priority queue
|
||||
"low": 1, // Lowest priority queue
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
// Create a new mux and register all handlers
|
||||
mux := asynq.NewServeMux()
|
||||
for typ, handler := range handleFunc {
|
||||
mux.HandleFunc(typ, handler)
|
||||
}
|
||||
|
||||
// Start the server
|
||||
if err := srv.Run(mux); err != nil {
|
||||
log.Fatalf("could not run server: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
"github.com/go-viper/mapstructure/v2"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
@@ -18,10 +19,10 @@ type Config struct {
|
||||
KnowledgeBase *KnowledgeBaseConfig `yaml:"knowledge_base" json:"knowledge_base"`
|
||||
Tenant *TenantConfig `yaml:"tenant" json:"tenant"`
|
||||
Models []ModelConfig `yaml:"models" json:"models"`
|
||||
Asynq *AsynqConfig `yaml:"asynq" json:"asynq"`
|
||||
VectorDatabase *VectorDatabaseConfig `yaml:"vector_database" json:"vector_database"`
|
||||
DocReader *DocReaderConfig `yaml:"docreader" json:"docreader"`
|
||||
StreamManager *StreamManagerConfig `yaml:"stream_manager" json:"stream_manager"`
|
||||
ExtractManager *ExtractManagerConfig `yaml:"extract" json:"extract"`
|
||||
}
|
||||
|
||||
type DocReaderConfig struct {
|
||||
@@ -109,15 +110,6 @@ type ModelConfig struct {
|
||||
Parameters map[string]interface{} `yaml:"parameters" json:"parameters"`
|
||||
}
|
||||
|
||||
type AsynqConfig struct {
|
||||
Addr string `yaml:"addr" json:"addr"`
|
||||
Username string `yaml:"username" json:"username"`
|
||||
Password string `yaml:"password" json:"password"`
|
||||
ReadTimeout time.Duration `yaml:"read_timeout" json:"read_timeout"`
|
||||
WriteTimeout time.Duration `yaml:"write_timeout" json:"write_timeout"`
|
||||
Concurrency int `yaml:"concurrency" json:"concurrency"`
|
||||
}
|
||||
|
||||
// StreamManagerConfig 流管理器配置
|
||||
type StreamManagerConfig struct {
|
||||
Type string `yaml:"type" json:"type"` // 类型: "memory" 或 "redis"
|
||||
@@ -134,6 +126,18 @@ type RedisConfig struct {
|
||||
TTL time.Duration `yaml:"ttl" json:"ttl"` // 过期时间(小时)
|
||||
}
|
||||
|
||||
// ExtractManagerConfig 抽取管理器配置
|
||||
type ExtractManagerConfig struct {
|
||||
ExtractGraph *types.PromptTemplateStructured `yaml:"extract_graph" json:"extract_graph"`
|
||||
ExtractEntity *types.PromptTemplateStructured `yaml:"extract_entity" json:"extract_entity"`
|
||||
FabriText *FebriText `yaml:"fabri_text" json:"fabri_text"`
|
||||
}
|
||||
|
||||
type FebriText struct {
|
||||
WithTag string `yaml:"with_tag" json:"with_tag"`
|
||||
WithNoTag string `yaml:"with_no_tag" json:"with_no_tag"`
|
||||
}
|
||||
|
||||
// LoadConfig 从配置文件加载配置
|
||||
func LoadConfig() (*Config, error) {
|
||||
// 设置配置文件名和路径
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
|
||||
esv7 "github.com/elastic/go-elasticsearch/v7"
|
||||
"github.com/elastic/go-elasticsearch/v8"
|
||||
"github.com/neo4j/neo4j-go-driver/v6/neo4j"
|
||||
"github.com/panjf2000/ants/v2"
|
||||
"go.uber.org/dig"
|
||||
"gorm.io/driver/postgres"
|
||||
@@ -22,6 +23,7 @@ import (
|
||||
"github.com/Tencent/WeKnora/internal/application/repository"
|
||||
elasticsearchRepoV7 "github.com/Tencent/WeKnora/internal/application/repository/retriever/elasticsearch/v7"
|
||||
elasticsearchRepoV8 "github.com/Tencent/WeKnora/internal/application/repository/retriever/elasticsearch/v8"
|
||||
neo4jRepo "github.com/Tencent/WeKnora/internal/application/repository/retriever/neo4j"
|
||||
postgresRepo "github.com/Tencent/WeKnora/internal/application/repository/retriever/postgres"
|
||||
"github.com/Tencent/WeKnora/internal/application/service"
|
||||
chatpipline "github.com/Tencent/WeKnora/internal/application/service/chat_pipline"
|
||||
@@ -68,6 +70,7 @@ func BuildContainer(container *dig.Container) *dig.Container {
|
||||
// External service clients
|
||||
must(container.Provide(initDocReaderClient))
|
||||
must(container.Provide(initOllamaService))
|
||||
must(container.Provide(initNeo4jClient))
|
||||
must(container.Provide(stream.NewStreamManager))
|
||||
|
||||
// Data repositories layer
|
||||
@@ -78,6 +81,9 @@ func BuildContainer(container *dig.Container) *dig.Container {
|
||||
must(container.Provide(repository.NewSessionRepository))
|
||||
must(container.Provide(repository.NewMessageRepository))
|
||||
must(container.Provide(repository.NewModelRepository))
|
||||
must(container.Provide(repository.NewUserRepository))
|
||||
must(container.Provide(repository.NewAuthTokenRepository))
|
||||
must(container.Provide(neo4jRepo.NewNeo4jRepository))
|
||||
|
||||
// Business service layer
|
||||
must(container.Provide(service.NewTenantService))
|
||||
@@ -87,10 +93,11 @@ func BuildContainer(container *dig.Container) *dig.Container {
|
||||
must(container.Provide(service.NewMessageService))
|
||||
must(container.Provide(service.NewChunkService))
|
||||
must(container.Provide(embedding.NewBatchEmbedder))
|
||||
must(container.Provide(service.NewTestDataService))
|
||||
must(container.Provide(service.NewModelService))
|
||||
must(container.Provide(service.NewDatasetService))
|
||||
must(container.Provide(service.NewEvaluationService))
|
||||
must(container.Provide(service.NewUserService))
|
||||
must(container.Provide(service.NewChunkExtractService))
|
||||
|
||||
// Chat pipeline components for processing chat requests
|
||||
must(container.Provide(chatpipline.NewEventManager))
|
||||
@@ -105,6 +112,8 @@ func BuildContainer(container *dig.Container) *dig.Container {
|
||||
must(container.Invoke(chatpipline.NewPluginFilterTopK))
|
||||
must(container.Invoke(chatpipline.NewPluginPreprocess))
|
||||
must(container.Invoke(chatpipline.NewPluginRewrite))
|
||||
must(container.Invoke(chatpipline.NewPluginExtractEntity))
|
||||
must(container.Invoke(chatpipline.NewPluginSearchEntity))
|
||||
|
||||
// HTTP handlers layer
|
||||
must(container.Provide(handler.NewTenantHandler))
|
||||
@@ -113,13 +122,17 @@ func BuildContainer(container *dig.Container) *dig.Container {
|
||||
must(container.Provide(handler.NewChunkHandler))
|
||||
must(container.Provide(handler.NewSessionHandler))
|
||||
must(container.Provide(handler.NewMessageHandler))
|
||||
must(container.Provide(handler.NewTestDataHandler))
|
||||
must(container.Provide(handler.NewModelHandler))
|
||||
must(container.Provide(handler.NewEvaluationHandler))
|
||||
must(container.Provide(handler.NewInitializationHandler))
|
||||
must(container.Provide(handler.NewAuthHandler))
|
||||
must(container.Provide(handler.NewSystemHandler))
|
||||
|
||||
// Router configuration
|
||||
must(container.Provide(router.NewRouter))
|
||||
must(container.Provide(router.NewAsyncqClient))
|
||||
must(container.Provide(router.NewAsynqServer))
|
||||
must(container.Invoke(router.RunAsynqServer))
|
||||
|
||||
return container
|
||||
}
|
||||
@@ -177,6 +190,16 @@ func initDatabase(cfg *config.Config) (*gorm.DB, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Auto-migrate database tables
|
||||
err = db.AutoMigrate(
|
||||
&types.User{},
|
||||
&types.AuthToken{},
|
||||
&types.KnowledgeBase{},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to auto-migrate database tables: %v", err)
|
||||
}
|
||||
|
||||
// Get underlying SQL DB object
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
@@ -216,7 +239,7 @@ func initFileService(cfg *config.Config) (interfaces.FileService, error) {
|
||||
false,
|
||||
)
|
||||
case "cos":
|
||||
if os.Getenv("COS_APP_ID") == "" ||
|
||||
if os.Getenv("COS_BUCKET_NAME") == "" ||
|
||||
os.Getenv("COS_REGION") == "" ||
|
||||
os.Getenv("COS_SECRET_ID") == "" ||
|
||||
os.Getenv("COS_SECRET_KEY") == "" ||
|
||||
@@ -224,7 +247,7 @@ func initFileService(cfg *config.Config) (interfaces.FileService, error) {
|
||||
return nil, fmt.Errorf("missing COS configuration")
|
||||
}
|
||||
return file.NewCosFileService(
|
||||
os.Getenv("COS_APP_ID"),
|
||||
os.Getenv("COS_BUCKET_NAME"),
|
||||
os.Getenv("COS_REGION"),
|
||||
os.Getenv("COS_SECRET_ID"),
|
||||
os.Getenv("COS_SECRET_KEY"),
|
||||
@@ -373,3 +396,23 @@ func initOllamaService() (*ollama.OllamaService, error) {
|
||||
// Get Ollama service from existing factory function
|
||||
return ollama.GetOllamaService()
|
||||
}
|
||||
|
||||
func initNeo4jClient() (neo4j.Driver, error) {
|
||||
if strings.ToLower(os.Getenv("NEO4J_ENABLE")) != "true" {
|
||||
logger.Debugf(context.Background(), "NOT SUPPORT RETRIEVE GRAPH")
|
||||
return nil, nil
|
||||
}
|
||||
uri := os.Getenv("NEO4J_URI")
|
||||
username := os.Getenv("NEO4J_USERNAME")
|
||||
password := os.Getenv("NEO4J_PASSWORD")
|
||||
|
||||
driver, err := neo4j.NewDriver(uri, neo4j.BasicAuth(username, password, ""))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = driver.VerifyAuthentication(context.Background(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return driver, nil
|
||||
}
|
||||
|
||||
343
internal/handler/auth.go
Normal file
343
internal/handler/auth.go
Normal file
@@ -0,0 +1,343 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/errors"
|
||||
"github.com/Tencent/WeKnora/internal/logger"
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
||||
)
|
||||
|
||||
// AuthHandler implements HTTP request handlers for user authentication
|
||||
// Provides functionality for user registration, login, logout, and token management
|
||||
// through the REST API endpoints
|
||||
type AuthHandler struct {
|
||||
userService interfaces.UserService
|
||||
tenantService interfaces.TenantService
|
||||
}
|
||||
|
||||
// NewAuthHandler creates a new auth handler instance with the provided services
|
||||
// Parameters:
|
||||
// - userService: An implementation of the UserService interface for business logic
|
||||
// - tenantService: An implementation of the TenantService interface for tenant management
|
||||
//
|
||||
// Returns a pointer to the newly created AuthHandler
|
||||
func NewAuthHandler(userService interfaces.UserService, tenantService interfaces.TenantService) *AuthHandler {
|
||||
return &AuthHandler{
|
||||
userService: userService,
|
||||
tenantService: tenantService,
|
||||
}
|
||||
}
|
||||
|
||||
// Register handles the HTTP request for user registration
|
||||
// It deserializes the request body into a registration request object, validates it,
|
||||
// calls the service to create the user, and returns the result
|
||||
// Parameters:
|
||||
// - c: Gin context for the HTTP request
|
||||
func (h *AuthHandler) Register(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
logger.Info(ctx, "Start user registration")
|
||||
|
||||
var req types.RegisterRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
logger.Error(ctx, "Failed to parse registration request parameters", err)
|
||||
appErr := errors.NewValidationError("Invalid registration parameters").WithDetails(err.Error())
|
||||
c.Error(appErr)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
if req.Username == "" || req.Email == "" || req.Password == "" {
|
||||
logger.Error(ctx, "Missing required registration fields")
|
||||
appErr := errors.NewValidationError("Username, email and password are required")
|
||||
c.Error(appErr)
|
||||
return
|
||||
}
|
||||
|
||||
// Call service to register user
|
||||
user, err := h.userService.Register(ctx, &req)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "Failed to register user: %v", err)
|
||||
appErr := errors.NewBadRequestError("Registration failed").WithDetails(err.Error())
|
||||
c.Error(appErr)
|
||||
return
|
||||
}
|
||||
|
||||
// Return success response
|
||||
response := &types.RegisterResponse{
|
||||
Success: true,
|
||||
Message: "Registration successful",
|
||||
User: user,
|
||||
}
|
||||
|
||||
logger.Infof(ctx, "User registered successfully: %s", user.Email)
|
||||
c.JSON(http.StatusCreated, response)
|
||||
}
|
||||
|
||||
// Login handles the HTTP request for user login
|
||||
// It deserializes the request body into a login request object, validates it,
|
||||
// calls the service to authenticate the user, and returns tokens
|
||||
// Parameters:
|
||||
// - c: Gin context for the HTTP request
|
||||
func (h *AuthHandler) Login(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
logger.Info(ctx, "Start user login")
|
||||
|
||||
var req types.LoginRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
logger.Error(ctx, "Failed to parse login request parameters", err)
|
||||
appErr := errors.NewValidationError("Invalid login parameters").WithDetails(err.Error())
|
||||
c.Error(appErr)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
if req.Email == "" || req.Password == "" {
|
||||
logger.Error(ctx, "Missing required login fields")
|
||||
appErr := errors.NewValidationError("Email and password are required")
|
||||
c.Error(appErr)
|
||||
return
|
||||
}
|
||||
|
||||
// Call service to authenticate user
|
||||
response, err := h.userService.Login(ctx, &req)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "Failed to login user: %v", err)
|
||||
appErr := errors.NewUnauthorizedError("Login failed").WithDetails(err.Error())
|
||||
c.Error(appErr)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if login was successful
|
||||
if !response.Success {
|
||||
logger.Warnf(ctx, "Login failed: %s", response.Message)
|
||||
c.JSON(http.StatusUnauthorized, response)
|
||||
return
|
||||
}
|
||||
|
||||
// User is already in the correct format from service
|
||||
|
||||
logger.Infof(ctx, "User logged in successfully: %s", req.Email)
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// Logout handles the HTTP request for user logout
|
||||
// It extracts the token from the Authorization header and revokes it
|
||||
// Parameters:
|
||||
// - c: Gin context for the HTTP request
|
||||
func (h *AuthHandler) Logout(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
logger.Info(ctx, "Start user logout")
|
||||
|
||||
// Extract token from Authorization header
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
logger.Error(ctx, "Missing Authorization header")
|
||||
appErr := errors.NewValidationError("Authorization header is required")
|
||||
c.Error(appErr)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse Bearer token
|
||||
tokenParts := strings.Split(authHeader, " ")
|
||||
if len(tokenParts) != 2 || tokenParts[0] != "Bearer" {
|
||||
logger.Error(ctx, "Invalid Authorization header format")
|
||||
appErr := errors.NewValidationError("Invalid Authorization header format")
|
||||
c.Error(appErr)
|
||||
return
|
||||
}
|
||||
|
||||
token := tokenParts[1]
|
||||
|
||||
// Revoke token
|
||||
err := h.userService.RevokeToken(ctx, token)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "Failed to revoke token: %v", err)
|
||||
appErr := errors.NewInternalServerError("Logout failed").WithDetails(err.Error())
|
||||
c.Error(appErr)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info(ctx, "User logged out successfully")
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "Logout successful",
|
||||
})
|
||||
}
|
||||
|
||||
// RefreshToken handles the HTTP request for refreshing access tokens
|
||||
// It extracts the refresh token from the request body and generates new tokens
|
||||
// Parameters:
|
||||
// - c: Gin context for the HTTP request
|
||||
func (h *AuthHandler) RefreshToken(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
logger.Info(ctx, "Start token refresh")
|
||||
|
||||
var req struct {
|
||||
RefreshToken string `json:"refreshToken" binding:"required"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
logger.Error(ctx, "Failed to parse refresh token request", err)
|
||||
appErr := errors.NewValidationError("Invalid refresh token request").WithDetails(err.Error())
|
||||
c.Error(appErr)
|
||||
return
|
||||
}
|
||||
|
||||
// Call service to refresh token
|
||||
accessToken, newRefreshToken, err := h.userService.RefreshToken(ctx, req.RefreshToken)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "Failed to refresh token: %v", err)
|
||||
appErr := errors.NewUnauthorizedError("Token refresh failed").WithDetails(err.Error())
|
||||
c.Error(appErr)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info(ctx, "Token refreshed successfully")
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "Token refreshed successfully",
|
||||
"access_token": accessToken,
|
||||
"refresh_token": newRefreshToken,
|
||||
})
|
||||
}
|
||||
|
||||
// GetCurrentUser handles the HTTP request for getting current user information
|
||||
// It extracts the user from the context (set by auth middleware) and returns user info
|
||||
// Parameters:
|
||||
// - c: Gin context for the HTTP request
|
||||
func (h *AuthHandler) GetCurrentUser(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
logger.Debugf(ctx, "Get current user info")
|
||||
|
||||
// Get current user from service (which extracts from context)
|
||||
user, err := h.userService.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "Failed to get current user: %v", err)
|
||||
appErr := errors.NewUnauthorizedError("Failed to get user information").WithDetails(err.Error())
|
||||
c.Error(appErr)
|
||||
return
|
||||
}
|
||||
|
||||
// Get tenant information
|
||||
var tenant *types.Tenant
|
||||
if user.TenantID > 0 {
|
||||
tenant, err = h.tenantService.GetTenantByID(ctx, user.TenantID)
|
||||
if err != nil {
|
||||
logger.Warnf(ctx, "Failed to get tenant info for user %s, tenant ID %d: %v", user.Email, user.TenantID, err)
|
||||
// Don't fail the request if tenant info is not available
|
||||
} else {
|
||||
logger.Debugf(ctx, "Retrieved tenant info for user %s: %s", user.Email, tenant.Name)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Debugf(ctx, "Retrieved current user info: %s", user.Email)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"data": gin.H{
|
||||
"user": user.ToUserInfo(),
|
||||
"tenant": tenant,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// ChangePassword handles the HTTP request for changing user password
|
||||
// It extracts the current user and validates the old password before setting new one
|
||||
// Parameters:
|
||||
// - c: Gin context for the HTTP request
|
||||
func (h *AuthHandler) ChangePassword(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
logger.Info(ctx, "Start password change")
|
||||
|
||||
var req struct {
|
||||
OldPassword string `json:"old_password" binding:"required"`
|
||||
NewPassword string `json:"new_password" binding:"required,min=6"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
logger.Error(ctx, "Failed to parse password change request", err)
|
||||
appErr := errors.NewValidationError("Invalid password change request").WithDetails(err.Error())
|
||||
c.Error(appErr)
|
||||
return
|
||||
}
|
||||
|
||||
// Get current user
|
||||
user, err := h.userService.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "Failed to get current user: %v", err)
|
||||
appErr := errors.NewUnauthorizedError("Failed to get user information").WithDetails(err.Error())
|
||||
c.Error(appErr)
|
||||
return
|
||||
}
|
||||
|
||||
// Change password
|
||||
err = h.userService.ChangePassword(ctx, user.ID, req.OldPassword, req.NewPassword)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "Failed to change password: %v", err)
|
||||
appErr := errors.NewBadRequestError("Password change failed").WithDetails(err.Error())
|
||||
c.Error(appErr)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Infof(ctx, "Password changed successfully for user: %s", user.Email)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "Password changed successfully",
|
||||
})
|
||||
}
|
||||
|
||||
// ValidateToken handles the HTTP request for validating access tokens
|
||||
// It extracts the token from the Authorization header and validates it
|
||||
// Parameters:
|
||||
// - c: Gin context for the HTTP request
|
||||
func (h *AuthHandler) ValidateToken(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
logger.Info(ctx, "Start token validation")
|
||||
|
||||
// Extract token from Authorization header
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
logger.Error(ctx, "Missing Authorization header")
|
||||
appErr := errors.NewValidationError("Authorization header is required")
|
||||
c.Error(appErr)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse Bearer token
|
||||
tokenParts := strings.Split(authHeader, " ")
|
||||
if len(tokenParts) != 2 || tokenParts[0] != "Bearer" {
|
||||
logger.Error(ctx, "Invalid Authorization header format")
|
||||
appErr := errors.NewValidationError("Invalid Authorization header format")
|
||||
c.Error(appErr)
|
||||
return
|
||||
}
|
||||
|
||||
token := tokenParts[1]
|
||||
|
||||
// Validate token
|
||||
user, err := h.userService.ValidateToken(ctx, token)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "Failed to validate token: %v", err)
|
||||
appErr := errors.NewUnauthorizedError("Token validation failed").WithDetails(err.Error())
|
||||
c.Error(appErr)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Infof(ctx, "Token validated successfully for user: %s", user.Email)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "Token is valid",
|
||||
"user": user.ToUserInfo(),
|
||||
})
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/Tencent/WeKnora/internal/logger"
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
||||
secutils "github.com/Tencent/WeKnora/internal/utils"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -52,6 +53,13 @@ func (h *ChunkHandler) ListKnowledgeChunks(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 对 chunk 内容进行安全清理
|
||||
for _, chunk := range result.Data.([]*types.Chunk) {
|
||||
if chunk.Content != "" {
|
||||
chunk.Content = secutils.SanitizeForDisplay(chunk.Content)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Infof(
|
||||
ctx, "Successfully retrieved knowledge chunks list, knowledge ID: %s, total: %d",
|
||||
knowledgeID, result.Total,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -7,7 +7,6 @@ import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/application/service"
|
||||
"github.com/Tencent/WeKnora/internal/config"
|
||||
"github.com/Tencent/WeKnora/internal/errors"
|
||||
"github.com/Tencent/WeKnora/internal/logger"
|
||||
@@ -22,7 +21,6 @@ type SessionHandler struct {
|
||||
sessionService interfaces.SessionService // Service for managing sessions
|
||||
streamManager interfaces.StreamManager // Manager for handling streaming responses
|
||||
config *config.Config // Application configuration
|
||||
testDataService *service.TestDataService // Service for test data (models, etc.)
|
||||
knowledgebaseService interfaces.KnowledgeBaseService
|
||||
}
|
||||
|
||||
@@ -32,7 +30,6 @@ func NewSessionHandler(
|
||||
messageService interfaces.MessageService,
|
||||
streamManager interfaces.StreamManager,
|
||||
config *config.Config,
|
||||
testDataService *service.TestDataService,
|
||||
knowledgebaseService interfaces.KnowledgeBaseService,
|
||||
) *SessionHandler {
|
||||
return &SessionHandler{
|
||||
@@ -40,7 +37,6 @@ func NewSessionHandler(
|
||||
messageService: messageService,
|
||||
streamManager: streamManager,
|
||||
config: config,
|
||||
testDataService: testDataService,
|
||||
knowledgebaseService: knowledgebaseService,
|
||||
}
|
||||
}
|
||||
@@ -87,7 +83,7 @@ type CreateSessionRequest struct {
|
||||
func (h *SessionHandler) CreateSession(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
logger.Info(ctx, "Start creating session")
|
||||
// logger.Infof(ctx, "Start creating session, config: %+v", h.config.Conversation)
|
||||
|
||||
// Parse and validate the request body
|
||||
var request CreateSessionRequest
|
||||
|
||||
49
internal/handler/system.go
Normal file
49
internal/handler/system.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"github.com/Tencent/WeKnora/internal/logger"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// SystemHandler handles system-related requests
|
||||
type SystemHandler struct{}
|
||||
|
||||
// NewSystemHandler creates a new system handler
|
||||
func NewSystemHandler() *SystemHandler {
|
||||
return &SystemHandler{}
|
||||
}
|
||||
|
||||
// GetSystemInfoResponse defines the response structure for system info
|
||||
type GetSystemInfoResponse struct {
|
||||
Version string `json:"version"`
|
||||
CommitID string `json:"commit_id,omitempty"`
|
||||
BuildTime string `json:"build_time,omitempty"`
|
||||
GoVersion string `json:"go_version,omitempty"`
|
||||
}
|
||||
|
||||
// 编译时注入的版本信息
|
||||
var (
|
||||
Version = "unknown"
|
||||
CommitID = "unknown"
|
||||
BuildTime = "unknown"
|
||||
GoVersion = "unknown"
|
||||
)
|
||||
|
||||
// GetSystemInfo gets system information including version and commit ID
|
||||
func (h *SystemHandler) GetSystemInfo(c *gin.Context) {
|
||||
ctx := logger.CloneContext(c.Request.Context())
|
||||
|
||||
response := GetSystemInfoResponse{
|
||||
Version: Version,
|
||||
CommitID: CommitID,
|
||||
BuildTime: BuildTime,
|
||||
GoVersion: GoVersion,
|
||||
}
|
||||
|
||||
logger.Info(ctx, "System info retrieved successfully")
|
||||
c.JSON(200, gin.H{
|
||||
"code": 0,
|
||||
"msg": "success",
|
||||
"data": response,
|
||||
})
|
||||
}
|
||||
@@ -1,84 +0,0 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/config"
|
||||
"github.com/Tencent/WeKnora/internal/logger"
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
||||
)
|
||||
|
||||
// TestDataHandler handles HTTP requests related to test data operations
|
||||
// Used for development and testing purposes to provide sample data
|
||||
type TestDataHandler struct {
|
||||
config *config.Config
|
||||
kbService interfaces.KnowledgeBaseService
|
||||
tenantService interfaces.TenantService
|
||||
}
|
||||
|
||||
// NewTestDataHandler creates a new instance of the test data handler
|
||||
// Parameters:
|
||||
// - config: Application configuration instance
|
||||
// - kbService: Knowledge base service for accessing knowledge base data
|
||||
// - tenantService: Tenant service for accessing tenant data
|
||||
//
|
||||
// Returns a pointer to the new TestDataHandler instance
|
||||
func NewTestDataHandler(
|
||||
config *config.Config,
|
||||
kbService interfaces.KnowledgeBaseService,
|
||||
tenantService interfaces.TenantService,
|
||||
) *TestDataHandler {
|
||||
return &TestDataHandler{
|
||||
config: config,
|
||||
kbService: kbService,
|
||||
tenantService: tenantService,
|
||||
}
|
||||
}
|
||||
|
||||
// GetTestData handles the HTTP request to retrieve test data for development purposes
|
||||
// It returns predefined test tenant and knowledge base information
|
||||
// This endpoint is only available in non-production environments
|
||||
// Parameters:
|
||||
// - c: Gin context for the HTTP request
|
||||
func (h *TestDataHandler) GetTestData(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
logger.Info(ctx, "Start retrieving test data")
|
||||
|
||||
tenantID := uint(types.InitDefaultTenantID)
|
||||
logger.Debugf(ctx, "Test tenant ID environment variable: %d", tenantID)
|
||||
|
||||
// Retrieve the test tenant data
|
||||
logger.Infof(ctx, "Retrieving test tenant, ID: %d", tenantID)
|
||||
tenant, err := h.tenantService.GetTenantByID(ctx, tenantID)
|
||||
if err != nil {
|
||||
logger.ErrorWithFields(ctx, err, nil)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
knowledgeBaseID := types.InitDefaultKnowledgeBaseID
|
||||
logger.Debugf(ctx, "Test knowledge base ID environment variable: %s", knowledgeBaseID)
|
||||
|
||||
// Retrieve the test knowledge base data
|
||||
logger.Infof(ctx, "Retrieving test knowledge base, ID: %s", knowledgeBaseID)
|
||||
knowledgeBase, err := h.kbService.GetKnowledgeBaseByID(ctx, knowledgeBaseID)
|
||||
if err != nil {
|
||||
logger.ErrorWithFields(ctx, err, nil)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info(ctx, "Test data retrieved successfully")
|
||||
// Return the test data in the response
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"data": gin.H{
|
||||
"tenant": tenant,
|
||||
"knowledge_bases": []types.KnowledgeBase{*knowledgeBase},
|
||||
},
|
||||
"success": true,
|
||||
})
|
||||
}
|
||||
@@ -16,9 +16,10 @@ import (
|
||||
|
||||
// 无需认证的API列表
|
||||
var noAuthAPI = map[string][]string{
|
||||
"/api/v1/test-data": {"GET"},
|
||||
"/api/v1/tenants": {"POST"},
|
||||
"/api/v1/initialization/*": {"GET", "POST"},
|
||||
"/health": {"GET"},
|
||||
"/api/v1/auth/register": {"POST"},
|
||||
"/api/v1/auth/login": {"POST"},
|
||||
"/api/v1/auth/refresh": {"POST"},
|
||||
}
|
||||
|
||||
// 检查请求是否在无需认证的API列表中
|
||||
@@ -37,7 +38,7 @@ func isNoAuthAPI(path string, method string) bool {
|
||||
}
|
||||
|
||||
// Auth 认证中间件
|
||||
func Auth(tenantService interfaces.TenantService, cfg *config.Config) gin.HandlerFunc {
|
||||
func Auth(tenantService interfaces.TenantService, userService interfaces.UserService, cfg *config.Config) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// ignore OPTIONS request
|
||||
if c.Request.Method == "OPTIONS" {
|
||||
@@ -51,53 +52,90 @@ func Auth(tenantService interfaces.TenantService, cfg *config.Config) gin.Handle
|
||||
return
|
||||
}
|
||||
|
||||
// Get API Key from request header
|
||||
// 尝试JWT Token认证
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader != "" && strings.HasPrefix(authHeader, "Bearer ") {
|
||||
token := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
user, err := userService.ValidateToken(c.Request.Context(), token)
|
||||
if err == nil && user != nil {
|
||||
// JWT Token认证成功
|
||||
// 获取租户信息
|
||||
tenant, err := tenantService.GetTenantByID(c.Request.Context(), user.TenantID)
|
||||
if err != nil {
|
||||
log.Printf("Error getting tenant by ID: %v, tenantID: %d, userID: %s", err, user.TenantID, user.ID)
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "Unauthorized: invalid tenant",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 存储用户和租户信息到上下文
|
||||
c.Set(types.TenantIDContextKey.String(), user.TenantID)
|
||||
c.Set(types.TenantInfoContextKey.String(), tenant)
|
||||
c.Set("user", user)
|
||||
c.Request = c.Request.WithContext(
|
||||
context.WithValue(
|
||||
context.WithValue(
|
||||
context.WithValue(c.Request.Context(), types.TenantIDContextKey, user.TenantID),
|
||||
types.TenantInfoContextKey, tenant,
|
||||
),
|
||||
"user", user,
|
||||
),
|
||||
)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 尝试X-API-Key认证(兼容模式)
|
||||
apiKey := c.GetHeader("X-API-Key")
|
||||
if apiKey == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
|
||||
c.Abort()
|
||||
if apiKey != "" {
|
||||
// Get tenant information
|
||||
tenantID, err := tenantService.ExtractTenantIDFromAPIKey(apiKey)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "Unauthorized: invalid API key format",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Verify API key validity (matches the one in database)
|
||||
t, err := tenantService.GetTenantByID(c.Request.Context(), tenantID)
|
||||
if err != nil {
|
||||
log.Printf("Error getting tenant by ID: %v, tenantID: %d, apiKey: %s", err, tenantID, apiKey)
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "Unauthorized: invalid API key",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
if t == nil || t.APIKey != apiKey {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "Unauthorized: invalid API key",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Store tenant ID in context
|
||||
c.Set(types.TenantIDContextKey.String(), tenantID)
|
||||
c.Set(types.TenantInfoContextKey.String(), t)
|
||||
c.Request = c.Request.WithContext(
|
||||
context.WithValue(
|
||||
context.WithValue(c.Request.Context(), types.TenantIDContextKey, tenantID),
|
||||
types.TenantInfoContextKey, t,
|
||||
),
|
||||
)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Get tenant information
|
||||
tenantID, err := tenantService.ExtractTenantIDFromAPIKey(apiKey)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "Unauthorized: invalid API key format",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Verify API key validity (matches the one in database)
|
||||
t, err := tenantService.GetTenantByID(c.Request.Context(), tenantID)
|
||||
if err != nil {
|
||||
log.Printf("Error getting tenant by ID: %v, tenantID: %d, apiKey: %s", err, tenantID, apiKey)
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "Unauthorized: invalid API key",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
if t == nil || t.APIKey != apiKey {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "Unauthorized: invalid API key",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Store tenant ID in context
|
||||
c.Set(types.TenantIDContextKey.String(), tenantID)
|
||||
c.Set(types.TenantInfoContextKey.String(), t)
|
||||
c.Request = c.Request.WithContext(
|
||||
context.WithValue(
|
||||
context.WithValue(c.Request.Context(), types.TenantIDContextKey, tenantID),
|
||||
types.TenantInfoContextKey, t,
|
||||
),
|
||||
)
|
||||
c.Next()
|
||||
// 没有提供任何认证信息
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized: missing authentication"})
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -85,6 +85,7 @@ func (c *RemoteAPIChat) buildChatCompletionRequest(messages []Message,
|
||||
Messages: c.convertMessages(messages),
|
||||
Stream: isStream,
|
||||
}
|
||||
thinking := false
|
||||
|
||||
// 添加可选参数
|
||||
if opts != nil {
|
||||
@@ -106,6 +107,13 @@ func (c *RemoteAPIChat) buildChatCompletionRequest(messages []Message,
|
||||
if opts.PresencePenalty > 0 {
|
||||
req.PresencePenalty = float32(opts.PresencePenalty)
|
||||
}
|
||||
if opts.Thinking != nil {
|
||||
thinking = *opts.Thinking
|
||||
}
|
||||
}
|
||||
|
||||
req.ChatTemplateKwargs = map[string]interface{}{
|
||||
"enable_thinking": thinking,
|
||||
}
|
||||
|
||||
return req
|
||||
|
||||
@@ -18,6 +18,14 @@ type RouterParams struct {
|
||||
dig.In
|
||||
|
||||
Config *config.Config
|
||||
UserService interfaces.UserService
|
||||
KBService interfaces.KnowledgeBaseService
|
||||
KnowledgeService interfaces.KnowledgeService
|
||||
ChunkService interfaces.ChunkService
|
||||
SessionService interfaces.SessionService
|
||||
MessageService interfaces.MessageService
|
||||
ModelService interfaces.ModelService
|
||||
EvaluationService interfaces.EvaluationService
|
||||
KBHandler *handler.KnowledgeBaseHandler
|
||||
KnowledgeHandler *handler.KnowledgeHandler
|
||||
TenantHandler *handler.TenantHandler
|
||||
@@ -25,10 +33,11 @@ type RouterParams struct {
|
||||
ChunkHandler *handler.ChunkHandler
|
||||
SessionHandler *handler.SessionHandler
|
||||
MessageHandler *handler.MessageHandler
|
||||
TestDataHandler *handler.TestDataHandler
|
||||
ModelHandler *handler.ModelHandler
|
||||
EvaluationHandler *handler.EvaluationHandler
|
||||
AuthHandler *handler.AuthHandler
|
||||
InitializationHandler *handler.InitializationHandler
|
||||
SystemHandler *handler.SystemHandler
|
||||
}
|
||||
|
||||
// NewRouter 创建新的路由
|
||||
@@ -50,7 +59,7 @@ func NewRouter(params RouterParams) *gin.Engine {
|
||||
r.Use(middleware.Logger())
|
||||
r.Use(middleware.Recovery())
|
||||
r.Use(middleware.ErrorHandler())
|
||||
r.Use(middleware.Auth(params.TenantService, params.Config))
|
||||
r.Use(middleware.Auth(params.TenantService, params.UserService, params.Config))
|
||||
|
||||
// 添加OpenTelemetry追踪中间件
|
||||
r.Use(middleware.TracingMiddleware())
|
||||
@@ -60,31 +69,10 @@ func NewRouter(params RouterParams) *gin.Engine {
|
||||
c.JSON(200, gin.H{"status": "ok"})
|
||||
})
|
||||
|
||||
// 测试数据接口(不需要认证)
|
||||
r.GET("/api/v1/test-data", params.TestDataHandler.GetTestData)
|
||||
|
||||
// 初始化接口(不需要认证)
|
||||
r.GET("/api/v1/initialization/status", params.InitializationHandler.CheckStatus)
|
||||
r.GET("/api/v1/initialization/config", params.InitializationHandler.GetCurrentConfig)
|
||||
r.POST("/api/v1/initialization/initialize", params.InitializationHandler.Initialize)
|
||||
|
||||
// Ollama相关接口(不需要认证)
|
||||
r.GET("/api/v1/initialization/ollama/status", params.InitializationHandler.CheckOllamaStatus)
|
||||
r.GET("/api/v1/initialization/ollama/models", params.InitializationHandler.ListOllamaModels)
|
||||
r.POST("/api/v1/initialization/ollama/models/check", params.InitializationHandler.CheckOllamaModels)
|
||||
r.POST("/api/v1/initialization/ollama/models/download", params.InitializationHandler.DownloadOllamaModel)
|
||||
r.GET("/api/v1/initialization/ollama/download/progress/:taskId", params.InitializationHandler.GetDownloadProgress)
|
||||
r.GET("/api/v1/initialization/ollama/download/tasks", params.InitializationHandler.ListDownloadTasks)
|
||||
|
||||
// 远程API相关接口(不需要认证)
|
||||
r.POST("/api/v1/initialization/remote/check", params.InitializationHandler.CheckRemoteModel)
|
||||
r.POST("/api/v1/initialization/embedding/test", params.InitializationHandler.TestEmbeddingModel)
|
||||
r.POST("/api/v1/initialization/rerank/check", params.InitializationHandler.CheckRerankModel)
|
||||
r.POST("/api/v1/initialization/multimodal/test", params.InitializationHandler.TestMultimodalFunction)
|
||||
|
||||
// 需要认证的API路由
|
||||
v1 := r.Group("/api/v1")
|
||||
{
|
||||
RegisterAuthRoutes(v1, params.AuthHandler)
|
||||
RegisterTenantRoutes(v1, params.TenantHandler)
|
||||
RegisterKnowledgeBaseRoutes(v1, params.KBHandler)
|
||||
RegisterKnowledgeRoutes(v1, params.KnowledgeHandler)
|
||||
@@ -94,6 +82,8 @@ func NewRouter(params RouterParams) *gin.Engine {
|
||||
RegisterMessageRoutes(v1, params.MessageHandler)
|
||||
RegisterModelRoutes(v1, params.ModelHandler)
|
||||
RegisterEvaluationRoutes(v1, params.EvaluationHandler)
|
||||
RegisterInitializationRoutes(v1, params.InitializationHandler)
|
||||
RegisterSystemRoutes(v1, params.SystemHandler)
|
||||
}
|
||||
|
||||
return r
|
||||
@@ -247,3 +237,46 @@ func RegisterEvaluationRoutes(r *gin.RouterGroup, handler *handler.EvaluationHan
|
||||
evaluationRoutes.GET("/", handler.GetEvaluationResult)
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterAuthRoutes registers authentication routes
|
||||
func RegisterAuthRoutes(r *gin.RouterGroup, handler *handler.AuthHandler) {
|
||||
r.POST("/auth/register", handler.Register)
|
||||
r.POST("/auth/login", handler.Login)
|
||||
r.POST("/auth/refresh", handler.RefreshToken)
|
||||
r.GET("/auth/validate", handler.ValidateToken)
|
||||
r.POST("/auth/logout", handler.Logout)
|
||||
r.GET("/auth/me", handler.GetCurrentUser)
|
||||
r.POST("/auth/change-password", handler.ChangePassword)
|
||||
}
|
||||
|
||||
func RegisterInitializationRoutes(r *gin.RouterGroup, handler *handler.InitializationHandler) {
|
||||
// 初始化接口
|
||||
r.GET("/initialization/config/:kbId", handler.GetCurrentConfigByKB)
|
||||
r.POST("/initialization/initialize/:kbId", handler.InitializeByKB)
|
||||
|
||||
// Ollama相关接口
|
||||
r.GET("/initialization/ollama/status", handler.CheckOllamaStatus)
|
||||
r.GET("/initialization/ollama/models", handler.ListOllamaModels)
|
||||
r.POST("/initialization/ollama/models/check", handler.CheckOllamaModels)
|
||||
r.POST("/initialization/ollama/models/download", handler.DownloadOllamaModel)
|
||||
r.GET("/initialization/ollama/download/progress/:taskId", handler.GetDownloadProgress)
|
||||
r.GET("/initialization/ollama/download/tasks", handler.ListDownloadTasks)
|
||||
|
||||
// 远程API相关接口
|
||||
r.POST("/initialization/remote/check", handler.CheckRemoteModel)
|
||||
r.POST("/initialization/embedding/test", handler.TestEmbeddingModel)
|
||||
r.POST("/initialization/rerank/check", handler.CheckRerankModel)
|
||||
r.POST("/initialization/multimodal/test", handler.TestMultimodalFunction)
|
||||
|
||||
r.POST("/initialization/extract/text-relation", handler.ExtractTextRelations)
|
||||
r.POST("/initialization/extract/fabri-tag", handler.FabriTag)
|
||||
r.POST("/initialization/extract/fabri-text", handler.FabriText)
|
||||
}
|
||||
|
||||
// RegisterSystemRoutes registers system information routes
|
||||
func RegisterSystemRoutes(r *gin.RouterGroup, handler *handler.SystemHandler) {
|
||||
systemRoutes := r.Group("/system")
|
||||
{
|
||||
systemRoutes.GET("/info", handler.GetSystemInfo)
|
||||
}
|
||||
}
|
||||
|
||||
66
internal/router/task.go
Normal file
66
internal/router/task.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
||||
"github.com/hibiken/asynq"
|
||||
"go.uber.org/dig"
|
||||
)
|
||||
|
||||
type AsynqTaskParams struct {
|
||||
dig.In
|
||||
|
||||
Server *asynq.Server
|
||||
Extracter interfaces.Extracter
|
||||
}
|
||||
|
||||
func getAsynqRedisClientOpt() *asynq.RedisClientOpt {
|
||||
opt := &asynq.RedisClientOpt{
|
||||
Addr: os.Getenv("REDIS_ADDR"),
|
||||
Password: os.Getenv("REDIS_PASSWORD"),
|
||||
ReadTimeout: 100 * time.Millisecond,
|
||||
WriteTimeout: 200 * time.Millisecond,
|
||||
DB: 0,
|
||||
}
|
||||
return opt
|
||||
}
|
||||
|
||||
func NewAsyncqClient() *asynq.Client {
|
||||
opt := getAsynqRedisClientOpt()
|
||||
client := asynq.NewClient(opt)
|
||||
return client
|
||||
}
|
||||
|
||||
func NewAsynqServer() *asynq.Server {
|
||||
opt := getAsynqRedisClientOpt()
|
||||
srv := asynq.NewServer(
|
||||
opt,
|
||||
asynq.Config{
|
||||
Queues: map[string]int{
|
||||
"critical": 6, // Highest priority queue
|
||||
"default": 3, // Default priority queue
|
||||
"low": 1, // Lowest priority queue
|
||||
},
|
||||
},
|
||||
)
|
||||
return srv
|
||||
}
|
||||
|
||||
func RunAsynqServer(params AsynqTaskParams) *asynq.ServeMux {
|
||||
// Create a new mux and register all handlers
|
||||
mux := asynq.NewServeMux()
|
||||
|
||||
mux.HandleFunc(types.TypeChunkExtract, params.Extracter.Extract)
|
||||
|
||||
go func() {
|
||||
// Start the server
|
||||
if err := params.Server.Run(mux); err != nil {
|
||||
log.Fatalf("could not run server: %v", err)
|
||||
}
|
||||
}()
|
||||
return mux
|
||||
}
|
||||
@@ -28,6 +28,8 @@ type ChatManage struct {
|
||||
SearchResult []*SearchResult `json:"-"` // Results from search phase
|
||||
RerankResult []*SearchResult `json:"-"` // Results after reranking
|
||||
MergeResult []*SearchResult `json:"-"` // Final merged results after all processing
|
||||
Entity []string `json:"-"` // List of identified entities
|
||||
GraphResult *GraphData `json:"-"` // Graph data from search phase
|
||||
UserContent string `json:"-"` // Processed user content
|
||||
ChatResponse *ChatResponse `json:"-"` // Final response from chat model
|
||||
ResponseChan <-chan StreamResponse `json:"-"` // Channel for streaming responses
|
||||
@@ -75,6 +77,7 @@ const (
|
||||
PREPROCESS_QUERY EventType = "preprocess_query" // Query preprocessing stage
|
||||
REWRITE_QUERY EventType = "rewrite_query" // Query rewriting for better retrieval
|
||||
CHUNK_SEARCH EventType = "chunk_search" // Search for relevant chunks
|
||||
ENTITY_SEARCH EventType = "entity_search" // Search for relevant entities
|
||||
CHUNK_RERANK EventType = "chunk_rerank" // Rerank search results
|
||||
CHUNK_MERGE EventType = "chunk_merge" // Merge similar chunks
|
||||
INTO_CHAT_MESSAGE EventType = "into_chat_message" // Convert chunks into chat messages
|
||||
@@ -104,6 +107,7 @@ var Pipline = map[string][]EventType{
|
||||
REWRITE_QUERY,
|
||||
PREPROCESS_QUERY,
|
||||
CHUNK_SEARCH,
|
||||
ENTITY_SEARCH,
|
||||
CHUNK_RERANK,
|
||||
CHUNK_MERGE,
|
||||
FILTER_TOP_K,
|
||||
|
||||
@@ -19,6 +19,7 @@ const (
|
||||
MatchTypeHistory
|
||||
MatchTypeParentChunk // 父Chunk匹配类型
|
||||
MatchTypeRelationChunk // 关系Chunk匹配类型
|
||||
MatchTypeGraph
|
||||
)
|
||||
|
||||
// IndexInfo contains information about indexed content
|
||||
|
||||
51
internal/types/extract_graph.go
Normal file
51
internal/types/extract_graph.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package types
|
||||
|
||||
const (
|
||||
TypeChunkExtract = "chunk:extract"
|
||||
)
|
||||
|
||||
type ExtractChunkPayload struct {
|
||||
TenantID uint `json:"tenant_id"`
|
||||
ChunkID string `json:"chunk_id"`
|
||||
ModelID string `json:"model_id"`
|
||||
}
|
||||
|
||||
type PromptTemplateStructured struct {
|
||||
Description string `json:"description"`
|
||||
Tags []string `json:"tags"`
|
||||
Examples []GraphData `json:"examples"`
|
||||
}
|
||||
|
||||
type GraphNode struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Chunks []string `json:"chunks,omitempty"`
|
||||
Attributes []string `json:"attributes,omitempty"`
|
||||
}
|
||||
|
||||
type GraphRelation struct {
|
||||
Node1 string `json:"node1,omitempty"`
|
||||
Node2 string `json:"node2,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
}
|
||||
|
||||
type GraphData struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
Node []*GraphNode `json:"node,omitempty"`
|
||||
Relation []*GraphRelation `json:"relation,omitempty"`
|
||||
}
|
||||
|
||||
type NameSpace struct {
|
||||
KnowledgeBase string `json:"knowledge_base"`
|
||||
Knowledge string `json:"knowledge"`
|
||||
}
|
||||
|
||||
func (n NameSpace) Labels() []string {
|
||||
res := make([]string, 0)
|
||||
if n.KnowledgeBase != "" {
|
||||
res = append(res, n.KnowledgeBase)
|
||||
}
|
||||
if n.Knowledge != "" {
|
||||
res = append(res, n.Knowledge)
|
||||
}
|
||||
return res
|
||||
}
|
||||
11
internal/types/interfaces/extracter.go
Normal file
11
internal/types/interfaces/extracter.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package interfaces
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/hibiken/asynq"
|
||||
)
|
||||
|
||||
type Extracter interface {
|
||||
Extract(ctx context.Context, t *asynq.Task) error
|
||||
}
|
||||
13
internal/types/interfaces/retriever_graph.go
Normal file
13
internal/types/interfaces/retriever_graph.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package interfaces
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
)
|
||||
|
||||
type RetrieveGraphRepository interface {
|
||||
AddGraph(ctx context.Context, namespace types.NameSpace, graphs []*types.GraphData) error
|
||||
DelGraph(ctx context.Context, namespace []types.NameSpace) error
|
||||
SearchNode(ctx context.Context, namespace types.NameSpace, nodes []string) (*types.GraphData, error)
|
||||
}
|
||||
75
internal/types/interfaces/user.go
Normal file
75
internal/types/interfaces/user.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package interfaces
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
)
|
||||
|
||||
// UserService defines the user service interface
|
||||
type UserService interface {
|
||||
// Register creates a new user account
|
||||
Register(ctx context.Context, req *types.RegisterRequest) (*types.User, error)
|
||||
// Login authenticates a user and returns tokens
|
||||
Login(ctx context.Context, req *types.LoginRequest) (*types.LoginResponse, error)
|
||||
// GetUserByID gets a user by ID
|
||||
GetUserByID(ctx context.Context, id string) (*types.User, error)
|
||||
// GetUserByEmail gets a user by email
|
||||
GetUserByEmail(ctx context.Context, email string) (*types.User, error)
|
||||
// GetUserByUsername gets a user by username
|
||||
GetUserByUsername(ctx context.Context, username string) (*types.User, error)
|
||||
// UpdateUser updates user information
|
||||
UpdateUser(ctx context.Context, user *types.User) error
|
||||
// DeleteUser deletes a user
|
||||
DeleteUser(ctx context.Context, id string) error
|
||||
// ChangePassword changes user password
|
||||
ChangePassword(ctx context.Context, userID string, oldPassword, newPassword string) error
|
||||
// ValidatePassword validates user password
|
||||
ValidatePassword(ctx context.Context, userID string, password string) error
|
||||
// GenerateTokens generates access and refresh tokens for user
|
||||
GenerateTokens(ctx context.Context, user *types.User) (accessToken, refreshToken string, err error)
|
||||
// ValidateToken validates an access token
|
||||
ValidateToken(ctx context.Context, token string) (*types.User, error)
|
||||
// RefreshToken refreshes access token using refresh token
|
||||
RefreshToken(ctx context.Context, refreshToken string) (accessToken, newRefreshToken string, err error)
|
||||
// RevokeToken revokes a token
|
||||
RevokeToken(ctx context.Context, token string) error
|
||||
// GetCurrentUser gets current user from context
|
||||
GetCurrentUser(ctx context.Context) (*types.User, error)
|
||||
}
|
||||
|
||||
// UserRepository defines the user repository interface
|
||||
type UserRepository interface {
|
||||
// CreateUser creates a user
|
||||
CreateUser(ctx context.Context, user *types.User) error
|
||||
// GetUserByID gets a user by ID
|
||||
GetUserByID(ctx context.Context, id string) (*types.User, error)
|
||||
// GetUserByEmail gets a user by email
|
||||
GetUserByEmail(ctx context.Context, email string) (*types.User, error)
|
||||
// GetUserByUsername gets a user by username
|
||||
GetUserByUsername(ctx context.Context, username string) (*types.User, error)
|
||||
// UpdateUser updates a user
|
||||
UpdateUser(ctx context.Context, user *types.User) error
|
||||
// DeleteUser deletes a user
|
||||
DeleteUser(ctx context.Context, id string) error
|
||||
// ListUsers lists users with pagination
|
||||
ListUsers(ctx context.Context, offset, limit int) ([]*types.User, error)
|
||||
}
|
||||
|
||||
// AuthTokenRepository defines the auth token repository interface
|
||||
type AuthTokenRepository interface {
|
||||
// CreateToken creates an auth token
|
||||
CreateToken(ctx context.Context, token *types.AuthToken) error
|
||||
// GetTokenByValue gets a token by its value
|
||||
GetTokenByValue(ctx context.Context, tokenValue string) (*types.AuthToken, error)
|
||||
// GetTokensByUserID gets all tokens for a user
|
||||
GetTokensByUserID(ctx context.Context, userID string) ([]*types.AuthToken, error)
|
||||
// UpdateToken updates a token
|
||||
UpdateToken(ctx context.Context, token *types.AuthToken) error
|
||||
// DeleteToken deletes a token
|
||||
DeleteToken(ctx context.Context, id string) error
|
||||
// DeleteExpiredTokens deletes all expired tokens
|
||||
DeleteExpiredTokens(ctx context.Context) error
|
||||
// RevokeTokensByUserID revokes all tokens for a user
|
||||
RevokeTokensByUserID(ctx context.Context, userID string) error
|
||||
}
|
||||
@@ -38,6 +38,8 @@ type KnowledgeBase struct {
|
||||
VLMConfig VLMConfig `yaml:"vlm_config" json:"vlm_config" gorm:"type:json"`
|
||||
// Storage config
|
||||
StorageConfig StorageConfig `yaml:"cos_config" json:"cos_config" gorm:"column:cos_config;type:json"`
|
||||
// Extract config
|
||||
ExtractConfig *ExtractConfig `yaml:"extract_config" json:"extract_config" gorm:"column:extract_config;type:json"`
|
||||
// Creation time of the knowledge base
|
||||
CreatedAt time.Time `yaml:"created_at" json:"created_at"`
|
||||
// Last updated time of the knowledge base
|
||||
@@ -167,3 +169,27 @@ func (c *VLMConfig) Scan(value interface{}) error {
|
||||
}
|
||||
return json.Unmarshal(b, c)
|
||||
}
|
||||
|
||||
type ExtractConfig struct {
|
||||
Text string `yaml:"text" json:"text"`
|
||||
Tags []string `yaml:"tags" json:"tags"`
|
||||
Nodes []*GraphNode `yaml:"nodes" json:"nodes"`
|
||||
Relations []*GraphRelation `yaml:"relations" json:"relations"`
|
||||
}
|
||||
|
||||
// Value implements the driver.Valuer interface, used to convert ExtractConfig to database value
|
||||
func (e ExtractConfig) Value() (driver.Value, error) {
|
||||
return json.Marshal(e)
|
||||
}
|
||||
|
||||
// Scan implements the sql.Scanner interface, used to convert database value to ExtractConfig
|
||||
func (e *ExtractConfig) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
b, ok := value.([]byte)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return json.Unmarshal(b, e)
|
||||
}
|
||||
|
||||
114
internal/types/user.go
Normal file
114
internal/types/user.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// User represents a user in the system
|
||||
type User struct {
|
||||
// Unique identifier of the user
|
||||
ID string `json:"id" gorm:"type:varchar(36);primaryKey"`
|
||||
// Username of the user
|
||||
Username string `json:"username" gorm:"type:varchar(100);uniqueIndex;not null"`
|
||||
// Email address of the user
|
||||
Email string `json:"email" gorm:"type:varchar(255);uniqueIndex;not null"`
|
||||
// Hashed password of the user
|
||||
PasswordHash string `json:"-" gorm:"type:varchar(255);not null"`
|
||||
// Avatar URL of the user
|
||||
Avatar string `json:"avatar" gorm:"type:varchar(500)"`
|
||||
// Tenant ID that the user belongs to
|
||||
TenantID uint `json:"tenant_id" gorm:"index"`
|
||||
// Whether the user is active
|
||||
IsActive bool `json:"is_active" gorm:"default:true"`
|
||||
// Creation time of the user
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
// Last updated time of the user
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
// Deletion time of the user
|
||||
DeletedAt gorm.DeletedAt `json:"deleted_at" gorm:"index"`
|
||||
|
||||
// Association relationship, not stored in the database
|
||||
Tenant *Tenant `json:"tenant,omitempty" gorm:"foreignKey:TenantID"`
|
||||
}
|
||||
|
||||
// AuthToken represents an authentication token
|
||||
type AuthToken struct {
|
||||
// Unique identifier of the token
|
||||
ID string `json:"id" gorm:"type:varchar(36);primaryKey"`
|
||||
// User ID that owns this token
|
||||
UserID string `json:"user_id" gorm:"type:varchar(36);index;not null"`
|
||||
// Token value (JWT or other format)
|
||||
Token string `json:"token" gorm:"type:text;not null"`
|
||||
// Token type (access_token, refresh_token)
|
||||
TokenType string `json:"token_type" gorm:"type:varchar(50);not null"`
|
||||
// Token expiration time
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
// Whether the token is revoked
|
||||
IsRevoked bool `json:"is_revoked" gorm:"default:false"`
|
||||
// Creation time of the token
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
// Last updated time of the token
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
// Association relationship
|
||||
User *User `json:"user,omitempty" gorm:"foreignKey:UserID"`
|
||||
}
|
||||
|
||||
// LoginRequest represents a login request
|
||||
type LoginRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
Password string `json:"password" binding:"required,min=6"`
|
||||
}
|
||||
|
||||
// RegisterRequest represents a registration request
|
||||
type RegisterRequest struct {
|
||||
Username string `json:"username" binding:"required,min=3,max=50"`
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
Password string `json:"password" binding:"required,min=6"`
|
||||
}
|
||||
|
||||
// LoginResponse represents a login response
|
||||
type LoginResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message,omitempty"`
|
||||
User *User `json:"user,omitempty"`
|
||||
Tenant *Tenant `json:"tenant,omitempty"`
|
||||
Token string `json:"token,omitempty"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
}
|
||||
|
||||
// RegisterResponse represents a registration response
|
||||
type RegisterResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message,omitempty"`
|
||||
User *User `json:"user,omitempty"`
|
||||
Tenant *Tenant `json:"tenant,omitempty"`
|
||||
}
|
||||
|
||||
// UserInfo represents user information for API responses
|
||||
type UserInfo struct {
|
||||
ID string `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email"`
|
||||
Avatar string `json:"avatar"`
|
||||
TenantID uint `json:"tenant_id"`
|
||||
IsActive bool `json:"is_active"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// ToUserInfo converts User to UserInfo (without sensitive data)
|
||||
func (u *User) ToUserInfo() *UserInfo {
|
||||
return &UserInfo{
|
||||
ID: u.ID,
|
||||
Username: u.Username,
|
||||
Email: u.Email,
|
||||
Avatar: u.Avatar,
|
||||
TenantID: u.TenantID,
|
||||
IsActive: u.IsActive,
|
||||
CreatedAt: u.CreatedAt,
|
||||
UpdatedAt: u.UpdatedAt,
|
||||
}
|
||||
}
|
||||
171
internal/utils/security.go
Normal file
171
internal/utils/security.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"html"
|
||||
"regexp"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// XSS 防护相关正则表达式
|
||||
var (
|
||||
// 匹配潜在的 XSS 攻击模式
|
||||
xssPatterns = []*regexp.Regexp{
|
||||
regexp.MustCompile(`(?i)<script[^>]*>.*?</script>`),
|
||||
regexp.MustCompile(`(?i)<iframe[^>]*>.*?</iframe>`),
|
||||
regexp.MustCompile(`(?i)<object[^>]*>.*?</object>`),
|
||||
regexp.MustCompile(`(?i)<embed[^>]*>.*?</embed>`),
|
||||
regexp.MustCompile(`(?i)<embed[^>]*>`),
|
||||
regexp.MustCompile(`(?i)<form[^>]*>.*?</form>`),
|
||||
regexp.MustCompile(`(?i)<input[^>]*>`),
|
||||
regexp.MustCompile(`(?i)<button[^>]*>.*?</button>`),
|
||||
regexp.MustCompile(`(?i)javascript:`),
|
||||
regexp.MustCompile(`(?i)vbscript:`),
|
||||
regexp.MustCompile(`(?i)onload\s*=`),
|
||||
regexp.MustCompile(`(?i)onerror\s*=`),
|
||||
regexp.MustCompile(`(?i)onclick\s*=`),
|
||||
regexp.MustCompile(`(?i)onmouseover\s*=`),
|
||||
regexp.MustCompile(`(?i)onfocus\s*=`),
|
||||
regexp.MustCompile(`(?i)onblur\s*=`),
|
||||
}
|
||||
)
|
||||
|
||||
// SanitizeHTML 清理 HTML 内容,防止 XSS 攻击
|
||||
func SanitizeHTML(input string) string {
|
||||
if input == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// 检查输入长度
|
||||
if len(input) > 10000 {
|
||||
input = input[:10000]
|
||||
}
|
||||
|
||||
// 检查是否包含潜在的 XSS 攻击
|
||||
for _, pattern := range xssPatterns {
|
||||
if pattern.MatchString(input) {
|
||||
// 如果包含恶意内容,进行 HTML 转义
|
||||
return html.EscapeString(input)
|
||||
}
|
||||
}
|
||||
|
||||
// 如果内容相对安全,返回原内容
|
||||
return input
|
||||
}
|
||||
|
||||
// EscapeHTML 转义 HTML 特殊字符
|
||||
func EscapeHTML(input string) string {
|
||||
if input == "" {
|
||||
return ""
|
||||
}
|
||||
return html.EscapeString(input)
|
||||
}
|
||||
|
||||
// ValidateInput 验证用户输入
|
||||
func ValidateInput(input string) (string, bool) {
|
||||
if input == "" {
|
||||
return "", true
|
||||
}
|
||||
|
||||
// 检查长度
|
||||
if len(input) > 10000 {
|
||||
return "", false
|
||||
}
|
||||
|
||||
// 检查是否包含控制字符
|
||||
for _, r := range input {
|
||||
if r < 32 && r != 9 && r != 10 && r != 13 {
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
// 检查 UTF-8 有效性
|
||||
if !utf8.ValidString(input) {
|
||||
return "", false
|
||||
}
|
||||
|
||||
// 检查是否包含潜在的 XSS 攻击
|
||||
for _, pattern := range xssPatterns {
|
||||
if pattern.MatchString(input) {
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
return strings.TrimSpace(input), true
|
||||
}
|
||||
|
||||
// IsValidURL 验证 URL 是否安全
|
||||
func IsValidURL(url string) bool {
|
||||
if url == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查长度
|
||||
if len(url) > 2048 {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查协议
|
||||
if !strings.HasPrefix(strings.ToLower(url), "http://") &&
|
||||
!strings.HasPrefix(strings.ToLower(url), "https://") {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查是否包含恶意内容
|
||||
for _, pattern := range xssPatterns {
|
||||
if pattern.MatchString(url) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// IsValidImageURL 验证图片 URL 是否安全
|
||||
func IsValidImageURL(url string) bool {
|
||||
if !IsValidURL(url) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查是否为图片文件
|
||||
imageExtensions := []string{".jpg", ".jpeg", ".png", ".gif", ".webp", ".svg", ".bmp", ".ico"}
|
||||
lowerURL := strings.ToLower(url)
|
||||
|
||||
for _, ext := range imageExtensions {
|
||||
if strings.Contains(lowerURL, ext) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// CleanMarkdown 清理 Markdown 内容
|
||||
func CleanMarkdown(input string) string {
|
||||
if input == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// 移除潜在的恶意脚本
|
||||
cleaned := input
|
||||
for _, pattern := range xssPatterns {
|
||||
cleaned = pattern.ReplaceAllString(cleaned, "")
|
||||
}
|
||||
|
||||
return cleaned
|
||||
}
|
||||
|
||||
// SanitizeForDisplay 为显示清理内容
|
||||
func SanitizeForDisplay(input string) string {
|
||||
if input == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// 首先清理 Markdown
|
||||
cleaned := CleanMarkdown(input)
|
||||
|
||||
// 然后进行 HTML 转义
|
||||
escaped := html.EscapeString(cleaned)
|
||||
|
||||
return escaped
|
||||
}
|
||||
@@ -304,6 +304,19 @@ async def handle_list_tools() -> list[types.Tool]:
|
||||
),
|
||||
|
||||
# Knowledge Management
|
||||
types.Tool(
|
||||
name="create_knowledge_from_file",
|
||||
description="Create knowledge from a local file on the server filesystem",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"kb_id": {"type": "string", "description": "Knowledge base ID"},
|
||||
"file_path": {"type": "string", "description": "Absolute path to the local file on the server"},
|
||||
"enable_multimodel": {"type": "boolean", "description": "Enable multimodal processing", "default": True}
|
||||
},
|
||||
"required": ["kb_id", "file_path"]
|
||||
}
|
||||
),
|
||||
types.Tool(
|
||||
name="create_knowledge_from_url",
|
||||
description="Create knowledge from URL",
|
||||
@@ -537,6 +550,12 @@ async def handle_call_tool(
|
||||
result = client.hybrid_search(args["kb_id"], args["query"], config)
|
||||
|
||||
# Knowledge Management
|
||||
elif name == "create_knowledge_from_file":
|
||||
result = client.create_knowledge_from_file(
|
||||
args["kb_id"],
|
||||
args["file_path"],
|
||||
args.get("enable_multimodel", True)
|
||||
)
|
||||
elif name == "create_knowledge_from_url":
|
||||
result = client.create_knowledge_from_url(
|
||||
args["kb_id"],
|
||||
|
||||
@@ -51,6 +51,7 @@ CREATE TABLE knowledge_bases (
|
||||
vlm_model_id VARCHAR(64) NOT NULL,
|
||||
cos_config JSON NOT NULL,
|
||||
vlm_config JSON NOT NULL,
|
||||
extract_config JSON NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||
deleted_at TIMESTAMP NULL DEFAULT NULL
|
||||
|
||||
@@ -62,6 +62,7 @@ CREATE TABLE IF NOT EXISTS knowledge_bases (
|
||||
vlm_model_id VARCHAR(64) NOT NULL,
|
||||
cos_config JSONB NOT NULL DEFAULT '{}',
|
||||
vlm_config JSONB NOT NULL DEFAULT '{}',
|
||||
extract_config JSONB NULL DEFAULT NULL,
|
||||
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
|
||||
deleted_at TIMESTAMP WITH TIME ZONE
|
||||
|
||||
@@ -78,13 +78,49 @@ check_platform() {
|
||||
log_info "检测系统平台信息..."
|
||||
if [ "$(uname -m)" = "x86_64" ]; then
|
||||
export PLATFORM="linux/amd64"
|
||||
export TARGETARCH="amd64"
|
||||
elif [ "$(uname -m)" = "aarch64" ] || [ "$(uname -m)" = "arm64" ]; then
|
||||
export PLATFORM="linux/arm64"
|
||||
export TARGETARCH="arm64"
|
||||
else
|
||||
log_warning "未识别的平台类型:$(uname -m),将使用默认平台 linux/amd64"
|
||||
export PLATFORM="linux/amd64"
|
||||
export TARGETARCH="amd64"
|
||||
fi
|
||||
log_info "当前平台:$PLATFORM"
|
||||
log_info "当前架构:$TARGETARCH"
|
||||
}
|
||||
|
||||
# 获取版本信息
|
||||
get_version_info() {
|
||||
# 从VERSION文件获取版本号
|
||||
if [ -f "VERSION" ]; then
|
||||
VERSION=$(cat VERSION | tr -d '\n\r')
|
||||
else
|
||||
VERSION="unknown"
|
||||
fi
|
||||
|
||||
# 获取commit ID
|
||||
if command -v git >/dev/null 2>&1; then
|
||||
COMMIT_ID=$(git rev-parse --short HEAD 2>/dev/null || echo "unknown")
|
||||
else
|
||||
COMMIT_ID="unknown"
|
||||
fi
|
||||
|
||||
# 获取构建时间
|
||||
BUILD_TIME=$(date -u '+%Y-%m-%d %H:%M:%S UTC')
|
||||
|
||||
# 获取Go版本
|
||||
if command -v go >/dev/null 2>&1; then
|
||||
GO_VERSION=$(go version 2>/dev/null || echo "unknown")
|
||||
else
|
||||
GO_VERSION="unknown"
|
||||
fi
|
||||
|
||||
log_info "版本信息: $VERSION"
|
||||
log_info "Commit ID: $COMMIT_ID"
|
||||
log_info "构建时间: $BUILD_TIME"
|
||||
log_info "Go版本: $GO_VERSION"
|
||||
}
|
||||
|
||||
# 构建应用镜像
|
||||
@@ -93,11 +129,18 @@ build_app_image() {
|
||||
|
||||
cd "$PROJECT_ROOT"
|
||||
|
||||
# 获取版本信息
|
||||
get_version_info
|
||||
|
||||
docker build \
|
||||
--platform $PLATFORM \
|
||||
--build-arg GOPRIVATE_ARG=${GOPRIVATE:-""} \
|
||||
--build-arg GOPROXY_ARG=${GOPROXY:-"https://goproxy.cn,direct"} \
|
||||
--build-arg GOSUMDB_ARG=${GOSUMDB:-"off"} \
|
||||
--build-arg VERSION_ARG="$VERSION" \
|
||||
--build-arg COMMIT_ID_ARG="$COMMIT_ID" \
|
||||
--build-arg BUILD_TIME_ARG="$BUILD_TIME" \
|
||||
--build-arg GO_VERSION_ARG="$GO_VERSION" \
|
||||
-f docker/Dockerfile.app \
|
||||
-t wechatopenai/weknora-app:latest \
|
||||
.
|
||||
@@ -120,6 +163,7 @@ build_docreader_image() {
|
||||
docker build \
|
||||
--platform $PLATFORM \
|
||||
--build-arg PLATFORM=$PLATFORM \
|
||||
--build-arg TARGETARCH=$TARGETARCH \
|
||||
-f docker/Dockerfile.docreader \
|
||||
-t wechatopenai/weknora-docreader:latest \
|
||||
.
|
||||
|
||||
86
scripts/get_version.sh
Executable file
86
scripts/get_version.sh
Executable file
@@ -0,0 +1,86 @@
|
||||
#!/bin/bash
|
||||
# 统一的版本信息获取脚本
|
||||
# 支持本地构建和CI构建环境
|
||||
|
||||
# 设置默认值
|
||||
VERSION="unknown"
|
||||
COMMIT_ID="unknown"
|
||||
BUILD_TIME="unknown"
|
||||
GO_VERSION="unknown"
|
||||
|
||||
# 获取版本号
|
||||
if [ -f "VERSION" ]; then
|
||||
VERSION=$(cat VERSION | tr -d '\n\r')
|
||||
fi
|
||||
|
||||
# 获取commit ID
|
||||
if [ -n "$GITHUB_SHA" ]; then
|
||||
# GitHub Actions环境
|
||||
COMMIT_ID="${GITHUB_SHA:0:7}"
|
||||
elif command -v git >/dev/null 2>&1; then
|
||||
# 本地环境
|
||||
COMMIT_ID=$(git rev-parse --short HEAD 2>/dev/null || echo "unknown")
|
||||
fi
|
||||
|
||||
# 获取构建时间
|
||||
if [ -n "$GITHUB_ACTIONS" ]; then
|
||||
# GitHub Actions环境,使用标准时间格式
|
||||
BUILD_TIME=$(date -u '+%Y-%m-%d %H:%M:%S UTC')
|
||||
else
|
||||
# 本地环境
|
||||
BUILD_TIME=$(date -u '+%Y-%m-%d %H:%M:%S UTC')
|
||||
fi
|
||||
|
||||
# 获取Go版本
|
||||
if command -v go >/dev/null 2>&1; then
|
||||
GO_VERSION=$(go version 2>/dev/null || echo "unknown")
|
||||
fi
|
||||
|
||||
# 根据参数输出不同格式
|
||||
case "${1:-env}" in
|
||||
"env")
|
||||
# 输出环境变量格式,对包含空格的值进行转义
|
||||
echo "VERSION=$VERSION"
|
||||
echo "COMMIT_ID=$COMMIT_ID"
|
||||
echo "BUILD_TIME=\"$BUILD_TIME\""
|
||||
echo "GO_VERSION=\"$GO_VERSION\""
|
||||
;;
|
||||
"json")
|
||||
# 输出JSON格式
|
||||
cat << EOF
|
||||
{
|
||||
"version": "$VERSION",
|
||||
"commit_id": "$COMMIT_ID",
|
||||
"build_time": "$BUILD_TIME",
|
||||
"go_version": "$GO_VERSION"
|
||||
}
|
||||
EOF
|
||||
;;
|
||||
"docker-args")
|
||||
# 输出Docker构建参数格式
|
||||
echo "--build-arg VERSION_ARG=$VERSION"
|
||||
echo "--build-arg COMMIT_ID_ARG=$COMMIT_ID"
|
||||
echo "--build-arg BUILD_TIME_ARG=$BUILD_TIME"
|
||||
echo "--build-arg GO_VERSION_ARG=$GO_VERSION"
|
||||
;;
|
||||
"ldflags")
|
||||
# 输出Go ldflags格式
|
||||
echo "-X 'github.com/Tencent/WeKnora/internal/handler.Version=$VERSION' -X 'github.com/Tencent/WeKnora/internal/handler.CommitID=$COMMIT_ID' -X 'github.com/Tencent/WeKnora/internal/handler.BuildTime=$BUILD_TIME' -X 'github.com/Tencent/WeKnora/internal/handler.GoVersion=$GO_VERSION'"
|
||||
;;
|
||||
"info")
|
||||
# 输出信息格式
|
||||
echo "版本信息: $VERSION"
|
||||
echo "Commit ID: $COMMIT_ID"
|
||||
echo "构建时间: $BUILD_TIME"
|
||||
echo "Go版本: $GO_VERSION"
|
||||
;;
|
||||
*)
|
||||
echo "用法: $0 [env|json|docker-args|ldflags|info]"
|
||||
echo " env - 输出环境变量格式 (默认)"
|
||||
echo " json - 输出JSON格式"
|
||||
echo " docker-args - 输出Docker构建参数格式"
|
||||
echo " ldflags - 输出Go ldflags格式"
|
||||
echo " info - 输出信息格式"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
@@ -18,8 +18,8 @@ SCRIPT_NAME=$(basename "$0")
|
||||
|
||||
# 显示帮助信息
|
||||
show_help() {
|
||||
echo -e "${GREEN}WeKnora 启动脚本 v${VERSION}${NC}"
|
||||
echo -e "${GREEN}用法:${NC} $0 [选项]"
|
||||
printf "%b\n" "${GREEN}WeKnora 启动脚本 v${VERSION}${NC}"
|
||||
printf "%b\n" "${GREEN}用法:${NC} $0 [选项]"
|
||||
echo "选项:"
|
||||
echo " -h, --help 显示帮助信息"
|
||||
echo " -o, --ollama 启动Ollama服务"
|
||||
@@ -37,25 +37,25 @@ show_help() {
|
||||
|
||||
# 显示版本信息
|
||||
show_version() {
|
||||
echo -e "${GREEN}WeKnora 启动脚本 v${VERSION}${NC}"
|
||||
printf "%b\n" "${GREEN}WeKnora 启动脚本 v${VERSION}${NC}"
|
||||
exit 0
|
||||
}
|
||||
|
||||
# 日志函数
|
||||
log_info() {
|
||||
echo -e "${BLUE}[INFO]${NC} $1"
|
||||
printf "%b\n" "${BLUE}[INFO]${NC} $1"
|
||||
}
|
||||
|
||||
log_warning() {
|
||||
echo -e "${YELLOW}[WARNING]${NC} $1"
|
||||
printf "%b\n" "${YELLOW}[WARNING]${NC} $1"
|
||||
}
|
||||
|
||||
log_error() {
|
||||
echo -e "${RED}[ERROR]${NC} $1"
|
||||
printf "%b\n" "${RED}[ERROR]${NC} $1"
|
||||
}
|
||||
|
||||
log_success() {
|
||||
echo -e "${GREEN}[SUCCESS]${NC} $1"
|
||||
printf "%b\n" "${GREEN}[SUCCESS]${NC} $1"
|
||||
}
|
||||
|
||||
# 选择可用的 Docker Compose 命令(优先 docker compose,其次 docker-compose)
|
||||
@@ -397,7 +397,7 @@ list_containers() {
|
||||
cd "$PROJECT_ROOT"
|
||||
|
||||
# 列出所有容器
|
||||
echo -e "${BLUE}当前正在运行的容器:${NC}"
|
||||
printf "%b\n" "${BLUE}当前正在运行的容器:${NC}"
|
||||
"$DOCKER_COMPOSE_BIN" $DOCKER_COMPOSE_SUBCMD ps --services | sort
|
||||
|
||||
return 0
|
||||
@@ -703,20 +703,26 @@ else
|
||||
if [ "$START_OLLAMA" = true ] && [ "$START_DOCKER" = true ]; then
|
||||
if [ $OLLAMA_RESULT -eq 0 ] && [ $DOCKER_RESULT -eq 0 ]; then
|
||||
log_success "所有服务启动完成,可通过以下地址访问:"
|
||||
echo -e "${GREEN} - 前端界面: http://localhost:${FRONTEND_PORT:-80}${NC}"
|
||||
echo -e "${GREEN} - API接口: http://localhost:${APP_PORT:-8080}${NC}"
|
||||
echo -e "${GREEN} - Jaeger链路追踪: http://localhost:16686${NC}"
|
||||
printf "%b\n" "${GREEN} - 前端界面: http://localhost:${FRONTEND_PORT:-80}${NC}"
|
||||
printf "%b\n" "${GREEN} - API接口: http://localhost:${APP_PORT:-8080}${NC}"
|
||||
printf "%b\n" "${GREEN} - Jaeger链路追踪: http://localhost:16686${NC}"
|
||||
echo ""
|
||||
log_info "正在持续输出容器日志(按 Ctrl+C 退出日志,容器不会停止)..."
|
||||
"$DOCKER_COMPOSE_BIN" $DOCKER_COMPOSE_SUBCMD logs app docreader postgres --since=10s -f
|
||||
else
|
||||
log_error "部分服务启动失败,请检查日志并修复问题"
|
||||
fi
|
||||
elif [ "$START_OLLAMA" = true ] && [ $OLLAMA_RESULT -eq 0 ]; then
|
||||
log_success "Ollama服务启动完成,可通过以下地址访问:"
|
||||
echo -e "${GREEN} - Ollama API: http://localhost:$OLLAMA_PORT${NC}"
|
||||
printf "%b\n" "${GREEN} - Ollama API: http://localhost:$OLLAMA_PORT${NC}"
|
||||
elif [ "$START_DOCKER" = true ] && [ $DOCKER_RESULT -eq 0 ]; then
|
||||
log_success "Docker容器启动完成,可通过以下地址访问:"
|
||||
echo -e "${GREEN} - 前端界面: http://localhost:${FRONTEND_PORT:-80}${NC}"
|
||||
echo -e "${GREEN} - API接口: http://localhost:${APP_PORT:-8080}${NC}"
|
||||
echo -e "${GREEN} - Jaeger链路追踪: http://localhost:16686${NC}"
|
||||
printf "%b\n" "${GREEN} - 前端界面: http://localhost:${FRONTEND_PORT:-80}${NC}"
|
||||
printf "%b\n" "${GREEN} - API接口: http://localhost:${APP_PORT:-8080}${NC}"
|
||||
printf "%b\n" "${GREEN} - Jaeger链路追踪: http://localhost:16686${NC}"
|
||||
echo ""
|
||||
log_info "正在持续输出容器日志(按 Ctrl+C 退出日志,容器不会停止)..."
|
||||
"$DOCKER_COMPOSE_BIN" $DOCKER_COMPOSE_SUBCMD logs app docreader postgres --since=10s -f
|
||||
fi
|
||||
fi
|
||||
|
||||
|
||||
@@ -40,6 +40,31 @@ class PaddleOCRBackend(OCRBackend):
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = ''
|
||||
paddle.set_device('cpu')
|
||||
|
||||
# 尝试检测CPU是否支持AVX指令集
|
||||
try:
|
||||
import subprocess
|
||||
import platform
|
||||
|
||||
# 检测CPU是否支持AVX
|
||||
if platform.system() == "Linux":
|
||||
try:
|
||||
result = subprocess.run(['grep', '-o', 'avx', '/proc/cpuinfo'],
|
||||
capture_output=True, text=True, timeout=5)
|
||||
has_avx = 'avx' in result.stdout.lower()
|
||||
if not has_avx:
|
||||
logger.warning("CPU does not support AVX instructions, using compatibility mode")
|
||||
# 进一步限制指令集使用
|
||||
os.environ['FLAGS_use_avx2'] = '0'
|
||||
os.environ['FLAGS_use_avx'] = '1'
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError, subprocess.SubprocessError):
|
||||
logger.warning("Could not detect AVX support, using compatibility mode")
|
||||
os.environ['FLAGS_use_avx2'] = '0'
|
||||
os.environ['FLAGS_use_avx'] = '1'
|
||||
except Exception as e:
|
||||
logger.warning(f"Error detecting CPU capabilities: {e}, using compatibility mode")
|
||||
os.environ['FLAGS_use_avx2'] = '0'
|
||||
os.environ['FLAGS_use_avx'] = '1'
|
||||
|
||||
from paddleocr import PaddleOCR
|
||||
# OCR configuration with text orientation classification enabled
|
||||
ocr_config = {
|
||||
@@ -67,6 +92,13 @@ class PaddleOCRBackend(OCRBackend):
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import paddleocr: {str(e)}. Please install it with 'pip install paddleocr'")
|
||||
except OSError as e:
|
||||
if "Illegal instruction" in str(e) or "core dumped" in str(e):
|
||||
logger.error(f"PaddlePaddle crashed due to CPU instruction set incompatibility: {str(e)}")
|
||||
logger.error("This usually happens when the CPU doesn't support AVX instructions.")
|
||||
logger.error("Please try installing a CPU-only version of PaddlePaddle or use a different OCR backend.")
|
||||
else:
|
||||
logger.error(f"Failed to initialize PaddleOCR due to OS error: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize PaddleOCR: {str(e)}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user