mirror of
https://github.com/Tencent/WeKnora.git
synced 2025-11-25 19:37:45 +08:00
feat: Add Login Page
This commit is contained in:
4
frontend/package-lock.json
generated
4
frontend/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "knowledage-base",
|
||||
"version": "0.0.0",
|
||||
"version": "0.1.0",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "knowledage-base",
|
||||
"version": "0.0.0",
|
||||
"version": "0.1.0",
|
||||
"dependencies": {
|
||||
"@microsoft/fetch-event-source": "^2.0.1",
|
||||
"axios": "^1.8.4",
|
||||
|
||||
234
frontend/src/api/auth/index.ts
Normal file
234
frontend/src/api/auth/index.ts
Normal file
@@ -0,0 +1,234 @@
|
||||
import { post, get, put } from '@/utils/request'
|
||||
|
||||
// 用户登录接口
|
||||
export interface LoginRequest {
|
||||
email: string
|
||||
password: string
|
||||
}
|
||||
|
||||
export interface LoginResponse {
|
||||
success: boolean
|
||||
message?: string
|
||||
user?: {
|
||||
id: string
|
||||
username: string
|
||||
email: string
|
||||
avatar?: string
|
||||
tenant_id: number
|
||||
is_active: boolean
|
||||
created_at: string
|
||||
updated_at: string
|
||||
}
|
||||
tenant?: {
|
||||
id: number
|
||||
name: string
|
||||
description: string
|
||||
api_key: string
|
||||
status: string
|
||||
business: string
|
||||
storage_quota: number
|
||||
storage_used: number
|
||||
created_at: string
|
||||
updated_at: string
|
||||
}
|
||||
token?: string
|
||||
refresh_token?: string
|
||||
}
|
||||
|
||||
// 用户注册接口
|
||||
export interface RegisterRequest {
|
||||
username: string
|
||||
email: string
|
||||
password: string
|
||||
}
|
||||
|
||||
export interface RegisterResponse {
|
||||
success: boolean
|
||||
message?: string
|
||||
data?: {
|
||||
user: {
|
||||
id: string
|
||||
username: string
|
||||
email: string
|
||||
}
|
||||
tenant: {
|
||||
id: string
|
||||
name: string
|
||||
api_key: string
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 用户信息接口
|
||||
export interface UserInfo {
|
||||
id: string
|
||||
username: string
|
||||
email: string
|
||||
avatar?: string
|
||||
tenant_id: string
|
||||
created_at: string
|
||||
updated_at: string
|
||||
}
|
||||
|
||||
// 租户信息接口
|
||||
export interface TenantInfo {
|
||||
id: string
|
||||
name: string
|
||||
api_key: string
|
||||
owner_id: string
|
||||
created_at: string
|
||||
updated_at: string
|
||||
knowledge_bases?: KnowledgeBaseInfo[]
|
||||
}
|
||||
|
||||
// 知识库信息接口
|
||||
export interface KnowledgeBaseInfo {
|
||||
id: string
|
||||
name: string
|
||||
description: string
|
||||
tenant_id: string
|
||||
created_at: string
|
||||
updated_at: string
|
||||
document_count?: number
|
||||
chunk_count?: number
|
||||
}
|
||||
|
||||
// 模型信息接口
|
||||
export interface ModelInfo {
|
||||
id: string
|
||||
name: string
|
||||
type: string
|
||||
source: string
|
||||
description?: string
|
||||
is_default?: boolean
|
||||
created_at: string
|
||||
updated_at: string
|
||||
}
|
||||
|
||||
/**
|
||||
* 用户登录
|
||||
*/
|
||||
export async function login(data: LoginRequest): Promise<LoginResponse> {
|
||||
try {
|
||||
const response = await post('/api/v1/auth/login', data)
|
||||
return response as unknown as LoginResponse
|
||||
} catch (error: any) {
|
||||
return {
|
||||
success: false,
|
||||
message: error.message || '登录失败'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 用户注册
|
||||
*/
|
||||
export async function register(data: RegisterRequest): Promise<RegisterResponse> {
|
||||
try {
|
||||
const response = await post('/api/v1/auth/register', data)
|
||||
return response as unknown as RegisterResponse
|
||||
} catch (error: any) {
|
||||
return {
|
||||
success: false,
|
||||
message: error.message || '注册失败'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取当前用户信息
|
||||
*/
|
||||
export async function getCurrentUser(): Promise<{ success: boolean; data?: UserInfo; message?: string }> {
|
||||
try {
|
||||
const response = await get('/api/v1/auth/me')
|
||||
return response as unknown as { success: boolean; data?: UserInfo; message?: string }
|
||||
} catch (error: any) {
|
||||
return {
|
||||
success: false,
|
||||
message: error.message || '获取用户信息失败'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取当前租户信息
|
||||
*/
|
||||
export async function getCurrentTenant(): Promise<{ success: boolean; data?: TenantInfo; message?: string }> {
|
||||
try {
|
||||
const response = await get('/api/v1/auth/tenant')
|
||||
return response as unknown as { success: boolean; data?: TenantInfo; message?: string }
|
||||
} catch (error: any) {
|
||||
return {
|
||||
success: false,
|
||||
message: error.message || '获取租户信息失败'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 刷新Token
|
||||
*/
|
||||
export async function refreshToken(refreshToken: string): Promise<{ success: boolean; data?: { token: string; refreshToken: string }; message?: string }> {
|
||||
try {
|
||||
const response: any = await post('/api/v1/auth/refresh', { refreshToken })
|
||||
if (response && response.success) {
|
||||
if (response.access_token || response.refresh_token) {
|
||||
return {
|
||||
success: true,
|
||||
data: {
|
||||
token: response.access_token,
|
||||
refreshToken: response.refresh_token,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 其他情况直接返回原始消息
|
||||
return {
|
||||
success: false,
|
||||
message: response?.message || '刷新Token失败'
|
||||
}
|
||||
} catch (error: any) {
|
||||
return {
|
||||
success: false,
|
||||
message: error.message || '刷新Token失败'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 用户登出
|
||||
*/
|
||||
export async function logout(): Promise<{ success: boolean; message?: string }> {
|
||||
try {
|
||||
await post('/api/v1/auth/logout', {})
|
||||
return {
|
||||
success: true
|
||||
}
|
||||
} catch (error: any) {
|
||||
return {
|
||||
success: false,
|
||||
message: error.message || '登出失败'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 验证Token有效性
|
||||
*/
|
||||
export async function validateToken(): Promise<{ success: boolean; valid?: boolean; message?: string }> {
|
||||
try {
|
||||
const response = await get('/api/v1/auth/validate')
|
||||
return response as unknown as { success: boolean; valid?: boolean; message?: string }
|
||||
} catch (error: any) {
|
||||
return {
|
||||
success: false,
|
||||
valid: false,
|
||||
message: error.message || 'Token验证失败'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,54 +1,30 @@
|
||||
import { get, post, put, del, postChat } from "../../utils/request";
|
||||
import { loadTestData } from "../test-data";
|
||||
|
||||
// 从localStorage获取设置
|
||||
function getSettings() {
|
||||
const settingsStr = localStorage.getItem("WeKnora_settings");
|
||||
if (settingsStr) {
|
||||
try {
|
||||
const settings = JSON.parse(settingsStr);
|
||||
if (settings.apiKey && settings.endpoint) {
|
||||
return settings;
|
||||
}
|
||||
} catch (e) {
|
||||
console.error("解析设置失败:", e);
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
// 根据是否有设置决定是否需要加载测试数据
|
||||
async function ensureConfigured() {
|
||||
const settings = getSettings();
|
||||
// 如果没有设置APIKey和Endpoint,则加载测试数据
|
||||
if (!settings) {
|
||||
await loadTestData();
|
||||
}
|
||||
}
|
||||
|
||||
export async function createSessions(data = {}) {
|
||||
await ensureConfigured();
|
||||
await loadTestData();
|
||||
return post("/api/v1/sessions", data);
|
||||
}
|
||||
|
||||
export async function getSessionsList(page: number, page_size: number) {
|
||||
await ensureConfigured();
|
||||
await loadTestData();
|
||||
return get(`/api/v1/sessions?page=${page}&page_size=${page_size}`);
|
||||
}
|
||||
|
||||
export async function generateSessionsTitle(session_id: string, data: any) {
|
||||
await ensureConfigured();
|
||||
await loadTestData();
|
||||
return post(`/api/v1/sessions/${session_id}/generate_title`, data);
|
||||
}
|
||||
|
||||
export async function knowledgeChat(data: { session_id: string; query: string; }) {
|
||||
await ensureConfigured();
|
||||
await loadTestData();
|
||||
return postChat(`/api/v1/knowledge-chat/${data.session_id}`, { query: data.query });
|
||||
}
|
||||
|
||||
export async function getMessageList(data: { session_id: string; limit: number, created_at: string }) {
|
||||
await ensureConfigured();
|
||||
|
||||
await loadTestData();
|
||||
if (data.created_at) {
|
||||
return get(`/api/v1/messages/${data.session_id}/load?before_time=${encodeURIComponent(data.created_at)}&limit=${data.limit}`);
|
||||
} else {
|
||||
@@ -57,6 +33,6 @@ export async function getMessageList(data: { session_id: string; limit: number,
|
||||
}
|
||||
|
||||
export async function delSession(session_id: string) {
|
||||
await ensureConfigured();
|
||||
await loadTestData();
|
||||
return del(`/api/v1/sessions/${session_id}`);
|
||||
}
|
||||
@@ -2,21 +2,9 @@ import { fetchEventSource } from '@microsoft/fetch-event-source'
|
||||
import { ref, type Ref, onUnmounted, nextTick } from 'vue'
|
||||
import { generateRandomString } from '@/utils/index';
|
||||
import { getTestData } from '@/utils/request';
|
||||
import { loadTestData } from '@/api/test-data';
|
||||
import { loadTestData } from "../test-data";
|
||||
|
||||
|
||||
// 从localStorage获取设置
|
||||
function getSettings() {
|
||||
const settingsStr = localStorage.getItem("WeKnora_settings");
|
||||
if (settingsStr) {
|
||||
try {
|
||||
const settings = JSON.parse(settingsStr);
|
||||
return settings;
|
||||
} catch (e) {
|
||||
console.error("解析设置失败:", e);
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
interface StreamOptions {
|
||||
// 请求方法 (默认POST)
|
||||
@@ -49,17 +37,7 @@ export function useStream() {
|
||||
isStreaming.value = true;
|
||||
isLoading.value = true;
|
||||
|
||||
// 获取设置信息
|
||||
const settings = getSettings();
|
||||
let apiUrl = '';
|
||||
let apiKey = '';
|
||||
|
||||
// 如果有设置信息,优先使用设置信息
|
||||
if (settings && settings.endpoint && settings.apiKey) {
|
||||
apiUrl = settings.endpoint;
|
||||
apiKey = settings.apiKey;
|
||||
} else {
|
||||
// 否则加载测试数据
|
||||
// 使用默认配置
|
||||
await loadTestData();
|
||||
const testData = getTestData();
|
||||
if (!testData) {
|
||||
@@ -67,9 +45,8 @@ export function useStream() {
|
||||
stopStream();
|
||||
return;
|
||||
}
|
||||
apiUrl = import.meta.env.VITE_IS_DOCKER ? "" : "http://localhost:8080";
|
||||
apiKey = testData.tenant.api_key;
|
||||
}
|
||||
const apiUrl = import.meta.env.VITE_IS_DOCKER ? "" : "http://localhost:8080";
|
||||
const apiKey = testData.tenant.api_key;
|
||||
|
||||
try {
|
||||
let url =
|
||||
|
||||
@@ -70,6 +70,11 @@ export function checkInitializationStatus(): Promise<{ initialized: boolean }> {
|
||||
resolve(response.data || { initialized: false });
|
||||
})
|
||||
.catch((error: any) => {
|
||||
// 如果是401,交给全局拦截器去处理(重定向登录),这里不要把它当成未初始化
|
||||
if (error && error.status === 401) {
|
||||
reject(error);
|
||||
return;
|
||||
}
|
||||
console.warn('检查初始化状态失败,假设需要初始化:', error);
|
||||
resolve({ initialized: false });
|
||||
});
|
||||
|
||||
@@ -1,41 +1,23 @@
|
||||
import { get, post, put, del, postUpload, getDown, getTestData } from "../../utils/request";
|
||||
import { loadTestData } from "../test-data";
|
||||
|
||||
// 获取知识库ID(优先从设置中获取)
|
||||
async function getKnowledgeBaseID() {
|
||||
// 从localStorage获取设置中的知识库ID
|
||||
const settingsStr = localStorage.getItem("WeKnora_settings");
|
||||
let knowledgeBaseId = "";
|
||||
|
||||
if (settingsStr) {
|
||||
try {
|
||||
const settings = JSON.parse(settingsStr);
|
||||
if (settings.knowledgeBaseId) {
|
||||
return settings.knowledgeBaseId;
|
||||
}
|
||||
} catch (e) {
|
||||
console.error("解析设置失败:", e);
|
||||
}
|
||||
}
|
||||
|
||||
export async function getDefaultKnowledgeBaseId(): Promise<string> {
|
||||
// 如果设置中没有知识库ID,则使用测试数据
|
||||
await loadTestData();
|
||||
|
||||
const testData = getTestData();
|
||||
if (!testData || testData.knowledge_bases.length === 0) {
|
||||
console.error("测试数据未初始化或不包含知识库");
|
||||
throw new Error("测试数据未初始化或不包含知识库");
|
||||
throw new Error('没有可用的知识库');
|
||||
}
|
||||
|
||||
return testData.knowledge_bases[0].id;
|
||||
}
|
||||
|
||||
export async function uploadKnowledgeBase(data = {}) {
|
||||
const kbId = await getKnowledgeBaseID();
|
||||
const kbId = await getDefaultKnowledgeBaseId();
|
||||
return postUpload(`/api/v1/knowledge-bases/${kbId}/knowledge/file`, data);
|
||||
}
|
||||
|
||||
export async function getKnowledgeBase({page, page_size}) {
|
||||
const kbId = await getKnowledgeBaseID();
|
||||
export async function getKnowledgeBase({page, page_size}: {page: number, page_size: number}) {
|
||||
const kbId = await getDefaultKnowledgeBaseId();
|
||||
return get(
|
||||
`/api/v1/knowledge-bases/${kbId}/knowledge?page=${page}&page_size=${page_size}`
|
||||
);
|
||||
@@ -57,6 +39,6 @@ export function batchQueryKnowledge(ids: any) {
|
||||
return get(`/api/v1/knowledge/batch?${ids}`);
|
||||
}
|
||||
|
||||
export function getKnowledgeDetailsCon(id: any, page) {
|
||||
export function getKnowledgeDetailsCon(id: any, page: number) {
|
||||
return get(`/api/v1/chunks/${id}?page=${page}&page_size=25`);
|
||||
}
|
||||
@@ -53,3 +53,12 @@ export async function loadTestData(): Promise<boolean> {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 重置测试数据加载状态,在重新登录或需要强制刷新时调用
|
||||
*/
|
||||
export function resetTestDataLoaded() {
|
||||
isTestDataLoaded = false;
|
||||
// 清空已缓存的测试数据,确保下次调用会重新获取
|
||||
setTestData(null);
|
||||
}
|
||||
|
||||
6
frontend/src/assets/img/logout.svg
Normal file
6
frontend/src/assets/img/logout.svg
Normal file
@@ -0,0 +1,6 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24" fill="none">
|
||||
<path d="M10 3H6a2 2 0 0 0-2 2v14a2 2 0 0 0 2 2h4" stroke="#000" stroke-opacity="0.6" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M17 16l4-4-4-4" stroke="#000" stroke-opacity="0.6" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M21 12H10" stroke="#000" stroke-opacity="0.6" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
</svg>
|
||||
|
||||
|
After Width: | Height: | Size: 509 B |
@@ -9,7 +9,7 @@
|
||||
:class="['menu_item', item.childrenPath && item.childrenPath == currentpath ? 'menu_item_c_active' : item.path == currentpath ? 'menu_item_active' : '']">
|
||||
<div class="menu_item-box">
|
||||
<div class="menu_icon">
|
||||
<img class="icon" :src="getImgSrc(item.icon == 'zhishiku' ? knowledgeIcon : item.icon == 'setting' ? settingIcon : prefixIcon)" alt="">
|
||||
<img class="icon" :src="getImgSrc(item.icon == 'zhishiku' ? knowledgeIcon : item.icon == 'setting' ? settingIcon : item.icon == 'logout' ? logoutIcon : prefixIcon)" alt="">
|
||||
</div>
|
||||
<span class="menu_title">{{ item.title }}</span>
|
||||
</div>
|
||||
@@ -58,11 +58,13 @@ import { onMounted, watch, computed, ref, reactive } from 'vue';
|
||||
import { useRoute, useRouter } from 'vue-router';
|
||||
import { getSessionsList, delSession } from "@/api/chat/index";
|
||||
import { useMenuStore } from '@/stores/menu';
|
||||
import { useAuthStore } from '@/stores/auth';
|
||||
import useKnowledgeBase from '@/hooks/useKnowledgeBase';
|
||||
import { MessagePlugin } from "tdesign-vue-next";
|
||||
let { requestMethod } = useKnowledgeBase()
|
||||
let uploadInput = ref();
|
||||
const usemenuStore = useMenuStore();
|
||||
const authStore = useAuthStore();
|
||||
const route = useRoute();
|
||||
const router = useRouter();
|
||||
const currentpath = ref('');
|
||||
@@ -164,12 +166,14 @@ let fileAddIcon = ref('file-add-green.svg');
|
||||
let knowledgeIcon = ref('zhishiku-green.svg');
|
||||
let prefixIcon = ref('prefixIcon.svg');
|
||||
let settingIcon = ref('setting.svg');
|
||||
let logoutIcon = ref('logout.svg');
|
||||
let pathPrefix = ref(route.name)
|
||||
const getIcon = (path) => {
|
||||
fileAddIcon.value = path == 'knowledgeBase' ? 'file-add-green.svg' : 'file-add.svg';
|
||||
knowledgeIcon.value = path == 'knowledgeBase' ? 'zhishiku-green.svg' : 'zhishiku.svg';
|
||||
prefixIcon.value = path == 'creatChat' ? 'prefixIcon-green.svg' : path == 'knowledgeBase' ? 'prefixIcon-grey.svg' : 'prefixIcon.svg';
|
||||
settingIcon.value = path == 'settings' ? 'setting-green.svg' : 'setting.svg';
|
||||
logoutIcon.value = 'logout.svg';
|
||||
}
|
||||
getIcon(route.name)
|
||||
const gotopage = (path) => {
|
||||
@@ -177,6 +181,13 @@ const gotopage = (path) => {
|
||||
// 如果是系统设置,跳转到初始化配置页面
|
||||
if (path === 'settings') {
|
||||
router.push('/initialization');
|
||||
return;
|
||||
}
|
||||
// 处理退出登录
|
||||
if (path === 'logout') {
|
||||
authStore.logout();
|
||||
router.push('/login');
|
||||
return;
|
||||
} else {
|
||||
router.push(`/platform/${path}`);
|
||||
}
|
||||
|
||||
@@ -1,12 +1,20 @@
|
||||
import { createRouter, createWebHistory } from 'vue-router'
|
||||
import { checkInitializationStatus } from '@/api/initialization'
|
||||
import { useAuthStore } from '@/stores/auth'
|
||||
import { validateToken } from '@/api/auth'
|
||||
|
||||
const router = createRouter({
|
||||
history: createWebHistory(import.meta.env.BASE_URL),
|
||||
routes: [
|
||||
{
|
||||
path: "/",
|
||||
redirect: "/platform",
|
||||
redirect: "/platform/knowledgeBase",
|
||||
},
|
||||
{
|
||||
path: "/login",
|
||||
name: "login",
|
||||
component: () => import("../views/auth/Login.vue"),
|
||||
meta: { requiresAuth: false, requiresInit: false }
|
||||
},
|
||||
{
|
||||
path: "/initialization",
|
||||
@@ -18,32 +26,32 @@ const router = createRouter({
|
||||
path: "/knowledgeBase",
|
||||
name: "home",
|
||||
component: () => import("../views/knowledge/KnowledgeBase.vue"),
|
||||
meta: { requiresInit: true }
|
||||
meta: { requiresInit: true, requiresAuth: true }
|
||||
},
|
||||
{
|
||||
path: "/platform",
|
||||
name: "Platform",
|
||||
redirect: "/platform/knowledgeBase",
|
||||
component: () => import("../views/platform/index.vue"),
|
||||
meta: { requiresInit: true },
|
||||
meta: { requiresInit: true, requiresAuth: true },
|
||||
children: [
|
||||
{
|
||||
path: "knowledgeBase",
|
||||
name: "knowledgeBase",
|
||||
component: () => import("../views/knowledge/KnowledgeBase.vue"),
|
||||
meta: { requiresInit: true }
|
||||
meta: { requiresInit: true, requiresAuth: true }
|
||||
},
|
||||
{
|
||||
path: "creatChat",
|
||||
name: "creatChat",
|
||||
component: () => import("../views/creatChat/creatChat.vue"),
|
||||
meta: { requiresInit: true }
|
||||
meta: { requiresInit: true, requiresAuth: true }
|
||||
},
|
||||
{
|
||||
path: "chat/:chatid",
|
||||
name: "chat",
|
||||
component: () => import("../views/chat/index.vue"),
|
||||
meta: { requiresInit: true }
|
||||
meta: { requiresInit: true, requiresAuth: true }
|
||||
},
|
||||
{
|
||||
path: "settings",
|
||||
@@ -56,33 +64,72 @@ const router = createRouter({
|
||||
],
|
||||
});
|
||||
|
||||
// 路由守卫:检查系统初始化状态
|
||||
// 路由守卫:检查认证状态和系统初始化状态
|
||||
router.beforeEach(async (to, from, next) => {
|
||||
// 如果访问的是初始化页面,直接放行
|
||||
if (to.meta.requiresInit === false) {
|
||||
next();
|
||||
return;
|
||||
const authStore = useAuthStore()
|
||||
|
||||
// 如果访问的是登录页面或初始化页面,直接放行
|
||||
if (to.meta.requiresAuth === false || to.meta.requiresInit === false) {
|
||||
// 如果已登录用户访问登录页面,重定向到知识库列表页面
|
||||
if (to.path === '/login' && authStore.isLoggedIn) {
|
||||
next('/platform/knowledgeBase')
|
||||
return
|
||||
}
|
||||
next()
|
||||
return
|
||||
}
|
||||
|
||||
1
|
||||
// 检查用户认证状态
|
||||
if (to.meta.requiresAuth !== false) {
|
||||
if (!authStore.isLoggedIn) {
|
||||
// 未登录,跳转到登录页面
|
||||
next('/login')
|
||||
return
|
||||
}
|
||||
|
||||
// 验证Token有效性
|
||||
// try {
|
||||
// const { valid } = await validateToken()
|
||||
// if (!valid) {
|
||||
// // Token无效,清空认证信息并跳转到登录页面
|
||||
// authStore.logout()
|
||||
// next('/login')
|
||||
// return
|
||||
// }
|
||||
// } catch (error) {
|
||||
// console.error('Token验证失败:', error)
|
||||
// authStore.logout()
|
||||
// next('/login')
|
||||
// return
|
||||
// }
|
||||
}
|
||||
|
||||
// 检查系统初始化状态
|
||||
if (to.meta.requiresInit !== false) {
|
||||
try {
|
||||
// 检查系统是否已初始化
|
||||
const { initialized } = await checkInitializationStatus();
|
||||
const { initialized } = await checkInitializationStatus()
|
||||
|
||||
if (initialized) {
|
||||
// 系统已初始化,记录到本地存储并正常跳转
|
||||
localStorage.setItem('system_initialized', 'true');
|
||||
next();
|
||||
localStorage.setItem('system_initialized', 'true')
|
||||
next()
|
||||
} else {
|
||||
// 系统未初始化,跳转到初始化页面
|
||||
console.log('系统未初始化,跳转到初始化页面');
|
||||
next('/initialization');
|
||||
next('/initialization')
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('检查初始化状态失败:', error);
|
||||
// 如果检查失败,默认认为需要初始化
|
||||
next('/initialization');
|
||||
console.error('检查初始化状态失败:', error)
|
||||
// 如果是401,跳转登录,不再误导去初始化
|
||||
const status = (error as any)?.status
|
||||
if (status === 401) {
|
||||
next('/login')
|
||||
return
|
||||
}
|
||||
// 其他错误默认认为需要初始化
|
||||
next('/initialization')
|
||||
}
|
||||
} else {
|
||||
next()
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
174
frontend/src/stores/auth.ts
Normal file
174
frontend/src/stores/auth.ts
Normal file
@@ -0,0 +1,174 @@
|
||||
import { defineStore } from 'pinia'
|
||||
import { resetTestDataLoaded } from '@/api/test-data'
|
||||
import { ref, computed } from 'vue'
|
||||
import type { UserInfo, TenantInfo, KnowledgeBaseInfo } from '@/api/auth'
|
||||
|
||||
export const useAuthStore = defineStore('auth', () => {
|
||||
// 状态
|
||||
const user = ref<UserInfo | null>(null)
|
||||
const tenant = ref<TenantInfo | null>(null)
|
||||
const token = ref<string>('')
|
||||
const refreshToken = ref<string>('')
|
||||
const knowledgeBases = ref<KnowledgeBaseInfo[]>([])
|
||||
const currentKnowledgeBase = ref<KnowledgeBaseInfo | null>(null)
|
||||
|
||||
// 计算属性
|
||||
const isLoggedIn = computed(() => {
|
||||
return !!token.value && !!user.value
|
||||
})
|
||||
|
||||
const hasValidTenant = computed(() => {
|
||||
return !!tenant.value && !!tenant.value.api_key
|
||||
})
|
||||
|
||||
const currentTenantId = computed(() => {
|
||||
return tenant.value?.id || ''
|
||||
})
|
||||
|
||||
const currentUserId = computed(() => {
|
||||
return user.value?.id || ''
|
||||
})
|
||||
|
||||
// 操作方法
|
||||
const setUser = (userData: UserInfo) => {
|
||||
user.value = userData
|
||||
// 保存到localStorage
|
||||
localStorage.setItem('weknora_user', JSON.stringify(userData))
|
||||
}
|
||||
|
||||
const setTenant = (tenantData: TenantInfo) => {
|
||||
tenant.value = tenantData
|
||||
// 保存到localStorage
|
||||
localStorage.setItem('weknora_tenant', JSON.stringify(tenantData))
|
||||
}
|
||||
|
||||
const setToken = (tokenValue: string) => {
|
||||
token.value = tokenValue
|
||||
localStorage.setItem('weknora_token', tokenValue)
|
||||
}
|
||||
|
||||
const setRefreshToken = (refreshTokenValue: string) => {
|
||||
refreshToken.value = refreshTokenValue
|
||||
localStorage.setItem('weknora_refresh_token', refreshTokenValue)
|
||||
}
|
||||
|
||||
const setKnowledgeBases = (kbList: KnowledgeBaseInfo[]) => {
|
||||
// 确保输入是数组
|
||||
knowledgeBases.value = Array.isArray(kbList) ? kbList : []
|
||||
localStorage.setItem('weknora_knowledge_bases', JSON.stringify(knowledgeBases.value))
|
||||
}
|
||||
|
||||
const setCurrentKnowledgeBase = (kb: KnowledgeBaseInfo | null) => {
|
||||
currentKnowledgeBase.value = kb
|
||||
if (kb) {
|
||||
localStorage.setItem('weknora_current_kb', JSON.stringify(kb))
|
||||
} else {
|
||||
localStorage.removeItem('weknora_current_kb')
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
const logout = () => {
|
||||
// 清空状态
|
||||
user.value = null
|
||||
tenant.value = null
|
||||
token.value = ''
|
||||
refreshToken.value = ''
|
||||
knowledgeBases.value = []
|
||||
currentKnowledgeBase.value = null
|
||||
|
||||
// 清空localStorage
|
||||
localStorage.removeItem('weknora_user')
|
||||
localStorage.removeItem('weknora_tenant')
|
||||
localStorage.removeItem('weknora_token')
|
||||
localStorage.removeItem('weknora_refresh_token')
|
||||
localStorage.removeItem('weknora_knowledge_bases')
|
||||
localStorage.removeItem('weknora_current_kb')
|
||||
|
||||
// 重置测试数据加载标志,确保重新登录后会重新获取KB列表
|
||||
try {
|
||||
resetTestDataLoaded()
|
||||
} catch {}
|
||||
}
|
||||
|
||||
const initFromStorage = () => {
|
||||
// 从localStorage恢复状态
|
||||
const storedUser = localStorage.getItem('weknora_user')
|
||||
const storedTenant = localStorage.getItem('weknora_tenant')
|
||||
const storedToken = localStorage.getItem('weknora_token')
|
||||
const storedRefreshToken = localStorage.getItem('weknora_refresh_token')
|
||||
const storedKnowledgeBases = localStorage.getItem('weknora_knowledge_bases')
|
||||
const storedCurrentKb = localStorage.getItem('weknora_current_kb')
|
||||
|
||||
if (storedUser) {
|
||||
try {
|
||||
user.value = JSON.parse(storedUser)
|
||||
} catch (e) {
|
||||
console.error('解析用户信息失败:', e)
|
||||
}
|
||||
}
|
||||
|
||||
if (storedTenant) {
|
||||
try {
|
||||
tenant.value = JSON.parse(storedTenant)
|
||||
} catch (e) {
|
||||
console.error('解析租户信息失败:', e)
|
||||
}
|
||||
}
|
||||
|
||||
if (storedToken) {
|
||||
token.value = storedToken
|
||||
}
|
||||
|
||||
if (storedRefreshToken) {
|
||||
refreshToken.value = storedRefreshToken
|
||||
}
|
||||
|
||||
if (storedKnowledgeBases) {
|
||||
try {
|
||||
const parsed = JSON.parse(storedKnowledgeBases)
|
||||
knowledgeBases.value = Array.isArray(parsed) ? parsed : []
|
||||
} catch (e) {
|
||||
console.error('解析知识库列表失败:', e)
|
||||
knowledgeBases.value = []
|
||||
}
|
||||
}
|
||||
|
||||
if (storedCurrentKb) {
|
||||
try {
|
||||
currentKnowledgeBase.value = JSON.parse(storedCurrentKb)
|
||||
} catch (e) {
|
||||
console.error('解析当前知识库失败:', e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 初始化时从localStorage恢复状态
|
||||
initFromStorage()
|
||||
|
||||
return {
|
||||
// 状态
|
||||
user,
|
||||
tenant,
|
||||
token,
|
||||
refreshToken,
|
||||
knowledgeBases,
|
||||
currentKnowledgeBase,
|
||||
|
||||
// 计算属性
|
||||
isLoggedIn,
|
||||
hasValidTenant,
|
||||
currentTenantId,
|
||||
currentUserId,
|
||||
|
||||
// 方法
|
||||
setUser,
|
||||
setTenant,
|
||||
setToken,
|
||||
setRefreshToken,
|
||||
setKnowledgeBases,
|
||||
setCurrentKnowledgeBase,
|
||||
logout,
|
||||
initFromStorage
|
||||
}
|
||||
})
|
||||
@@ -13,7 +13,8 @@ export const useMenuStore = defineStore('menuStore', {
|
||||
childrenPath: 'chat',
|
||||
children: reactive<object[]>([]),
|
||||
},
|
||||
{ title: '系统设置', icon: 'setting', path: 'settings' }
|
||||
{ title: '系统设置', icon: 'setting', path: 'settings' },
|
||||
{ title: '退出登录', icon: 'logout', path: 'logout' }
|
||||
]),
|
||||
isFirstSession: false,
|
||||
firstQuery: ''
|
||||
|
||||
@@ -2,26 +2,8 @@
|
||||
import axios from "axios";
|
||||
import { generateRandomString } from "./index";
|
||||
|
||||
// 从localStorage获取设置
|
||||
function getSettings() {
|
||||
const settingsStr = localStorage.getItem("WeKnora_settings");
|
||||
if (settingsStr) {
|
||||
try {
|
||||
return JSON.parse(settingsStr);
|
||||
} catch (e) {
|
||||
console.error("解析设置失败:", e);
|
||||
}
|
||||
}
|
||||
return {
|
||||
endpoint: import.meta.env.VITE_IS_DOCKER ? "" : "http://localhost:8080",
|
||||
apiKey: "",
|
||||
knowledgeBaseId: "",
|
||||
};
|
||||
}
|
||||
|
||||
// API基础URL,优先使用设置中的endpoint
|
||||
const settings = getSettings();
|
||||
const BASE_URL = settings.endpoint;
|
||||
// API基础URL
|
||||
const BASE_URL = import.meta.env.VITE_IS_DOCKER ? "" : "http://localhost:8080";
|
||||
|
||||
// 测试数据
|
||||
let testData: {
|
||||
@@ -50,13 +32,6 @@ const instance = axios.create({
|
||||
// 设置测试数据
|
||||
export function setTestData(data: typeof testData) {
|
||||
testData = data;
|
||||
if (data) {
|
||||
// 优先使用设置中的ApiKey,如果没有则使用测试数据中的
|
||||
const apiKey = settings.apiKey || (data?.tenant?.api_key || "");
|
||||
if (apiKey) {
|
||||
instance.defaults.headers["X-API-Key"] = apiKey;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 获取测试数据
|
||||
@@ -66,25 +41,38 @@ export function getTestData() {
|
||||
|
||||
instance.interceptors.request.use(
|
||||
(config) => {
|
||||
// 每次请求前检查是否有更新的设置
|
||||
const currentSettings = getSettings();
|
||||
|
||||
// 更新BaseURL (如果有变化)
|
||||
if (currentSettings.endpoint && config.baseURL !== currentSettings.endpoint) {
|
||||
config.baseURL = currentSettings.endpoint;
|
||||
}
|
||||
|
||||
// 更新API Key (如果有)
|
||||
if (currentSettings.apiKey) {
|
||||
config.headers["X-API-Key"] = currentSettings.apiKey;
|
||||
// 添加JWT token认证
|
||||
const token = localStorage.getItem('weknora_token');
|
||||
if (token) {
|
||||
config.headers["Authorization"] = `Bearer ${token}`;
|
||||
}
|
||||
|
||||
config.headers["X-Request-ID"] = `${generateRandomString(12)}`;
|
||||
return config;
|
||||
},
|
||||
(error) => {}
|
||||
(error) => {
|
||||
return Promise.reject(error);
|
||||
}
|
||||
);
|
||||
|
||||
// Token刷新标志,防止多个请求同时刷新token
|
||||
let isRefreshing = false;
|
||||
let failedQueue: Array<{ resolve: Function; reject: Function }> = [];
|
||||
let hasRedirectedOn401 = false;
|
||||
|
||||
// 处理队列中的请求
|
||||
const processQueue = (error: any, token: string | null = null) => {
|
||||
failedQueue.forEach(({ resolve, reject }) => {
|
||||
if (error) {
|
||||
reject(error);
|
||||
} else {
|
||||
resolve(token);
|
||||
}
|
||||
});
|
||||
|
||||
failedQueue = [];
|
||||
};
|
||||
|
||||
instance.interceptors.response.use(
|
||||
(response) => {
|
||||
// 根据业务状态码处理逻辑
|
||||
@@ -95,12 +83,98 @@ instance.interceptors.response.use(
|
||||
return Promise.reject(data);
|
||||
}
|
||||
},
|
||||
(error: any) => {
|
||||
async (error: any) => {
|
||||
const originalRequest = error.config;
|
||||
|
||||
if (!error.response) {
|
||||
return Promise.reject({ message: "网络错误,请检查您的网络连接" });
|
||||
}
|
||||
const { data } = error.response;
|
||||
return Promise.reject(data);
|
||||
|
||||
// 如果是登录接口的401,直接返回错误以便页面展示toast,不做跳转
|
||||
if (error.response.status === 401 && originalRequest?.url?.includes('/auth/login')) {
|
||||
const { status, data } = error.response;
|
||||
return Promise.reject({ status, message: (typeof data === 'object' ? data?.message : data) || '用户名或密码错误' });
|
||||
}
|
||||
|
||||
// 如果是401错误且不是刷新token的请求,尝试刷新token
|
||||
if (error.response.status === 401 && !originalRequest._retry && !originalRequest.url?.includes('/auth/refresh')) {
|
||||
if (isRefreshing) {
|
||||
// 如果正在刷新token,将请求加入队列
|
||||
return new Promise((resolve, reject) => {
|
||||
failedQueue.push({ resolve, reject });
|
||||
}).then(token => {
|
||||
originalRequest.headers['Authorization'] = 'Bearer ' + token;
|
||||
return instance(originalRequest);
|
||||
}).catch(err => {
|
||||
return Promise.reject(err);
|
||||
});
|
||||
}
|
||||
|
||||
originalRequest._retry = true;
|
||||
isRefreshing = true;
|
||||
|
||||
const refreshToken = localStorage.getItem('weknora_refresh_token');
|
||||
|
||||
if (refreshToken) {
|
||||
try {
|
||||
// 动态导入refresh token API
|
||||
const { refreshToken: refreshTokenAPI } = await import('../api/auth/index');
|
||||
const response = await refreshTokenAPI(refreshToken);
|
||||
|
||||
if (response.success && response.data) {
|
||||
const { token, refreshToken: newRefreshToken } = response.data;
|
||||
|
||||
// 更新localStorage中的token
|
||||
localStorage.setItem('weknora_token', token);
|
||||
localStorage.setItem('weknora_refresh_token', newRefreshToken);
|
||||
|
||||
// 更新请求头
|
||||
originalRequest.headers['Authorization'] = 'Bearer ' + token;
|
||||
|
||||
// 处理队列中的请求
|
||||
processQueue(null, token);
|
||||
|
||||
return instance(originalRequest);
|
||||
} else {
|
||||
throw new Error(response.message || 'Token刷新失败');
|
||||
}
|
||||
} catch (refreshError) {
|
||||
// 刷新失败,清除所有token并跳转到登录页
|
||||
localStorage.removeItem('weknora_token');
|
||||
localStorage.removeItem('weknora_refresh_token');
|
||||
localStorage.removeItem('weknora_user');
|
||||
localStorage.removeItem('weknora_tenant');
|
||||
|
||||
processQueue(refreshError, null);
|
||||
|
||||
// 跳转到登录页
|
||||
if (!hasRedirectedOn401 && typeof window !== 'undefined') {
|
||||
hasRedirectedOn401 = true;
|
||||
window.location.href = '/login';
|
||||
}
|
||||
|
||||
return Promise.reject(refreshError);
|
||||
} finally {
|
||||
isRefreshing = false;
|
||||
}
|
||||
} else {
|
||||
// 没有refresh token,直接跳转到登录页
|
||||
localStorage.removeItem('weknora_token');
|
||||
localStorage.removeItem('weknora_user');
|
||||
localStorage.removeItem('weknora_tenant');
|
||||
|
||||
if (!hasRedirectedOn401 && typeof window !== 'undefined') {
|
||||
hasRedirectedOn401 = true;
|
||||
window.location.href = '/login';
|
||||
}
|
||||
|
||||
return Promise.reject({ message: '请重新登录' });
|
||||
}
|
||||
}
|
||||
|
||||
const { status, data } = error.response;
|
||||
// 将HTTP状态码一并抛出,方便上层判断401等场景
|
||||
return Promise.reject({ status, ...(typeof data === 'object' ? data : { message: data }) });
|
||||
}
|
||||
);
|
||||
|
||||
|
||||
559
frontend/src/views/auth/Login.vue
Normal file
559
frontend/src/views/auth/Login.vue
Normal file
@@ -0,0 +1,559 @@
|
||||
<template>
|
||||
<div class="login-container">
|
||||
<!-- 登录表单 -->
|
||||
<div class="login-card" v-if="!isRegisterMode">
|
||||
<!-- 系统Logo和标题 -->
|
||||
<div class="login-header">
|
||||
<div class="logo">
|
||||
<img src="@/assets/img/weknora.png" alt="WeKnora" class="logo-img" />
|
||||
</div>
|
||||
<p class="login-subtitle">基于大模型的文档理解与语义检索框架</p>
|
||||
</div>
|
||||
|
||||
<div class="login-form">
|
||||
<t-form
|
||||
ref="formRef"
|
||||
:data="formData"
|
||||
:rules="formRules"
|
||||
@submit="handleLogin"
|
||||
layout="vertical"
|
||||
>
|
||||
<t-form-item label="邮箱" name="email">
|
||||
<t-input
|
||||
v-model="formData.email"
|
||||
placeholder="请输入邮箱地址"
|
||||
type="email"
|
||||
size="large"
|
||||
:disabled="loading"
|
||||
/>
|
||||
</t-form-item>
|
||||
|
||||
<t-form-item label="密码" name="password">
|
||||
<t-input
|
||||
v-model="formData.password"
|
||||
placeholder="请输入密码(8-32位,包含字母和数字)"
|
||||
type="password"
|
||||
size="large"
|
||||
:disabled="loading"
|
||||
@keydown.enter="handleLogin"
|
||||
/>
|
||||
</t-form-item>
|
||||
|
||||
<t-button
|
||||
type="submit"
|
||||
theme="primary"
|
||||
size="large"
|
||||
block
|
||||
:loading="loading"
|
||||
class="login-button"
|
||||
>
|
||||
{{ loading ? '登录中...' : '登录' }}
|
||||
</t-button>
|
||||
</t-form>
|
||||
|
||||
<!-- 注册链接 -->
|
||||
<div class="register-link">
|
||||
<span>还没有账号?</span>
|
||||
<a href="#" @click.prevent="toggleMode" class="register-btn">
|
||||
立即注册
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 注册表单 -->
|
||||
<div class="register-card" v-if="isRegisterMode">
|
||||
<div class="login-header">
|
||||
<h1 class="login-title">创建账号</h1>
|
||||
<p class="login-subtitle">注册后系统将为您创建专属租户</p>
|
||||
</div>
|
||||
|
||||
<div class="login-form">
|
||||
<t-form
|
||||
ref="registerFormRef"
|
||||
:data="registerData"
|
||||
:rules="registerRules"
|
||||
@submit="handleRegister"
|
||||
layout="vertical"
|
||||
>
|
||||
<t-form-item label="用户名" name="username">
|
||||
<t-input
|
||||
v-model="registerData.username"
|
||||
placeholder="请输入用户名"
|
||||
size="large"
|
||||
:disabled="loading"
|
||||
/>
|
||||
</t-form-item>
|
||||
|
||||
<t-form-item label="邮箱" name="email">
|
||||
<t-input
|
||||
v-model="registerData.email"
|
||||
placeholder="请输入邮箱地址"
|
||||
type="email"
|
||||
size="large"
|
||||
:disabled="loading"
|
||||
/>
|
||||
</t-form-item>
|
||||
|
||||
<t-form-item label="密码" name="password">
|
||||
<t-input
|
||||
v-model="registerData.password"
|
||||
placeholder="请输入密码(8-32位,包含字母和数字)"
|
||||
type="password"
|
||||
size="large"
|
||||
:disabled="loading"
|
||||
/>
|
||||
</t-form-item>
|
||||
|
||||
<t-form-item label="确认密码" name="confirmPassword">
|
||||
<t-input
|
||||
v-model="registerData.confirmPassword"
|
||||
placeholder="请再次输入密码"
|
||||
type="password"
|
||||
size="large"
|
||||
:disabled="loading"
|
||||
@keydown.enter="handleRegister"
|
||||
/>
|
||||
</t-form-item>
|
||||
|
||||
<t-button
|
||||
type="submit"
|
||||
theme="primary"
|
||||
size="large"
|
||||
block
|
||||
:loading="loading"
|
||||
class="login-button"
|
||||
>
|
||||
{{ loading ? '注册中...' : '注册' }}
|
||||
</t-button>
|
||||
</t-form>
|
||||
|
||||
<!-- 返回登录 -->
|
||||
<div class="register-link">
|
||||
<span>已有账号?</span>
|
||||
<a href="#" @click.prevent="toggleMode" class="register-btn">
|
||||
返回登录
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, reactive, computed, nextTick, onMounted } from 'vue'
|
||||
import { useRouter } from 'vue-router'
|
||||
import { MessagePlugin } from 'tdesign-vue-next'
|
||||
import { login, register } from '@/api/auth'
|
||||
import { loadTestData, resetTestDataLoaded } from '@/api/test-data'
|
||||
import { useAuthStore } from '@/stores/auth'
|
||||
|
||||
const router = useRouter()
|
||||
const authStore = useAuthStore()
|
||||
|
||||
// 表单引用
|
||||
const formRef = ref()
|
||||
const registerFormRef = ref()
|
||||
|
||||
// 状态管理
|
||||
const loading = ref(false)
|
||||
const isRegisterMode = ref(false)
|
||||
|
||||
|
||||
// 登录表单数据
|
||||
const formData = reactive<{[key: string]: any}>({
|
||||
email: '',
|
||||
password: '',
|
||||
})
|
||||
|
||||
// 注册表单数据
|
||||
const registerData = reactive<{[key: string]: any}>({
|
||||
username: '',
|
||||
email: '',
|
||||
password: '',
|
||||
confirmPassword: ''
|
||||
})
|
||||
|
||||
// 登录表单验证规则
|
||||
const formRules = {
|
||||
email: [
|
||||
{ required: true, message: '请输入邮箱地址', type: 'error' },
|
||||
{ email: true, message: '请输入正确的邮箱格式', type: 'error' }
|
||||
],
|
||||
password: [
|
||||
{ required: true, message: '请输入密码', type: 'error' },
|
||||
{ min: 8, message: '密码至少8位', type: 'error' },
|
||||
{ max: 32, message: '密码不能超过32位', type: 'error' },
|
||||
{ pattern: /[a-zA-Z]/, message: '密码必须包含字母', type: 'error' },
|
||||
{ pattern: /\d/, message: '密码必须包含数字', type: 'error' }
|
||||
]
|
||||
}
|
||||
|
||||
// 注册表单验证规则
|
||||
const registerRules = {
|
||||
username: [
|
||||
{ required: true, message: '请输入用户名', type: 'error' },
|
||||
{ min: 2, message: '用户名至少2位', type: 'error' },
|
||||
{ max: 20, message: '用户名不能超过20位', type: 'error' },
|
||||
{
|
||||
pattern: /^[a-zA-Z0-9_\u4e00-\u9fa5]+$/,
|
||||
message: '用户名只能包含字母、数字、下划线和中文',
|
||||
type: 'error'
|
||||
}
|
||||
],
|
||||
email: [
|
||||
{ required: true, message: '请输入邮箱地址', type: 'error' },
|
||||
{ email: true, message: '请输入正确的邮箱格式', type: 'error' }
|
||||
],
|
||||
password: [
|
||||
{ required: true, message: '请输入密码', type: 'error' },
|
||||
{ min: 8, message: '密码至少8位', type: 'error' },
|
||||
{ max: 32, message: '密码不能超过32位', type: 'error' },
|
||||
{ pattern: /[a-zA-Z]/, message: '密码必须包含字母', type: 'error' },
|
||||
{ pattern: /\d/, message: '密码必须包含数字', type: 'error' }
|
||||
],
|
||||
confirmPassword: [
|
||||
{ required: true, message: '请确认密码', type: 'error' },
|
||||
{
|
||||
validator: (val: string) => val === registerData.password,
|
||||
message: '两次输入的密码不一致',
|
||||
type: 'error'
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
// 切换登录/注册模式
|
||||
const toggleMode = () => {
|
||||
isRegisterMode.value = !isRegisterMode.value
|
||||
|
||||
Object.keys(registerData).forEach(key => {
|
||||
(registerData as any)[key] = ''
|
||||
})
|
||||
}
|
||||
|
||||
// 处理登录
|
||||
const handleLogin = async () => {
|
||||
try {
|
||||
const valid = await formRef.value?.validate()
|
||||
if (!valid) return
|
||||
|
||||
loading.value = true
|
||||
|
||||
const response = await login({
|
||||
email: formData.email,
|
||||
password: formData.password,
|
||||
})
|
||||
|
||||
if (response.success) {
|
||||
// 保存用户信息和token
|
||||
if (response.user && response.tenant && response.token) {
|
||||
authStore.setUser({
|
||||
id: response.user.id || '',
|
||||
username: response.user.username || '',
|
||||
email: response.user.email || '',
|
||||
avatar: response.user.avatar,
|
||||
tenant_id: String(response.tenant.id) || '',
|
||||
created_at: response.user.created_at || new Date().toISOString(),
|
||||
updated_at: response.user.updated_at || new Date().toISOString()
|
||||
})
|
||||
authStore.setToken(response.token)
|
||||
if (response.refresh_token) {
|
||||
authStore.setRefreshToken(response.refresh_token)
|
||||
}
|
||||
authStore.setTenant({
|
||||
id: String(response.tenant.id) || '',
|
||||
name: response.tenant.name || '',
|
||||
api_key: response.tenant.api_key || '',
|
||||
owner_id: response.user.id || '',
|
||||
created_at: response.tenant.created_at || new Date().toISOString(),
|
||||
updated_at: response.tenant.updated_at || new Date().toISOString()
|
||||
})
|
||||
}
|
||||
|
||||
MessagePlugin.success('登录成功!')
|
||||
|
||||
// 登录成功后先重置并加载一次测试数据,确保有KB可用
|
||||
try {
|
||||
resetTestDataLoaded()
|
||||
await loadTestData()
|
||||
} catch (_) {}
|
||||
|
||||
// 等待状态更新完成后再跳转
|
||||
await nextTick()
|
||||
router.replace('/platform/knowledgeBase')
|
||||
} else {
|
||||
MessagePlugin.error(response.message || '登录失败,请检查邮箱或密码')
|
||||
}
|
||||
} catch (error: any) {
|
||||
console.error('登录错误:', error)
|
||||
MessagePlugin.error(error.message || '登录失败,请稍后重试')
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// 处理注册
|
||||
const handleRegister = async () => {
|
||||
try {
|
||||
const valid = await registerFormRef.value?.validate()
|
||||
if (!valid) return
|
||||
|
||||
loading.value = true
|
||||
|
||||
const response = await register({
|
||||
username: registerData.username,
|
||||
email: registerData.email,
|
||||
password: registerData.password
|
||||
})
|
||||
|
||||
if (response.success) {
|
||||
MessagePlugin.success('注册成功!系统已为您创建专属租户,请登录使用')
|
||||
|
||||
// 切换到登录模式并填入邮箱
|
||||
isRegisterMode.value = false
|
||||
formData.email = registerData.email
|
||||
|
||||
// 清空注册表单
|
||||
Object.keys(registerData).forEach(key => {
|
||||
(registerData as any)[key] = ''
|
||||
})
|
||||
} else {
|
||||
MessagePlugin.error(response.message || '注册失败')
|
||||
}
|
||||
} catch (error: any) {
|
||||
console.error('注册错误:', error)
|
||||
MessagePlugin.error(error.message || '注册失败,请稍后重试')
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// 处理忘记密码
|
||||
const handleForgotPassword = () => {
|
||||
MessagePlugin.info('忘记密码功能暂未开放,请联系管理员')
|
||||
}
|
||||
|
||||
// 检查是否已登录
|
||||
onMounted(() => {
|
||||
if (authStore.isLoggedIn) {
|
||||
router.replace('/platform/tenant/knowledge-bases')
|
||||
}
|
||||
})
|
||||
</script>
|
||||
|
||||
<style lang="less" scoped>
|
||||
.login-container {
|
||||
min-height: 100vh;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
|
||||
padding: 20px;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
.login-card,
|
||||
.register-card {
|
||||
width: 100%;
|
||||
max-width: 440px;
|
||||
background: #fff;
|
||||
border-radius: 14px;
|
||||
box-shadow: 0 10px 16px 0 #0000000f, 0 20px 24px -2px #0000001a;
|
||||
padding: 40px;
|
||||
box-sizing: border-box;
|
||||
animation: fadeInUp .28s ease-out both;
|
||||
}
|
||||
|
||||
.login-header {
|
||||
text-align: center;
|
||||
margin-bottom: 32px;
|
||||
|
||||
.logo {
|
||||
margin-bottom: 16px;
|
||||
|
||||
.logo-img {
|
||||
width: 180px;
|
||||
height: auto;
|
||||
border-radius: 12px;
|
||||
}
|
||||
}
|
||||
|
||||
.login-title {
|
||||
font-size: 28px;
|
||||
font-weight: 600;
|
||||
color: #000000e6;
|
||||
margin: 0 0 8px 0;
|
||||
font-family: "PingFang SC";
|
||||
}
|
||||
|
||||
.login-subtitle {
|
||||
font-size: 16px;
|
||||
color: #0000008c;
|
||||
margin: 0;
|
||||
font-family: "PingFang SC";
|
||||
}
|
||||
}
|
||||
|
||||
.login-form {
|
||||
:deep(.t-form-item__label) {
|
||||
font-size: 14px;
|
||||
color: #000000e6;
|
||||
font-weight: 500;
|
||||
margin-bottom: 8px;
|
||||
font-family: "PingFang SC";
|
||||
display: block;
|
||||
text-align: left;
|
||||
}
|
||||
|
||||
:deep(.t-input) {
|
||||
border: 1px solid #E7E7E7;
|
||||
border-radius: 8px;
|
||||
background: #fff;
|
||||
|
||||
&:focus-within {
|
||||
border-color: #07C05F;
|
||||
box-shadow: 0 0 0 2px rgba(7, 192, 95, 0.1);
|
||||
}
|
||||
|
||||
&:hover {
|
||||
border-color: #07C05F;
|
||||
}
|
||||
|
||||
.t-input__inner {
|
||||
border: none !important;
|
||||
box-shadow: none !important;
|
||||
outline: none !important;
|
||||
background: transparent;
|
||||
font-size: 16px;
|
||||
font-family: "PingFang SC";
|
||||
|
||||
&:focus {
|
||||
border: none !important;
|
||||
box-shadow: none !important;
|
||||
outline: none !important;
|
||||
}
|
||||
}
|
||||
|
||||
.t-input__wrap {
|
||||
border: none !important;
|
||||
box-shadow: none !important;
|
||||
}
|
||||
}
|
||||
|
||||
:deep(.t-form-item) {
|
||||
margin-bottom: 20px;
|
||||
|
||||
&:last-child {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
}
|
||||
|
||||
:deep(.t-form-item__control) {
|
||||
width: 100%;
|
||||
}
|
||||
}
|
||||
|
||||
.login-options {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
margin: 16px 0 24px 0;
|
||||
width: 100%;
|
||||
|
||||
:deep(.t-checkbox) {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
|
||||
.t-checkbox__input {
|
||||
margin-right: 8px;
|
||||
}
|
||||
}
|
||||
|
||||
:deep(.t-checkbox__label) {
|
||||
font-size: 14px;
|
||||
color: #00000099;
|
||||
font-family: "PingFang SC";
|
||||
line-height: 1.4;
|
||||
margin-left: 0;
|
||||
}
|
||||
|
||||
.forgot-password {
|
||||
font-size: 14px;
|
||||
color: #07C05F;
|
||||
text-decoration: none;
|
||||
font-family: "PingFang SC";
|
||||
line-height: 1.4;
|
||||
|
||||
&:hover {
|
||||
text-decoration: underline;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
.login-button {
|
||||
height: 48px;
|
||||
border-radius: 8px;
|
||||
font-size: 16px;
|
||||
font-weight: 500;
|
||||
font-family: "PingFang SC";
|
||||
margin: 16px 0 8px 0;
|
||||
|
||||
:deep(.t-button) {
|
||||
background-color: #07C05F;
|
||||
border-color: #07C05F;
|
||||
|
||||
&:hover {
|
||||
background-color: #06a855;
|
||||
border-color: #06a855;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
.register-link {
|
||||
text-align: center;
|
||||
font-size: 14px;
|
||||
color: #00000099;
|
||||
font-family: "PingFang SC";
|
||||
|
||||
.register-btn {
|
||||
color: #07C05F;
|
||||
text-decoration: none;
|
||||
margin-left: 4px;
|
||||
|
||||
&:hover {
|
||||
text-decoration: underline;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 响应式设计
|
||||
@media (max-width: 480px) {
|
||||
.login-container {
|
||||
padding: 16px;
|
||||
}
|
||||
|
||||
.login-card,
|
||||
.register-card {
|
||||
padding: 28px;
|
||||
}
|
||||
|
||||
.login-header {
|
||||
margin-bottom: 24px;
|
||||
|
||||
.login-title {
|
||||
font-size: 24px;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@keyframes fadeInUp {
|
||||
from {
|
||||
opacity: 0;
|
||||
transform: translate3d(0, 6px, 0);
|
||||
}
|
||||
to {
|
||||
opacity: 1;
|
||||
transform: translate3d(0, 0, 0);
|
||||
}
|
||||
}
|
||||
</style>
|
||||
@@ -27,33 +27,7 @@ const sendMsg = (value: string) => {
|
||||
}
|
||||
|
||||
async function createNewSession(value: string) {
|
||||
// 从localStorage获取设置中的知识库ID
|
||||
const settingsStr = localStorage.getItem("WeKnora_settings");
|
||||
let knowledgeBaseId = "";
|
||||
|
||||
if (settingsStr) {
|
||||
try {
|
||||
const settings = JSON.parse(settingsStr);
|
||||
if (settings.knowledgeBaseId) {
|
||||
knowledgeBaseId = settings.knowledgeBaseId;
|
||||
createSessions({ knowledge_base_id: knowledgeBaseId }).then(res => {
|
||||
if (res.data && res.data.id) {
|
||||
getTitle(res.data.id, value);
|
||||
} else {
|
||||
// 错误处理
|
||||
console.error("创建会话失败");
|
||||
}
|
||||
}).catch(error => {
|
||||
console.error("创建会话出错:", error);
|
||||
});
|
||||
return;
|
||||
}
|
||||
} catch (e) {
|
||||
console.error("解析设置失败:", e);
|
||||
}
|
||||
}
|
||||
|
||||
// 如果设置中没有知识库ID,则使用测试数据
|
||||
// 使用测试数据获取知识库ID
|
||||
const testData = getTestData();
|
||||
if (!testData || testData.knowledge_bases.length === 0) {
|
||||
console.error("测试数据未初始化或不包含知识库");
|
||||
@@ -61,7 +35,7 @@ async function createNewSession(value: string) {
|
||||
}
|
||||
|
||||
// 使用第一个知识库ID
|
||||
knowledgeBaseId = testData.knowledge_bases[0].id;
|
||||
const knowledgeBaseId = testData.knowledge_bases[0].id;
|
||||
|
||||
createSessions({ knowledge_base_id: knowledgeBaseId }).then(res => {
|
||||
if (res.data && res.data.id) {
|
||||
|
||||
@@ -17,6 +17,10 @@
|
||||
<span class="dot" />{{ s.label }}
|
||||
</li>
|
||||
</ul>
|
||||
<t-divider />
|
||||
<t-button size="small" variant="outline" theme="danger" block @click="handleLogout">
|
||||
退出登录
|
||||
</t-button>
|
||||
</div>
|
||||
</aside>
|
||||
<div class="init-main">
|
||||
@@ -780,8 +784,10 @@ import {
|
||||
listOllamaModels,
|
||||
testEmbeddingModel
|
||||
} from '@/api/initialization';
|
||||
import { useAuthStore } from '@/stores/auth';
|
||||
|
||||
const router = useRouter();
|
||||
const authStore = useAuthStore();
|
||||
type TFormRef = {
|
||||
validate: (fields?: string[] | undefined) => Promise<true | any>;
|
||||
clearValidate?: (fields?: string | string[]) => void;
|
||||
@@ -956,6 +962,12 @@ const goToSection = (id: string) => {
|
||||
}
|
||||
};
|
||||
|
||||
// 退出登录
|
||||
const handleLogout = () => {
|
||||
authStore.logout();
|
||||
router.replace('/login');
|
||||
};
|
||||
|
||||
// 监听滚动,高亮当前区块
|
||||
const onScroll = () => {
|
||||
const order = ['ollama','llm','embedding','rerank','multimodal','docsplit','submit'];
|
||||
@@ -2335,6 +2347,27 @@ const detectEmbeddingDimension = async () => {
|
||||
|
||||
<style lang="less" scoped>
|
||||
.initialization-container {
|
||||
padding: 20px 16px;
|
||||
background: linear-gradient(180deg, #f7faf9 0%, #f9fbfa 60%, #ffffff 100%);
|
||||
scroll-behavior: smooth;
|
||||
|
||||
.initialization-header {
|
||||
text-align: center;
|
||||
margin: 10px auto 18px;
|
||||
|
||||
h1 {
|
||||
margin: 0 0 6px;
|
||||
font-size: 22px;
|
||||
font-weight: 700;
|
||||
color: #0f172a;
|
||||
}
|
||||
|
||||
p {
|
||||
margin: 0;
|
||||
color: #64748b;
|
||||
font-size: 14px;
|
||||
}
|
||||
}
|
||||
.init-layout {
|
||||
display: grid;
|
||||
grid-template-columns: 220px 1fr;
|
||||
@@ -2420,6 +2453,30 @@ const detectEmbeddingDimension = async () => {
|
||||
min-width: 0;
|
||||
max-width: 960px;
|
||||
}
|
||||
/* 统一分区卡片视觉 */
|
||||
.config-section {
|
||||
background: #fff;
|
||||
border: 1px solid #eef4f0;
|
||||
border-radius: 12px;
|
||||
box-shadow: 0 6px 18px rgba(7, 192, 95, 0.04);
|
||||
padding: 16px;
|
||||
margin: 14px 0;
|
||||
|
||||
h3 {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
margin: 0 0 12px;
|
||||
font-size: 16px;
|
||||
font-weight: 700;
|
||||
color: #0f172a;
|
||||
}
|
||||
|
||||
.section-icon {
|
||||
color: #07c05f;
|
||||
font-size: 18px;
|
||||
}
|
||||
}
|
||||
.ollama-summary-card {
|
||||
max-width: 100%;
|
||||
margin: 0 0 16px 0;
|
||||
|
||||
11
go.mod
11
go.mod
@@ -10,6 +10,7 @@ require (
|
||||
github.com/gin-contrib/cors v1.7.5
|
||||
github.com/gin-gonic/gin v1.10.0
|
||||
github.com/go-viper/mapstructure/v2 v2.2.1
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/hibiken/asynq v0.25.1
|
||||
github.com/minio/minio-go/v7 v7.0.90
|
||||
@@ -31,7 +32,8 @@ require (
|
||||
go.opentelemetry.io/otel/sdk v1.37.0
|
||||
go.opentelemetry.io/otel/trace v1.37.0
|
||||
go.uber.org/dig v1.18.1
|
||||
golang.org/x/sync v0.15.0
|
||||
golang.org/x/crypto v0.42.0
|
||||
golang.org/x/sync v0.17.0
|
||||
google.golang.org/grpc v1.73.0
|
||||
google.golang.org/protobuf v1.36.6
|
||||
gorm.io/driver/postgres v1.5.11
|
||||
@@ -100,10 +102,9 @@ require (
|
||||
go.opentelemetry.io/proto/otlp v1.7.0 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
golang.org/x/arch v0.15.0 // indirect
|
||||
golang.org/x/crypto v0.39.0 // indirect
|
||||
golang.org/x/net v0.41.0 // indirect
|
||||
golang.org/x/sys v0.33.0 // indirect
|
||||
golang.org/x/text v0.26.0 // indirect
|
||||
golang.org/x/net v0.43.0 // indirect
|
||||
golang.org/x/sys v0.36.0 // indirect
|
||||
golang.org/x/text v0.29.0 // indirect
|
||||
golang.org/x/time v0.11.0 // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 // indirect
|
||||
|
||||
26
go.sum
26
go.sum
@@ -70,6 +70,8 @@ github.com/go-viper/mapstructure/v2 v2.2.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlnd
|
||||
github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
|
||||
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
@@ -253,20 +255,20 @@ go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
|
||||
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||
golang.org/x/arch v0.15.0 h1:QtOrQd0bTUnhNVNndMpLHNWrDmYzZ2KDqSrEymqInZw=
|
||||
golang.org/x/arch v0.15.0/go.mod h1:JmwW7aLIoRUKgaTzhkiEFxvcEiQGyOg9BMonBJUS7EE=
|
||||
golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM=
|
||||
golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U=
|
||||
golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw=
|
||||
golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA=
|
||||
golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8=
|
||||
golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI=
|
||||
golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8=
|
||||
golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE=
|
||||
golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg=
|
||||
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
|
||||
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
|
||||
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/term v0.32.0 h1:DR4lr0TjUs3epypdhTOkMmuF5CDFJ/8pOnbzMZPQ7bg=
|
||||
golang.org/x/term v0.32.0/go.mod h1:uZG1FhGx848Sqfsq4/DlJr3xGGsYMu/L5GW4abiaEPQ=
|
||||
golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M=
|
||||
golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA=
|
||||
golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k=
|
||||
golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.35.0 h1:bZBVKBudEyhRcajGcNc3jIfWPqV4y/Kt2XcoigOWtDQ=
|
||||
golang.org/x/term v0.35.0/go.mod h1:TPGtkTLesOwf2DE8CgVYiZinHAOuy5AYUYT1lENIZnA=
|
||||
golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk=
|
||||
golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4=
|
||||
golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0=
|
||||
golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
|
||||
154
internal/application/repository/user.go
Normal file
154
internal/application/repository/user.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/logger"
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
ErrUserAlreadyExists = errors.New("user already exists")
|
||||
ErrTokenNotFound = errors.New("token not found")
|
||||
)
|
||||
|
||||
// userRepository implements user repository interface
|
||||
type userRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewUserRepository creates a new user repository
|
||||
func NewUserRepository(db *gorm.DB) interfaces.UserRepository {
|
||||
return &userRepository{db: db}
|
||||
}
|
||||
|
||||
// CreateUser creates a user
|
||||
func (r *userRepository) CreateUser(ctx context.Context, user *types.User) error {
|
||||
logger.Infof(ctx, "Creating user in database: %s", user.Email)
|
||||
return r.db.WithContext(ctx).Create(user).Error
|
||||
}
|
||||
|
||||
// GetUserByID gets a user by ID
|
||||
func (r *userRepository) GetUserByID(ctx context.Context, id string) (*types.User, error) {
|
||||
var user types.User
|
||||
if err := r.db.WithContext(ctx).Where("id = ?", id).First(&user).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// GetUserByEmail gets a user by email
|
||||
func (r *userRepository) GetUserByEmail(ctx context.Context, email string) (*types.User, error) {
|
||||
var user types.User
|
||||
if err := r.db.WithContext(ctx).Where("email = ?", email).First(&user).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// GetUserByUsername gets a user by username
|
||||
func (r *userRepository) GetUserByUsername(ctx context.Context, username string) (*types.User, error) {
|
||||
var user types.User
|
||||
if err := r.db.WithContext(ctx).Where("username = ?", username).First(&user).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// UpdateUser updates a user
|
||||
func (r *userRepository) UpdateUser(ctx context.Context, user *types.User) error {
|
||||
return r.db.WithContext(ctx).Save(user).Error
|
||||
}
|
||||
|
||||
// DeleteUser deletes a user
|
||||
func (r *userRepository) DeleteUser(ctx context.Context, id string) error {
|
||||
return r.db.WithContext(ctx).Where("id = ?", id).Delete(&types.User{}).Error
|
||||
}
|
||||
|
||||
// ListUsers lists users with pagination
|
||||
func (r *userRepository) ListUsers(ctx context.Context, offset, limit int) ([]*types.User, error) {
|
||||
var users []*types.User
|
||||
query := r.db.WithContext(ctx).Order("created_at DESC")
|
||||
|
||||
if limit > 0 {
|
||||
query = query.Limit(limit)
|
||||
}
|
||||
|
||||
if offset > 0 {
|
||||
query = query.Offset(offset)
|
||||
}
|
||||
|
||||
if err := query.Find(&users).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// authTokenRepository implements auth token repository interface
|
||||
type authTokenRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewAuthTokenRepository creates a new auth token repository
|
||||
func NewAuthTokenRepository(db *gorm.DB) interfaces.AuthTokenRepository {
|
||||
return &authTokenRepository{db: db}
|
||||
}
|
||||
|
||||
// CreateToken creates an auth token
|
||||
func (r *authTokenRepository) CreateToken(ctx context.Context, token *types.AuthToken) error {
|
||||
return r.db.WithContext(ctx).Create(token).Error
|
||||
}
|
||||
|
||||
// GetTokenByValue gets a token by its value
|
||||
func (r *authTokenRepository) GetTokenByValue(ctx context.Context, tokenValue string) (*types.AuthToken, error) {
|
||||
var token types.AuthToken
|
||||
if err := r.db.WithContext(ctx).Where("token = ?", tokenValue).First(&token).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrTokenNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// GetTokensByUserID gets all tokens for a user
|
||||
func (r *authTokenRepository) GetTokensByUserID(ctx context.Context, userID string) ([]*types.AuthToken, error) {
|
||||
var tokens []*types.AuthToken
|
||||
if err := r.db.WithContext(ctx).Where("user_id = ?", userID).Find(&tokens).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
// UpdateToken updates a token
|
||||
func (r *authTokenRepository) UpdateToken(ctx context.Context, token *types.AuthToken) error {
|
||||
return r.db.WithContext(ctx).Save(token).Error
|
||||
}
|
||||
|
||||
// DeleteToken deletes a token
|
||||
func (r *authTokenRepository) DeleteToken(ctx context.Context, id string) error {
|
||||
return r.db.WithContext(ctx).Where("id = ?", id).Delete(&types.AuthToken{}).Error
|
||||
}
|
||||
|
||||
// DeleteExpiredTokens deletes all expired tokens
|
||||
func (r *authTokenRepository) DeleteExpiredTokens(ctx context.Context) error {
|
||||
return r.db.WithContext(ctx).Where("expires_at < NOW()").Delete(&types.AuthToken{}).Error
|
||||
}
|
||||
|
||||
// RevokeTokensByUserID revokes all tokens for a user
|
||||
func (r *authTokenRepository) RevokeTokensByUserID(ctx context.Context, userID string) error {
|
||||
return r.db.WithContext(ctx).Model(&types.AuthToken{}).Where("user_id = ?", userID).Update("is_revoked", true).Error
|
||||
}
|
||||
@@ -18,7 +18,9 @@ import (
|
||||
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
||||
)
|
||||
|
||||
var apiKeySecret = []byte(os.Getenv("TENANT_AES_KEY"))
|
||||
var apiKeySecret = func() []byte {
|
||||
return []byte(os.Getenv("TENANT_AES_KEY"))
|
||||
}
|
||||
|
||||
// ListTenantsParams defines parameters for listing tenants with filtering and pagination
|
||||
type ListTenantsParams struct {
|
||||
@@ -221,7 +223,7 @@ func (r *tenantService) generateApiKey(tenantID uint) string {
|
||||
binary.LittleEndian.PutUint64(idBytes, uint64(tenantID))
|
||||
|
||||
// 2. Encrypt tenant_id using AES-GCM
|
||||
block, err := aes.NewCipher(apiKeySecret)
|
||||
block, err := aes.NewCipher(apiKeySecret())
|
||||
if err != nil {
|
||||
panic("Failed to create AES cipher: " + err.Error())
|
||||
}
|
||||
@@ -267,7 +269,7 @@ func (r *tenantService) ExtractTenantIDFromAPIKey(apiKey string) (uint, error) {
|
||||
nonce, ciphertext := encryptedData[:12], encryptedData[12:]
|
||||
|
||||
// 4. Decrypt
|
||||
block, err := aes.NewCipher(apiKeySecret)
|
||||
block, err := aes.NewCipher(apiKeySecret())
|
||||
if err != nil {
|
||||
return 0, errors.New("decryption error")
|
||||
}
|
||||
|
||||
408
internal/application/service/user.go
Normal file
408
internal/application/service/user.go
Normal file
@@ -0,0 +1,408 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/logger"
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
||||
)
|
||||
|
||||
// JWT secret key - in production this should be from environment variable
|
||||
var jwtSecret = []byte("your-secret-key")
|
||||
|
||||
// userService implements the UserService interface
|
||||
type userService struct {
|
||||
userRepo interfaces.UserRepository
|
||||
tokenRepo interfaces.AuthTokenRepository
|
||||
tenantService interfaces.TenantService
|
||||
}
|
||||
|
||||
// NewUserService creates a new user service instance
|
||||
func NewUserService(userRepo interfaces.UserRepository, tokenRepo interfaces.AuthTokenRepository, tenantService interfaces.TenantService) interfaces.UserService {
|
||||
return &userService{
|
||||
userRepo: userRepo,
|
||||
tokenRepo: tokenRepo,
|
||||
tenantService: tenantService,
|
||||
}
|
||||
}
|
||||
|
||||
// Register creates a new user account
|
||||
func (s *userService) Register(ctx context.Context, req *types.RegisterRequest) (*types.User, error) {
|
||||
logger.Info(ctx, "Start user registration")
|
||||
|
||||
// Validate input
|
||||
if req.Username == "" || req.Email == "" || req.Password == "" {
|
||||
return nil, errors.New("username, email and password are required")
|
||||
}
|
||||
|
||||
// Check if user already exists
|
||||
existingUser, _ := s.userRepo.GetUserByEmail(ctx, req.Email)
|
||||
if existingUser != nil {
|
||||
return nil, errors.New("user with this email already exists")
|
||||
}
|
||||
|
||||
existingUser, _ = s.userRepo.GetUserByUsername(ctx, req.Username)
|
||||
if existingUser != nil {
|
||||
return nil, errors.New("user with this username already exists")
|
||||
}
|
||||
|
||||
// Hash password
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "Failed to hash password: %v", err)
|
||||
return nil, errors.New("failed to process password")
|
||||
}
|
||||
|
||||
// Create default tenant for the user
|
||||
tenant := &types.Tenant{
|
||||
Name: fmt.Sprintf("%s's Workspace", req.Username),
|
||||
Description: "Default workspace",
|
||||
Status: "active",
|
||||
RetrieverEngines: types.RetrieverEngines{
|
||||
Engines: []types.RetrieverEngineParams{
|
||||
{
|
||||
RetrieverType: types.KeywordsRetrieverType,
|
||||
RetrieverEngineType: types.PostgresRetrieverEngineType,
|
||||
},
|
||||
{
|
||||
RetrieverType: types.VectorRetrieverType,
|
||||
RetrieverEngineType: types.PostgresRetrieverEngineType,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
createdTenant, err := s.tenantService.CreateTenant(ctx, tenant)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "Failed to create tenant: %v", err)
|
||||
return nil, errors.New("failed to create workspace")
|
||||
}
|
||||
|
||||
// Create user
|
||||
user := &types.User{
|
||||
ID: uuid.New().String(),
|
||||
Username: req.Username,
|
||||
Email: req.Email,
|
||||
PasswordHash: string(hashedPassword),
|
||||
TenantID: createdTenant.ID,
|
||||
IsActive: true,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
err = s.userRepo.CreateUser(ctx, user)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "Failed to create user: %v", err)
|
||||
return nil, errors.New("failed to create user")
|
||||
}
|
||||
|
||||
logger.Infof(ctx, "User registered successfully: %s", user.Email)
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// Login authenticates a user and returns tokens
|
||||
func (s *userService) Login(ctx context.Context, req *types.LoginRequest) (*types.LoginResponse, error) {
|
||||
logger.Infof(ctx, "Start user login for email: %s", req.Email)
|
||||
|
||||
// Get user by email
|
||||
user, err := s.userRepo.GetUserByEmail(ctx, req.Email)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "Failed to get user by email %s: %v", req.Email, err)
|
||||
return &types.LoginResponse{
|
||||
Success: false,
|
||||
Message: "Invalid email or password",
|
||||
}, nil
|
||||
}
|
||||
if user == nil {
|
||||
logger.Warnf(ctx, "User not found for email: %s", req.Email)
|
||||
return &types.LoginResponse{
|
||||
Success: false,
|
||||
Message: "Invalid email or password",
|
||||
}, nil
|
||||
}
|
||||
|
||||
logger.Infof(ctx, "Found user: ID=%s, Email=%s, IsActive=%t", user.ID, user.Email, user.IsActive)
|
||||
|
||||
// Check if user is active
|
||||
if !user.IsActive {
|
||||
logger.Warnf(ctx, "User account is disabled for email: %s", req.Email)
|
||||
return &types.LoginResponse{
|
||||
Success: false,
|
||||
Message: "Account is disabled",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Verify password
|
||||
logger.Infof(ctx, "Verifying password for user: %s", user.Email)
|
||||
err = bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.Password))
|
||||
if err != nil {
|
||||
logger.Warnf(ctx, "Password verification failed for user %s: %v", user.Email, err)
|
||||
return &types.LoginResponse{
|
||||
Success: false,
|
||||
Message: "Invalid email or password",
|
||||
}, nil
|
||||
}
|
||||
logger.Infof(ctx, "Password verification successful for user: %s", user.Email)
|
||||
|
||||
// Generate tokens
|
||||
logger.Infof(ctx, "Generating tokens for user: %s", user.Email)
|
||||
accessToken, refreshToken, err := s.GenerateTokens(ctx, user)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "Failed to generate tokens for user %s: %v", user.Email, err)
|
||||
return &types.LoginResponse{
|
||||
Success: false,
|
||||
Message: "Login failed",
|
||||
}, nil
|
||||
}
|
||||
logger.Infof(ctx, "Tokens generated successfully for user: %s", user.Email)
|
||||
|
||||
// Get tenant information
|
||||
logger.Infof(ctx, "Getting tenant information for user %s, tenant ID: %s", user.Email, user.TenantID)
|
||||
tenant, err := s.tenantService.GetTenantByID(ctx, user.TenantID)
|
||||
if err != nil {
|
||||
logger.Warnf(ctx, "Failed to get tenant info for user %s, tenant ID %s: %v", user.Email, user.TenantID, err)
|
||||
} else {
|
||||
logger.Infof(ctx, "Tenant information retrieved successfully for user: %s", user.Email)
|
||||
}
|
||||
|
||||
logger.Infof(ctx, "User logged in successfully: %s", user.Email)
|
||||
return &types.LoginResponse{
|
||||
Success: true,
|
||||
Message: "Login successful",
|
||||
User: user,
|
||||
Tenant: tenant,
|
||||
Token: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetUserByID gets a user by ID
|
||||
func (s *userService) GetUserByID(ctx context.Context, id string) (*types.User, error) {
|
||||
return s.userRepo.GetUserByID(ctx, id)
|
||||
}
|
||||
|
||||
// GetUserByEmail gets a user by email
|
||||
func (s *userService) GetUserByEmail(ctx context.Context, email string) (*types.User, error) {
|
||||
return s.userRepo.GetUserByEmail(ctx, email)
|
||||
}
|
||||
|
||||
// GetUserByUsername gets a user by username
|
||||
func (s *userService) GetUserByUsername(ctx context.Context, username string) (*types.User, error) {
|
||||
return s.userRepo.GetUserByUsername(ctx, username)
|
||||
}
|
||||
|
||||
// UpdateUser updates user information
|
||||
func (s *userService) UpdateUser(ctx context.Context, user *types.User) error {
|
||||
user.UpdatedAt = time.Now()
|
||||
return s.userRepo.UpdateUser(ctx, user)
|
||||
}
|
||||
|
||||
// DeleteUser deletes a user
|
||||
func (s *userService) DeleteUser(ctx context.Context, id string) error {
|
||||
return s.userRepo.DeleteUser(ctx, id)
|
||||
}
|
||||
|
||||
// ChangePassword changes user password
|
||||
func (s *userService) ChangePassword(ctx context.Context, userID string, oldPassword, newPassword string) error {
|
||||
user, err := s.userRepo.GetUserByID(ctx, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Verify old password
|
||||
err = bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(oldPassword))
|
||||
if err != nil {
|
||||
return errors.New("invalid old password")
|
||||
}
|
||||
|
||||
// Hash new password
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user.PasswordHash = string(hashedPassword)
|
||||
user.UpdatedAt = time.Now()
|
||||
|
||||
return s.userRepo.UpdateUser(ctx, user)
|
||||
}
|
||||
|
||||
// ValidatePassword validates user password
|
||||
func (s *userService) ValidatePassword(ctx context.Context, userID string, password string) error {
|
||||
user, err := s.userRepo.GetUserByID(ctx, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password))
|
||||
}
|
||||
|
||||
// GenerateTokens generates access and refresh tokens for user
|
||||
func (s *userService) GenerateTokens(ctx context.Context, user *types.User) (accessToken, refreshToken string, err error) {
|
||||
// Generate access token (expires in 24 hours)
|
||||
accessClaims := jwt.MapClaims{
|
||||
"user_id": user.ID,
|
||||
"email": user.Email,
|
||||
"tenant_id": user.TenantID,
|
||||
"exp": time.Now().Add(24 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"type": "access",
|
||||
}
|
||||
|
||||
accessTokenObj := jwt.NewWithClaims(jwt.SigningMethodHS256, accessClaims)
|
||||
accessToken, err = accessTokenObj.SignedString(jwtSecret)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
// Generate refresh token (expires in 7 days)
|
||||
refreshClaims := jwt.MapClaims{
|
||||
"user_id": user.ID,
|
||||
"exp": time.Now().Add(7 * 24 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"type": "refresh",
|
||||
}
|
||||
|
||||
refreshTokenObj := jwt.NewWithClaims(jwt.SigningMethodHS256, refreshClaims)
|
||||
refreshToken, err = refreshTokenObj.SignedString(jwtSecret)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
// Store tokens in database
|
||||
accessTokenRecord := &types.AuthToken{
|
||||
ID: uuid.New().String(),
|
||||
UserID: user.ID,
|
||||
Token: accessToken,
|
||||
TokenType: "access_token",
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
refreshTokenRecord := &types.AuthToken{
|
||||
ID: uuid.New().String(),
|
||||
UserID: user.ID,
|
||||
Token: refreshToken,
|
||||
TokenType: "refresh_token",
|
||||
ExpiresAt: time.Now().Add(7 * 24 * time.Hour),
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
_ = s.tokenRepo.CreateToken(ctx, accessTokenRecord)
|
||||
_ = s.tokenRepo.CreateToken(ctx, refreshTokenRecord)
|
||||
|
||||
return accessToken, refreshToken, nil
|
||||
}
|
||||
|
||||
// ValidateToken validates an access token
|
||||
func (s *userService) ValidateToken(ctx context.Context, tokenString string) (*types.User, error) {
|
||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return jwtSecret, nil
|
||||
})
|
||||
|
||||
if err != nil || !token.Valid {
|
||||
return nil, errors.New("invalid token")
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
return nil, errors.New("invalid token claims")
|
||||
}
|
||||
|
||||
userID, ok := claims["user_id"].(string)
|
||||
if !ok {
|
||||
return nil, errors.New("invalid user ID in token")
|
||||
}
|
||||
|
||||
// Check if token is revoked
|
||||
tokenRecord, err := s.tokenRepo.GetTokenByValue(ctx, tokenString)
|
||||
if err != nil || tokenRecord == nil || tokenRecord.IsRevoked {
|
||||
return nil, errors.New("token is revoked")
|
||||
}
|
||||
|
||||
return s.userRepo.GetUserByID(ctx, userID)
|
||||
}
|
||||
|
||||
// RefreshToken refreshes access token using refresh token
|
||||
func (s *userService) RefreshToken(ctx context.Context, refreshTokenString string) (accessToken, newRefreshToken string, err error) {
|
||||
token, err := jwt.Parse(refreshTokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return jwtSecret, nil
|
||||
})
|
||||
|
||||
if err != nil || !token.Valid {
|
||||
return "", "", errors.New("invalid refresh token")
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
return "", "", errors.New("invalid token claims")
|
||||
}
|
||||
|
||||
tokenType, ok := claims["type"].(string)
|
||||
if !ok || tokenType != "refresh" {
|
||||
return "", "", errors.New("not a refresh token")
|
||||
}
|
||||
|
||||
userID, ok := claims["user_id"].(string)
|
||||
if !ok {
|
||||
return "", "", errors.New("invalid user ID in token")
|
||||
}
|
||||
|
||||
// Check if token is revoked
|
||||
tokenRecord, err := s.tokenRepo.GetTokenByValue(ctx, refreshTokenString)
|
||||
if err != nil || tokenRecord == nil || tokenRecord.IsRevoked {
|
||||
return "", "", errors.New("refresh token is revoked")
|
||||
}
|
||||
|
||||
// Get user
|
||||
user, err := s.userRepo.GetUserByID(ctx, userID)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
// Revoke old refresh token
|
||||
tokenRecord.IsRevoked = true
|
||||
_ = s.tokenRepo.UpdateToken(ctx, tokenRecord)
|
||||
|
||||
// Generate new tokens
|
||||
return s.GenerateTokens(ctx, user)
|
||||
}
|
||||
|
||||
// RevokeToken revokes a token
|
||||
func (s *userService) RevokeToken(ctx context.Context, tokenString string) error {
|
||||
tokenRecord, err := s.tokenRepo.GetTokenByValue(ctx, tokenString)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tokenRecord.IsRevoked = true
|
||||
tokenRecord.UpdatedAt = time.Now()
|
||||
|
||||
return s.tokenRepo.UpdateToken(ctx, tokenRecord)
|
||||
}
|
||||
|
||||
// GetCurrentUser gets current user from context
|
||||
func (s *userService) GetCurrentUser(ctx context.Context) (*types.User, error) {
|
||||
userID, ok := ctx.Value("user_id").(string)
|
||||
if !ok {
|
||||
return nil, errors.New("user not found in context")
|
||||
}
|
||||
|
||||
return s.userRepo.GetUserByID(ctx, userID)
|
||||
}
|
||||
@@ -78,6 +78,8 @@ func BuildContainer(container *dig.Container) *dig.Container {
|
||||
must(container.Provide(repository.NewSessionRepository))
|
||||
must(container.Provide(repository.NewMessageRepository))
|
||||
must(container.Provide(repository.NewModelRepository))
|
||||
must(container.Provide(repository.NewUserRepository))
|
||||
must(container.Provide(repository.NewAuthTokenRepository))
|
||||
|
||||
// Business service layer
|
||||
must(container.Provide(service.NewTenantService))
|
||||
@@ -91,6 +93,7 @@ func BuildContainer(container *dig.Container) *dig.Container {
|
||||
must(container.Provide(service.NewModelService))
|
||||
must(container.Provide(service.NewDatasetService))
|
||||
must(container.Provide(service.NewEvaluationService))
|
||||
must(container.Provide(service.NewUserService))
|
||||
|
||||
// Chat pipeline components for processing chat requests
|
||||
must(container.Provide(chatpipline.NewEventManager))
|
||||
@@ -117,6 +120,7 @@ func BuildContainer(container *dig.Container) *dig.Container {
|
||||
must(container.Provide(handler.NewModelHandler))
|
||||
must(container.Provide(handler.NewEvaluationHandler))
|
||||
must(container.Provide(handler.NewInitializationHandler))
|
||||
must(container.Provide(handler.NewAuthHandler))
|
||||
|
||||
// Router configuration
|
||||
must(container.Provide(router.NewRouter))
|
||||
@@ -177,6 +181,15 @@ func initDatabase(cfg *config.Config) (*gorm.DB, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Auto-migrate database tables
|
||||
err = db.AutoMigrate(
|
||||
&types.User{},
|
||||
&types.AuthToken{},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to auto-migrate database tables: %v", err)
|
||||
}
|
||||
|
||||
// Get underlying SQL DB object
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
|
||||
325
internal/handler/auth.go
Normal file
325
internal/handler/auth.go
Normal file
@@ -0,0 +1,325 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/errors"
|
||||
"github.com/Tencent/WeKnora/internal/logger"
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
||||
)
|
||||
|
||||
// AuthHandler implements HTTP request handlers for user authentication
|
||||
// Provides functionality for user registration, login, logout, and token management
|
||||
// through the REST API endpoints
|
||||
type AuthHandler struct {
|
||||
userService interfaces.UserService
|
||||
}
|
||||
|
||||
// NewAuthHandler creates a new auth handler instance with the provided service
|
||||
// Parameters:
|
||||
// - userService: An implementation of the UserService interface for business logic
|
||||
//
|
||||
// Returns a pointer to the newly created AuthHandler
|
||||
func NewAuthHandler(userService interfaces.UserService) *AuthHandler {
|
||||
return &AuthHandler{
|
||||
userService: userService,
|
||||
}
|
||||
}
|
||||
|
||||
// Register handles the HTTP request for user registration
|
||||
// It deserializes the request body into a registration request object, validates it,
|
||||
// calls the service to create the user, and returns the result
|
||||
// Parameters:
|
||||
// - c: Gin context for the HTTP request
|
||||
func (h *AuthHandler) Register(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
logger.Info(ctx, "Start user registration")
|
||||
|
||||
var req types.RegisterRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
logger.Error(ctx, "Failed to parse registration request parameters", err)
|
||||
appErr := errors.NewValidationError("Invalid registration parameters").WithDetails(err.Error())
|
||||
c.Error(appErr)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
if req.Username == "" || req.Email == "" || req.Password == "" {
|
||||
logger.Error(ctx, "Missing required registration fields")
|
||||
appErr := errors.NewValidationError("Username, email and password are required")
|
||||
c.Error(appErr)
|
||||
return
|
||||
}
|
||||
|
||||
// Call service to register user
|
||||
user, err := h.userService.Register(ctx, &req)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "Failed to register user: %v", err)
|
||||
appErr := errors.NewBadRequestError("Registration failed").WithDetails(err.Error())
|
||||
c.Error(appErr)
|
||||
return
|
||||
}
|
||||
|
||||
// Return success response
|
||||
response := &types.RegisterResponse{
|
||||
Success: true,
|
||||
Message: "Registration successful",
|
||||
User: user,
|
||||
}
|
||||
|
||||
logger.Infof(ctx, "User registered successfully: %s", user.Email)
|
||||
c.JSON(http.StatusCreated, response)
|
||||
}
|
||||
|
||||
// Login handles the HTTP request for user login
|
||||
// It deserializes the request body into a login request object, validates it,
|
||||
// calls the service to authenticate the user, and returns tokens
|
||||
// Parameters:
|
||||
// - c: Gin context for the HTTP request
|
||||
func (h *AuthHandler) Login(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
logger.Info(ctx, "Start user login")
|
||||
|
||||
var req types.LoginRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
logger.Error(ctx, "Failed to parse login request parameters", err)
|
||||
appErr := errors.NewValidationError("Invalid login parameters").WithDetails(err.Error())
|
||||
c.Error(appErr)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
if req.Email == "" || req.Password == "" {
|
||||
logger.Error(ctx, "Missing required login fields")
|
||||
appErr := errors.NewValidationError("Email and password are required")
|
||||
c.Error(appErr)
|
||||
return
|
||||
}
|
||||
|
||||
// Call service to authenticate user
|
||||
response, err := h.userService.Login(ctx, &req)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "Failed to login user: %v", err)
|
||||
appErr := errors.NewUnauthorizedError("Login failed").WithDetails(err.Error())
|
||||
c.Error(appErr)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if login was successful
|
||||
if !response.Success {
|
||||
logger.Warnf(ctx, "Login failed: %s", response.Message)
|
||||
c.JSON(http.StatusUnauthorized, response)
|
||||
return
|
||||
}
|
||||
|
||||
// User is already in the correct format from service
|
||||
|
||||
logger.Infof(ctx, "User logged in successfully: %s", req.Email)
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// Logout handles the HTTP request for user logout
|
||||
// It extracts the token from the Authorization header and revokes it
|
||||
// Parameters:
|
||||
// - c: Gin context for the HTTP request
|
||||
func (h *AuthHandler) Logout(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
logger.Info(ctx, "Start user logout")
|
||||
|
||||
// Extract token from Authorization header
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
logger.Error(ctx, "Missing Authorization header")
|
||||
appErr := errors.NewValidationError("Authorization header is required")
|
||||
c.Error(appErr)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse Bearer token
|
||||
tokenParts := strings.Split(authHeader, " ")
|
||||
if len(tokenParts) != 2 || tokenParts[0] != "Bearer" {
|
||||
logger.Error(ctx, "Invalid Authorization header format")
|
||||
appErr := errors.NewValidationError("Invalid Authorization header format")
|
||||
c.Error(appErr)
|
||||
return
|
||||
}
|
||||
|
||||
token := tokenParts[1]
|
||||
|
||||
// Revoke token
|
||||
err := h.userService.RevokeToken(ctx, token)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "Failed to revoke token: %v", err)
|
||||
appErr := errors.NewInternalServerError("Logout failed").WithDetails(err.Error())
|
||||
c.Error(appErr)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info(ctx, "User logged out successfully")
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "Logout successful",
|
||||
})
|
||||
}
|
||||
|
||||
// RefreshToken handles the HTTP request for refreshing access tokens
|
||||
// It extracts the refresh token from the request body and generates new tokens
|
||||
// Parameters:
|
||||
// - c: Gin context for the HTTP request
|
||||
func (h *AuthHandler) RefreshToken(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
logger.Info(ctx, "Start token refresh")
|
||||
|
||||
var req struct {
|
||||
RefreshToken string `json:"refreshToken" binding:"required"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
logger.Error(ctx, "Failed to parse refresh token request", err)
|
||||
appErr := errors.NewValidationError("Invalid refresh token request").WithDetails(err.Error())
|
||||
c.Error(appErr)
|
||||
return
|
||||
}
|
||||
|
||||
// Call service to refresh token
|
||||
accessToken, newRefreshToken, err := h.userService.RefreshToken(ctx, req.RefreshToken)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "Failed to refresh token: %v", err)
|
||||
appErr := errors.NewUnauthorizedError("Token refresh failed").WithDetails(err.Error())
|
||||
c.Error(appErr)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info(ctx, "Token refreshed successfully")
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "Token refreshed successfully",
|
||||
"access_token": accessToken,
|
||||
"refresh_token": newRefreshToken,
|
||||
})
|
||||
}
|
||||
|
||||
// GetCurrentUser handles the HTTP request for getting current user information
|
||||
// It extracts the user from the context (set by auth middleware) and returns user info
|
||||
// Parameters:
|
||||
// - c: Gin context for the HTTP request
|
||||
func (h *AuthHandler) GetCurrentUser(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
logger.Info(ctx, "Get current user info")
|
||||
|
||||
// Get current user from service (which extracts from context)
|
||||
user, err := h.userService.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "Failed to get current user: %v", err)
|
||||
appErr := errors.NewUnauthorizedError("Failed to get user information").WithDetails(err.Error())
|
||||
c.Error(appErr)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Infof(ctx, "Retrieved current user info: %s", user.Email)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"user": user.ToUserInfo(),
|
||||
})
|
||||
}
|
||||
|
||||
// ChangePassword handles the HTTP request for changing user password
|
||||
// It extracts the current user and validates the old password before setting new one
|
||||
// Parameters:
|
||||
// - c: Gin context for the HTTP request
|
||||
func (h *AuthHandler) ChangePassword(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
logger.Info(ctx, "Start password change")
|
||||
|
||||
var req struct {
|
||||
OldPassword string `json:"old_password" binding:"required"`
|
||||
NewPassword string `json:"new_password" binding:"required,min=6"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
logger.Error(ctx, "Failed to parse password change request", err)
|
||||
appErr := errors.NewValidationError("Invalid password change request").WithDetails(err.Error())
|
||||
c.Error(appErr)
|
||||
return
|
||||
}
|
||||
|
||||
// Get current user
|
||||
user, err := h.userService.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "Failed to get current user: %v", err)
|
||||
appErr := errors.NewUnauthorizedError("Failed to get user information").WithDetails(err.Error())
|
||||
c.Error(appErr)
|
||||
return
|
||||
}
|
||||
|
||||
// Change password
|
||||
err = h.userService.ChangePassword(ctx, user.ID, req.OldPassword, req.NewPassword)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "Failed to change password: %v", err)
|
||||
appErr := errors.NewBadRequestError("Password change failed").WithDetails(err.Error())
|
||||
c.Error(appErr)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Infof(ctx, "Password changed successfully for user: %s", user.Email)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "Password changed successfully",
|
||||
})
|
||||
}
|
||||
|
||||
// ValidateToken handles the HTTP request for validating access tokens
|
||||
// It extracts the token from the Authorization header and validates it
|
||||
// Parameters:
|
||||
// - c: Gin context for the HTTP request
|
||||
func (h *AuthHandler) ValidateToken(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
logger.Info(ctx, "Start token validation")
|
||||
|
||||
// Extract token from Authorization header
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
logger.Error(ctx, "Missing Authorization header")
|
||||
appErr := errors.NewValidationError("Authorization header is required")
|
||||
c.Error(appErr)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse Bearer token
|
||||
tokenParts := strings.Split(authHeader, " ")
|
||||
if len(tokenParts) != 2 || tokenParts[0] != "Bearer" {
|
||||
logger.Error(ctx, "Invalid Authorization header format")
|
||||
appErr := errors.NewValidationError("Invalid Authorization header format")
|
||||
c.Error(appErr)
|
||||
return
|
||||
}
|
||||
|
||||
token := tokenParts[1]
|
||||
|
||||
// Validate token
|
||||
user, err := h.userService.ValidateToken(ctx, token)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "Failed to validate token: %v", err)
|
||||
appErr := errors.NewUnauthorizedError("Token validation failed").WithDetails(err.Error())
|
||||
c.Error(appErr)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Infof(ctx, "Token validated successfully for user: %s", user.Email)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "Token is valid",
|
||||
"user": user.ToUserInfo(),
|
||||
})
|
||||
}
|
||||
@@ -141,8 +141,10 @@ func (h *InitializationHandler) CheckStatus(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
logger.Info(ctx, "Checking system initialization status")
|
||||
|
||||
tenantID := ctx.Value(types.TenantIDContextKey).(uint)
|
||||
|
||||
// 检查是否存在租户
|
||||
tenant, err := h.tenantService.GetTenantByID(ctx, types.InitDefaultTenantID)
|
||||
tenant, err := h.tenantService.GetTenantByID(ctx, tenantID)
|
||||
if err != nil {
|
||||
logger.ErrorWithFields(ctx, err, nil)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -165,7 +167,6 @@ func (h *InitializationHandler) CheckStatus(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
ctx = context.WithValue(ctx, types.TenantIDContextKey, types.InitDefaultTenantID)
|
||||
|
||||
// 检查是否存在模型
|
||||
models, err := h.modelService.ListModels(ctx)
|
||||
@@ -194,6 +195,7 @@ func (h *InitializationHandler) Initialize(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
logger.Info(ctx, "Starting system initialization")
|
||||
tenantID := ctx.Value(types.TenantIDContextKey).(uint)
|
||||
|
||||
var req InitializationRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
@@ -259,63 +261,16 @@ func (h *InitializationHandler) Initialize(c *gin.Context) {
|
||||
}
|
||||
var err error
|
||||
// 1. 处理租户 - 检查是否存在,不存在则创建
|
||||
tenant, _ := h.tenantService.GetTenantByID(ctx, types.InitDefaultTenantID)
|
||||
tenant, _ := h.tenantService.GetTenantByID(ctx, tenantID)
|
||||
if tenant == nil {
|
||||
logger.Info(ctx, "Tenant not found, creating tenant")
|
||||
// 创建默认租户
|
||||
tenant = &types.Tenant{
|
||||
ID: types.InitDefaultTenantID,
|
||||
Name: "Default Tenant",
|
||||
Description: "System Default Tenant",
|
||||
RetrieverEngines: types.RetrieverEngines{
|
||||
Engines: []types.RetrieverEngineParams{
|
||||
{
|
||||
RetrieverType: types.KeywordsRetrieverType,
|
||||
RetrieverEngineType: types.PostgresRetrieverEngineType,
|
||||
},
|
||||
{
|
||||
RetrieverType: types.VectorRetrieverType,
|
||||
RetrieverEngineType: types.PostgresRetrieverEngineType,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
logger.Info(ctx, "Creating default tenant")
|
||||
tenant, err = h.tenantService.CreateTenant(ctx, tenant)
|
||||
if err != nil {
|
||||
logger.ErrorWithFields(ctx, err, nil)
|
||||
c.Error(errors.NewInternalServerError("创建租户失败: " + err.Error()))
|
||||
err = errors.NewInternalServerError("Failed to get tenant")
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
logger.Info(ctx, "Tenant exists, updating if needed")
|
||||
// 更新租户信息(如果需要)
|
||||
updated := false
|
||||
if tenant.Name != "Default Tenant" {
|
||||
tenant.Name = "Default Tenant"
|
||||
updated = true
|
||||
}
|
||||
if tenant.Description != "System Default Tenant" {
|
||||
tenant.Description = "System Default Tenant"
|
||||
updated = true
|
||||
}
|
||||
|
||||
if updated {
|
||||
_, err = h.tenantService.UpdateTenant(ctx, tenant)
|
||||
if err != nil {
|
||||
logger.ErrorWithFields(ctx, err, nil)
|
||||
c.Error(errors.NewInternalServerError("更新租户失败: " + err.Error()))
|
||||
return
|
||||
}
|
||||
logger.Info(ctx, "Tenant updated successfully")
|
||||
}
|
||||
}
|
||||
|
||||
// 创建带有租户ID的新上下文
|
||||
newCtx := context.WithValue(ctx, types.TenantIDContextKey, types.InitDefaultTenantID)
|
||||
|
||||
// 2. 处理模型 - 检查现有模型并更新或创建
|
||||
existingModels, err := h.modelService.ListModels(newCtx)
|
||||
existingModels, err := h.modelService.ListModels(ctx)
|
||||
if err != nil {
|
||||
logger.ErrorWithFields(ctx, err, nil)
|
||||
// 如果获取失败,继续执行创建流程
|
||||
@@ -420,7 +375,7 @@ func (h *InitializationHandler) Initialize(c *gin.Context) {
|
||||
existingModel.IsDefault = true
|
||||
existingModel.Status = types.ModelStatusActive
|
||||
|
||||
err := h.modelService.UpdateModel(newCtx, existingModel)
|
||||
err := h.modelService.UpdateModel(ctx, existingModel)
|
||||
if err != nil {
|
||||
logger.ErrorWithFields(ctx, err, map[string]interface{}{
|
||||
"model_name": modelConfig.name,
|
||||
@@ -437,7 +392,7 @@ func (h *InitializationHandler) Initialize(c *gin.Context) {
|
||||
modelConfig.name, modelConfig.modelType,
|
||||
)
|
||||
newModel := &types.Model{
|
||||
TenantID: types.InitDefaultTenantID,
|
||||
TenantID: tenantID,
|
||||
Name: modelConfig.name,
|
||||
Type: modelConfig.modelType,
|
||||
Source: modelConfig.source,
|
||||
@@ -453,7 +408,7 @@ func (h *InitializationHandler) Initialize(c *gin.Context) {
|
||||
Status: types.ModelStatusActive,
|
||||
}
|
||||
|
||||
err := h.modelService.CreateModel(newCtx, newModel)
|
||||
err := h.modelService.CreateModel(ctx, newModel)
|
||||
if err != nil {
|
||||
logger.ErrorWithFields(ctx, err, map[string]interface{}{
|
||||
"model_name": modelConfig.name,
|
||||
@@ -470,7 +425,7 @@ func (h *InitializationHandler) Initialize(c *gin.Context) {
|
||||
if !req.Multimodal.Enabled {
|
||||
if existingVLM, exists := modelMap[types.ModelTypeVLLM]; exists {
|
||||
logger.Info(ctx, "Deleting VLM model as multimodal is disabled")
|
||||
err := h.modelService.DeleteModel(newCtx, existingVLM.ID)
|
||||
err := h.modelService.DeleteModel(ctx, existingVLM.ID)
|
||||
if err != nil {
|
||||
logger.ErrorWithFields(ctx, err, map[string]interface{}{
|
||||
"model_id": existingVLM.ID,
|
||||
@@ -485,7 +440,7 @@ func (h *InitializationHandler) Initialize(c *gin.Context) {
|
||||
if !req.Rerank.Enabled {
|
||||
if existingRerank, exists := modelMap[types.ModelTypeRerank]; exists {
|
||||
logger.Info(ctx, "Deleting Rerank model as rerank is disabled")
|
||||
err := h.modelService.DeleteModel(newCtx, existingRerank.ID)
|
||||
err := h.modelService.DeleteModel(ctx, existingRerank.ID)
|
||||
if err != nil {
|
||||
logger.ErrorWithFields(ctx, err, map[string]interface{}{
|
||||
"model_id": existingRerank.ID,
|
||||
@@ -497,7 +452,7 @@ func (h *InitializationHandler) Initialize(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 3. 处理知识库 - 检查是否存在,不存在则创建,存在则更新
|
||||
kb, err := h.kbService.GetKnowledgeBaseByID(newCtx, types.InitDefaultKnowledgeBaseID)
|
||||
kbs, err := h.kbService.ListKnowledgeBases(ctx)
|
||||
|
||||
// 找到embedding模型ID和LLM模型ID
|
||||
var embeddingModelID, llmModelID, rerankModelID, vlmModelID string
|
||||
@@ -516,14 +471,16 @@ func (h *InitializationHandler) Initialize(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
if kb == nil {
|
||||
var kb *types.KnowledgeBase
|
||||
|
||||
if len(kbs) == 0 {
|
||||
// 创建新知识库
|
||||
logger.Info(ctx, "Creating default knowledge base")
|
||||
kb = &types.KnowledgeBase{
|
||||
ID: types.InitDefaultKnowledgeBaseID,
|
||||
ID: uuid.New().String(),
|
||||
Name: "Default Knowledge Base",
|
||||
Description: "System Default Knowledge Base",
|
||||
TenantID: types.InitDefaultTenantID,
|
||||
TenantID: tenantID,
|
||||
ChunkingConfig: types.ChunkingConfig{
|
||||
ChunkSize: req.DocumentSplitting.ChunkSize,
|
||||
ChunkOverlap: req.DocumentSplitting.ChunkOverlap,
|
||||
@@ -566,7 +523,7 @@ func (h *InitializationHandler) Initialize(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
_, err = h.kbService.CreateKnowledgeBase(newCtx, kb)
|
||||
_, err = h.kbService.CreateKnowledgeBase(ctx, kb)
|
||||
if err != nil {
|
||||
logger.ErrorWithFields(ctx, err, nil)
|
||||
c.Error(errors.NewInternalServerError("创建知识库失败: " + err.Error()))
|
||||
@@ -575,10 +532,11 @@ func (h *InitializationHandler) Initialize(c *gin.Context) {
|
||||
} else {
|
||||
// 更新现有知识库
|
||||
logger.Info(ctx, "Updating existing knowledge base")
|
||||
kb = kbs[0]
|
||||
|
||||
// 检查是否有文件,如果有文件则不允许修改Embedding模型
|
||||
knowledgeList, err := h.knowledgeService.ListKnowledgeByKnowledgeBaseID(
|
||||
newCtx, types.InitDefaultKnowledgeBaseID,
|
||||
ctx, kb.ID,
|
||||
)
|
||||
hasFiles := err == nil && len(knowledgeList) > 0
|
||||
|
||||
@@ -639,7 +597,7 @@ func (h *InitializationHandler) Initialize(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 更新基本信息和配置
|
||||
err = h.kbRepository.UpdateKnowledgeBase(newCtx, kb)
|
||||
err = h.kbRepository.UpdateKnowledgeBase(ctx, kb)
|
||||
if err != nil {
|
||||
logger.ErrorWithFields(ctx, err, nil)
|
||||
c.Error(errors.NewInternalServerError("更新知识库配置失败: " + err.Error()))
|
||||
@@ -649,7 +607,7 @@ func (h *InitializationHandler) Initialize(c *gin.Context) {
|
||||
// 如果需要更新模型ID,使用repository直接更新
|
||||
if !hasFiles || kb.SummaryModelID != llmModelID {
|
||||
// 刷新知识库对象以获取最新信息
|
||||
kb, err = h.kbService.GetKnowledgeBaseByID(newCtx, types.InitDefaultKnowledgeBaseID)
|
||||
kb, err = h.kbService.GetKnowledgeBaseByID(ctx, kb.ID)
|
||||
if err != nil {
|
||||
logger.ErrorWithFields(ctx, err, nil)
|
||||
c.Error(errors.NewInternalServerError("获取更新后的知识库失败: " + err.Error()))
|
||||
@@ -665,7 +623,7 @@ func (h *InitializationHandler) Initialize(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 使用repository直接更新模型ID
|
||||
err = h.kbRepository.UpdateKnowledgeBase(newCtx, kb)
|
||||
err = h.kbRepository.UpdateKnowledgeBase(ctx, kb)
|
||||
if err != nil {
|
||||
logger.ErrorWithFields(ctx, err, nil)
|
||||
c.Error(errors.NewInternalServerError("更新知识库模型ID失败: " + err.Error()))
|
||||
@@ -1074,11 +1032,8 @@ func (h *InitializationHandler) GetCurrentConfig(c *gin.Context) {
|
||||
|
||||
logger.Info(ctx, "Getting current system configuration")
|
||||
|
||||
// 设置租户上下文
|
||||
newCtx := context.WithValue(ctx, types.TenantIDContextKey, types.InitDefaultTenantID)
|
||||
|
||||
// 获取模型信息
|
||||
models, err := h.modelService.ListModels(newCtx)
|
||||
models, err := h.modelService.ListModels(ctx)
|
||||
if err != nil {
|
||||
logger.ErrorWithFields(ctx, err, nil)
|
||||
c.Error(errors.NewInternalServerError("获取模型列表失败: " + err.Error()))
|
||||
@@ -1086,16 +1041,24 @@ func (h *InitializationHandler) GetCurrentConfig(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 获取知识库信息
|
||||
kb, err := h.kbService.GetKnowledgeBaseByID(newCtx, types.InitDefaultKnowledgeBaseID)
|
||||
kbs, err := h.kbService.ListKnowledgeBases(ctx)
|
||||
if err != nil {
|
||||
logger.ErrorWithFields(ctx, err, nil)
|
||||
c.Error(errors.NewInternalServerError("获取知识库信息失败: " + err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if len(kbs) == 0 {
|
||||
logger.Error(ctx, "No knowledge bases found")
|
||||
c.Error(errors.NewInternalServerError("获取知识库信息失败"))
|
||||
return
|
||||
}
|
||||
|
||||
kb := kbs[0]
|
||||
|
||||
// 检查知识库是否有文件
|
||||
knowledgeList, err := h.knowledgeService.ListPagedKnowledgeByKnowledgeBaseID(newCtx,
|
||||
types.InitDefaultKnowledgeBaseID, &types.Pagination{
|
||||
knowledgeList, err := h.knowledgeService.ListPagedKnowledgeByKnowledgeBaseID(ctx,
|
||||
kb.ID, &types.Pagination{
|
||||
Page: 1,
|
||||
PageSize: 1,
|
||||
})
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/config"
|
||||
"github.com/Tencent/WeKnora/internal/errors"
|
||||
"github.com/Tencent/WeKnora/internal/logger"
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
||||
@@ -48,7 +49,7 @@ func (h *TestDataHandler) GetTestData(c *gin.Context) {
|
||||
|
||||
logger.Info(ctx, "Start retrieving test data")
|
||||
|
||||
tenantID := uint(types.InitDefaultTenantID)
|
||||
tenantID := c.GetUint(types.TenantIDContextKey.String())
|
||||
logger.Debugf(ctx, "Test tenant ID environment variable: %d", tenantID)
|
||||
|
||||
// Retrieve the test tenant data
|
||||
@@ -60,24 +61,26 @@ func (h *TestDataHandler) GetTestData(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
knowledgeBaseID := types.InitDefaultKnowledgeBaseID
|
||||
logger.Debugf(ctx, "Test knowledge base ID environment variable: %s", knowledgeBaseID)
|
||||
|
||||
// Retrieve the test knowledge base data
|
||||
logger.Infof(ctx, "Retrieving test knowledge base, ID: %s", knowledgeBaseID)
|
||||
knowledgeBase, err := h.kbService.GetKnowledgeBaseByID(ctx, knowledgeBaseID)
|
||||
kbs, err := h.kbService.ListKnowledgeBases(ctx)
|
||||
if err != nil {
|
||||
logger.ErrorWithFields(ctx, err, nil)
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(kbs) == 0 {
|
||||
logger.Error(ctx, "No knowledge bases found")
|
||||
c.Error(errors.NewInternalServerError("获取知识库信息失败"))
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info(ctx, "Test data retrieved successfully")
|
||||
// Return the test data in the response
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"data": gin.H{
|
||||
"tenant": tenant,
|
||||
"knowledge_bases": []types.KnowledgeBase{*knowledgeBase},
|
||||
"knowledge_bases": kbs,
|
||||
},
|
||||
"success": true,
|
||||
})
|
||||
|
||||
@@ -16,10 +16,10 @@ import (
|
||||
|
||||
// 无需认证的API列表
|
||||
var noAuthAPI = map[string][]string{
|
||||
"/api/v1/test-data": {"GET"},
|
||||
"/api/v1/tenants": {"POST"},
|
||||
"/api/v1/initialization/*": {"GET", "POST"},
|
||||
"/health": {"GET"},
|
||||
"/api/v1/auth/register": {"POST"},
|
||||
"/api/v1/auth/login": {"POST"},
|
||||
"/api/v1/auth/refresh": {"POST"},
|
||||
}
|
||||
|
||||
// 检查请求是否在无需认证的API列表中
|
||||
@@ -38,7 +38,7 @@ func isNoAuthAPI(path string, method string) bool {
|
||||
}
|
||||
|
||||
// Auth 认证中间件
|
||||
func Auth(tenantService interfaces.TenantService, cfg *config.Config) gin.HandlerFunc {
|
||||
func Auth(tenantService interfaces.TenantService, userService interfaces.UserService, cfg *config.Config) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// ignore OPTIONS request
|
||||
if c.Request.Method == "OPTIONS" {
|
||||
@@ -52,14 +52,45 @@ func Auth(tenantService interfaces.TenantService, cfg *config.Config) gin.Handle
|
||||
return
|
||||
}
|
||||
|
||||
// Get API Key from request header
|
||||
apiKey := c.GetHeader("X-API-Key")
|
||||
if apiKey == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
|
||||
// 尝试JWT Token认证
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader != "" && strings.HasPrefix(authHeader, "Bearer ") {
|
||||
token := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
user, err := userService.ValidateToken(c.Request.Context(), token)
|
||||
if err == nil && user != nil {
|
||||
// JWT Token认证成功
|
||||
// 获取租户信息
|
||||
tenant, err := tenantService.GetTenantByID(c.Request.Context(), user.TenantID)
|
||||
if err != nil {
|
||||
log.Printf("Error getting tenant by ID: %v, tenantID: %d, userID: %s", err, user.TenantID, user.ID)
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "Unauthorized: invalid tenant",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 存储用户和租户信息到上下文
|
||||
c.Set(types.TenantIDContextKey.String(), user.TenantID)
|
||||
c.Set(types.TenantInfoContextKey.String(), tenant)
|
||||
c.Set("user", user)
|
||||
c.Request = c.Request.WithContext(
|
||||
context.WithValue(
|
||||
context.WithValue(
|
||||
context.WithValue(c.Request.Context(), types.TenantIDContextKey, user.TenantID),
|
||||
types.TenantInfoContextKey, tenant,
|
||||
),
|
||||
"user", user,
|
||||
),
|
||||
)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 尝试X-API-Key认证(兼容模式)
|
||||
apiKey := c.GetHeader("X-API-Key")
|
||||
if apiKey != "" {
|
||||
// Get tenant information
|
||||
tenantID, err := tenantService.ExtractTenantIDFromAPIKey(apiKey)
|
||||
if err != nil {
|
||||
@@ -99,6 +130,12 @@ func Auth(tenantService interfaces.TenantService, cfg *config.Config) gin.Handle
|
||||
),
|
||||
)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 没有提供任何认证信息
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized: missing authentication"})
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -18,6 +18,14 @@ type RouterParams struct {
|
||||
dig.In
|
||||
|
||||
Config *config.Config
|
||||
UserService interfaces.UserService
|
||||
KBService interfaces.KnowledgeBaseService
|
||||
KnowledgeService interfaces.KnowledgeService
|
||||
ChunkService interfaces.ChunkService
|
||||
SessionService interfaces.SessionService
|
||||
MessageService interfaces.MessageService
|
||||
ModelService interfaces.ModelService
|
||||
EvaluationService interfaces.EvaluationService
|
||||
KBHandler *handler.KnowledgeBaseHandler
|
||||
KnowledgeHandler *handler.KnowledgeHandler
|
||||
TenantHandler *handler.TenantHandler
|
||||
@@ -28,6 +36,7 @@ type RouterParams struct {
|
||||
TestDataHandler *handler.TestDataHandler
|
||||
ModelHandler *handler.ModelHandler
|
||||
EvaluationHandler *handler.EvaluationHandler
|
||||
AuthHandler *handler.AuthHandler
|
||||
InitializationHandler *handler.InitializationHandler
|
||||
}
|
||||
|
||||
@@ -50,7 +59,7 @@ func NewRouter(params RouterParams) *gin.Engine {
|
||||
r.Use(middleware.Logger())
|
||||
r.Use(middleware.Recovery())
|
||||
r.Use(middleware.ErrorHandler())
|
||||
r.Use(middleware.Auth(params.TenantService, params.Config))
|
||||
r.Use(middleware.Auth(params.TenantService, params.UserService, params.Config))
|
||||
|
||||
// 添加OpenTelemetry追踪中间件
|
||||
r.Use(middleware.TracingMiddleware())
|
||||
@@ -60,31 +69,10 @@ func NewRouter(params RouterParams) *gin.Engine {
|
||||
c.JSON(200, gin.H{"status": "ok"})
|
||||
})
|
||||
|
||||
// 测试数据接口(不需要认证)
|
||||
r.GET("/api/v1/test-data", params.TestDataHandler.GetTestData)
|
||||
|
||||
// 初始化接口(不需要认证)
|
||||
r.GET("/api/v1/initialization/status", params.InitializationHandler.CheckStatus)
|
||||
r.GET("/api/v1/initialization/config", params.InitializationHandler.GetCurrentConfig)
|
||||
r.POST("/api/v1/initialization/initialize", params.InitializationHandler.Initialize)
|
||||
|
||||
// Ollama相关接口(不需要认证)
|
||||
r.GET("/api/v1/initialization/ollama/status", params.InitializationHandler.CheckOllamaStatus)
|
||||
r.GET("/api/v1/initialization/ollama/models", params.InitializationHandler.ListOllamaModels)
|
||||
r.POST("/api/v1/initialization/ollama/models/check", params.InitializationHandler.CheckOllamaModels)
|
||||
r.POST("/api/v1/initialization/ollama/models/download", params.InitializationHandler.DownloadOllamaModel)
|
||||
r.GET("/api/v1/initialization/ollama/download/progress/:taskId", params.InitializationHandler.GetDownloadProgress)
|
||||
r.GET("/api/v1/initialization/ollama/download/tasks", params.InitializationHandler.ListDownloadTasks)
|
||||
|
||||
// 远程API相关接口(不需要认证)
|
||||
r.POST("/api/v1/initialization/remote/check", params.InitializationHandler.CheckRemoteModel)
|
||||
r.POST("/api/v1/initialization/embedding/test", params.InitializationHandler.TestEmbeddingModel)
|
||||
r.POST("/api/v1/initialization/rerank/check", params.InitializationHandler.CheckRerankModel)
|
||||
r.POST("/api/v1/initialization/multimodal/test", params.InitializationHandler.TestMultimodalFunction)
|
||||
|
||||
// 需要认证的API路由
|
||||
v1 := r.Group("/api/v1")
|
||||
{
|
||||
RegisterAuthRoutes(v1, params.AuthHandler)
|
||||
RegisterTenantRoutes(v1, params.TenantHandler)
|
||||
RegisterKnowledgeBaseRoutes(v1, params.KBHandler)
|
||||
RegisterKnowledgeRoutes(v1, params.KnowledgeHandler)
|
||||
@@ -94,6 +82,8 @@ func NewRouter(params RouterParams) *gin.Engine {
|
||||
RegisterMessageRoutes(v1, params.MessageHandler)
|
||||
RegisterModelRoutes(v1, params.ModelHandler)
|
||||
RegisterEvaluationRoutes(v1, params.EvaluationHandler)
|
||||
RegisterInitializationRoutes(v1, params.InitializationHandler)
|
||||
RegisterTestDataRoutes(v1, params.TestDataHandler)
|
||||
}
|
||||
|
||||
return r
|
||||
@@ -247,3 +237,39 @@ func RegisterEvaluationRoutes(r *gin.RouterGroup, handler *handler.EvaluationHan
|
||||
evaluationRoutes.GET("/", handler.GetEvaluationResult)
|
||||
}
|
||||
}
|
||||
|
||||
func RegisterTestDataRoutes(r *gin.RouterGroup, handler *handler.TestDataHandler) {
|
||||
r.GET("/test-data", handler.GetTestData)
|
||||
}
|
||||
|
||||
// RegisterAuthRoutes registers authentication routes
|
||||
func RegisterAuthRoutes(r *gin.RouterGroup, handler *handler.AuthHandler) {
|
||||
r.POST("/auth/register", handler.Register)
|
||||
r.POST("/auth/login", handler.Login)
|
||||
r.POST("/auth/refresh", handler.RefreshToken)
|
||||
r.GET("/auth/validate", handler.ValidateToken)
|
||||
r.POST("/auth/logout", handler.Logout)
|
||||
r.GET("/auth/me", handler.GetCurrentUser)
|
||||
r.POST("/auth/change-password", handler.ChangePassword)
|
||||
}
|
||||
|
||||
func RegisterInitializationRoutes(r *gin.RouterGroup, handler *handler.InitializationHandler) {
|
||||
// 初始化接口
|
||||
r.GET("/initialization/status", handler.CheckStatus)
|
||||
r.GET("/initialization/config", handler.GetCurrentConfig)
|
||||
r.POST("/initialization/initialize", handler.Initialize)
|
||||
|
||||
// Ollama相关接口
|
||||
r.GET("/initialization/ollama/status", handler.CheckOllamaStatus)
|
||||
r.GET("/initialization/ollama/models", handler.ListOllamaModels)
|
||||
r.POST("/initialization/ollama/models/check", handler.CheckOllamaModels)
|
||||
r.POST("/initialization/ollama/models/download", handler.DownloadOllamaModel)
|
||||
r.GET("/initialization/ollama/download/progress/:taskId", handler.GetDownloadProgress)
|
||||
r.GET("/initialization/ollama/download/tasks", handler.ListDownloadTasks)
|
||||
|
||||
// 远程API相关接口
|
||||
r.POST("/initialization/remote/check", handler.CheckRemoteModel)
|
||||
r.POST("/initialization/embedding/test", handler.TestEmbeddingModel)
|
||||
r.POST("/initialization/rerank/check", handler.CheckRerankModel)
|
||||
r.POST("/initialization/multimodal/test", handler.TestMultimodalFunction)
|
||||
}
|
||||
|
||||
75
internal/types/interfaces/user.go
Normal file
75
internal/types/interfaces/user.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package interfaces
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
)
|
||||
|
||||
// UserService defines the user service interface
|
||||
type UserService interface {
|
||||
// Register creates a new user account
|
||||
Register(ctx context.Context, req *types.RegisterRequest) (*types.User, error)
|
||||
// Login authenticates a user and returns tokens
|
||||
Login(ctx context.Context, req *types.LoginRequest) (*types.LoginResponse, error)
|
||||
// GetUserByID gets a user by ID
|
||||
GetUserByID(ctx context.Context, id string) (*types.User, error)
|
||||
// GetUserByEmail gets a user by email
|
||||
GetUserByEmail(ctx context.Context, email string) (*types.User, error)
|
||||
// GetUserByUsername gets a user by username
|
||||
GetUserByUsername(ctx context.Context, username string) (*types.User, error)
|
||||
// UpdateUser updates user information
|
||||
UpdateUser(ctx context.Context, user *types.User) error
|
||||
// DeleteUser deletes a user
|
||||
DeleteUser(ctx context.Context, id string) error
|
||||
// ChangePassword changes user password
|
||||
ChangePassword(ctx context.Context, userID string, oldPassword, newPassword string) error
|
||||
// ValidatePassword validates user password
|
||||
ValidatePassword(ctx context.Context, userID string, password string) error
|
||||
// GenerateTokens generates access and refresh tokens for user
|
||||
GenerateTokens(ctx context.Context, user *types.User) (accessToken, refreshToken string, err error)
|
||||
// ValidateToken validates an access token
|
||||
ValidateToken(ctx context.Context, token string) (*types.User, error)
|
||||
// RefreshToken refreshes access token using refresh token
|
||||
RefreshToken(ctx context.Context, refreshToken string) (accessToken, newRefreshToken string, err error)
|
||||
// RevokeToken revokes a token
|
||||
RevokeToken(ctx context.Context, token string) error
|
||||
// GetCurrentUser gets current user from context
|
||||
GetCurrentUser(ctx context.Context) (*types.User, error)
|
||||
}
|
||||
|
||||
// UserRepository defines the user repository interface
|
||||
type UserRepository interface {
|
||||
// CreateUser creates a user
|
||||
CreateUser(ctx context.Context, user *types.User) error
|
||||
// GetUserByID gets a user by ID
|
||||
GetUserByID(ctx context.Context, id string) (*types.User, error)
|
||||
// GetUserByEmail gets a user by email
|
||||
GetUserByEmail(ctx context.Context, email string) (*types.User, error)
|
||||
// GetUserByUsername gets a user by username
|
||||
GetUserByUsername(ctx context.Context, username string) (*types.User, error)
|
||||
// UpdateUser updates a user
|
||||
UpdateUser(ctx context.Context, user *types.User) error
|
||||
// DeleteUser deletes a user
|
||||
DeleteUser(ctx context.Context, id string) error
|
||||
// ListUsers lists users with pagination
|
||||
ListUsers(ctx context.Context, offset, limit int) ([]*types.User, error)
|
||||
}
|
||||
|
||||
// AuthTokenRepository defines the auth token repository interface
|
||||
type AuthTokenRepository interface {
|
||||
// CreateToken creates an auth token
|
||||
CreateToken(ctx context.Context, token *types.AuthToken) error
|
||||
// GetTokenByValue gets a token by its value
|
||||
GetTokenByValue(ctx context.Context, tokenValue string) (*types.AuthToken, error)
|
||||
// GetTokensByUserID gets all tokens for a user
|
||||
GetTokensByUserID(ctx context.Context, userID string) ([]*types.AuthToken, error)
|
||||
// UpdateToken updates a token
|
||||
UpdateToken(ctx context.Context, token *types.AuthToken) error
|
||||
// DeleteToken deletes a token
|
||||
DeleteToken(ctx context.Context, id string) error
|
||||
// DeleteExpiredTokens deletes all expired tokens
|
||||
DeleteExpiredTokens(ctx context.Context) error
|
||||
// RevokeTokensByUserID revokes all tokens for a user
|
||||
RevokeTokensByUserID(ctx context.Context, userID string) error
|
||||
}
|
||||
114
internal/types/user.go
Normal file
114
internal/types/user.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// User represents a user in the system
|
||||
type User struct {
|
||||
// Unique identifier of the user
|
||||
ID string `json:"id" gorm:"type:varchar(36);primaryKey"`
|
||||
// Username of the user
|
||||
Username string `json:"username" gorm:"type:varchar(100);uniqueIndex;not null"`
|
||||
// Email address of the user
|
||||
Email string `json:"email" gorm:"type:varchar(255);uniqueIndex;not null"`
|
||||
// Hashed password of the user
|
||||
PasswordHash string `json:"-" gorm:"type:varchar(255);not null"`
|
||||
// Avatar URL of the user
|
||||
Avatar string `json:"avatar" gorm:"type:varchar(500)"`
|
||||
// Tenant ID that the user belongs to
|
||||
TenantID uint `json:"tenant_id" gorm:"index"`
|
||||
// Whether the user is active
|
||||
IsActive bool `json:"is_active" gorm:"default:true"`
|
||||
// Creation time of the user
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
// Last updated time of the user
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
// Deletion time of the user
|
||||
DeletedAt gorm.DeletedAt `json:"deleted_at" gorm:"index"`
|
||||
|
||||
// Association relationship, not stored in the database
|
||||
Tenant *Tenant `json:"tenant,omitempty" gorm:"foreignKey:TenantID"`
|
||||
}
|
||||
|
||||
// AuthToken represents an authentication token
|
||||
type AuthToken struct {
|
||||
// Unique identifier of the token
|
||||
ID string `json:"id" gorm:"type:varchar(36);primaryKey"`
|
||||
// User ID that owns this token
|
||||
UserID string `json:"user_id" gorm:"type:varchar(36);index;not null"`
|
||||
// Token value (JWT or other format)
|
||||
Token string `json:"token" gorm:"type:text;not null"`
|
||||
// Token type (access_token, refresh_token)
|
||||
TokenType string `json:"token_type" gorm:"type:varchar(50);not null"`
|
||||
// Token expiration time
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
// Whether the token is revoked
|
||||
IsRevoked bool `json:"is_revoked" gorm:"default:false"`
|
||||
// Creation time of the token
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
// Last updated time of the token
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
// Association relationship
|
||||
User *User `json:"user,omitempty" gorm:"foreignKey:UserID"`
|
||||
}
|
||||
|
||||
// LoginRequest represents a login request
|
||||
type LoginRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
Password string `json:"password" binding:"required,min=6"`
|
||||
}
|
||||
|
||||
// RegisterRequest represents a registration request
|
||||
type RegisterRequest struct {
|
||||
Username string `json:"username" binding:"required,min=3,max=50"`
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
Password string `json:"password" binding:"required,min=6"`
|
||||
}
|
||||
|
||||
// LoginResponse represents a login response
|
||||
type LoginResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message,omitempty"`
|
||||
User *User `json:"user,omitempty"`
|
||||
Tenant *Tenant `json:"tenant,omitempty"`
|
||||
Token string `json:"token,omitempty"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
}
|
||||
|
||||
// RegisterResponse represents a registration response
|
||||
type RegisterResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message,omitempty"`
|
||||
User *User `json:"user,omitempty"`
|
||||
Tenant *Tenant `json:"tenant,omitempty"`
|
||||
}
|
||||
|
||||
// UserInfo represents user information for API responses
|
||||
type UserInfo struct {
|
||||
ID string `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email"`
|
||||
Avatar string `json:"avatar"`
|
||||
TenantID uint `json:"tenant_id"`
|
||||
IsActive bool `json:"is_active"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// ToUserInfo converts User to UserInfo (without sensitive data)
|
||||
func (u *User) ToUserInfo() *UserInfo {
|
||||
return &UserInfo{
|
||||
ID: u.ID,
|
||||
Username: u.Username,
|
||||
Email: u.Email,
|
||||
Avatar: u.Avatar,
|
||||
TenantID: u.TenantID,
|
||||
IsActive: u.IsActive,
|
||||
CreatedAt: u.CreatedAt,
|
||||
UpdatedAt: u.UpdatedAt,
|
||||
}
|
||||
}
|
||||
@@ -78,13 +78,17 @@ check_platform() {
|
||||
log_info "检测系统平台信息..."
|
||||
if [ "$(uname -m)" = "x86_64" ]; then
|
||||
export PLATFORM="linux/amd64"
|
||||
export TARGETARCH="amd64"
|
||||
elif [ "$(uname -m)" = "aarch64" ] || [ "$(uname -m)" = "arm64" ]; then
|
||||
export PLATFORM="linux/arm64"
|
||||
export TARGETARCH="arm64"
|
||||
else
|
||||
log_warning "未识别的平台类型:$(uname -m),将使用默认平台 linux/amd64"
|
||||
export PLATFORM="linux/amd64"
|
||||
export TARGETARCH="amd64"
|
||||
fi
|
||||
log_info "当前平台:$PLATFORM"
|
||||
log_info "当前架构:$TARGETARCH"
|
||||
}
|
||||
|
||||
# 构建应用镜像
|
||||
@@ -119,7 +123,7 @@ build_docreader_image() {
|
||||
|
||||
docker build \
|
||||
--platform $PLATFORM \
|
||||
--build-arg PLATFORM=$PLATFORM \
|
||||
--build-arg TARGETARCH=$TARGETARCH \
|
||||
-f docker/Dockerfile.docreader \
|
||||
-t wechatopenai/weknora-docreader:latest \
|
||||
.
|
||||
|
||||
Reference in New Issue
Block a user