feat: Add Login Page

This commit is contained in:
wizardchen
2025-09-16 02:47:39 +08:00
parent 092b30af3e
commit 81bd2e6c2c
30 changed files with 2583 additions and 370 deletions

View File

@@ -1,12 +1,12 @@
{
"name": "knowledage-base",
"version": "0.0.0",
"version": "0.1.0",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "knowledage-base",
"version": "0.0.0",
"version": "0.1.0",
"dependencies": {
"@microsoft/fetch-event-source": "^2.0.1",
"axios": "^1.8.4",

View File

@@ -0,0 +1,234 @@
import { post, get, put } from '@/utils/request'
// 用户登录接口
export interface LoginRequest {
email: string
password: string
}
export interface LoginResponse {
success: boolean
message?: string
user?: {
id: string
username: string
email: string
avatar?: string
tenant_id: number
is_active: boolean
created_at: string
updated_at: string
}
tenant?: {
id: number
name: string
description: string
api_key: string
status: string
business: string
storage_quota: number
storage_used: number
created_at: string
updated_at: string
}
token?: string
refresh_token?: string
}
// 用户注册接口
export interface RegisterRequest {
username: string
email: string
password: string
}
export interface RegisterResponse {
success: boolean
message?: string
data?: {
user: {
id: string
username: string
email: string
}
tenant: {
id: string
name: string
api_key: string
}
}
}
// 用户信息接口
export interface UserInfo {
id: string
username: string
email: string
avatar?: string
tenant_id: string
created_at: string
updated_at: string
}
// 租户信息接口
export interface TenantInfo {
id: string
name: string
api_key: string
owner_id: string
created_at: string
updated_at: string
knowledge_bases?: KnowledgeBaseInfo[]
}
// 知识库信息接口
export interface KnowledgeBaseInfo {
id: string
name: string
description: string
tenant_id: string
created_at: string
updated_at: string
document_count?: number
chunk_count?: number
}
// 模型信息接口
export interface ModelInfo {
id: string
name: string
type: string
source: string
description?: string
is_default?: boolean
created_at: string
updated_at: string
}
/**
* 用户登录
*/
export async function login(data: LoginRequest): Promise<LoginResponse> {
try {
const response = await post('/api/v1/auth/login', data)
return response as unknown as LoginResponse
} catch (error: any) {
return {
success: false,
message: error.message || '登录失败'
}
}
}
/**
* 用户注册
*/
export async function register(data: RegisterRequest): Promise<RegisterResponse> {
try {
const response = await post('/api/v1/auth/register', data)
return response as unknown as RegisterResponse
} catch (error: any) {
return {
success: false,
message: error.message || '注册失败'
}
}
}
/**
* 获取当前用户信息
*/
export async function getCurrentUser(): Promise<{ success: boolean; data?: UserInfo; message?: string }> {
try {
const response = await get('/api/v1/auth/me')
return response as unknown as { success: boolean; data?: UserInfo; message?: string }
} catch (error: any) {
return {
success: false,
message: error.message || '获取用户信息失败'
}
}
}
/**
* 获取当前租户信息
*/
export async function getCurrentTenant(): Promise<{ success: boolean; data?: TenantInfo; message?: string }> {
try {
const response = await get('/api/v1/auth/tenant')
return response as unknown as { success: boolean; data?: TenantInfo; message?: string }
} catch (error: any) {
return {
success: false,
message: error.message || '获取租户信息失败'
}
}
}
/**
* 刷新Token
*/
export async function refreshToken(refreshToken: string): Promise<{ success: boolean; data?: { token: string; refreshToken: string }; message?: string }> {
try {
const response: any = await post('/api/v1/auth/refresh', { refreshToken })
if (response && response.success) {
if (response.access_token || response.refresh_token) {
return {
success: true,
data: {
token: response.access_token,
refreshToken: response.refresh_token,
}
}
}
}
// 其他情况直接返回原始消息
return {
success: false,
message: response?.message || '刷新Token失败'
}
} catch (error: any) {
return {
success: false,
message: error.message || '刷新Token失败'
}
}
}
/**
* 用户登出
*/
export async function logout(): Promise<{ success: boolean; message?: string }> {
try {
await post('/api/v1/auth/logout', {})
return {
success: true
}
} catch (error: any) {
return {
success: false,
message: error.message || '登出失败'
}
}
}
/**
* 验证Token有效性
*/
export async function validateToken(): Promise<{ success: boolean; valid?: boolean; message?: string }> {
try {
const response = await get('/api/v1/auth/validate')
return response as unknown as { success: boolean; valid?: boolean; message?: string }
} catch (error: any) {
return {
success: false,
valid: false,
message: error.message || 'Token验证失败'
}
}
}

View File

@@ -1,54 +1,30 @@
import { get, post, put, del, postChat } from "../../utils/request";
import { loadTestData } from "../test-data";
// 从localStorage获取设置
function getSettings() {
const settingsStr = localStorage.getItem("WeKnora_settings");
if (settingsStr) {
try {
const settings = JSON.parse(settingsStr);
if (settings.apiKey && settings.endpoint) {
return settings;
}
} catch (e) {
console.error("解析设置失败:", e);
}
}
return null;
}
// 根据是否有设置决定是否需要加载测试数据
async function ensureConfigured() {
const settings = getSettings();
// 如果没有设置APIKey和Endpoint则加载测试数据
if (!settings) {
await loadTestData();
}
}
export async function createSessions(data = {}) {
await ensureConfigured();
await loadTestData();
return post("/api/v1/sessions", data);
}
export async function getSessionsList(page: number, page_size: number) {
await ensureConfigured();
await loadTestData();
return get(`/api/v1/sessions?page=${page}&page_size=${page_size}`);
}
export async function generateSessionsTitle(session_id: string, data: any) {
await ensureConfigured();
await loadTestData();
return post(`/api/v1/sessions/${session_id}/generate_title`, data);
}
export async function knowledgeChat(data: { session_id: string; query: string; }) {
await ensureConfigured();
await loadTestData();
return postChat(`/api/v1/knowledge-chat/${data.session_id}`, { query: data.query });
}
export async function getMessageList(data: { session_id: string; limit: number, created_at: string }) {
await ensureConfigured();
await loadTestData();
if (data.created_at) {
return get(`/api/v1/messages/${data.session_id}/load?before_time=${encodeURIComponent(data.created_at)}&limit=${data.limit}`);
} else {
@@ -57,6 +33,6 @@ export async function getMessageList(data: { session_id: string; limit: number,
}
export async function delSession(session_id: string) {
await ensureConfigured();
await loadTestData();
return del(`/api/v1/sessions/${session_id}`);
}

View File

@@ -2,21 +2,9 @@ import { fetchEventSource } from '@microsoft/fetch-event-source'
import { ref, type Ref, onUnmounted, nextTick } from 'vue'
import { generateRandomString } from '@/utils/index';
import { getTestData } from '@/utils/request';
import { loadTestData } from '@/api/test-data';
import { loadTestData } from "../test-data";
// 从localStorage获取设置
function getSettings() {
const settingsStr = localStorage.getItem("WeKnora_settings");
if (settingsStr) {
try {
const settings = JSON.parse(settingsStr);
return settings;
} catch (e) {
console.error("解析设置失败:", e);
}
}
return null;
}
interface StreamOptions {
// 请求方法 (默认POST)
@@ -49,27 +37,16 @@ export function useStream() {
isStreaming.value = true;
isLoading.value = true;
// 获取设置信息
const settings = getSettings();
let apiUrl = '';
let apiKey = '';
// 如果有设置信息,优先使用设置信息
if (settings && settings.endpoint && settings.apiKey) {
apiUrl = settings.endpoint;
apiKey = settings.apiKey;
} else {
// 否则加载测试数据
await loadTestData();
const testData = getTestData();
if (!testData) {
error.value = "测试数据未初始化,无法进行聊天";
stopStream();
return;
}
apiUrl = import.meta.env.VITE_IS_DOCKER ? "" : "http://localhost:8080";
apiKey = testData.tenant.api_key;
// 使用默认配置
await loadTestData();
const testData = getTestData();
if (!testData) {
error.value = "测试数据未初始化,无法进行聊天";
stopStream();
return;
}
const apiUrl = import.meta.env.VITE_IS_DOCKER ? "" : "http://localhost:8080";
const apiKey = testData.tenant.api_key;
try {
let url =

View File

@@ -70,6 +70,11 @@ export function checkInitializationStatus(): Promise<{ initialized: boolean }> {
resolve(response.data || { initialized: false });
})
.catch((error: any) => {
// 如果是401交给全局拦截器去处理重定向登录这里不要把它当成未初始化
if (error && error.status === 401) {
reject(error);
return;
}
console.warn('检查初始化状态失败,假设需要初始化:', error);
resolve({ initialized: false });
});

View File

@@ -1,41 +1,23 @@
import { get, post, put, del, postUpload, getDown, getTestData } from "../../utils/request";
import { loadTestData } from "../test-data";
// 获取知识库ID优先从设置中获取
async function getKnowledgeBaseID() {
// 从localStorage获取设置中的知识库ID
const settingsStr = localStorage.getItem("WeKnora_settings");
let knowledgeBaseId = "";
if (settingsStr) {
try {
const settings = JSON.parse(settingsStr);
if (settings.knowledgeBaseId) {
return settings.knowledgeBaseId;
}
} catch (e) {
console.error("解析设置失败:", e);
}
}
export async function getDefaultKnowledgeBaseId(): Promise<string> {
// 如果设置中没有知识库ID则使用测试数据
await loadTestData();
const testData = getTestData();
if (!testData || testData.knowledge_bases.length === 0) {
console.error("测试数据未初始化或不包含知识库");
throw new Error("测试数据未初始化或不包含知识库");
throw new Error('没有可用的知识库');
}
return testData.knowledge_bases[0].id;
}
export async function uploadKnowledgeBase(data = {}) {
const kbId = await getKnowledgeBaseID();
const kbId = await getDefaultKnowledgeBaseId();
return postUpload(`/api/v1/knowledge-bases/${kbId}/knowledge/file`, data);
}
export async function getKnowledgeBase({page, page_size}) {
const kbId = await getKnowledgeBaseID();
export async function getKnowledgeBase({page, page_size}: {page: number, page_size: number}) {
const kbId = await getDefaultKnowledgeBaseId();
return get(
`/api/v1/knowledge-bases/${kbId}/knowledge?page=${page}&page_size=${page_size}`
);
@@ -57,6 +39,6 @@ export function batchQueryKnowledge(ids: any) {
return get(`/api/v1/knowledge/batch?${ids}`);
}
export function getKnowledgeDetailsCon(id: any, page) {
export function getKnowledgeDetailsCon(id: any, page: number) {
return get(`/api/v1/chunks/${id}?page=${page}&page_size=25`);
}

View File

@@ -53,3 +53,12 @@ export async function loadTestData(): Promise<boolean> {
return false;
}
}
/**
* 重置测试数据加载状态,在重新登录或需要强制刷新时调用
*/
export function resetTestDataLoaded() {
isTestDataLoaded = false;
// 清空已缓存的测试数据,确保下次调用会重新获取
setTestData(null);
}

View File

@@ -0,0 +1,6 @@
<svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24" fill="none">
<path d="M10 3H6a2 2 0 0 0-2 2v14a2 2 0 0 0 2 2h4" stroke="#000" stroke-opacity="0.6" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M17 16l4-4-4-4" stroke="#000" stroke-opacity="0.6" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M21 12H10" stroke="#000" stroke-opacity="0.6" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
</svg>

After

Width:  |  Height:  |  Size: 509 B

View File

@@ -9,7 +9,7 @@
:class="['menu_item', item.childrenPath && item.childrenPath == currentpath ? 'menu_item_c_active' : item.path == currentpath ? 'menu_item_active' : '']">
<div class="menu_item-box">
<div class="menu_icon">
<img class="icon" :src="getImgSrc(item.icon == 'zhishiku' ? knowledgeIcon : item.icon == 'setting' ? settingIcon : prefixIcon)" alt="">
<img class="icon" :src="getImgSrc(item.icon == 'zhishiku' ? knowledgeIcon : item.icon == 'setting' ? settingIcon : item.icon == 'logout' ? logoutIcon : prefixIcon)" alt="">
</div>
<span class="menu_title">{{ item.title }}</span>
</div>
@@ -58,11 +58,13 @@ import { onMounted, watch, computed, ref, reactive } from 'vue';
import { useRoute, useRouter } from 'vue-router';
import { getSessionsList, delSession } from "@/api/chat/index";
import { useMenuStore } from '@/stores/menu';
import { useAuthStore } from '@/stores/auth';
import useKnowledgeBase from '@/hooks/useKnowledgeBase';
import { MessagePlugin } from "tdesign-vue-next";
let { requestMethod } = useKnowledgeBase()
let uploadInput = ref();
const usemenuStore = useMenuStore();
const authStore = useAuthStore();
const route = useRoute();
const router = useRouter();
const currentpath = ref('');
@@ -164,12 +166,14 @@ let fileAddIcon = ref('file-add-green.svg');
let knowledgeIcon = ref('zhishiku-green.svg');
let prefixIcon = ref('prefixIcon.svg');
let settingIcon = ref('setting.svg');
let logoutIcon = ref('logout.svg');
let pathPrefix = ref(route.name)
const getIcon = (path) => {
fileAddIcon.value = path == 'knowledgeBase' ? 'file-add-green.svg' : 'file-add.svg';
knowledgeIcon.value = path == 'knowledgeBase' ? 'zhishiku-green.svg' : 'zhishiku.svg';
prefixIcon.value = path == 'creatChat' ? 'prefixIcon-green.svg' : path == 'knowledgeBase' ? 'prefixIcon-grey.svg' : 'prefixIcon.svg';
settingIcon.value = path == 'settings' ? 'setting-green.svg' : 'setting.svg';
logoutIcon.value = 'logout.svg';
}
getIcon(route.name)
const gotopage = (path) => {
@@ -177,6 +181,13 @@ const gotopage = (path) => {
// 如果是系统设置,跳转到初始化配置页面
if (path === 'settings') {
router.push('/initialization');
return;
}
// 处理退出登录
if (path === 'logout') {
authStore.logout();
router.push('/login');
return;
} else {
router.push(`/platform/${path}`);
}

View File

@@ -1,12 +1,20 @@
import { createRouter, createWebHistory } from 'vue-router'
import { checkInitializationStatus } from '@/api/initialization'
import { useAuthStore } from '@/stores/auth'
import { validateToken } from '@/api/auth'
const router = createRouter({
history: createWebHistory(import.meta.env.BASE_URL),
routes: [
{
path: "/",
redirect: "/platform",
redirect: "/platform/knowledgeBase",
},
{
path: "/login",
name: "login",
component: () => import("../views/auth/Login.vue"),
meta: { requiresAuth: false, requiresInit: false }
},
{
path: "/initialization",
@@ -18,71 +26,110 @@ const router = createRouter({
path: "/knowledgeBase",
name: "home",
component: () => import("../views/knowledge/KnowledgeBase.vue"),
meta: { requiresInit: true }
meta: { requiresInit: true, requiresAuth: true }
},
{
path: "/platform",
name: "Platform",
redirect: "/platform/knowledgeBase",
component: () => import("../views/platform/index.vue"),
meta: { requiresInit: true },
meta: { requiresInit: true, requiresAuth: true },
children: [
{
path: "knowledgeBase",
name: "knowledgeBase",
component: () => import("../views/knowledge/KnowledgeBase.vue"),
meta: { requiresInit: true }
meta: { requiresInit: true, requiresAuth: true }
},
{
path: "creatChat",
name: "creatChat",
component: () => import("../views/creatChat/creatChat.vue"),
meta: { requiresInit: true }
meta: { requiresInit: true, requiresAuth: true }
},
{
path: "chat/:chatid",
name: "chat",
component: () => import("../views/chat/index.vue"),
meta: { requiresInit: true }
meta: { requiresInit: true, requiresAuth: true }
},
{
path: "settings",
name: "settings",
component: () => import("../views/settings/Settings.vue"),
path: "settings",
name: "settings",
component: () => import("../views/settings/Settings.vue"),
meta: { requiresInit: true }
},
},
],
},
],
});
// 路由守卫:检查系统初始化状态
// 路由守卫:检查认证状态和系统初始化状态
router.beforeEach(async (to, from, next) => {
// 如果访问的是初始化页面,直接放行
if (to.meta.requiresInit === false) {
next();
return;
const authStore = useAuthStore()
// 如果访问的是登录页面或初始化页面,直接放行
if (to.meta.requiresAuth === false || to.meta.requiresInit === false) {
// 如果已登录用户访问登录页面,重定向到知识库列表页面
if (to.path === '/login' && authStore.isLoggedIn) {
next('/platform/knowledgeBase')
return
}
next()
return
}
1
try {
// 检查系统是否已初始化
const { initialized } = await checkInitializationStatus();
if (initialized) {
// 系统已初始化,记录到本地存储并正常跳转
localStorage.setItem('system_initialized', 'true');
next();
} else {
// 系统未初始化,跳转到初始化页面
console.log('系统未初始化,跳转到初始化页面');
next('/initialization');
// 检查用户认证状态
if (to.meta.requiresAuth !== false) {
if (!authStore.isLoggedIn) {
// 未登录,跳转到登录页面
next('/login')
return
}
} catch (error) {
console.error('检查初始化状态失败:', error);
// 如果检查失败,默认认为需要初始化
next('/initialization');
// 验证Token有效性
// try {
// const { valid } = await validateToken()
// if (!valid) {
// // Token无效清空认证信息并跳转到登录页面
// authStore.logout()
// next('/login')
// return
// }
// } catch (error) {
// console.error('Token验证失败:', error)
// authStore.logout()
// next('/login')
// return
// }
}
// 检查系统初始化状态
if (to.meta.requiresInit !== false) {
try {
const { initialized } = await checkInitializationStatus()
if (initialized) {
// 系统已初始化,记录到本地存储并正常跳转
localStorage.setItem('system_initialized', 'true')
next()
} else {
// 系统未初始化,跳转到初始化页面
next('/initialization')
}
} catch (error) {
console.error('检查初始化状态失败:', error)
// 如果是401跳转登录不再误导去初始化
const status = (error as any)?.status
if (status === 401) {
next('/login')
return
}
// 其他错误默认认为需要初始化
next('/initialization')
}
} else {
next()
}
});

174
frontend/src/stores/auth.ts Normal file
View File

@@ -0,0 +1,174 @@
import { defineStore } from 'pinia'
import { resetTestDataLoaded } from '@/api/test-data'
import { ref, computed } from 'vue'
import type { UserInfo, TenantInfo, KnowledgeBaseInfo } from '@/api/auth'
export const useAuthStore = defineStore('auth', () => {
// 状态
const user = ref<UserInfo | null>(null)
const tenant = ref<TenantInfo | null>(null)
const token = ref<string>('')
const refreshToken = ref<string>('')
const knowledgeBases = ref<KnowledgeBaseInfo[]>([])
const currentKnowledgeBase = ref<KnowledgeBaseInfo | null>(null)
// 计算属性
const isLoggedIn = computed(() => {
return !!token.value && !!user.value
})
const hasValidTenant = computed(() => {
return !!tenant.value && !!tenant.value.api_key
})
const currentTenantId = computed(() => {
return tenant.value?.id || ''
})
const currentUserId = computed(() => {
return user.value?.id || ''
})
// 操作方法
const setUser = (userData: UserInfo) => {
user.value = userData
// 保存到localStorage
localStorage.setItem('weknora_user', JSON.stringify(userData))
}
const setTenant = (tenantData: TenantInfo) => {
tenant.value = tenantData
// 保存到localStorage
localStorage.setItem('weknora_tenant', JSON.stringify(tenantData))
}
const setToken = (tokenValue: string) => {
token.value = tokenValue
localStorage.setItem('weknora_token', tokenValue)
}
const setRefreshToken = (refreshTokenValue: string) => {
refreshToken.value = refreshTokenValue
localStorage.setItem('weknora_refresh_token', refreshTokenValue)
}
const setKnowledgeBases = (kbList: KnowledgeBaseInfo[]) => {
// 确保输入是数组
knowledgeBases.value = Array.isArray(kbList) ? kbList : []
localStorage.setItem('weknora_knowledge_bases', JSON.stringify(knowledgeBases.value))
}
const setCurrentKnowledgeBase = (kb: KnowledgeBaseInfo | null) => {
currentKnowledgeBase.value = kb
if (kb) {
localStorage.setItem('weknora_current_kb', JSON.stringify(kb))
} else {
localStorage.removeItem('weknora_current_kb')
}
}
const logout = () => {
// 清空状态
user.value = null
tenant.value = null
token.value = ''
refreshToken.value = ''
knowledgeBases.value = []
currentKnowledgeBase.value = null
// 清空localStorage
localStorage.removeItem('weknora_user')
localStorage.removeItem('weknora_tenant')
localStorage.removeItem('weknora_token')
localStorage.removeItem('weknora_refresh_token')
localStorage.removeItem('weknora_knowledge_bases')
localStorage.removeItem('weknora_current_kb')
// 重置测试数据加载标志确保重新登录后会重新获取KB列表
try {
resetTestDataLoaded()
} catch {}
}
const initFromStorage = () => {
// 从localStorage恢复状态
const storedUser = localStorage.getItem('weknora_user')
const storedTenant = localStorage.getItem('weknora_tenant')
const storedToken = localStorage.getItem('weknora_token')
const storedRefreshToken = localStorage.getItem('weknora_refresh_token')
const storedKnowledgeBases = localStorage.getItem('weknora_knowledge_bases')
const storedCurrentKb = localStorage.getItem('weknora_current_kb')
if (storedUser) {
try {
user.value = JSON.parse(storedUser)
} catch (e) {
console.error('解析用户信息失败:', e)
}
}
if (storedTenant) {
try {
tenant.value = JSON.parse(storedTenant)
} catch (e) {
console.error('解析租户信息失败:', e)
}
}
if (storedToken) {
token.value = storedToken
}
if (storedRefreshToken) {
refreshToken.value = storedRefreshToken
}
if (storedKnowledgeBases) {
try {
const parsed = JSON.parse(storedKnowledgeBases)
knowledgeBases.value = Array.isArray(parsed) ? parsed : []
} catch (e) {
console.error('解析知识库列表失败:', e)
knowledgeBases.value = []
}
}
if (storedCurrentKb) {
try {
currentKnowledgeBase.value = JSON.parse(storedCurrentKb)
} catch (e) {
console.error('解析当前知识库失败:', e)
}
}
}
// 初始化时从localStorage恢复状态
initFromStorage()
return {
// 状态
user,
tenant,
token,
refreshToken,
knowledgeBases,
currentKnowledgeBase,
// 计算属性
isLoggedIn,
hasValidTenant,
currentTenantId,
currentUserId,
// 方法
setUser,
setTenant,
setToken,
setRefreshToken,
setKnowledgeBases,
setCurrentKnowledgeBase,
logout,
initFromStorage
}
})

View File

@@ -13,7 +13,8 @@ export const useMenuStore = defineStore('menuStore', {
childrenPath: 'chat',
children: reactive<object[]>([]),
},
{ title: '系统设置', icon: 'setting', path: 'settings' }
{ title: '系统设置', icon: 'setting', path: 'settings' },
{ title: '退出登录', icon: 'logout', path: 'logout' }
]),
isFirstSession: false,
firstQuery: ''

View File

@@ -2,26 +2,8 @@
import axios from "axios";
import { generateRandomString } from "./index";
// 从localStorage获取设置
function getSettings() {
const settingsStr = localStorage.getItem("WeKnora_settings");
if (settingsStr) {
try {
return JSON.parse(settingsStr);
} catch (e) {
console.error("解析设置失败:", e);
}
}
return {
endpoint: import.meta.env.VITE_IS_DOCKER ? "" : "http://localhost:8080",
apiKey: "",
knowledgeBaseId: "",
};
}
// API基础URL优先使用设置中的endpoint
const settings = getSettings();
const BASE_URL = settings.endpoint;
// API基础URL
const BASE_URL = import.meta.env.VITE_IS_DOCKER ? "" : "http://localhost:8080";
// 测试数据
let testData: {
@@ -50,13 +32,6 @@ const instance = axios.create({
// 设置测试数据
export function setTestData(data: typeof testData) {
testData = data;
if (data) {
// 优先使用设置中的ApiKey如果没有则使用测试数据中的
const apiKey = settings.apiKey || (data?.tenant?.api_key || "");
if (apiKey) {
instance.defaults.headers["X-API-Key"] = apiKey;
}
}
}
// 获取测试数据
@@ -66,25 +41,38 @@ export function getTestData() {
instance.interceptors.request.use(
(config) => {
// 每次请求前检查是否有更新的设置
const currentSettings = getSettings();
// 更新BaseURL (如果有变化)
if (currentSettings.endpoint && config.baseURL !== currentSettings.endpoint) {
config.baseURL = currentSettings.endpoint;
}
// 更新API Key (如果有)
if (currentSettings.apiKey) {
config.headers["X-API-Key"] = currentSettings.apiKey;
// 添加JWT token认证
const token = localStorage.getItem('weknora_token');
if (token) {
config.headers["Authorization"] = `Bearer ${token}`;
}
config.headers["X-Request-ID"] = `${generateRandomString(12)}`;
return config;
},
(error) => {}
(error) => {
return Promise.reject(error);
}
);
// Token刷新标志防止多个请求同时刷新token
let isRefreshing = false;
let failedQueue: Array<{ resolve: Function; reject: Function }> = [];
let hasRedirectedOn401 = false;
// 处理队列中的请求
const processQueue = (error: any, token: string | null = null) => {
failedQueue.forEach(({ resolve, reject }) => {
if (error) {
reject(error);
} else {
resolve(token);
}
});
failedQueue = [];
};
instance.interceptors.response.use(
(response) => {
// 根据业务状态码处理逻辑
@@ -95,12 +83,98 @@ instance.interceptors.response.use(
return Promise.reject(data);
}
},
(error: any) => {
async (error: any) => {
const originalRequest = error.config;
if (!error.response) {
return Promise.reject({ message: "网络错误,请检查您的网络连接" });
}
const { data } = error.response;
return Promise.reject(data);
// 如果是登录接口的401直接返回错误以便页面展示toast不做跳转
if (error.response.status === 401 && originalRequest?.url?.includes('/auth/login')) {
const { status, data } = error.response;
return Promise.reject({ status, message: (typeof data === 'object' ? data?.message : data) || '用户名或密码错误' });
}
// 如果是401错误且不是刷新token的请求尝试刷新token
if (error.response.status === 401 && !originalRequest._retry && !originalRequest.url?.includes('/auth/refresh')) {
if (isRefreshing) {
// 如果正在刷新token将请求加入队列
return new Promise((resolve, reject) => {
failedQueue.push({ resolve, reject });
}).then(token => {
originalRequest.headers['Authorization'] = 'Bearer ' + token;
return instance(originalRequest);
}).catch(err => {
return Promise.reject(err);
});
}
originalRequest._retry = true;
isRefreshing = true;
const refreshToken = localStorage.getItem('weknora_refresh_token');
if (refreshToken) {
try {
// 动态导入refresh token API
const { refreshToken: refreshTokenAPI } = await import('../api/auth/index');
const response = await refreshTokenAPI(refreshToken);
if (response.success && response.data) {
const { token, refreshToken: newRefreshToken } = response.data;
// 更新localStorage中的token
localStorage.setItem('weknora_token', token);
localStorage.setItem('weknora_refresh_token', newRefreshToken);
// 更新请求头
originalRequest.headers['Authorization'] = 'Bearer ' + token;
// 处理队列中的请求
processQueue(null, token);
return instance(originalRequest);
} else {
throw new Error(response.message || 'Token刷新失败');
}
} catch (refreshError) {
// 刷新失败清除所有token并跳转到登录页
localStorage.removeItem('weknora_token');
localStorage.removeItem('weknora_refresh_token');
localStorage.removeItem('weknora_user');
localStorage.removeItem('weknora_tenant');
processQueue(refreshError, null);
// 跳转到登录页
if (!hasRedirectedOn401 && typeof window !== 'undefined') {
hasRedirectedOn401 = true;
window.location.href = '/login';
}
return Promise.reject(refreshError);
} finally {
isRefreshing = false;
}
} else {
// 没有refresh token直接跳转到登录页
localStorage.removeItem('weknora_token');
localStorage.removeItem('weknora_user');
localStorage.removeItem('weknora_tenant');
if (!hasRedirectedOn401 && typeof window !== 'undefined') {
hasRedirectedOn401 = true;
window.location.href = '/login';
}
return Promise.reject({ message: '请重新登录' });
}
}
const { status, data } = error.response;
// 将HTTP状态码一并抛出方便上层判断401等场景
return Promise.reject({ status, ...(typeof data === 'object' ? data : { message: data }) });
}
);

View File

@@ -0,0 +1,559 @@
<template>
<div class="login-container">
<!-- 登录表单 -->
<div class="login-card" v-if="!isRegisterMode">
<!-- 系统Logo和标题 -->
<div class="login-header">
<div class="logo">
<img src="@/assets/img/weknora.png" alt="WeKnora" class="logo-img" />
</div>
<p class="login-subtitle">基于大模型的文档理解与语义检索框架</p>
</div>
<div class="login-form">
<t-form
ref="formRef"
:data="formData"
:rules="formRules"
@submit="handleLogin"
layout="vertical"
>
<t-form-item label="邮箱" name="email">
<t-input
v-model="formData.email"
placeholder="请输入邮箱地址"
type="email"
size="large"
:disabled="loading"
/>
</t-form-item>
<t-form-item label="密码" name="password">
<t-input
v-model="formData.password"
placeholder="请输入密码8-32位包含字母和数字"
type="password"
size="large"
:disabled="loading"
@keydown.enter="handleLogin"
/>
</t-form-item>
<t-button
type="submit"
theme="primary"
size="large"
block
:loading="loading"
class="login-button"
>
{{ loading ? '登录中...' : '登录' }}
</t-button>
</t-form>
<!-- 注册链接 -->
<div class="register-link">
<span>还没有账号</span>
<a href="#" @click.prevent="toggleMode" class="register-btn">
立即注册
</a>
</div>
</div>
</div>
<!-- 注册表单 -->
<div class="register-card" v-if="isRegisterMode">
<div class="login-header">
<h1 class="login-title">创建账号</h1>
<p class="login-subtitle">注册后系统将为您创建专属租户</p>
</div>
<div class="login-form">
<t-form
ref="registerFormRef"
:data="registerData"
:rules="registerRules"
@submit="handleRegister"
layout="vertical"
>
<t-form-item label="用户名" name="username">
<t-input
v-model="registerData.username"
placeholder="请输入用户名"
size="large"
:disabled="loading"
/>
</t-form-item>
<t-form-item label="邮箱" name="email">
<t-input
v-model="registerData.email"
placeholder="请输入邮箱地址"
type="email"
size="large"
:disabled="loading"
/>
</t-form-item>
<t-form-item label="密码" name="password">
<t-input
v-model="registerData.password"
placeholder="请输入密码8-32位包含字母和数字"
type="password"
size="large"
:disabled="loading"
/>
</t-form-item>
<t-form-item label="确认密码" name="confirmPassword">
<t-input
v-model="registerData.confirmPassword"
placeholder="请再次输入密码"
type="password"
size="large"
:disabled="loading"
@keydown.enter="handleRegister"
/>
</t-form-item>
<t-button
type="submit"
theme="primary"
size="large"
block
:loading="loading"
class="login-button"
>
{{ loading ? '注册中...' : '注册' }}
</t-button>
</t-form>
<!-- 返回登录 -->
<div class="register-link">
<span>已有账号</span>
<a href="#" @click.prevent="toggleMode" class="register-btn">
返回登录
</a>
</div>
</div>
</div>
</div>
</template>
<script setup lang="ts">
import { ref, reactive, computed, nextTick, onMounted } from 'vue'
import { useRouter } from 'vue-router'
import { MessagePlugin } from 'tdesign-vue-next'
import { login, register } from '@/api/auth'
import { loadTestData, resetTestDataLoaded } from '@/api/test-data'
import { useAuthStore } from '@/stores/auth'
const router = useRouter()
const authStore = useAuthStore()
// 表单引用
const formRef = ref()
const registerFormRef = ref()
// 状态管理
const loading = ref(false)
const isRegisterMode = ref(false)
// 登录表单数据
const formData = reactive<{[key: string]: any}>({
email: '',
password: '',
})
// 注册表单数据
const registerData = reactive<{[key: string]: any}>({
username: '',
email: '',
password: '',
confirmPassword: ''
})
// 登录表单验证规则
const formRules = {
email: [
{ required: true, message: '请输入邮箱地址', type: 'error' },
{ email: true, message: '请输入正确的邮箱格式', type: 'error' }
],
password: [
{ required: true, message: '请输入密码', type: 'error' },
{ min: 8, message: '密码至少8位', type: 'error' },
{ max: 32, message: '密码不能超过32位', type: 'error' },
{ pattern: /[a-zA-Z]/, message: '密码必须包含字母', type: 'error' },
{ pattern: /\d/, message: '密码必须包含数字', type: 'error' }
]
}
// 注册表单验证规则
const registerRules = {
username: [
{ required: true, message: '请输入用户名', type: 'error' },
{ min: 2, message: '用户名至少2位', type: 'error' },
{ max: 20, message: '用户名不能超过20位', type: 'error' },
{
pattern: /^[a-zA-Z0-9_\u4e00-\u9fa5]+$/,
message: '用户名只能包含字母、数字、下划线和中文',
type: 'error'
}
],
email: [
{ required: true, message: '请输入邮箱地址', type: 'error' },
{ email: true, message: '请输入正确的邮箱格式', type: 'error' }
],
password: [
{ required: true, message: '请输入密码', type: 'error' },
{ min: 8, message: '密码至少8位', type: 'error' },
{ max: 32, message: '密码不能超过32位', type: 'error' },
{ pattern: /[a-zA-Z]/, message: '密码必须包含字母', type: 'error' },
{ pattern: /\d/, message: '密码必须包含数字', type: 'error' }
],
confirmPassword: [
{ required: true, message: '请确认密码', type: 'error' },
{
validator: (val: string) => val === registerData.password,
message: '两次输入的密码不一致',
type: 'error'
}
]
}
// 切换登录/注册模式
const toggleMode = () => {
isRegisterMode.value = !isRegisterMode.value
Object.keys(registerData).forEach(key => {
(registerData as any)[key] = ''
})
}
// 处理登录
const handleLogin = async () => {
try {
const valid = await formRef.value?.validate()
if (!valid) return
loading.value = true
const response = await login({
email: formData.email,
password: formData.password,
})
if (response.success) {
// 保存用户信息和token
if (response.user && response.tenant && response.token) {
authStore.setUser({
id: response.user.id || '',
username: response.user.username || '',
email: response.user.email || '',
avatar: response.user.avatar,
tenant_id: String(response.tenant.id) || '',
created_at: response.user.created_at || new Date().toISOString(),
updated_at: response.user.updated_at || new Date().toISOString()
})
authStore.setToken(response.token)
if (response.refresh_token) {
authStore.setRefreshToken(response.refresh_token)
}
authStore.setTenant({
id: String(response.tenant.id) || '',
name: response.tenant.name || '',
api_key: response.tenant.api_key || '',
owner_id: response.user.id || '',
created_at: response.tenant.created_at || new Date().toISOString(),
updated_at: response.tenant.updated_at || new Date().toISOString()
})
}
MessagePlugin.success('登录成功!')
// 登录成功后先重置并加载一次测试数据确保有KB可用
try {
resetTestDataLoaded()
await loadTestData()
} catch (_) {}
// 等待状态更新完成后再跳转
await nextTick()
router.replace('/platform/knowledgeBase')
} else {
MessagePlugin.error(response.message || '登录失败,请检查邮箱或密码')
}
} catch (error: any) {
console.error('登录错误:', error)
MessagePlugin.error(error.message || '登录失败,请稍后重试')
} finally {
loading.value = false
}
}
// 处理注册
const handleRegister = async () => {
try {
const valid = await registerFormRef.value?.validate()
if (!valid) return
loading.value = true
const response = await register({
username: registerData.username,
email: registerData.email,
password: registerData.password
})
if (response.success) {
MessagePlugin.success('注册成功!系统已为您创建专属租户,请登录使用')
// 切换到登录模式并填入邮箱
isRegisterMode.value = false
formData.email = registerData.email
// 清空注册表单
Object.keys(registerData).forEach(key => {
(registerData as any)[key] = ''
})
} else {
MessagePlugin.error(response.message || '注册失败')
}
} catch (error: any) {
console.error('注册错误:', error)
MessagePlugin.error(error.message || '注册失败,请稍后重试')
} finally {
loading.value = false
}
}
// 处理忘记密码
const handleForgotPassword = () => {
MessagePlugin.info('忘记密码功能暂未开放,请联系管理员')
}
// 检查是否已登录
onMounted(() => {
if (authStore.isLoggedIn) {
router.replace('/platform/tenant/knowledge-bases')
}
})
</script>
<style lang="less" scoped>
.login-container {
min-height: 100vh;
display: flex;
align-items: center;
justify-content: center;
background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
padding: 20px;
box-sizing: border-box;
}
.login-card,
.register-card {
width: 100%;
max-width: 440px;
background: #fff;
border-radius: 14px;
box-shadow: 0 10px 16px 0 #0000000f, 0 20px 24px -2px #0000001a;
padding: 40px;
box-sizing: border-box;
animation: fadeInUp .28s ease-out both;
}
.login-header {
text-align: center;
margin-bottom: 32px;
.logo {
margin-bottom: 16px;
.logo-img {
width: 180px;
height: auto;
border-radius: 12px;
}
}
.login-title {
font-size: 28px;
font-weight: 600;
color: #000000e6;
margin: 0 0 8px 0;
font-family: "PingFang SC";
}
.login-subtitle {
font-size: 16px;
color: #0000008c;
margin: 0;
font-family: "PingFang SC";
}
}
.login-form {
:deep(.t-form-item__label) {
font-size: 14px;
color: #000000e6;
font-weight: 500;
margin-bottom: 8px;
font-family: "PingFang SC";
display: block;
text-align: left;
}
:deep(.t-input) {
border: 1px solid #E7E7E7;
border-radius: 8px;
background: #fff;
&:focus-within {
border-color: #07C05F;
box-shadow: 0 0 0 2px rgba(7, 192, 95, 0.1);
}
&:hover {
border-color: #07C05F;
}
.t-input__inner {
border: none !important;
box-shadow: none !important;
outline: none !important;
background: transparent;
font-size: 16px;
font-family: "PingFang SC";
&:focus {
border: none !important;
box-shadow: none !important;
outline: none !important;
}
}
.t-input__wrap {
border: none !important;
box-shadow: none !important;
}
}
:deep(.t-form-item) {
margin-bottom: 20px;
&:last-child {
margin-bottom: 0;
}
}
:deep(.t-form-item__control) {
width: 100%;
}
}
.login-options {
display: flex;
justify-content: space-between;
align-items: center;
margin: 16px 0 24px 0;
width: 100%;
:deep(.t-checkbox) {
display: flex;
align-items: center;
.t-checkbox__input {
margin-right: 8px;
}
}
:deep(.t-checkbox__label) {
font-size: 14px;
color: #00000099;
font-family: "PingFang SC";
line-height: 1.4;
margin-left: 0;
}
.forgot-password {
font-size: 14px;
color: #07C05F;
text-decoration: none;
font-family: "PingFang SC";
line-height: 1.4;
&:hover {
text-decoration: underline;
}
}
}
.login-button {
height: 48px;
border-radius: 8px;
font-size: 16px;
font-weight: 500;
font-family: "PingFang SC";
margin: 16px 0 8px 0;
:deep(.t-button) {
background-color: #07C05F;
border-color: #07C05F;
&:hover {
background-color: #06a855;
border-color: #06a855;
}
}
}
.register-link {
text-align: center;
font-size: 14px;
color: #00000099;
font-family: "PingFang SC";
.register-btn {
color: #07C05F;
text-decoration: none;
margin-left: 4px;
&:hover {
text-decoration: underline;
}
}
}
// 响应式设计
@media (max-width: 480px) {
.login-container {
padding: 16px;
}
.login-card,
.register-card {
padding: 28px;
}
.login-header {
margin-bottom: 24px;
.login-title {
font-size: 24px;
}
}
}
@keyframes fadeInUp {
from {
opacity: 0;
transform: translate3d(0, 6px, 0);
}
to {
opacity: 1;
transform: translate3d(0, 0, 0);
}
}
</style>

View File

@@ -27,33 +27,7 @@ const sendMsg = (value: string) => {
}
async function createNewSession(value: string) {
// 从localStorage获取设置中的知识库ID
const settingsStr = localStorage.getItem("WeKnora_settings");
let knowledgeBaseId = "";
if (settingsStr) {
try {
const settings = JSON.parse(settingsStr);
if (settings.knowledgeBaseId) {
knowledgeBaseId = settings.knowledgeBaseId;
createSessions({ knowledge_base_id: knowledgeBaseId }).then(res => {
if (res.data && res.data.id) {
getTitle(res.data.id, value);
} else {
// 错误处理
console.error("创建会话失败");
}
}).catch(error => {
console.error("创建会话出错:", error);
});
return;
}
} catch (e) {
console.error("解析设置失败:", e);
}
}
// 如果设置中没有知识库ID则使用测试数据
// 使用测试数据获取知识库ID
const testData = getTestData();
if (!testData || testData.knowledge_bases.length === 0) {
console.error("测试数据未初始化或不包含知识库");
@@ -61,7 +35,7 @@ async function createNewSession(value: string) {
}
// 使用第一个知识库ID
knowledgeBaseId = testData.knowledge_bases[0].id;
const knowledgeBaseId = testData.knowledge_bases[0].id;
createSessions({ knowledge_base_id: knowledgeBaseId }).then(res => {
if (res.data && res.data.id) {

View File

@@ -17,6 +17,10 @@
<span class="dot" />{{ s.label }}
</li>
</ul>
<t-divider />
<t-button size="small" variant="outline" theme="danger" block @click="handleLogout">
退出登录
</t-button>
</div>
</aside>
<div class="init-main">
@@ -780,8 +784,10 @@ import {
listOllamaModels,
testEmbeddingModel
} from '@/api/initialization';
import { useAuthStore } from '@/stores/auth';
const router = useRouter();
const authStore = useAuthStore();
type TFormRef = {
validate: (fields?: string[] | undefined) => Promise<true | any>;
clearValidate?: (fields?: string | string[]) => void;
@@ -956,6 +962,12 @@ const goToSection = (id: string) => {
}
};
// 退出登录
const handleLogout = () => {
authStore.logout();
router.replace('/login');
};
// 监听滚动,高亮当前区块
const onScroll = () => {
const order = ['ollama','llm','embedding','rerank','multimodal','docsplit','submit'];
@@ -2335,6 +2347,27 @@ const detectEmbeddingDimension = async () => {
<style lang="less" scoped>
.initialization-container {
padding: 20px 16px;
background: linear-gradient(180deg, #f7faf9 0%, #f9fbfa 60%, #ffffff 100%);
scroll-behavior: smooth;
.initialization-header {
text-align: center;
margin: 10px auto 18px;
h1 {
margin: 0 0 6px;
font-size: 22px;
font-weight: 700;
color: #0f172a;
}
p {
margin: 0;
color: #64748b;
font-size: 14px;
}
}
.init-layout {
display: grid;
grid-template-columns: 220px 1fr;
@@ -2420,6 +2453,30 @@ const detectEmbeddingDimension = async () => {
min-width: 0;
max-width: 960px;
}
/* 统一分区卡片视觉 */
.config-section {
background: #fff;
border: 1px solid #eef4f0;
border-radius: 12px;
box-shadow: 0 6px 18px rgba(7, 192, 95, 0.04);
padding: 16px;
margin: 14px 0;
h3 {
display: flex;
align-items: center;
gap: 8px;
margin: 0 0 12px;
font-size: 16px;
font-weight: 700;
color: #0f172a;
}
.section-icon {
color: #07c05f;
font-size: 18px;
}
}
.ollama-summary-card {
max-width: 100%;
margin: 0 0 16px 0;

11
go.mod
View File

@@ -10,6 +10,7 @@ require (
github.com/gin-contrib/cors v1.7.5
github.com/gin-gonic/gin v1.10.0
github.com/go-viper/mapstructure/v2 v2.2.1
github.com/golang-jwt/jwt/v5 v5.3.0
github.com/google/uuid v1.6.0
github.com/hibiken/asynq v0.25.1
github.com/minio/minio-go/v7 v7.0.90
@@ -31,7 +32,8 @@ require (
go.opentelemetry.io/otel/sdk v1.37.0
go.opentelemetry.io/otel/trace v1.37.0
go.uber.org/dig v1.18.1
golang.org/x/sync v0.15.0
golang.org/x/crypto v0.42.0
golang.org/x/sync v0.17.0
google.golang.org/grpc v1.73.0
google.golang.org/protobuf v1.36.6
gorm.io/driver/postgres v1.5.11
@@ -100,10 +102,9 @@ require (
go.opentelemetry.io/proto/otlp v1.7.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/arch v0.15.0 // indirect
golang.org/x/crypto v0.39.0 // indirect
golang.org/x/net v0.41.0 // indirect
golang.org/x/sys v0.33.0 // indirect
golang.org/x/text v0.26.0 // indirect
golang.org/x/net v0.43.0 // indirect
golang.org/x/sys v0.36.0 // indirect
golang.org/x/text v0.29.0 // indirect
golang.org/x/time v0.11.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 // indirect

26
go.sum
View File

@@ -70,6 +70,8 @@ github.com/go-viper/mapstructure/v2 v2.2.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlnd
github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
@@ -253,20 +255,20 @@ go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
golang.org/x/arch v0.15.0 h1:QtOrQd0bTUnhNVNndMpLHNWrDmYzZ2KDqSrEymqInZw=
golang.org/x/arch v0.15.0/go.mod h1:JmwW7aLIoRUKgaTzhkiEFxvcEiQGyOg9BMonBJUS7EE=
golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM=
golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U=
golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw=
golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA=
golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8=
golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI=
golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8=
golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE=
golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg=
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/term v0.32.0 h1:DR4lr0TjUs3epypdhTOkMmuF5CDFJ/8pOnbzMZPQ7bg=
golang.org/x/term v0.32.0/go.mod h1:uZG1FhGx848Sqfsq4/DlJr3xGGsYMu/L5GW4abiaEPQ=
golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M=
golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA=
golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k=
golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.35.0 h1:bZBVKBudEyhRcajGcNc3jIfWPqV4y/Kt2XcoigOWtDQ=
golang.org/x/term v0.35.0/go.mod h1:TPGtkTLesOwf2DE8CgVYiZinHAOuy5AYUYT1lENIZnA=
golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk=
golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4=
golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0=
golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

View File

@@ -0,0 +1,154 @@
package repository
import (
"context"
"errors"
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
"gorm.io/gorm"
)
var (
ErrUserNotFound = errors.New("user not found")
ErrUserAlreadyExists = errors.New("user already exists")
ErrTokenNotFound = errors.New("token not found")
)
// userRepository implements user repository interface
type userRepository struct {
db *gorm.DB
}
// NewUserRepository creates a new user repository
func NewUserRepository(db *gorm.DB) interfaces.UserRepository {
return &userRepository{db: db}
}
// CreateUser creates a user
func (r *userRepository) CreateUser(ctx context.Context, user *types.User) error {
logger.Infof(ctx, "Creating user in database: %s", user.Email)
return r.db.WithContext(ctx).Create(user).Error
}
// GetUserByID gets a user by ID
func (r *userRepository) GetUserByID(ctx context.Context, id string) (*types.User, error) {
var user types.User
if err := r.db.WithContext(ctx).Where("id = ?", id).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, err
}
return &user, nil
}
// GetUserByEmail gets a user by email
func (r *userRepository) GetUserByEmail(ctx context.Context, email string) (*types.User, error) {
var user types.User
if err := r.db.WithContext(ctx).Where("email = ?", email).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, err
}
return &user, nil
}
// GetUserByUsername gets a user by username
func (r *userRepository) GetUserByUsername(ctx context.Context, username string) (*types.User, error) {
var user types.User
if err := r.db.WithContext(ctx).Where("username = ?", username).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, err
}
return &user, nil
}
// UpdateUser updates a user
func (r *userRepository) UpdateUser(ctx context.Context, user *types.User) error {
return r.db.WithContext(ctx).Save(user).Error
}
// DeleteUser deletes a user
func (r *userRepository) DeleteUser(ctx context.Context, id string) error {
return r.db.WithContext(ctx).Where("id = ?", id).Delete(&types.User{}).Error
}
// ListUsers lists users with pagination
func (r *userRepository) ListUsers(ctx context.Context, offset, limit int) ([]*types.User, error) {
var users []*types.User
query := r.db.WithContext(ctx).Order("created_at DESC")
if limit > 0 {
query = query.Limit(limit)
}
if offset > 0 {
query = query.Offset(offset)
}
if err := query.Find(&users).Error; err != nil {
return nil, err
}
return users, nil
}
// authTokenRepository implements auth token repository interface
type authTokenRepository struct {
db *gorm.DB
}
// NewAuthTokenRepository creates a new auth token repository
func NewAuthTokenRepository(db *gorm.DB) interfaces.AuthTokenRepository {
return &authTokenRepository{db: db}
}
// CreateToken creates an auth token
func (r *authTokenRepository) CreateToken(ctx context.Context, token *types.AuthToken) error {
return r.db.WithContext(ctx).Create(token).Error
}
// GetTokenByValue gets a token by its value
func (r *authTokenRepository) GetTokenByValue(ctx context.Context, tokenValue string) (*types.AuthToken, error) {
var token types.AuthToken
if err := r.db.WithContext(ctx).Where("token = ?", tokenValue).First(&token).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrTokenNotFound
}
return nil, err
}
return &token, nil
}
// GetTokensByUserID gets all tokens for a user
func (r *authTokenRepository) GetTokensByUserID(ctx context.Context, userID string) ([]*types.AuthToken, error) {
var tokens []*types.AuthToken
if err := r.db.WithContext(ctx).Where("user_id = ?", userID).Find(&tokens).Error; err != nil {
return nil, err
}
return tokens, nil
}
// UpdateToken updates a token
func (r *authTokenRepository) UpdateToken(ctx context.Context, token *types.AuthToken) error {
return r.db.WithContext(ctx).Save(token).Error
}
// DeleteToken deletes a token
func (r *authTokenRepository) DeleteToken(ctx context.Context, id string) error {
return r.db.WithContext(ctx).Where("id = ?", id).Delete(&types.AuthToken{}).Error
}
// DeleteExpiredTokens deletes all expired tokens
func (r *authTokenRepository) DeleteExpiredTokens(ctx context.Context) error {
return r.db.WithContext(ctx).Where("expires_at < NOW()").Delete(&types.AuthToken{}).Error
}
// RevokeTokensByUserID revokes all tokens for a user
func (r *authTokenRepository) RevokeTokensByUserID(ctx context.Context, userID string) error {
return r.db.WithContext(ctx).Model(&types.AuthToken{}).Where("user_id = ?", userID).Update("is_revoked", true).Error
}

View File

@@ -18,7 +18,9 @@ import (
"github.com/Tencent/WeKnora/internal/types/interfaces"
)
var apiKeySecret = []byte(os.Getenv("TENANT_AES_KEY"))
var apiKeySecret = func() []byte {
return []byte(os.Getenv("TENANT_AES_KEY"))
}
// ListTenantsParams defines parameters for listing tenants with filtering and pagination
type ListTenantsParams struct {
@@ -221,7 +223,7 @@ func (r *tenantService) generateApiKey(tenantID uint) string {
binary.LittleEndian.PutUint64(idBytes, uint64(tenantID))
// 2. Encrypt tenant_id using AES-GCM
block, err := aes.NewCipher(apiKeySecret)
block, err := aes.NewCipher(apiKeySecret())
if err != nil {
panic("Failed to create AES cipher: " + err.Error())
}
@@ -267,7 +269,7 @@ func (r *tenantService) ExtractTenantIDFromAPIKey(apiKey string) (uint, error) {
nonce, ciphertext := encryptedData[:12], encryptedData[12:]
// 4. Decrypt
block, err := aes.NewCipher(apiKeySecret)
block, err := aes.NewCipher(apiKeySecret())
if err != nil {
return 0, errors.New("decryption error")
}

View File

@@ -0,0 +1,408 @@
package service
import (
"context"
"errors"
"fmt"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
)
// JWT secret key - in production this should be from environment variable
var jwtSecret = []byte("your-secret-key")
// userService implements the UserService interface
type userService struct {
userRepo interfaces.UserRepository
tokenRepo interfaces.AuthTokenRepository
tenantService interfaces.TenantService
}
// NewUserService creates a new user service instance
func NewUserService(userRepo interfaces.UserRepository, tokenRepo interfaces.AuthTokenRepository, tenantService interfaces.TenantService) interfaces.UserService {
return &userService{
userRepo: userRepo,
tokenRepo: tokenRepo,
tenantService: tenantService,
}
}
// Register creates a new user account
func (s *userService) Register(ctx context.Context, req *types.RegisterRequest) (*types.User, error) {
logger.Info(ctx, "Start user registration")
// Validate input
if req.Username == "" || req.Email == "" || req.Password == "" {
return nil, errors.New("username, email and password are required")
}
// Check if user already exists
existingUser, _ := s.userRepo.GetUserByEmail(ctx, req.Email)
if existingUser != nil {
return nil, errors.New("user with this email already exists")
}
existingUser, _ = s.userRepo.GetUserByUsername(ctx, req.Username)
if existingUser != nil {
return nil, errors.New("user with this username already exists")
}
// Hash password
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
if err != nil {
logger.Errorf(ctx, "Failed to hash password: %v", err)
return nil, errors.New("failed to process password")
}
// Create default tenant for the user
tenant := &types.Tenant{
Name: fmt.Sprintf("%s's Workspace", req.Username),
Description: "Default workspace",
Status: "active",
RetrieverEngines: types.RetrieverEngines{
Engines: []types.RetrieverEngineParams{
{
RetrieverType: types.KeywordsRetrieverType,
RetrieverEngineType: types.PostgresRetrieverEngineType,
},
{
RetrieverType: types.VectorRetrieverType,
RetrieverEngineType: types.PostgresRetrieverEngineType,
},
},
},
}
createdTenant, err := s.tenantService.CreateTenant(ctx, tenant)
if err != nil {
logger.Errorf(ctx, "Failed to create tenant: %v", err)
return nil, errors.New("failed to create workspace")
}
// Create user
user := &types.User{
ID: uuid.New().String(),
Username: req.Username,
Email: req.Email,
PasswordHash: string(hashedPassword),
TenantID: createdTenant.ID,
IsActive: true,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
err = s.userRepo.CreateUser(ctx, user)
if err != nil {
logger.Errorf(ctx, "Failed to create user: %v", err)
return nil, errors.New("failed to create user")
}
logger.Infof(ctx, "User registered successfully: %s", user.Email)
return user, nil
}
// Login authenticates a user and returns tokens
func (s *userService) Login(ctx context.Context, req *types.LoginRequest) (*types.LoginResponse, error) {
logger.Infof(ctx, "Start user login for email: %s", req.Email)
// Get user by email
user, err := s.userRepo.GetUserByEmail(ctx, req.Email)
if err != nil {
logger.Errorf(ctx, "Failed to get user by email %s: %v", req.Email, err)
return &types.LoginResponse{
Success: false,
Message: "Invalid email or password",
}, nil
}
if user == nil {
logger.Warnf(ctx, "User not found for email: %s", req.Email)
return &types.LoginResponse{
Success: false,
Message: "Invalid email or password",
}, nil
}
logger.Infof(ctx, "Found user: ID=%s, Email=%s, IsActive=%t", user.ID, user.Email, user.IsActive)
// Check if user is active
if !user.IsActive {
logger.Warnf(ctx, "User account is disabled for email: %s", req.Email)
return &types.LoginResponse{
Success: false,
Message: "Account is disabled",
}, nil
}
// Verify password
logger.Infof(ctx, "Verifying password for user: %s", user.Email)
err = bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.Password))
if err != nil {
logger.Warnf(ctx, "Password verification failed for user %s: %v", user.Email, err)
return &types.LoginResponse{
Success: false,
Message: "Invalid email or password",
}, nil
}
logger.Infof(ctx, "Password verification successful for user: %s", user.Email)
// Generate tokens
logger.Infof(ctx, "Generating tokens for user: %s", user.Email)
accessToken, refreshToken, err := s.GenerateTokens(ctx, user)
if err != nil {
logger.Errorf(ctx, "Failed to generate tokens for user %s: %v", user.Email, err)
return &types.LoginResponse{
Success: false,
Message: "Login failed",
}, nil
}
logger.Infof(ctx, "Tokens generated successfully for user: %s", user.Email)
// Get tenant information
logger.Infof(ctx, "Getting tenant information for user %s, tenant ID: %s", user.Email, user.TenantID)
tenant, err := s.tenantService.GetTenantByID(ctx, user.TenantID)
if err != nil {
logger.Warnf(ctx, "Failed to get tenant info for user %s, tenant ID %s: %v", user.Email, user.TenantID, err)
} else {
logger.Infof(ctx, "Tenant information retrieved successfully for user: %s", user.Email)
}
logger.Infof(ctx, "User logged in successfully: %s", user.Email)
return &types.LoginResponse{
Success: true,
Message: "Login successful",
User: user,
Tenant: tenant,
Token: accessToken,
RefreshToken: refreshToken,
}, nil
}
// GetUserByID gets a user by ID
func (s *userService) GetUserByID(ctx context.Context, id string) (*types.User, error) {
return s.userRepo.GetUserByID(ctx, id)
}
// GetUserByEmail gets a user by email
func (s *userService) GetUserByEmail(ctx context.Context, email string) (*types.User, error) {
return s.userRepo.GetUserByEmail(ctx, email)
}
// GetUserByUsername gets a user by username
func (s *userService) GetUserByUsername(ctx context.Context, username string) (*types.User, error) {
return s.userRepo.GetUserByUsername(ctx, username)
}
// UpdateUser updates user information
func (s *userService) UpdateUser(ctx context.Context, user *types.User) error {
user.UpdatedAt = time.Now()
return s.userRepo.UpdateUser(ctx, user)
}
// DeleteUser deletes a user
func (s *userService) DeleteUser(ctx context.Context, id string) error {
return s.userRepo.DeleteUser(ctx, id)
}
// ChangePassword changes user password
func (s *userService) ChangePassword(ctx context.Context, userID string, oldPassword, newPassword string) error {
user, err := s.userRepo.GetUserByID(ctx, userID)
if err != nil {
return err
}
// Verify old password
err = bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(oldPassword))
if err != nil {
return errors.New("invalid old password")
}
// Hash new password
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost)
if err != nil {
return err
}
user.PasswordHash = string(hashedPassword)
user.UpdatedAt = time.Now()
return s.userRepo.UpdateUser(ctx, user)
}
// ValidatePassword validates user password
func (s *userService) ValidatePassword(ctx context.Context, userID string, password string) error {
user, err := s.userRepo.GetUserByID(ctx, userID)
if err != nil {
return err
}
return bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password))
}
// GenerateTokens generates access and refresh tokens for user
func (s *userService) GenerateTokens(ctx context.Context, user *types.User) (accessToken, refreshToken string, err error) {
// Generate access token (expires in 24 hours)
accessClaims := jwt.MapClaims{
"user_id": user.ID,
"email": user.Email,
"tenant_id": user.TenantID,
"exp": time.Now().Add(24 * time.Hour).Unix(),
"iat": time.Now().Unix(),
"type": "access",
}
accessTokenObj := jwt.NewWithClaims(jwt.SigningMethodHS256, accessClaims)
accessToken, err = accessTokenObj.SignedString(jwtSecret)
if err != nil {
return "", "", err
}
// Generate refresh token (expires in 7 days)
refreshClaims := jwt.MapClaims{
"user_id": user.ID,
"exp": time.Now().Add(7 * 24 * time.Hour).Unix(),
"iat": time.Now().Unix(),
"type": "refresh",
}
refreshTokenObj := jwt.NewWithClaims(jwt.SigningMethodHS256, refreshClaims)
refreshToken, err = refreshTokenObj.SignedString(jwtSecret)
if err != nil {
return "", "", err
}
// Store tokens in database
accessTokenRecord := &types.AuthToken{
ID: uuid.New().String(),
UserID: user.ID,
Token: accessToken,
TokenType: "access_token",
ExpiresAt: time.Now().Add(24 * time.Hour),
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
refreshTokenRecord := &types.AuthToken{
ID: uuid.New().String(),
UserID: user.ID,
Token: refreshToken,
TokenType: "refresh_token",
ExpiresAt: time.Now().Add(7 * 24 * time.Hour),
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
_ = s.tokenRepo.CreateToken(ctx, accessTokenRecord)
_ = s.tokenRepo.CreateToken(ctx, refreshTokenRecord)
return accessToken, refreshToken, nil
}
// ValidateToken validates an access token
func (s *userService) ValidateToken(ctx context.Context, tokenString string) (*types.User, error) {
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return jwtSecret, nil
})
if err != nil || !token.Valid {
return nil, errors.New("invalid token")
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return nil, errors.New("invalid token claims")
}
userID, ok := claims["user_id"].(string)
if !ok {
return nil, errors.New("invalid user ID in token")
}
// Check if token is revoked
tokenRecord, err := s.tokenRepo.GetTokenByValue(ctx, tokenString)
if err != nil || tokenRecord == nil || tokenRecord.IsRevoked {
return nil, errors.New("token is revoked")
}
return s.userRepo.GetUserByID(ctx, userID)
}
// RefreshToken refreshes access token using refresh token
func (s *userService) RefreshToken(ctx context.Context, refreshTokenString string) (accessToken, newRefreshToken string, err error) {
token, err := jwt.Parse(refreshTokenString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return jwtSecret, nil
})
if err != nil || !token.Valid {
return "", "", errors.New("invalid refresh token")
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return "", "", errors.New("invalid token claims")
}
tokenType, ok := claims["type"].(string)
if !ok || tokenType != "refresh" {
return "", "", errors.New("not a refresh token")
}
userID, ok := claims["user_id"].(string)
if !ok {
return "", "", errors.New("invalid user ID in token")
}
// Check if token is revoked
tokenRecord, err := s.tokenRepo.GetTokenByValue(ctx, refreshTokenString)
if err != nil || tokenRecord == nil || tokenRecord.IsRevoked {
return "", "", errors.New("refresh token is revoked")
}
// Get user
user, err := s.userRepo.GetUserByID(ctx, userID)
if err != nil {
return "", "", err
}
// Revoke old refresh token
tokenRecord.IsRevoked = true
_ = s.tokenRepo.UpdateToken(ctx, tokenRecord)
// Generate new tokens
return s.GenerateTokens(ctx, user)
}
// RevokeToken revokes a token
func (s *userService) RevokeToken(ctx context.Context, tokenString string) error {
tokenRecord, err := s.tokenRepo.GetTokenByValue(ctx, tokenString)
if err != nil {
return err
}
tokenRecord.IsRevoked = true
tokenRecord.UpdatedAt = time.Now()
return s.tokenRepo.UpdateToken(ctx, tokenRecord)
}
// GetCurrentUser gets current user from context
func (s *userService) GetCurrentUser(ctx context.Context) (*types.User, error) {
userID, ok := ctx.Value("user_id").(string)
if !ok {
return nil, errors.New("user not found in context")
}
return s.userRepo.GetUserByID(ctx, userID)
}

View File

@@ -78,6 +78,8 @@ func BuildContainer(container *dig.Container) *dig.Container {
must(container.Provide(repository.NewSessionRepository))
must(container.Provide(repository.NewMessageRepository))
must(container.Provide(repository.NewModelRepository))
must(container.Provide(repository.NewUserRepository))
must(container.Provide(repository.NewAuthTokenRepository))
// Business service layer
must(container.Provide(service.NewTenantService))
@@ -91,6 +93,7 @@ func BuildContainer(container *dig.Container) *dig.Container {
must(container.Provide(service.NewModelService))
must(container.Provide(service.NewDatasetService))
must(container.Provide(service.NewEvaluationService))
must(container.Provide(service.NewUserService))
// Chat pipeline components for processing chat requests
must(container.Provide(chatpipline.NewEventManager))
@@ -117,6 +120,7 @@ func BuildContainer(container *dig.Container) *dig.Container {
must(container.Provide(handler.NewModelHandler))
must(container.Provide(handler.NewEvaluationHandler))
must(container.Provide(handler.NewInitializationHandler))
must(container.Provide(handler.NewAuthHandler))
// Router configuration
must(container.Provide(router.NewRouter))
@@ -177,6 +181,15 @@ func initDatabase(cfg *config.Config) (*gorm.DB, error) {
return nil, err
}
// Auto-migrate database tables
err = db.AutoMigrate(
&types.User{},
&types.AuthToken{},
)
if err != nil {
return nil, fmt.Errorf("failed to auto-migrate database tables: %v", err)
}
// Get underlying SQL DB object
sqlDB, err := db.DB()
if err != nil {

325
internal/handler/auth.go Normal file
View File

@@ -0,0 +1,325 @@
package handler
import (
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/Tencent/WeKnora/internal/errors"
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
)
// AuthHandler implements HTTP request handlers for user authentication
// Provides functionality for user registration, login, logout, and token management
// through the REST API endpoints
type AuthHandler struct {
userService interfaces.UserService
}
// NewAuthHandler creates a new auth handler instance with the provided service
// Parameters:
// - userService: An implementation of the UserService interface for business logic
//
// Returns a pointer to the newly created AuthHandler
func NewAuthHandler(userService interfaces.UserService) *AuthHandler {
return &AuthHandler{
userService: userService,
}
}
// Register handles the HTTP request for user registration
// It deserializes the request body into a registration request object, validates it,
// calls the service to create the user, and returns the result
// Parameters:
// - c: Gin context for the HTTP request
func (h *AuthHandler) Register(c *gin.Context) {
ctx := c.Request.Context()
logger.Info(ctx, "Start user registration")
var req types.RegisterRequest
if err := c.ShouldBindJSON(&req); err != nil {
logger.Error(ctx, "Failed to parse registration request parameters", err)
appErr := errors.NewValidationError("Invalid registration parameters").WithDetails(err.Error())
c.Error(appErr)
return
}
// Validate required fields
if req.Username == "" || req.Email == "" || req.Password == "" {
logger.Error(ctx, "Missing required registration fields")
appErr := errors.NewValidationError("Username, email and password are required")
c.Error(appErr)
return
}
// Call service to register user
user, err := h.userService.Register(ctx, &req)
if err != nil {
logger.Errorf(ctx, "Failed to register user: %v", err)
appErr := errors.NewBadRequestError("Registration failed").WithDetails(err.Error())
c.Error(appErr)
return
}
// Return success response
response := &types.RegisterResponse{
Success: true,
Message: "Registration successful",
User: user,
}
logger.Infof(ctx, "User registered successfully: %s", user.Email)
c.JSON(http.StatusCreated, response)
}
// Login handles the HTTP request for user login
// It deserializes the request body into a login request object, validates it,
// calls the service to authenticate the user, and returns tokens
// Parameters:
// - c: Gin context for the HTTP request
func (h *AuthHandler) Login(c *gin.Context) {
ctx := c.Request.Context()
logger.Info(ctx, "Start user login")
var req types.LoginRequest
if err := c.ShouldBindJSON(&req); err != nil {
logger.Error(ctx, "Failed to parse login request parameters", err)
appErr := errors.NewValidationError("Invalid login parameters").WithDetails(err.Error())
c.Error(appErr)
return
}
// Validate required fields
if req.Email == "" || req.Password == "" {
logger.Error(ctx, "Missing required login fields")
appErr := errors.NewValidationError("Email and password are required")
c.Error(appErr)
return
}
// Call service to authenticate user
response, err := h.userService.Login(ctx, &req)
if err != nil {
logger.Errorf(ctx, "Failed to login user: %v", err)
appErr := errors.NewUnauthorizedError("Login failed").WithDetails(err.Error())
c.Error(appErr)
return
}
// Check if login was successful
if !response.Success {
logger.Warnf(ctx, "Login failed: %s", response.Message)
c.JSON(http.StatusUnauthorized, response)
return
}
// User is already in the correct format from service
logger.Infof(ctx, "User logged in successfully: %s", req.Email)
c.JSON(http.StatusOK, response)
}
// Logout handles the HTTP request for user logout
// It extracts the token from the Authorization header and revokes it
// Parameters:
// - c: Gin context for the HTTP request
func (h *AuthHandler) Logout(c *gin.Context) {
ctx := c.Request.Context()
logger.Info(ctx, "Start user logout")
// Extract token from Authorization header
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
logger.Error(ctx, "Missing Authorization header")
appErr := errors.NewValidationError("Authorization header is required")
c.Error(appErr)
return
}
// Parse Bearer token
tokenParts := strings.Split(authHeader, " ")
if len(tokenParts) != 2 || tokenParts[0] != "Bearer" {
logger.Error(ctx, "Invalid Authorization header format")
appErr := errors.NewValidationError("Invalid Authorization header format")
c.Error(appErr)
return
}
token := tokenParts[1]
// Revoke token
err := h.userService.RevokeToken(ctx, token)
if err != nil {
logger.Errorf(ctx, "Failed to revoke token: %v", err)
appErr := errors.NewInternalServerError("Logout failed").WithDetails(err.Error())
c.Error(appErr)
return
}
logger.Info(ctx, "User logged out successfully")
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "Logout successful",
})
}
// RefreshToken handles the HTTP request for refreshing access tokens
// It extracts the refresh token from the request body and generates new tokens
// Parameters:
// - c: Gin context for the HTTP request
func (h *AuthHandler) RefreshToken(c *gin.Context) {
ctx := c.Request.Context()
logger.Info(ctx, "Start token refresh")
var req struct {
RefreshToken string `json:"refreshToken" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
logger.Error(ctx, "Failed to parse refresh token request", err)
appErr := errors.NewValidationError("Invalid refresh token request").WithDetails(err.Error())
c.Error(appErr)
return
}
// Call service to refresh token
accessToken, newRefreshToken, err := h.userService.RefreshToken(ctx, req.RefreshToken)
if err != nil {
logger.Errorf(ctx, "Failed to refresh token: %v", err)
appErr := errors.NewUnauthorizedError("Token refresh failed").WithDetails(err.Error())
c.Error(appErr)
return
}
logger.Info(ctx, "Token refreshed successfully")
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "Token refreshed successfully",
"access_token": accessToken,
"refresh_token": newRefreshToken,
})
}
// GetCurrentUser handles the HTTP request for getting current user information
// It extracts the user from the context (set by auth middleware) and returns user info
// Parameters:
// - c: Gin context for the HTTP request
func (h *AuthHandler) GetCurrentUser(c *gin.Context) {
ctx := c.Request.Context()
logger.Info(ctx, "Get current user info")
// Get current user from service (which extracts from context)
user, err := h.userService.GetCurrentUser(ctx)
if err != nil {
logger.Errorf(ctx, "Failed to get current user: %v", err)
appErr := errors.NewUnauthorizedError("Failed to get user information").WithDetails(err.Error())
c.Error(appErr)
return
}
logger.Infof(ctx, "Retrieved current user info: %s", user.Email)
c.JSON(http.StatusOK, gin.H{
"success": true,
"user": user.ToUserInfo(),
})
}
// ChangePassword handles the HTTP request for changing user password
// It extracts the current user and validates the old password before setting new one
// Parameters:
// - c: Gin context for the HTTP request
func (h *AuthHandler) ChangePassword(c *gin.Context) {
ctx := c.Request.Context()
logger.Info(ctx, "Start password change")
var req struct {
OldPassword string `json:"old_password" binding:"required"`
NewPassword string `json:"new_password" binding:"required,min=6"`
}
if err := c.ShouldBindJSON(&req); err != nil {
logger.Error(ctx, "Failed to parse password change request", err)
appErr := errors.NewValidationError("Invalid password change request").WithDetails(err.Error())
c.Error(appErr)
return
}
// Get current user
user, err := h.userService.GetCurrentUser(ctx)
if err != nil {
logger.Errorf(ctx, "Failed to get current user: %v", err)
appErr := errors.NewUnauthorizedError("Failed to get user information").WithDetails(err.Error())
c.Error(appErr)
return
}
// Change password
err = h.userService.ChangePassword(ctx, user.ID, req.OldPassword, req.NewPassword)
if err != nil {
logger.Errorf(ctx, "Failed to change password: %v", err)
appErr := errors.NewBadRequestError("Password change failed").WithDetails(err.Error())
c.Error(appErr)
return
}
logger.Infof(ctx, "Password changed successfully for user: %s", user.Email)
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "Password changed successfully",
})
}
// ValidateToken handles the HTTP request for validating access tokens
// It extracts the token from the Authorization header and validates it
// Parameters:
// - c: Gin context for the HTTP request
func (h *AuthHandler) ValidateToken(c *gin.Context) {
ctx := c.Request.Context()
logger.Info(ctx, "Start token validation")
// Extract token from Authorization header
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
logger.Error(ctx, "Missing Authorization header")
appErr := errors.NewValidationError("Authorization header is required")
c.Error(appErr)
return
}
// Parse Bearer token
tokenParts := strings.Split(authHeader, " ")
if len(tokenParts) != 2 || tokenParts[0] != "Bearer" {
logger.Error(ctx, "Invalid Authorization header format")
appErr := errors.NewValidationError("Invalid Authorization header format")
c.Error(appErr)
return
}
token := tokenParts[1]
// Validate token
user, err := h.userService.ValidateToken(ctx, token)
if err != nil {
logger.Errorf(ctx, "Failed to validate token: %v", err)
appErr := errors.NewUnauthorizedError("Token validation failed").WithDetails(err.Error())
c.Error(appErr)
return
}
logger.Infof(ctx, "Token validated successfully for user: %s", user.Email)
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "Token is valid",
"user": user.ToUserInfo(),
})
}

View File

@@ -141,8 +141,10 @@ func (h *InitializationHandler) CheckStatus(c *gin.Context) {
ctx := c.Request.Context()
logger.Info(ctx, "Checking system initialization status")
tenantID := ctx.Value(types.TenantIDContextKey).(uint)
// 检查是否存在租户
tenant, err := h.tenantService.GetTenantByID(ctx, types.InitDefaultTenantID)
tenant, err := h.tenantService.GetTenantByID(ctx, tenantID)
if err != nil {
logger.ErrorWithFields(ctx, err, nil)
c.JSON(http.StatusOK, gin.H{
@@ -165,7 +167,6 @@ func (h *InitializationHandler) CheckStatus(c *gin.Context) {
})
return
}
ctx = context.WithValue(ctx, types.TenantIDContextKey, types.InitDefaultTenantID)
// 检查是否存在模型
models, err := h.modelService.ListModels(ctx)
@@ -194,6 +195,7 @@ func (h *InitializationHandler) Initialize(c *gin.Context) {
ctx := c.Request.Context()
logger.Info(ctx, "Starting system initialization")
tenantID := ctx.Value(types.TenantIDContextKey).(uint)
var req InitializationRequest
if err := c.ShouldBindJSON(&req); err != nil {
@@ -259,63 +261,16 @@ func (h *InitializationHandler) Initialize(c *gin.Context) {
}
var err error
// 1. 处理租户 - 检查是否存在,不存在则创建
tenant, _ := h.tenantService.GetTenantByID(ctx, types.InitDefaultTenantID)
tenant, _ := h.tenantService.GetTenantByID(ctx, tenantID)
if tenant == nil {
logger.Info(ctx, "Tenant not found, creating tenant")
// 创建默认租户
tenant = &types.Tenant{
ID: types.InitDefaultTenantID,
Name: "Default Tenant",
Description: "System Default Tenant",
RetrieverEngines: types.RetrieverEngines{
Engines: []types.RetrieverEngineParams{
{
RetrieverType: types.KeywordsRetrieverType,
RetrieverEngineType: types.PostgresRetrieverEngineType,
},
{
RetrieverType: types.VectorRetrieverType,
RetrieverEngineType: types.PostgresRetrieverEngineType,
},
},
},
}
logger.Info(ctx, "Creating default tenant")
tenant, err = h.tenantService.CreateTenant(ctx, tenant)
if err != nil {
logger.ErrorWithFields(ctx, err, nil)
c.Error(errors.NewInternalServerError("创建租户失败: " + err.Error()))
return
}
} else {
logger.Info(ctx, "Tenant exists, updating if needed")
// 更新租户信息(如果需要)
updated := false
if tenant.Name != "Default Tenant" {
tenant.Name = "Default Tenant"
updated = true
}
if tenant.Description != "System Default Tenant" {
tenant.Description = "System Default Tenant"
updated = true
}
if updated {
_, err = h.tenantService.UpdateTenant(ctx, tenant)
if err != nil {
logger.ErrorWithFields(ctx, err, nil)
c.Error(errors.NewInternalServerError("更新租户失败: " + err.Error()))
return
}
logger.Info(ctx, "Tenant updated successfully")
}
err = errors.NewInternalServerError("Failed to get tenant")
c.Error(err)
return
}
// 创建带有租户ID的新上下文
newCtx := context.WithValue(ctx, types.TenantIDContextKey, types.InitDefaultTenantID)
// 2. 处理模型 - 检查现有模型并更新或创建
existingModels, err := h.modelService.ListModels(newCtx)
existingModels, err := h.modelService.ListModels(ctx)
if err != nil {
logger.ErrorWithFields(ctx, err, nil)
// 如果获取失败,继续执行创建流程
@@ -420,7 +375,7 @@ func (h *InitializationHandler) Initialize(c *gin.Context) {
existingModel.IsDefault = true
existingModel.Status = types.ModelStatusActive
err := h.modelService.UpdateModel(newCtx, existingModel)
err := h.modelService.UpdateModel(ctx, existingModel)
if err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{
"model_name": modelConfig.name,
@@ -437,7 +392,7 @@ func (h *InitializationHandler) Initialize(c *gin.Context) {
modelConfig.name, modelConfig.modelType,
)
newModel := &types.Model{
TenantID: types.InitDefaultTenantID,
TenantID: tenantID,
Name: modelConfig.name,
Type: modelConfig.modelType,
Source: modelConfig.source,
@@ -453,7 +408,7 @@ func (h *InitializationHandler) Initialize(c *gin.Context) {
Status: types.ModelStatusActive,
}
err := h.modelService.CreateModel(newCtx, newModel)
err := h.modelService.CreateModel(ctx, newModel)
if err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{
"model_name": modelConfig.name,
@@ -470,7 +425,7 @@ func (h *InitializationHandler) Initialize(c *gin.Context) {
if !req.Multimodal.Enabled {
if existingVLM, exists := modelMap[types.ModelTypeVLLM]; exists {
logger.Info(ctx, "Deleting VLM model as multimodal is disabled")
err := h.modelService.DeleteModel(newCtx, existingVLM.ID)
err := h.modelService.DeleteModel(ctx, existingVLM.ID)
if err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{
"model_id": existingVLM.ID,
@@ -485,7 +440,7 @@ func (h *InitializationHandler) Initialize(c *gin.Context) {
if !req.Rerank.Enabled {
if existingRerank, exists := modelMap[types.ModelTypeRerank]; exists {
logger.Info(ctx, "Deleting Rerank model as rerank is disabled")
err := h.modelService.DeleteModel(newCtx, existingRerank.ID)
err := h.modelService.DeleteModel(ctx, existingRerank.ID)
if err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{
"model_id": existingRerank.ID,
@@ -497,7 +452,7 @@ func (h *InitializationHandler) Initialize(c *gin.Context) {
}
// 3. 处理知识库 - 检查是否存在,不存在则创建,存在则更新
kb, err := h.kbService.GetKnowledgeBaseByID(newCtx, types.InitDefaultKnowledgeBaseID)
kbs, err := h.kbService.ListKnowledgeBases(ctx)
// 找到embedding模型ID和LLM模型ID
var embeddingModelID, llmModelID, rerankModelID, vlmModelID string
@@ -516,14 +471,16 @@ func (h *InitializationHandler) Initialize(c *gin.Context) {
}
}
if kb == nil {
var kb *types.KnowledgeBase
if len(kbs) == 0 {
// 创建新知识库
logger.Info(ctx, "Creating default knowledge base")
kb = &types.KnowledgeBase{
ID: types.InitDefaultKnowledgeBaseID,
ID: uuid.New().String(),
Name: "Default Knowledge Base",
Description: "System Default Knowledge Base",
TenantID: types.InitDefaultTenantID,
TenantID: tenantID,
ChunkingConfig: types.ChunkingConfig{
ChunkSize: req.DocumentSplitting.ChunkSize,
ChunkOverlap: req.DocumentSplitting.ChunkOverlap,
@@ -566,7 +523,7 @@ func (h *InitializationHandler) Initialize(c *gin.Context) {
}
}
_, err = h.kbService.CreateKnowledgeBase(newCtx, kb)
_, err = h.kbService.CreateKnowledgeBase(ctx, kb)
if err != nil {
logger.ErrorWithFields(ctx, err, nil)
c.Error(errors.NewInternalServerError("创建知识库失败: " + err.Error()))
@@ -575,10 +532,11 @@ func (h *InitializationHandler) Initialize(c *gin.Context) {
} else {
// 更新现有知识库
logger.Info(ctx, "Updating existing knowledge base")
kb = kbs[0]
// 检查是否有文件如果有文件则不允许修改Embedding模型
knowledgeList, err := h.knowledgeService.ListKnowledgeByKnowledgeBaseID(
newCtx, types.InitDefaultKnowledgeBaseID,
ctx, kb.ID,
)
hasFiles := err == nil && len(knowledgeList) > 0
@@ -639,7 +597,7 @@ func (h *InitializationHandler) Initialize(c *gin.Context) {
}
// 更新基本信息和配置
err = h.kbRepository.UpdateKnowledgeBase(newCtx, kb)
err = h.kbRepository.UpdateKnowledgeBase(ctx, kb)
if err != nil {
logger.ErrorWithFields(ctx, err, nil)
c.Error(errors.NewInternalServerError("更新知识库配置失败: " + err.Error()))
@@ -649,7 +607,7 @@ func (h *InitializationHandler) Initialize(c *gin.Context) {
// 如果需要更新模型ID使用repository直接更新
if !hasFiles || kb.SummaryModelID != llmModelID {
// 刷新知识库对象以获取最新信息
kb, err = h.kbService.GetKnowledgeBaseByID(newCtx, types.InitDefaultKnowledgeBaseID)
kb, err = h.kbService.GetKnowledgeBaseByID(ctx, kb.ID)
if err != nil {
logger.ErrorWithFields(ctx, err, nil)
c.Error(errors.NewInternalServerError("获取更新后的知识库失败: " + err.Error()))
@@ -665,7 +623,7 @@ func (h *InitializationHandler) Initialize(c *gin.Context) {
}
// 使用repository直接更新模型ID
err = h.kbRepository.UpdateKnowledgeBase(newCtx, kb)
err = h.kbRepository.UpdateKnowledgeBase(ctx, kb)
if err != nil {
logger.ErrorWithFields(ctx, err, nil)
c.Error(errors.NewInternalServerError("更新知识库模型ID失败: " + err.Error()))
@@ -1074,11 +1032,8 @@ func (h *InitializationHandler) GetCurrentConfig(c *gin.Context) {
logger.Info(ctx, "Getting current system configuration")
// 设置租户上下文
newCtx := context.WithValue(ctx, types.TenantIDContextKey, types.InitDefaultTenantID)
// 获取模型信息
models, err := h.modelService.ListModels(newCtx)
models, err := h.modelService.ListModels(ctx)
if err != nil {
logger.ErrorWithFields(ctx, err, nil)
c.Error(errors.NewInternalServerError("获取模型列表失败: " + err.Error()))
@@ -1086,16 +1041,24 @@ func (h *InitializationHandler) GetCurrentConfig(c *gin.Context) {
}
// 获取知识库信息
kb, err := h.kbService.GetKnowledgeBaseByID(newCtx, types.InitDefaultKnowledgeBaseID)
kbs, err := h.kbService.ListKnowledgeBases(ctx)
if err != nil {
logger.ErrorWithFields(ctx, err, nil)
c.Error(errors.NewInternalServerError("获取知识库信息失败: " + err.Error()))
return
}
if len(kbs) == 0 {
logger.Error(ctx, "No knowledge bases found")
c.Error(errors.NewInternalServerError("获取知识库信息失败"))
return
}
kb := kbs[0]
// 检查知识库是否有文件
knowledgeList, err := h.knowledgeService.ListPagedKnowledgeByKnowledgeBaseID(newCtx,
types.InitDefaultKnowledgeBaseID, &types.Pagination{
knowledgeList, err := h.knowledgeService.ListPagedKnowledgeByKnowledgeBaseID(ctx,
kb.ID, &types.Pagination{
Page: 1,
PageSize: 1,
})

View File

@@ -6,6 +6,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/Tencent/WeKnora/internal/config"
"github.com/Tencent/WeKnora/internal/errors"
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
@@ -48,7 +49,7 @@ func (h *TestDataHandler) GetTestData(c *gin.Context) {
logger.Info(ctx, "Start retrieving test data")
tenantID := uint(types.InitDefaultTenantID)
tenantID := c.GetUint(types.TenantIDContextKey.String())
logger.Debugf(ctx, "Test tenant ID environment variable: %d", tenantID)
// Retrieve the test tenant data
@@ -60,24 +61,26 @@ func (h *TestDataHandler) GetTestData(c *gin.Context) {
return
}
knowledgeBaseID := types.InitDefaultKnowledgeBaseID
logger.Debugf(ctx, "Test knowledge base ID environment variable: %s", knowledgeBaseID)
// Retrieve the test knowledge base data
logger.Infof(ctx, "Retrieving test knowledge base, ID: %s", knowledgeBaseID)
knowledgeBase, err := h.kbService.GetKnowledgeBaseByID(ctx, knowledgeBaseID)
kbs, err := h.kbService.ListKnowledgeBases(ctx)
if err != nil {
logger.ErrorWithFields(ctx, err, nil)
c.Error(err)
return
}
if len(kbs) == 0 {
logger.Error(ctx, "No knowledge bases found")
c.Error(errors.NewInternalServerError("获取知识库信息失败"))
return
}
logger.Info(ctx, "Test data retrieved successfully")
// Return the test data in the response
c.JSON(http.StatusOK, gin.H{
"data": gin.H{
"tenant": tenant,
"knowledge_bases": []types.KnowledgeBase{*knowledgeBase},
"knowledge_bases": kbs,
},
"success": true,
})

View File

@@ -16,10 +16,10 @@ import (
// 无需认证的API列表
var noAuthAPI = map[string][]string{
"/api/v1/test-data": {"GET"},
"/api/v1/tenants": {"POST"},
"/api/v1/initialization/*": {"GET", "POST"},
"/health": {"GET"},
"/health": {"GET"},
"/api/v1/auth/register": {"POST"},
"/api/v1/auth/login": {"POST"},
"/api/v1/auth/refresh": {"POST"},
}
// 检查请求是否在无需认证的API列表中
@@ -38,7 +38,7 @@ func isNoAuthAPI(path string, method string) bool {
}
// Auth 认证中间件
func Auth(tenantService interfaces.TenantService, cfg *config.Config) gin.HandlerFunc {
func Auth(tenantService interfaces.TenantService, userService interfaces.UserService, cfg *config.Config) gin.HandlerFunc {
return func(c *gin.Context) {
// ignore OPTIONS request
if c.Request.Method == "OPTIONS" {
@@ -52,53 +52,90 @@ func Auth(tenantService interfaces.TenantService, cfg *config.Config) gin.Handle
return
}
// Get API Key from request header
// 尝试JWT Token认证
authHeader := c.GetHeader("Authorization")
if authHeader != "" && strings.HasPrefix(authHeader, "Bearer ") {
token := strings.TrimPrefix(authHeader, "Bearer ")
user, err := userService.ValidateToken(c.Request.Context(), token)
if err == nil && user != nil {
// JWT Token认证成功
// 获取租户信息
tenant, err := tenantService.GetTenantByID(c.Request.Context(), user.TenantID)
if err != nil {
log.Printf("Error getting tenant by ID: %v, tenantID: %d, userID: %s", err, user.TenantID, user.ID)
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Unauthorized: invalid tenant",
})
c.Abort()
return
}
// 存储用户和租户信息到上下文
c.Set(types.TenantIDContextKey.String(), user.TenantID)
c.Set(types.TenantInfoContextKey.String(), tenant)
c.Set("user", user)
c.Request = c.Request.WithContext(
context.WithValue(
context.WithValue(
context.WithValue(c.Request.Context(), types.TenantIDContextKey, user.TenantID),
types.TenantInfoContextKey, tenant,
),
"user", user,
),
)
c.Next()
return
}
}
// 尝试X-API-Key认证兼容模式
apiKey := c.GetHeader("X-API-Key")
if apiKey == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
c.Abort()
if apiKey != "" {
// Get tenant information
tenantID, err := tenantService.ExtractTenantIDFromAPIKey(apiKey)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Unauthorized: invalid API key format",
})
c.Abort()
return
}
// Verify API key validity (matches the one in database)
t, err := tenantService.GetTenantByID(c.Request.Context(), tenantID)
if err != nil {
log.Printf("Error getting tenant by ID: %v, tenantID: %d, apiKey: %s", err, tenantID, apiKey)
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Unauthorized: invalid API key",
})
c.Abort()
return
}
if t == nil || t.APIKey != apiKey {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Unauthorized: invalid API key",
})
c.Abort()
return
}
// Store tenant ID in context
c.Set(types.TenantIDContextKey.String(), tenantID)
c.Set(types.TenantInfoContextKey.String(), t)
c.Request = c.Request.WithContext(
context.WithValue(
context.WithValue(c.Request.Context(), types.TenantIDContextKey, tenantID),
types.TenantInfoContextKey, t,
),
)
c.Next()
return
}
// Get tenant information
tenantID, err := tenantService.ExtractTenantIDFromAPIKey(apiKey)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Unauthorized: invalid API key format",
})
c.Abort()
return
}
// Verify API key validity (matches the one in database)
t, err := tenantService.GetTenantByID(c.Request.Context(), tenantID)
if err != nil {
log.Printf("Error getting tenant by ID: %v, tenantID: %d, apiKey: %s", err, tenantID, apiKey)
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Unauthorized: invalid API key",
})
c.Abort()
return
}
if t == nil || t.APIKey != apiKey {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Unauthorized: invalid API key",
})
c.Abort()
return
}
// Store tenant ID in context
c.Set(types.TenantIDContextKey.String(), tenantID)
c.Set(types.TenantInfoContextKey.String(), t)
c.Request = c.Request.WithContext(
context.WithValue(
context.WithValue(c.Request.Context(), types.TenantIDContextKey, tenantID),
types.TenantInfoContextKey, t,
),
)
c.Next()
// 没有提供任何认证信息
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized: missing authentication"})
c.Abort()
}
}

View File

@@ -18,6 +18,14 @@ type RouterParams struct {
dig.In
Config *config.Config
UserService interfaces.UserService
KBService interfaces.KnowledgeBaseService
KnowledgeService interfaces.KnowledgeService
ChunkService interfaces.ChunkService
SessionService interfaces.SessionService
MessageService interfaces.MessageService
ModelService interfaces.ModelService
EvaluationService interfaces.EvaluationService
KBHandler *handler.KnowledgeBaseHandler
KnowledgeHandler *handler.KnowledgeHandler
TenantHandler *handler.TenantHandler
@@ -28,6 +36,7 @@ type RouterParams struct {
TestDataHandler *handler.TestDataHandler
ModelHandler *handler.ModelHandler
EvaluationHandler *handler.EvaluationHandler
AuthHandler *handler.AuthHandler
InitializationHandler *handler.InitializationHandler
}
@@ -50,7 +59,7 @@ func NewRouter(params RouterParams) *gin.Engine {
r.Use(middleware.Logger())
r.Use(middleware.Recovery())
r.Use(middleware.ErrorHandler())
r.Use(middleware.Auth(params.TenantService, params.Config))
r.Use(middleware.Auth(params.TenantService, params.UserService, params.Config))
// 添加OpenTelemetry追踪中间件
r.Use(middleware.TracingMiddleware())
@@ -60,31 +69,10 @@ func NewRouter(params RouterParams) *gin.Engine {
c.JSON(200, gin.H{"status": "ok"})
})
// 测试数据接口(不需要认证)
r.GET("/api/v1/test-data", params.TestDataHandler.GetTestData)
// 初始化接口(不需要认证)
r.GET("/api/v1/initialization/status", params.InitializationHandler.CheckStatus)
r.GET("/api/v1/initialization/config", params.InitializationHandler.GetCurrentConfig)
r.POST("/api/v1/initialization/initialize", params.InitializationHandler.Initialize)
// Ollama相关接口不需要认证
r.GET("/api/v1/initialization/ollama/status", params.InitializationHandler.CheckOllamaStatus)
r.GET("/api/v1/initialization/ollama/models", params.InitializationHandler.ListOllamaModels)
r.POST("/api/v1/initialization/ollama/models/check", params.InitializationHandler.CheckOllamaModels)
r.POST("/api/v1/initialization/ollama/models/download", params.InitializationHandler.DownloadOllamaModel)
r.GET("/api/v1/initialization/ollama/download/progress/:taskId", params.InitializationHandler.GetDownloadProgress)
r.GET("/api/v1/initialization/ollama/download/tasks", params.InitializationHandler.ListDownloadTasks)
// 远程API相关接口不需要认证
r.POST("/api/v1/initialization/remote/check", params.InitializationHandler.CheckRemoteModel)
r.POST("/api/v1/initialization/embedding/test", params.InitializationHandler.TestEmbeddingModel)
r.POST("/api/v1/initialization/rerank/check", params.InitializationHandler.CheckRerankModel)
r.POST("/api/v1/initialization/multimodal/test", params.InitializationHandler.TestMultimodalFunction)
// 需要认证的API路由
v1 := r.Group("/api/v1")
{
RegisterAuthRoutes(v1, params.AuthHandler)
RegisterTenantRoutes(v1, params.TenantHandler)
RegisterKnowledgeBaseRoutes(v1, params.KBHandler)
RegisterKnowledgeRoutes(v1, params.KnowledgeHandler)
@@ -94,6 +82,8 @@ func NewRouter(params RouterParams) *gin.Engine {
RegisterMessageRoutes(v1, params.MessageHandler)
RegisterModelRoutes(v1, params.ModelHandler)
RegisterEvaluationRoutes(v1, params.EvaluationHandler)
RegisterInitializationRoutes(v1, params.InitializationHandler)
RegisterTestDataRoutes(v1, params.TestDataHandler)
}
return r
@@ -247,3 +237,39 @@ func RegisterEvaluationRoutes(r *gin.RouterGroup, handler *handler.EvaluationHan
evaluationRoutes.GET("/", handler.GetEvaluationResult)
}
}
func RegisterTestDataRoutes(r *gin.RouterGroup, handler *handler.TestDataHandler) {
r.GET("/test-data", handler.GetTestData)
}
// RegisterAuthRoutes registers authentication routes
func RegisterAuthRoutes(r *gin.RouterGroup, handler *handler.AuthHandler) {
r.POST("/auth/register", handler.Register)
r.POST("/auth/login", handler.Login)
r.POST("/auth/refresh", handler.RefreshToken)
r.GET("/auth/validate", handler.ValidateToken)
r.POST("/auth/logout", handler.Logout)
r.GET("/auth/me", handler.GetCurrentUser)
r.POST("/auth/change-password", handler.ChangePassword)
}
func RegisterInitializationRoutes(r *gin.RouterGroup, handler *handler.InitializationHandler) {
// 初始化接口
r.GET("/initialization/status", handler.CheckStatus)
r.GET("/initialization/config", handler.GetCurrentConfig)
r.POST("/initialization/initialize", handler.Initialize)
// Ollama相关接口
r.GET("/initialization/ollama/status", handler.CheckOllamaStatus)
r.GET("/initialization/ollama/models", handler.ListOllamaModels)
r.POST("/initialization/ollama/models/check", handler.CheckOllamaModels)
r.POST("/initialization/ollama/models/download", handler.DownloadOllamaModel)
r.GET("/initialization/ollama/download/progress/:taskId", handler.GetDownloadProgress)
r.GET("/initialization/ollama/download/tasks", handler.ListDownloadTasks)
// 远程API相关接口
r.POST("/initialization/remote/check", handler.CheckRemoteModel)
r.POST("/initialization/embedding/test", handler.TestEmbeddingModel)
r.POST("/initialization/rerank/check", handler.CheckRerankModel)
r.POST("/initialization/multimodal/test", handler.TestMultimodalFunction)
}

View File

@@ -0,0 +1,75 @@
package interfaces
import (
"context"
"github.com/Tencent/WeKnora/internal/types"
)
// UserService defines the user service interface
type UserService interface {
// Register creates a new user account
Register(ctx context.Context, req *types.RegisterRequest) (*types.User, error)
// Login authenticates a user and returns tokens
Login(ctx context.Context, req *types.LoginRequest) (*types.LoginResponse, error)
// GetUserByID gets a user by ID
GetUserByID(ctx context.Context, id string) (*types.User, error)
// GetUserByEmail gets a user by email
GetUserByEmail(ctx context.Context, email string) (*types.User, error)
// GetUserByUsername gets a user by username
GetUserByUsername(ctx context.Context, username string) (*types.User, error)
// UpdateUser updates user information
UpdateUser(ctx context.Context, user *types.User) error
// DeleteUser deletes a user
DeleteUser(ctx context.Context, id string) error
// ChangePassword changes user password
ChangePassword(ctx context.Context, userID string, oldPassword, newPassword string) error
// ValidatePassword validates user password
ValidatePassword(ctx context.Context, userID string, password string) error
// GenerateTokens generates access and refresh tokens for user
GenerateTokens(ctx context.Context, user *types.User) (accessToken, refreshToken string, err error)
// ValidateToken validates an access token
ValidateToken(ctx context.Context, token string) (*types.User, error)
// RefreshToken refreshes access token using refresh token
RefreshToken(ctx context.Context, refreshToken string) (accessToken, newRefreshToken string, err error)
// RevokeToken revokes a token
RevokeToken(ctx context.Context, token string) error
// GetCurrentUser gets current user from context
GetCurrentUser(ctx context.Context) (*types.User, error)
}
// UserRepository defines the user repository interface
type UserRepository interface {
// CreateUser creates a user
CreateUser(ctx context.Context, user *types.User) error
// GetUserByID gets a user by ID
GetUserByID(ctx context.Context, id string) (*types.User, error)
// GetUserByEmail gets a user by email
GetUserByEmail(ctx context.Context, email string) (*types.User, error)
// GetUserByUsername gets a user by username
GetUserByUsername(ctx context.Context, username string) (*types.User, error)
// UpdateUser updates a user
UpdateUser(ctx context.Context, user *types.User) error
// DeleteUser deletes a user
DeleteUser(ctx context.Context, id string) error
// ListUsers lists users with pagination
ListUsers(ctx context.Context, offset, limit int) ([]*types.User, error)
}
// AuthTokenRepository defines the auth token repository interface
type AuthTokenRepository interface {
// CreateToken creates an auth token
CreateToken(ctx context.Context, token *types.AuthToken) error
// GetTokenByValue gets a token by its value
GetTokenByValue(ctx context.Context, tokenValue string) (*types.AuthToken, error)
// GetTokensByUserID gets all tokens for a user
GetTokensByUserID(ctx context.Context, userID string) ([]*types.AuthToken, error)
// UpdateToken updates a token
UpdateToken(ctx context.Context, token *types.AuthToken) error
// DeleteToken deletes a token
DeleteToken(ctx context.Context, id string) error
// DeleteExpiredTokens deletes all expired tokens
DeleteExpiredTokens(ctx context.Context) error
// RevokeTokensByUserID revokes all tokens for a user
RevokeTokensByUserID(ctx context.Context, userID string) error
}

114
internal/types/user.go Normal file
View File

@@ -0,0 +1,114 @@
package types
import (
"time"
"gorm.io/gorm"
)
// User represents a user in the system
type User struct {
// Unique identifier of the user
ID string `json:"id" gorm:"type:varchar(36);primaryKey"`
// Username of the user
Username string `json:"username" gorm:"type:varchar(100);uniqueIndex;not null"`
// Email address of the user
Email string `json:"email" gorm:"type:varchar(255);uniqueIndex;not null"`
// Hashed password of the user
PasswordHash string `json:"-" gorm:"type:varchar(255);not null"`
// Avatar URL of the user
Avatar string `json:"avatar" gorm:"type:varchar(500)"`
// Tenant ID that the user belongs to
TenantID uint `json:"tenant_id" gorm:"index"`
// Whether the user is active
IsActive bool `json:"is_active" gorm:"default:true"`
// Creation time of the user
CreatedAt time.Time `json:"created_at"`
// Last updated time of the user
UpdatedAt time.Time `json:"updated_at"`
// Deletion time of the user
DeletedAt gorm.DeletedAt `json:"deleted_at" gorm:"index"`
// Association relationship, not stored in the database
Tenant *Tenant `json:"tenant,omitempty" gorm:"foreignKey:TenantID"`
}
// AuthToken represents an authentication token
type AuthToken struct {
// Unique identifier of the token
ID string `json:"id" gorm:"type:varchar(36);primaryKey"`
// User ID that owns this token
UserID string `json:"user_id" gorm:"type:varchar(36);index;not null"`
// Token value (JWT or other format)
Token string `json:"token" gorm:"type:text;not null"`
// Token type (access_token, refresh_token)
TokenType string `json:"token_type" gorm:"type:varchar(50);not null"`
// Token expiration time
ExpiresAt time.Time `json:"expires_at"`
// Whether the token is revoked
IsRevoked bool `json:"is_revoked" gorm:"default:false"`
// Creation time of the token
CreatedAt time.Time `json:"created_at"`
// Last updated time of the token
UpdatedAt time.Time `json:"updated_at"`
// Association relationship
User *User `json:"user,omitempty" gorm:"foreignKey:UserID"`
}
// LoginRequest represents a login request
type LoginRequest struct {
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required,min=6"`
}
// RegisterRequest represents a registration request
type RegisterRequest struct {
Username string `json:"username" binding:"required,min=3,max=50"`
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required,min=6"`
}
// LoginResponse represents a login response
type LoginResponse struct {
Success bool `json:"success"`
Message string `json:"message,omitempty"`
User *User `json:"user,omitempty"`
Tenant *Tenant `json:"tenant,omitempty"`
Token string `json:"token,omitempty"`
RefreshToken string `json:"refresh_token,omitempty"`
}
// RegisterResponse represents a registration response
type RegisterResponse struct {
Success bool `json:"success"`
Message string `json:"message,omitempty"`
User *User `json:"user,omitempty"`
Tenant *Tenant `json:"tenant,omitempty"`
}
// UserInfo represents user information for API responses
type UserInfo struct {
ID string `json:"id"`
Username string `json:"username"`
Email string `json:"email"`
Avatar string `json:"avatar"`
TenantID uint `json:"tenant_id"`
IsActive bool `json:"is_active"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// ToUserInfo converts User to UserInfo (without sensitive data)
func (u *User) ToUserInfo() *UserInfo {
return &UserInfo{
ID: u.ID,
Username: u.Username,
Email: u.Email,
Avatar: u.Avatar,
TenantID: u.TenantID,
IsActive: u.IsActive,
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
}
}

View File

@@ -78,13 +78,17 @@ check_platform() {
log_info "检测系统平台信息..."
if [ "$(uname -m)" = "x86_64" ]; then
export PLATFORM="linux/amd64"
export TARGETARCH="amd64"
elif [ "$(uname -m)" = "aarch64" ] || [ "$(uname -m)" = "arm64" ]; then
export PLATFORM="linux/arm64"
export TARGETARCH="arm64"
else
log_warning "未识别的平台类型:$(uname -m),将使用默认平台 linux/amd64"
export PLATFORM="linux/amd64"
export TARGETARCH="amd64"
fi
log_info "当前平台:$PLATFORM"
log_info "当前架构:$TARGETARCH"
}
# 构建应用镜像
@@ -119,7 +123,7 @@ build_docreader_image() {
docker build \
--platform $PLATFORM \
--build-arg PLATFORM=$PLATFORM \
--build-arg TARGETARCH=$TARGETARCH \
-f docker/Dockerfile.docreader \
-t wechatopenai/weknora-docreader:latest \
.