mirror of
https://github.com/Tencent/WeKnora.git
synced 2025-11-25 03:15:00 +08:00
450 lines
13 KiB
Go
450 lines
13 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"github.com/google/uuid"
|
|
"golang.org/x/crypto/bcrypt"
|
|
|
|
"github.com/Tencent/WeKnora/internal/logger"
|
|
"github.com/Tencent/WeKnora/internal/types"
|
|
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
|
)
|
|
|
|
// JWT secret key - in production this should be from environment variable
|
|
var jwtSecret = []byte("your-secret-key")
|
|
|
|
// userService implements the UserService interface
|
|
type userService struct {
|
|
userRepo interfaces.UserRepository
|
|
tokenRepo interfaces.AuthTokenRepository
|
|
tenantService interfaces.TenantService
|
|
}
|
|
|
|
// NewUserService creates a new user service instance
|
|
func NewUserService(userRepo interfaces.UserRepository, tokenRepo interfaces.AuthTokenRepository, tenantService interfaces.TenantService) interfaces.UserService {
|
|
return &userService{
|
|
userRepo: userRepo,
|
|
tokenRepo: tokenRepo,
|
|
tenantService: tenantService,
|
|
}
|
|
}
|
|
|
|
var engine = map[string][]types.RetrieverEngineParams{
|
|
"postgres": {
|
|
{
|
|
RetrieverType: types.KeywordsRetrieverType,
|
|
RetrieverEngineType: types.PostgresRetrieverEngineType,
|
|
},
|
|
{
|
|
RetrieverType: types.VectorRetrieverType,
|
|
RetrieverEngineType: types.PostgresRetrieverEngineType,
|
|
},
|
|
},
|
|
"elasticsearch_v7": {
|
|
{
|
|
RetrieverType: types.KeywordsRetrieverType,
|
|
RetrieverEngineType: types.ElasticsearchRetrieverEngineType,
|
|
},
|
|
},
|
|
"elasticsearch_v8": {
|
|
{
|
|
RetrieverType: types.KeywordsRetrieverType,
|
|
RetrieverEngineType: types.ElasticsearchRetrieverEngineType,
|
|
},
|
|
{
|
|
RetrieverType: types.VectorRetrieverType,
|
|
RetrieverEngineType: types.ElasticsearchRetrieverEngineType,
|
|
},
|
|
},
|
|
}
|
|
|
|
// Register creates a new user account
|
|
func (s *userService) Register(ctx context.Context, req *types.RegisterRequest) (*types.User, error) {
|
|
logger.Info(ctx, "Start user registration")
|
|
|
|
// Validate input
|
|
if req.Username == "" || req.Email == "" || req.Password == "" {
|
|
return nil, errors.New("username, email and password are required")
|
|
}
|
|
|
|
// Check if user already exists
|
|
existingUser, _ := s.userRepo.GetUserByEmail(ctx, req.Email)
|
|
if existingUser != nil {
|
|
return nil, errors.New("user with this email already exists")
|
|
}
|
|
|
|
existingUser, _ = s.userRepo.GetUserByUsername(ctx, req.Username)
|
|
if existingUser != nil {
|
|
return nil, errors.New("user with this username already exists")
|
|
}
|
|
|
|
// Hash password
|
|
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
logger.Errorf(ctx, "Failed to hash password: %v", err)
|
|
return nil, errors.New("failed to process password")
|
|
}
|
|
|
|
egs := []types.RetrieverEngineParams{}
|
|
for _, driver := range strings.Split(os.Getenv("RETRIEVE_DRIVER"), ",") {
|
|
if val, ok := engine[driver]; ok {
|
|
egs = append(egs, val...)
|
|
}
|
|
}
|
|
egs = uniqueRetrieverEngine(egs)
|
|
logger.Debugf(ctx, "user register retriever engines: %v", egs)
|
|
|
|
// Create default tenant for the user
|
|
tenant := &types.Tenant{
|
|
Name: fmt.Sprintf("%s's Workspace", req.Username),
|
|
Description: "Default workspace",
|
|
Status: "active",
|
|
RetrieverEngines: types.RetrieverEngines{Engines: egs},
|
|
}
|
|
|
|
createdTenant, err := s.tenantService.CreateTenant(ctx, tenant)
|
|
if err != nil {
|
|
logger.Errorf(ctx, "Failed to create tenant: %v", err)
|
|
return nil, errors.New("failed to create workspace")
|
|
}
|
|
|
|
// Create user
|
|
user := &types.User{
|
|
ID: uuid.New().String(),
|
|
Username: req.Username,
|
|
Email: req.Email,
|
|
PasswordHash: string(hashedPassword),
|
|
TenantID: createdTenant.ID,
|
|
IsActive: true,
|
|
CreatedAt: time.Now(),
|
|
UpdatedAt: time.Now(),
|
|
}
|
|
|
|
err = s.userRepo.CreateUser(ctx, user)
|
|
if err != nil {
|
|
logger.Errorf(ctx, "Failed to create user: %v", err)
|
|
return nil, errors.New("failed to create user")
|
|
}
|
|
|
|
logger.Infof(ctx, "User registered successfully: %s", user.Email)
|
|
return user, nil
|
|
}
|
|
|
|
// Login authenticates a user and returns tokens
|
|
func (s *userService) Login(ctx context.Context, req *types.LoginRequest) (*types.LoginResponse, error) {
|
|
logger.Infof(ctx, "Start user login for email: %s", req.Email)
|
|
|
|
// Get user by email
|
|
user, err := s.userRepo.GetUserByEmail(ctx, req.Email)
|
|
if err != nil {
|
|
logger.Errorf(ctx, "Failed to get user by email %s: %v", req.Email, err)
|
|
return &types.LoginResponse{
|
|
Success: false,
|
|
Message: "Invalid email or password",
|
|
}, nil
|
|
}
|
|
if user == nil {
|
|
logger.Warnf(ctx, "User not found for email: %s", req.Email)
|
|
return &types.LoginResponse{
|
|
Success: false,
|
|
Message: "Invalid email or password",
|
|
}, nil
|
|
}
|
|
|
|
logger.Infof(ctx, "Found user: ID=%s, Email=%s, IsActive=%t", user.ID, user.Email, user.IsActive)
|
|
|
|
// Check if user is active
|
|
if !user.IsActive {
|
|
logger.Warnf(ctx, "User account is disabled for email: %s", req.Email)
|
|
return &types.LoginResponse{
|
|
Success: false,
|
|
Message: "Account is disabled",
|
|
}, nil
|
|
}
|
|
|
|
// Verify password
|
|
logger.Infof(ctx, "Verifying password for user: %s", user.Email)
|
|
err = bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.Password))
|
|
if err != nil {
|
|
logger.Warnf(ctx, "Password verification failed for user %s: %v", user.Email, err)
|
|
return &types.LoginResponse{
|
|
Success: false,
|
|
Message: "Invalid email or password",
|
|
}, nil
|
|
}
|
|
logger.Infof(ctx, "Password verification successful for user: %s", user.Email)
|
|
|
|
// Generate tokens
|
|
logger.Infof(ctx, "Generating tokens for user: %s", user.Email)
|
|
accessToken, refreshToken, err := s.GenerateTokens(ctx, user)
|
|
if err != nil {
|
|
logger.Errorf(ctx, "Failed to generate tokens for user %s: %v", user.Email, err)
|
|
return &types.LoginResponse{
|
|
Success: false,
|
|
Message: "Login failed",
|
|
}, nil
|
|
}
|
|
logger.Infof(ctx, "Tokens generated successfully for user: %s", user.Email)
|
|
|
|
// Get tenant information
|
|
logger.Infof(ctx, "Getting tenant information for user %s, tenant ID: %s", user.Email, user.TenantID)
|
|
tenant, err := s.tenantService.GetTenantByID(ctx, user.TenantID)
|
|
if err != nil {
|
|
logger.Warnf(ctx, "Failed to get tenant info for user %s, tenant ID %s: %v", user.Email, user.TenantID, err)
|
|
} else {
|
|
logger.Infof(ctx, "Tenant information retrieved successfully for user: %s", user.Email)
|
|
}
|
|
|
|
logger.Infof(ctx, "User logged in successfully: %s", user.Email)
|
|
return &types.LoginResponse{
|
|
Success: true,
|
|
Message: "Login successful",
|
|
User: user,
|
|
Tenant: tenant,
|
|
Token: accessToken,
|
|
RefreshToken: refreshToken,
|
|
}, nil
|
|
}
|
|
|
|
// GetUserByID gets a user by ID
|
|
func (s *userService) GetUserByID(ctx context.Context, id string) (*types.User, error) {
|
|
return s.userRepo.GetUserByID(ctx, id)
|
|
}
|
|
|
|
// GetUserByEmail gets a user by email
|
|
func (s *userService) GetUserByEmail(ctx context.Context, email string) (*types.User, error) {
|
|
return s.userRepo.GetUserByEmail(ctx, email)
|
|
}
|
|
|
|
// GetUserByUsername gets a user by username
|
|
func (s *userService) GetUserByUsername(ctx context.Context, username string) (*types.User, error) {
|
|
return s.userRepo.GetUserByUsername(ctx, username)
|
|
}
|
|
|
|
// UpdateUser updates user information
|
|
func (s *userService) UpdateUser(ctx context.Context, user *types.User) error {
|
|
user.UpdatedAt = time.Now()
|
|
return s.userRepo.UpdateUser(ctx, user)
|
|
}
|
|
|
|
// DeleteUser deletes a user
|
|
func (s *userService) DeleteUser(ctx context.Context, id string) error {
|
|
return s.userRepo.DeleteUser(ctx, id)
|
|
}
|
|
|
|
// ChangePassword changes user password
|
|
func (s *userService) ChangePassword(ctx context.Context, userID string, oldPassword, newPassword string) error {
|
|
user, err := s.userRepo.GetUserByID(ctx, userID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Verify old password
|
|
err = bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(oldPassword))
|
|
if err != nil {
|
|
return errors.New("invalid old password")
|
|
}
|
|
|
|
// Hash new password
|
|
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
user.PasswordHash = string(hashedPassword)
|
|
user.UpdatedAt = time.Now()
|
|
|
|
return s.userRepo.UpdateUser(ctx, user)
|
|
}
|
|
|
|
// ValidatePassword validates user password
|
|
func (s *userService) ValidatePassword(ctx context.Context, userID string, password string) error {
|
|
user, err := s.userRepo.GetUserByID(ctx, userID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password))
|
|
}
|
|
|
|
// GenerateTokens generates access and refresh tokens for user
|
|
func (s *userService) GenerateTokens(ctx context.Context, user *types.User) (accessToken, refreshToken string, err error) {
|
|
// Generate access token (expires in 24 hours)
|
|
accessClaims := jwt.MapClaims{
|
|
"user_id": user.ID,
|
|
"email": user.Email,
|
|
"tenant_id": user.TenantID,
|
|
"exp": time.Now().Add(24 * time.Hour).Unix(),
|
|
"iat": time.Now().Unix(),
|
|
"type": "access",
|
|
}
|
|
|
|
accessTokenObj := jwt.NewWithClaims(jwt.SigningMethodHS256, accessClaims)
|
|
accessToken, err = accessTokenObj.SignedString(jwtSecret)
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
|
|
// Generate refresh token (expires in 7 days)
|
|
refreshClaims := jwt.MapClaims{
|
|
"user_id": user.ID,
|
|
"exp": time.Now().Add(7 * 24 * time.Hour).Unix(),
|
|
"iat": time.Now().Unix(),
|
|
"type": "refresh",
|
|
}
|
|
|
|
refreshTokenObj := jwt.NewWithClaims(jwt.SigningMethodHS256, refreshClaims)
|
|
refreshToken, err = refreshTokenObj.SignedString(jwtSecret)
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
|
|
// Store tokens in database
|
|
accessTokenRecord := &types.AuthToken{
|
|
ID: uuid.New().String(),
|
|
UserID: user.ID,
|
|
Token: accessToken,
|
|
TokenType: "access_token",
|
|
ExpiresAt: time.Now().Add(24 * time.Hour),
|
|
CreatedAt: time.Now(),
|
|
UpdatedAt: time.Now(),
|
|
}
|
|
|
|
refreshTokenRecord := &types.AuthToken{
|
|
ID: uuid.New().String(),
|
|
UserID: user.ID,
|
|
Token: refreshToken,
|
|
TokenType: "refresh_token",
|
|
ExpiresAt: time.Now().Add(7 * 24 * time.Hour),
|
|
CreatedAt: time.Now(),
|
|
UpdatedAt: time.Now(),
|
|
}
|
|
|
|
_ = s.tokenRepo.CreateToken(ctx, accessTokenRecord)
|
|
_ = s.tokenRepo.CreateToken(ctx, refreshTokenRecord)
|
|
|
|
return accessToken, refreshToken, nil
|
|
}
|
|
|
|
// ValidateToken validates an access token
|
|
func (s *userService) ValidateToken(ctx context.Context, tokenString string) (*types.User, error) {
|
|
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
|
}
|
|
return jwtSecret, nil
|
|
})
|
|
|
|
if err != nil || !token.Valid {
|
|
return nil, errors.New("invalid token")
|
|
}
|
|
|
|
claims, ok := token.Claims.(jwt.MapClaims)
|
|
if !ok {
|
|
return nil, errors.New("invalid token claims")
|
|
}
|
|
|
|
userID, ok := claims["user_id"].(string)
|
|
if !ok {
|
|
return nil, errors.New("invalid user ID in token")
|
|
}
|
|
|
|
// Check if token is revoked
|
|
tokenRecord, err := s.tokenRepo.GetTokenByValue(ctx, tokenString)
|
|
if err != nil || tokenRecord == nil || tokenRecord.IsRevoked {
|
|
return nil, errors.New("token is revoked")
|
|
}
|
|
|
|
return s.userRepo.GetUserByID(ctx, userID)
|
|
}
|
|
|
|
// RefreshToken refreshes access token using refresh token
|
|
func (s *userService) RefreshToken(ctx context.Context, refreshTokenString string) (accessToken, newRefreshToken string, err error) {
|
|
token, err := jwt.Parse(refreshTokenString, func(token *jwt.Token) (interface{}, error) {
|
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
|
}
|
|
return jwtSecret, nil
|
|
})
|
|
|
|
if err != nil || !token.Valid {
|
|
return "", "", errors.New("invalid refresh token")
|
|
}
|
|
|
|
claims, ok := token.Claims.(jwt.MapClaims)
|
|
if !ok {
|
|
return "", "", errors.New("invalid token claims")
|
|
}
|
|
|
|
tokenType, ok := claims["type"].(string)
|
|
if !ok || tokenType != "refresh" {
|
|
return "", "", errors.New("not a refresh token")
|
|
}
|
|
|
|
userID, ok := claims["user_id"].(string)
|
|
if !ok {
|
|
return "", "", errors.New("invalid user ID in token")
|
|
}
|
|
|
|
// Check if token is revoked
|
|
tokenRecord, err := s.tokenRepo.GetTokenByValue(ctx, refreshTokenString)
|
|
if err != nil || tokenRecord == nil || tokenRecord.IsRevoked {
|
|
return "", "", errors.New("refresh token is revoked")
|
|
}
|
|
|
|
// Get user
|
|
user, err := s.userRepo.GetUserByID(ctx, userID)
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
|
|
// Revoke old refresh token
|
|
tokenRecord.IsRevoked = true
|
|
_ = s.tokenRepo.UpdateToken(ctx, tokenRecord)
|
|
|
|
// Generate new tokens
|
|
return s.GenerateTokens(ctx, user)
|
|
}
|
|
|
|
// RevokeToken revokes a token
|
|
func (s *userService) RevokeToken(ctx context.Context, tokenString string) error {
|
|
tokenRecord, err := s.tokenRepo.GetTokenByValue(ctx, tokenString)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
tokenRecord.IsRevoked = true
|
|
tokenRecord.UpdatedAt = time.Now()
|
|
|
|
return s.tokenRepo.UpdateToken(ctx, tokenRecord)
|
|
}
|
|
|
|
// GetCurrentUser gets current user from context
|
|
func (s *userService) GetCurrentUser(ctx context.Context) (*types.User, error) {
|
|
user, ok := ctx.Value("user").(*types.User)
|
|
if !ok {
|
|
return nil, errors.New("user not found in context")
|
|
}
|
|
|
|
return user, nil
|
|
}
|
|
|
|
func uniqueRetrieverEngine(engine []types.RetrieverEngineParams) []types.RetrieverEngineParams {
|
|
seen := make(map[types.RetrieverEngineParams]bool)
|
|
var result []types.RetrieverEngineParams
|
|
for _, v := range engine {
|
|
if !seen[v] {
|
|
seen[v] = true
|
|
result = append(result, v)
|
|
}
|
|
}
|
|
return result
|
|
}
|