diff --git a/frontend/src/views/initialization/InitializationContent.vue b/frontend/src/views/initialization/InitializationContent.vue index 5552a37..6e909d2 100644 --- a/frontend/src/views/initialization/InitializationContent.vue +++ b/frontend/src/views/initialization/InitializationContent.vue @@ -2800,4 +2800,4 @@ onMounted(async () => { } } } - + \ No newline at end of file diff --git a/internal/application/service/user.go b/internal/application/service/user.go index a0475b6..671d12e 100644 --- a/internal/application/service/user.go +++ b/internal/application/service/user.go @@ -4,6 +4,8 @@ import ( "context" "errors" "fmt" + "os" + "strings" "time" "github.com/golang-jwt/jwt/v5" @@ -34,6 +36,35 @@ func NewUserService(userRepo interfaces.UserRepository, tokenRepo interfaces.Aut } } +var engine = map[string][]types.RetrieverEngineParams{ + "postgres": { + { + RetrieverType: types.KeywordsRetrieverType, + RetrieverEngineType: types.PostgresRetrieverEngineType, + }, + { + RetrieverType: types.VectorRetrieverType, + RetrieverEngineType: types.PostgresRetrieverEngineType, + }, + }, + "elasticsearch_v7": { + { + RetrieverType: types.KeywordsRetrieverType, + RetrieverEngineType: types.ElasticsearchRetrieverEngineType, + }, + }, + "elasticsearch_v8": { + { + RetrieverType: types.KeywordsRetrieverType, + RetrieverEngineType: types.ElasticsearchRetrieverEngineType, + }, + { + RetrieverType: types.VectorRetrieverType, + RetrieverEngineType: types.ElasticsearchRetrieverEngineType, + }, + }, +} + // Register creates a new user account func (s *userService) Register(ctx context.Context, req *types.RegisterRequest) (*types.User, error) { logger.Info(ctx, "Start user registration") @@ -61,23 +92,21 @@ func (s *userService) Register(ctx context.Context, req *types.RegisterRequest) return nil, errors.New("failed to process password") } + egs := []types.RetrieverEngineParams{} + for _, driver := range strings.Split(os.Getenv("RETRIEVE_DRIVER"), ",") { + if val, ok := engine[driver]; ok { + egs = append(egs, val...) + } + } + egs = uniqueRetrieverEngine(egs) + logger.Debugf(ctx, "user register retriever engines: %v", egs) + // Create default tenant for the user tenant := &types.Tenant{ - Name: fmt.Sprintf("%s's Workspace", req.Username), - Description: "Default workspace", - Status: "active", - RetrieverEngines: types.RetrieverEngines{ - Engines: []types.RetrieverEngineParams{ - { - RetrieverType: types.KeywordsRetrieverType, - RetrieverEngineType: types.PostgresRetrieverEngineType, - }, - { - RetrieverType: types.VectorRetrieverType, - RetrieverEngineType: types.PostgresRetrieverEngineType, - }, - }, - }, + Name: fmt.Sprintf("%s's Workspace", req.Username), + Description: "Default workspace", + Status: "active", + RetrieverEngines: types.RetrieverEngines{Engines: egs}, } createdTenant, err := s.tenantService.CreateTenant(ctx, tenant) @@ -406,3 +435,15 @@ func (s *userService) GetCurrentUser(ctx context.Context) (*types.User, error) { return user, nil } + +func uniqueRetrieverEngine(engine []types.RetrieverEngineParams) []types.RetrieverEngineParams { + seen := make(map[types.RetrieverEngineParams]bool) + var result []types.RetrieverEngineParams + for _, v := range engine { + if !seen[v] { + seen[v] = true + result = append(result, v) + } + } + return result +}