Files
WeKnora/internal/application/service/user.go

450 lines
13 KiB
Go

package service
import (
"context"
"errors"
"fmt"
"os"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
)
// JWT secret key - in production this should be from environment variable
var jwtSecret = []byte("your-secret-key")
// userService implements the UserService interface
type userService struct {
userRepo interfaces.UserRepository
tokenRepo interfaces.AuthTokenRepository
tenantService interfaces.TenantService
}
// NewUserService creates a new user service instance
func NewUserService(userRepo interfaces.UserRepository, tokenRepo interfaces.AuthTokenRepository, tenantService interfaces.TenantService) interfaces.UserService {
return &userService{
userRepo: userRepo,
tokenRepo: tokenRepo,
tenantService: tenantService,
}
}
var engine = map[string][]types.RetrieverEngineParams{
"postgres": {
{
RetrieverType: types.KeywordsRetrieverType,
RetrieverEngineType: types.PostgresRetrieverEngineType,
},
{
RetrieverType: types.VectorRetrieverType,
RetrieverEngineType: types.PostgresRetrieverEngineType,
},
},
"elasticsearch_v7": {
{
RetrieverType: types.KeywordsRetrieverType,
RetrieverEngineType: types.ElasticsearchRetrieverEngineType,
},
},
"elasticsearch_v8": {
{
RetrieverType: types.KeywordsRetrieverType,
RetrieverEngineType: types.ElasticsearchRetrieverEngineType,
},
{
RetrieverType: types.VectorRetrieverType,
RetrieverEngineType: types.ElasticsearchRetrieverEngineType,
},
},
}
// Register creates a new user account
func (s *userService) Register(ctx context.Context, req *types.RegisterRequest) (*types.User, error) {
logger.Info(ctx, "Start user registration")
// Validate input
if req.Username == "" || req.Email == "" || req.Password == "" {
return nil, errors.New("username, email and password are required")
}
// Check if user already exists
existingUser, _ := s.userRepo.GetUserByEmail(ctx, req.Email)
if existingUser != nil {
return nil, errors.New("user with this email already exists")
}
existingUser, _ = s.userRepo.GetUserByUsername(ctx, req.Username)
if existingUser != nil {
return nil, errors.New("user with this username already exists")
}
// Hash password
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
if err != nil {
logger.Errorf(ctx, "Failed to hash password: %v", err)
return nil, errors.New("failed to process password")
}
egs := []types.RetrieverEngineParams{}
for _, driver := range strings.Split(os.Getenv("RETRIEVE_DRIVER"), ",") {
if val, ok := engine[driver]; ok {
egs = append(egs, val...)
}
}
egs = uniqueRetrieverEngine(egs)
logger.Debugf(ctx, "user register retriever engines: %v", egs)
// Create default tenant for the user
tenant := &types.Tenant{
Name: fmt.Sprintf("%s's Workspace", req.Username),
Description: "Default workspace",
Status: "active",
RetrieverEngines: types.RetrieverEngines{Engines: egs},
}
createdTenant, err := s.tenantService.CreateTenant(ctx, tenant)
if err != nil {
logger.Errorf(ctx, "Failed to create tenant: %v", err)
return nil, errors.New("failed to create workspace")
}
// Create user
user := &types.User{
ID: uuid.New().String(),
Username: req.Username,
Email: req.Email,
PasswordHash: string(hashedPassword),
TenantID: createdTenant.ID,
IsActive: true,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
err = s.userRepo.CreateUser(ctx, user)
if err != nil {
logger.Errorf(ctx, "Failed to create user: %v", err)
return nil, errors.New("failed to create user")
}
logger.Infof(ctx, "User registered successfully: %s", user.Email)
return user, nil
}
// Login authenticates a user and returns tokens
func (s *userService) Login(ctx context.Context, req *types.LoginRequest) (*types.LoginResponse, error) {
logger.Infof(ctx, "Start user login for email: %s", req.Email)
// Get user by email
user, err := s.userRepo.GetUserByEmail(ctx, req.Email)
if err != nil {
logger.Errorf(ctx, "Failed to get user by email %s: %v", req.Email, err)
return &types.LoginResponse{
Success: false,
Message: "Invalid email or password",
}, nil
}
if user == nil {
logger.Warnf(ctx, "User not found for email: %s", req.Email)
return &types.LoginResponse{
Success: false,
Message: "Invalid email or password",
}, nil
}
logger.Infof(ctx, "Found user: ID=%s, Email=%s, IsActive=%t", user.ID, user.Email, user.IsActive)
// Check if user is active
if !user.IsActive {
logger.Warnf(ctx, "User account is disabled for email: %s", req.Email)
return &types.LoginResponse{
Success: false,
Message: "Account is disabled",
}, nil
}
// Verify password
logger.Infof(ctx, "Verifying password for user: %s", user.Email)
err = bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.Password))
if err != nil {
logger.Warnf(ctx, "Password verification failed for user %s: %v", user.Email, err)
return &types.LoginResponse{
Success: false,
Message: "Invalid email or password",
}, nil
}
logger.Infof(ctx, "Password verification successful for user: %s", user.Email)
// Generate tokens
logger.Infof(ctx, "Generating tokens for user: %s", user.Email)
accessToken, refreshToken, err := s.GenerateTokens(ctx, user)
if err != nil {
logger.Errorf(ctx, "Failed to generate tokens for user %s: %v", user.Email, err)
return &types.LoginResponse{
Success: false,
Message: "Login failed",
}, nil
}
logger.Infof(ctx, "Tokens generated successfully for user: %s", user.Email)
// Get tenant information
logger.Infof(ctx, "Getting tenant information for user %s, tenant ID: %s", user.Email, user.TenantID)
tenant, err := s.tenantService.GetTenantByID(ctx, user.TenantID)
if err != nil {
logger.Warnf(ctx, "Failed to get tenant info for user %s, tenant ID %s: %v", user.Email, user.TenantID, err)
} else {
logger.Infof(ctx, "Tenant information retrieved successfully for user: %s", user.Email)
}
logger.Infof(ctx, "User logged in successfully: %s", user.Email)
return &types.LoginResponse{
Success: true,
Message: "Login successful",
User: user,
Tenant: tenant,
Token: accessToken,
RefreshToken: refreshToken,
}, nil
}
// GetUserByID gets a user by ID
func (s *userService) GetUserByID(ctx context.Context, id string) (*types.User, error) {
return s.userRepo.GetUserByID(ctx, id)
}
// GetUserByEmail gets a user by email
func (s *userService) GetUserByEmail(ctx context.Context, email string) (*types.User, error) {
return s.userRepo.GetUserByEmail(ctx, email)
}
// GetUserByUsername gets a user by username
func (s *userService) GetUserByUsername(ctx context.Context, username string) (*types.User, error) {
return s.userRepo.GetUserByUsername(ctx, username)
}
// UpdateUser updates user information
func (s *userService) UpdateUser(ctx context.Context, user *types.User) error {
user.UpdatedAt = time.Now()
return s.userRepo.UpdateUser(ctx, user)
}
// DeleteUser deletes a user
func (s *userService) DeleteUser(ctx context.Context, id string) error {
return s.userRepo.DeleteUser(ctx, id)
}
// ChangePassword changes user password
func (s *userService) ChangePassword(ctx context.Context, userID string, oldPassword, newPassword string) error {
user, err := s.userRepo.GetUserByID(ctx, userID)
if err != nil {
return err
}
// Verify old password
err = bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(oldPassword))
if err != nil {
return errors.New("invalid old password")
}
// Hash new password
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost)
if err != nil {
return err
}
user.PasswordHash = string(hashedPassword)
user.UpdatedAt = time.Now()
return s.userRepo.UpdateUser(ctx, user)
}
// ValidatePassword validates user password
func (s *userService) ValidatePassword(ctx context.Context, userID string, password string) error {
user, err := s.userRepo.GetUserByID(ctx, userID)
if err != nil {
return err
}
return bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password))
}
// GenerateTokens generates access and refresh tokens for user
func (s *userService) GenerateTokens(ctx context.Context, user *types.User) (accessToken, refreshToken string, err error) {
// Generate access token (expires in 24 hours)
accessClaims := jwt.MapClaims{
"user_id": user.ID,
"email": user.Email,
"tenant_id": user.TenantID,
"exp": time.Now().Add(24 * time.Hour).Unix(),
"iat": time.Now().Unix(),
"type": "access",
}
accessTokenObj := jwt.NewWithClaims(jwt.SigningMethodHS256, accessClaims)
accessToken, err = accessTokenObj.SignedString(jwtSecret)
if err != nil {
return "", "", err
}
// Generate refresh token (expires in 7 days)
refreshClaims := jwt.MapClaims{
"user_id": user.ID,
"exp": time.Now().Add(7 * 24 * time.Hour).Unix(),
"iat": time.Now().Unix(),
"type": "refresh",
}
refreshTokenObj := jwt.NewWithClaims(jwt.SigningMethodHS256, refreshClaims)
refreshToken, err = refreshTokenObj.SignedString(jwtSecret)
if err != nil {
return "", "", err
}
// Store tokens in database
accessTokenRecord := &types.AuthToken{
ID: uuid.New().String(),
UserID: user.ID,
Token: accessToken,
TokenType: "access_token",
ExpiresAt: time.Now().Add(24 * time.Hour),
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
refreshTokenRecord := &types.AuthToken{
ID: uuid.New().String(),
UserID: user.ID,
Token: refreshToken,
TokenType: "refresh_token",
ExpiresAt: time.Now().Add(7 * 24 * time.Hour),
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
_ = s.tokenRepo.CreateToken(ctx, accessTokenRecord)
_ = s.tokenRepo.CreateToken(ctx, refreshTokenRecord)
return accessToken, refreshToken, nil
}
// ValidateToken validates an access token
func (s *userService) ValidateToken(ctx context.Context, tokenString string) (*types.User, error) {
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return jwtSecret, nil
})
if err != nil || !token.Valid {
return nil, errors.New("invalid token")
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return nil, errors.New("invalid token claims")
}
userID, ok := claims["user_id"].(string)
if !ok {
return nil, errors.New("invalid user ID in token")
}
// Check if token is revoked
tokenRecord, err := s.tokenRepo.GetTokenByValue(ctx, tokenString)
if err != nil || tokenRecord == nil || tokenRecord.IsRevoked {
return nil, errors.New("token is revoked")
}
return s.userRepo.GetUserByID(ctx, userID)
}
// RefreshToken refreshes access token using refresh token
func (s *userService) RefreshToken(ctx context.Context, refreshTokenString string) (accessToken, newRefreshToken string, err error) {
token, err := jwt.Parse(refreshTokenString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return jwtSecret, nil
})
if err != nil || !token.Valid {
return "", "", errors.New("invalid refresh token")
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return "", "", errors.New("invalid token claims")
}
tokenType, ok := claims["type"].(string)
if !ok || tokenType != "refresh" {
return "", "", errors.New("not a refresh token")
}
userID, ok := claims["user_id"].(string)
if !ok {
return "", "", errors.New("invalid user ID in token")
}
// Check if token is revoked
tokenRecord, err := s.tokenRepo.GetTokenByValue(ctx, refreshTokenString)
if err != nil || tokenRecord == nil || tokenRecord.IsRevoked {
return "", "", errors.New("refresh token is revoked")
}
// Get user
user, err := s.userRepo.GetUserByID(ctx, userID)
if err != nil {
return "", "", err
}
// Revoke old refresh token
tokenRecord.IsRevoked = true
_ = s.tokenRepo.UpdateToken(ctx, tokenRecord)
// Generate new tokens
return s.GenerateTokens(ctx, user)
}
// RevokeToken revokes a token
func (s *userService) RevokeToken(ctx context.Context, tokenString string) error {
tokenRecord, err := s.tokenRepo.GetTokenByValue(ctx, tokenString)
if err != nil {
return err
}
tokenRecord.IsRevoked = true
tokenRecord.UpdatedAt = time.Now()
return s.tokenRepo.UpdateToken(ctx, tokenRecord)
}
// GetCurrentUser gets current user from context
func (s *userService) GetCurrentUser(ctx context.Context) (*types.User, error) {
user, ok := ctx.Value("user").(*types.User)
if !ok {
return nil, errors.New("user not found in context")
}
return user, nil
}
func uniqueRetrieverEngine(engine []types.RetrieverEngineParams) []types.RetrieverEngineParams {
seen := make(map[types.RetrieverEngineParams]bool)
var result []types.RetrieverEngineParams
for _, v := range engine {
if !seen[v] {
seen[v] = true
result = append(result, v)
}
}
return result
}