diff --git a/internal/controllers/user_controller.go b/internal/controllers/user_controller.go index 0bd061b..5ebd12c 100644 --- a/internal/controllers/user_controller.go +++ b/internal/controllers/user_controller.go @@ -1,17 +1,18 @@ package controllers import ( - "fmt" - + "GoMembership/internal/config" "GoMembership/internal/constants" "GoMembership/internal/middlewares" "GoMembership/internal/models" "GoMembership/internal/services" + "GoMembership/internal/utils" "net/http" "github.com/gin-gonic/gin" + "GoMembership/pkg/errors" "GoMembership/pkg/logger" ) @@ -27,12 +28,80 @@ type RegistrationData struct { User models.User `json:"user"` } -func (uc *UserController) CurrentUserHandler(c *gin.Context) { - userIDString, ok := c.Get("user_id") - if !ok || userIDString == nil { - logger.Error.Printf("Error getting user_id from header") +func (uc *UserController) UpdateHandler(c *gin.Context) { + var user models.User + if err := c.ShouldBindJSON(&user); err != nil { + logger.Error.Printf("Couldn't decode input: %v", err) + c.JSON(http.StatusBadRequest, gin.H{"error": "Couldn't decode request data"}) + return } - userID := userIDString.(float64) + tokenString, err := c.Cookie("jwt") + if err != nil { + logger.Error.Printf("No Auth token: %v\n", err) + c.JSON(http.StatusUnauthorized, gin.H{"error": "No Auth token"}) + c.Abort() + return + } + _, claims, err := middlewares.ExtractContentFrom(tokenString) + if err != nil { + + logger.Error.Printf("Error retrieving token and claims from JWT") + c.JSON(http.StatusInternalServerError, gin.H{"error": "JWT parsing error"}) + return + } + jwtUserID := int64((*claims)["user_id"].(float64)) + userRole := int8((*claims)["role_id"].(float64)) + if user.ID == 0 { + logger.Error.Printf("No User.ID in request from user with id: %v, aborting", jwtUserID) + c.JSON(http.StatusBadRequest, gin.H{"error": "No user id provided"}) + return + } + if user.ID != jwtUserID && userRole < constants.Roles.Editor { + c.JSON(http.StatusForbidden, gin.H{"error": "You are not authorized to update this user"}) + return + } + // TODO: If it's not an admin, prevent changes to critical fields + // if userRole != constants.Roles.Admin { + // existingUser, err := uc.Service.GetUserByID(jwtUserID) + // if err != nil { + // c.JSON(http.StatusInternalServerError, gin.H{"error": "Error retrieving user data"}) + // return + // } + // user.Email = existingUser.Email + // user.RoleID = existingUser.RoleID + // } + + updatedUser, err := uc.Service.UpdateUser(&user) + if err != nil { + switch err { + case errors.ErrUserNotFound: + c.JSON(http.StatusNotFound, gin.H{"error": "User not found"}) + case errors.ErrInvalidUserData: + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user data"}) + default: + logger.Error.Printf("Failed to update user: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "Internal Server error"}) + } + return + } + c.JSON(http.StatusAccepted, gin.H{"message": "User updated successfully", "user": updatedUser}) +} + +func (uc *UserController) CurrentUserHandler(c *gin.Context) { + userIDInterface, ok := c.Get("user_id") + if !ok || userIDInterface == nil { + logger.Error.Printf("Error getting user_id from header") + c.JSON(http.StatusInternalServerError, gin.H{"error": "Missing or invalid user ID type"}) + return + } + userID, ok := userIDInterface.(int64) + + if !ok { + logger.Error.Printf("Error: user_id is not of type int64") + c.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user ID type"}) + return + } + user, err := uc.Service.GetUserByID(int64(userID)) if err != nil { logger.Error.Printf("Error retrieving valid user: %v", err) @@ -44,7 +113,13 @@ func (uc *UserController) CurrentUserHandler(c *gin.Context) { } func (uc *UserController) LogoutHandler(c *gin.Context) { - // just clear the JWT cookie + tokenString, err := c.Cookie("jwt") + if err != nil { + logger.Error.Printf("unable to get token from cookie: %#v", err) + } + + middlewares.InvalidateSession(tokenString) + c.SetCookie("jwt", "", -1, "/", "", true, true) c.JSON(http.StatusOK, gin.H{"message": "Logged out successfully"}) } @@ -82,25 +157,17 @@ func (uc *UserController) LoginHandler(c *gin.Context) { return } - token, err := middlewares.GenerateToken(user.ID) + logger.Error.Printf("jwtsevret: %v", config.Auth.JWTSecret) + token, err := middlewares.GenerateToken(config.Auth.JWTSecret, user, "") if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate JWT token"}) return } - c.SetCookie( - "jwt", - token, - 10*60, // 10 minutes - "/", - "", - true, - true, - ) + utils.SetCookie(c, token) c.JSON(http.StatusOK, gin.H{ - "message": "Login successful", - "set-token": token, + "message": "Login successful", }) } @@ -118,7 +185,7 @@ func (uc *UserController) RegisterUser(c *gin.Context) { c.JSON(http.StatusNotAcceptable, gin.H{"error": "No subscription model provided"}) return } - + logger.Error.Printf("user.membership: %#v", regData.User.Membership) selectedModel, err := uc.MembershipService.GetModelByName(®Data.User.Membership.SubscriptionModel.Name) if err != nil { logger.Error.Printf("%v:No subscription model found: %#v", regData.User.Email, err) @@ -134,7 +201,7 @@ func (uc *UserController) RegisterUser(c *gin.Context) { id, token, err := uc.Service.RegisterUser(®Data.User) if err != nil { logger.Error.Printf("Couldn't register User(%v): %v", regData.User.Email, err) - c.JSON(int(id), gin.H{"error": fmt.Sprintf("Couldn't register User: %v", err)}) + c.JSON(int(id), gin.H{"error": "Couldn't register User"}) return } regData.User.ID = id @@ -194,7 +261,7 @@ func (uc *UserController) VerifyMailHandler(c *gin.Context) { c.HTML(http.StatusUnauthorized, "verification_error.html", gin.H{"ErrorMessage": "Emailadresse wurde schon bestÃĪtigt. Sollte dies nicht der Fall sein, wende Dich bitte an info@carsharing-hasloh.de."}) return } - logger.Info.Printf("User: %#v", user) + logger.Info.Printf("VerificationMailHandler User: %#v", user.Email) uc.EmailService.SendWelcomeEmail(user) c.HTML(http.StatusOK, "verification_success.html", gin.H{"FirstName": user.FirstName}) diff --git a/internal/controllers/user_controller_test.go b/internal/controllers/user_controller_test.go index 49d7b35..f926af2 100644 --- a/internal/controllers/user_controller_test.go +++ b/internal/controllers/user_controller_test.go @@ -1,6 +1,7 @@ package controllers import ( + "bytes" "encoding/json" "fmt" "io" @@ -9,6 +10,7 @@ import ( "net/url" "path/filepath" "regexp" + "strconv" "strings" "testing" "time" @@ -22,12 +24,9 @@ import ( "GoMembership/internal/models" "GoMembership/internal/utils" "GoMembership/pkg/logger" -) -type loginInput struct { - Email string `json:"email"` - Password string `json:"password"` -} + "github.com/golang-jwt/jwt/v5" +) type RegisterUserTest struct { WantDBData map[string]interface{} @@ -37,6 +36,8 @@ type RegisterUserTest struct { Assert bool } +var jwtSigningMethod = jwt.SigningMethodHS256 + func (rt *RegisterUserTest) SetupContext() (*gin.Context, *httptest.ResponseRecorder, *gin.Engine) { return GetMockedJSONContext([]byte(rt.Input), "register") } @@ -61,17 +62,23 @@ func testUserController(t *testing.T) { tests := getTestUsers() for _, tt := range tests { + logger.Error.Print("==============================================================") + logger.Error.Printf("Register User Testing : %v", tt.Name) + logger.Error.Print("==============================================================") t.Run(tt.Name, func(t *testing.T) { if err := runSingleTest(&tt); err != nil { t.Fatalf("Test failed: %v", err.Error()) } }) } - testCurrentUserHandler(t) + + loginEmail, loginCookie := testLoginHandler(t) + logoutCookie := testCurrentUserHandler(t, loginEmail, loginCookie) + testUpdateUser(t, loginEmail, loginCookie) + testLogoutHandler(t, logoutCookie) } -func testLogoutHandler(t *testing.T) { - loginCookie := testCurrentUserHandler(t) +func testLogoutHandler(t *testing.T, loginCookie http.Cookie) { tests := []struct { name string @@ -93,6 +100,9 @@ func testLogoutHandler(t *testing.T) { } for _, tt := range tests { + logger.Error.Print("==============================================================") + logger.Error.Printf("Logout User Testing : %v", tt.name) + logger.Error.Print("==============================================================") t.Run(tt.name, func(t *testing.T) { gin.SetMode(gin.TestMode) router := gin.New() @@ -125,11 +135,11 @@ func testLogoutHandler(t *testing.T) { // Verify that the user can no longer access protected routes w = httptest.NewRecorder() - req, _ = http.NewRequest("GET", "/current-user", nil) + req, _ = http.NewRequest("GET", "/current", nil) if logoutCookie != nil { req.AddCookie(logoutCookie) } - router.GET("/current-user", middlewares.AuthMiddleware(), Uc.CurrentUserHandler) + router.GET("/current", middlewares.AuthMiddleware(), Uc.CurrentUserHandler) router.ServeHTTP(w, req) assert.Equal(t, http.StatusUnauthorized, w.Code, "User should not be able to access protected routes after logout") }) @@ -196,9 +206,8 @@ func testLoginHandler(t *testing.T) (string, http.Cookie) { assert.NoError(t, err) if tt.wantToken { - logger.Info.Printf("Response: %#v", response) - assert.Contains(t, response, "set-token") - assert.NotEmpty(t, response["set-token"]) + assert.Contains(t, response, "message") + assert.Equal(t, "Login successful", response["message"]) for _, cookie := range w.Result().Cookies() { if cookie.Name == "jwt" { loginCookie = *cookie @@ -211,7 +220,8 @@ func testLoginHandler(t *testing.T) (string, http.Cookie) { } assert.NotEmpty(t, loginCookie) } else { - assert.NotContains(t, response, "set-token") + assert.Contains(t, response, "error") + assert.NotEmpty(t, response["error"]) } }) @@ -220,8 +230,7 @@ func testLoginHandler(t *testing.T) (string, http.Cookie) { return loginInput.Email, loginCookie } -func testCurrentUserHandler(t *testing.T) http.Cookie { - loginEmail, loginCookie := testLoginHandler(t) +func testCurrentUserHandler(t *testing.T, loginEmail string, loginCookie http.Cookie) http.Cookie { // This test should run after the user login test invalidCookie := http.Cookie{ Name: "jwt", @@ -232,6 +241,7 @@ func testCurrentUserHandler(t *testing.T) http.Cookie { setupCookie func(*http.Request) expectedUserMail string expectedStatus int + expectNewCookie bool }{ { name: "With valid cookie", @@ -241,6 +251,24 @@ func testCurrentUserHandler(t *testing.T) http.Cookie { expectedUserMail: loginEmail, expectedStatus: http.StatusOK, }, + { + name: "With valid expired cookie", + setupCookie: func(req *http.Request) { + sessionID := "test-session" + token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{ + "user_id": 1, + "role_id": 0, + "session_id": sessionID, + "exp": time.Now().Add(-time.Hour).Unix(), // Expired 1 hour ago + }) + tokenString, _ := token.SignedString([]byte(config.Auth.JWTSecret)) + req.AddCookie(&http.Cookie{Name: "jwt", Value: tokenString}) + middlewares.UpdateSession(sessionID, 1) // Add a valid session + }, + expectedUserMail: config.Recipients.AdminEmail, + expectedStatus: http.StatusOK, + expectNewCookie: true, + }, { name: "Without cookie", setupCookie: func(req *http.Request) {}, @@ -259,18 +287,15 @@ func testCurrentUserHandler(t *testing.T) http.Cookie { logger.Error.Print("==============================================================") logger.Error.Printf("Testing : %v", tt.name) logger.Error.Print("==============================================================") - if tt.expectedStatus == http.StatusOK { - time.Sleep(time.Second) // Small delay to ensure different timestamps to get a different JWT token - } t.Run(tt.name, func(t *testing.T) { gin.SetMode(gin.TestMode) router := gin.New() router.Use(middlewares.AuthMiddleware()) - router.GET("/current-user", Uc.CurrentUserHandler) + router.GET("/current", Uc.CurrentUserHandler) w := httptest.NewRecorder() - req, _ := http.NewRequest("GET", "/current-user", nil) + req, _ := http.NewRequest("GET", "/current", nil) tt.setupCookie(req) router.ServeHTTP(w, req) @@ -290,9 +315,13 @@ func testCurrentUserHandler(t *testing.T) http.Cookie { break } } - assert.NotNil(t, newCookie, "Cookie should be renewed") - assert.NotEqual(t, loginCookie.Value, newCookie.Value, "Cookie value should be different") - assert.True(t, newCookie.MaxAge > 0, "New cookie should not be expired") + if tt.expectNewCookie { + assert.NotNil(t, newCookie, "New cookie should be set for expired token") + assert.NotEqual(t, loginCookie.Value, newCookie.Value, "Cookie value should be different") + assert.True(t, newCookie.MaxAge > 0, "New cookie should not be expired") + } else { + assert.Nil(t, newCookie, "No new cookie should be set for non-expired token") + } } else { // For unauthorized requests, check for an error message var errorResponse map[string]string @@ -316,29 +345,32 @@ func validateUser(assert bool, wantDBData map[string]interface{}) error { if assert != (len(*users) != 0) { return fmt.Errorf("User entry query didn't met expectation: %v != %#v", assert, *users) } - if assert { //check for email delivery messages := utils.SMTPGetMessages() for _, message := range messages { mail, err := utils.DecodeMail(message.MsgRequest()) if err != nil { + logger.Error.Printf("Error in validateUser: %#v", err) return err } if strings.Contains(mail.Subject, constants.MailRegistrationSubject) { if err := checkRegistrationMail(mail, &(*users)[0]); err != nil { + logger.Error.Printf("Error in checkRegistrationMail: %#v", err) return err } } else if strings.Contains(mail.Subject, constants.MailVerificationSubject) { if err := checkVerificationMail(mail, &(*users)[0]); err != nil { + logger.Error.Printf("Error in checkVerificationMail: %#v", err) return err } verifiedUsers, err := Uc.Service.GetUsers(wantDBData) if err != nil { + logger.Error.Printf("Error in GetUsers: %#v", err) return err } if (*verifiedUsers)[0].Status != constants.VerifiedStatus { - return fmt.Errorf("Users status isn't verified after email verification. Status is: %#v", (*verifiedUsers)[0].Status) + return fmt.Errorf("Users(%v) status isn't verified after email verification. Status is: %v", (*verifiedUsers)[0].Email, (*verifiedUsers)[0].Status) } } else { return fmt.Errorf("Subject not expected: %v", mail.Subject) @@ -348,6 +380,168 @@ func validateUser(assert bool, wantDBData map[string]interface{}) error { return nil } +func testUpdateUser(t *testing.T, loginEmail string, loginCookie http.Cookie) { + + invalidCookie := http.Cookie{ + Name: "jwt", + Value: "invalid.token.here", + } + // Get the user we just created + users, err := Uc.Service.GetUsers(map[string]interface{}{"email": "john.doe@example.com"}) + if err != nil || len(*users) == 0 { + t.Fatalf("Failed to get test user: %v", err) + } + user := (*users)[0] + + tests := []struct { + name string + setupCookie func(*http.Request) + updateFunc func(*models.User) + expectedStatus int + expectedError string + }{ + { + name: "Valid Update", + setupCookie: func(req *http.Request) { + req.AddCookie(&loginCookie) + }, + updateFunc: func(u *models.User) { + u.Password = "" + u.FirstName = "John Updated" + u.LastName = "Doe Updated" + u.Phone = "01738484994" + }, + expectedStatus: http.StatusAccepted, + }, + { + name: "Password Update", + setupCookie: func(req *http.Request) { + req.AddCookie(&loginCookie) + }, + updateFunc: func(u *models.User) { + u.Password = "NewPassword" + }, + expectedStatus: http.StatusAccepted, + }, + { + name: "Valid Update, invalid cookie", + setupCookie: func(req *http.Request) { + req.AddCookie(&invalidCookie) + }, + updateFunc: func(u *models.User) { + u.Password = "" + u.FirstName = "John Updated" + u.LastName = "Doe Updated" + u.Phone = "01738484994" + }, + expectedStatus: http.StatusUnauthorized, + expectedError: "Auth token invalid", + }, + { + name: "Invalid Email Update", + setupCookie: func(req *http.Request) { + req.AddCookie(&loginCookie) + }, + updateFunc: func(u *models.User) { + u.Password = "" + u.Email = "invalid-email" + }, + expectedStatus: http.StatusBadRequest, + expectedError: "Invalid user data", + }, + { + name: "User ID mismatch while not admin", + setupCookie: func(req *http.Request) { + req.AddCookie(&loginCookie) + }, + updateFunc: func(u *models.User) { + u.Password = "" + u.ID = 1 + u.FirstName = "John Missing ID" + }, + expectedStatus: http.StatusForbidden, + expectedError: "You are not authorized to update this user", + }, + // { + // name: "Non-existent User", + // setupCookie: func(req *http.Request) { + // req.AddCookie(&loginCookie) + // }, + // updateFunc: func(u *models.User) { + // u.Password = "" + // u.ID = 99999 + // u.FirstName = "Non-existent" + // }, + // expectedStatus: http.StatusNotFound, + // expectedError: "User not found", + // }, + } + + for _, tt := range tests { + logger.Error.Print("==============================================================") + logger.Error.Printf("Update Testing : %v", tt.name) + logger.Error.Print("==============================================================") + t.Run(tt.name, func(t *testing.T) { + // Create a copy of the user and apply the updates + updatedUser := user + tt.updateFunc(&updatedUser) + + // Convert user to JSON + jsonData, err := json.Marshal(updatedUser) + if err != nil { + t.Fatalf("Failed to marshal user data: %v", err) + } + + // Create request + req, _ := http.NewRequest("PUT", "/users/"+strconv.FormatInt(user.ID, 10), bytes.NewBuffer(jsonData)) + req.Header.Set("Content-Type", "application/json") + + tt.setupCookie(req) + // Create response recorder + w := httptest.NewRecorder() + + // Set up router and add middleware + router := gin.New() + router.Use(middlewares.AuthMiddleware()) + router.PUT("/users/:id", Uc.UpdateHandler) + + // Perform request + router.ServeHTTP(w, req) + + // Check status code + assert.Equal(t, tt.expectedStatus, w.Code) + + // Parse response + var response map[string]interface{} + err = json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + + if tt.expectedError != "" { + assert.Equal(t, tt.expectedError, response["error"]) + } else { + assert.Equal(t, "User updated successfully", response["message"]) + + // Verify the update in the database + updatedUserFromDB, err := Uc.Service.GetUserByID(user.ID) + updatedUserFromDB.UpdatedAt = updatedUser.UpdatedAt + updatedUserFromDB.Membership.UpdatedAt = updatedUser.Membership.UpdatedAt + updatedUserFromDB.BankAccount.UpdatedAt = updatedUser.BankAccount.UpdatedAt + updatedUserFromDB.Verification.UpdatedAt = updatedUser.Verification.UpdatedAt + updatedUserFromDB.Membership.SubscriptionModel.UpdatedAt = updatedUser.Membership.SubscriptionModel.UpdatedAt + if updatedUser.Password == "" { + assert.Equal(t, user.Password, (*updatedUserFromDB).Password) + } else { + assert.NotEqual(t, user.Password, (*updatedUserFromDB).Password) + updatedUser.Password = "" + } + updatedUserFromDB.Password = "" + assert.NoError(t, err) + assert.Equal(t, updatedUser, *updatedUserFromDB, "Updated user in DB does not match expected user") + } + }) + } +} + func checkWelcomeMail(message *utils.Email, user *models.User) error { if !strings.Contains(message.To, user.Email) { @@ -506,23 +700,6 @@ func getVerificationURL(mailBody string) (string, error) { } // TEST DATA: -func getBaseUser() models.User { - return models.User{ - DateOfBirth: time.Date(2000, time.January, 1, 0, 0, 0, 0, time.UTC), - FirstName: "John", - LastName: "Doe", - Email: "john.doe@example.com", - Address: "Pablo Escobar Str. 4", - ZipCode: "25474", - City: "Hasloh", - Phone: "01738484993", - BankAccount: models.BankAccount{IBAN: "DE89370400440532013000"}, - Membership: models.Membership{SubscriptionModel: models.SubscriptionModel{Name: "Basic"}}, - ProfilePicture: "", - Password: "password123", - Company: "", - } -} func customizeInput(customize func(models.User) models.User) *RegistrationData { user := getBaseUser() diff --git a/internal/repositories/user_repository.go b/internal/repositories/user_repository.go index f58583c..00ca4cd 100644 --- a/internal/repositories/user_repository.go +++ b/internal/repositories/user_repository.go @@ -1,8 +1,6 @@ package repositories import ( - "time" - "gorm.io/gorm" "GoMembership/internal/constants" @@ -16,13 +14,13 @@ import ( type UserRepositoryInterface interface { CreateUser(user *models.User) (int64, error) - UpdateUser(userID int64, user *models.User) error + UpdateUser(user *models.User) (*models.User, error) GetUsers(where map[string]interface{}) (*[]models.User, error) - GetUserByID(id int64) (*models.User, error) + GetUserByID(userID *int64) (*models.User, error) GetUserByEmail(email string) (*models.User, error) - SetVerificationToken(user *models.User, token *string) (int64, error) + SetVerificationToken(verification *models.Verification) (int64, error) IsVerified(userID *int64) (bool, error) - VerifyUserOfToken(token *string) (*models.User, error) + GetVerificationOfToken(token *string) (*models.Verification, error) } type UserRepository struct{} @@ -35,21 +33,36 @@ func (ur *UserRepository) CreateUser(user *models.User) (int64, error) { return user.ID, nil } -func (ur *UserRepository) UpdateUser(userID int64, user *models.User) error { - // logger.Info.Printf("Updating User: %#v\n", user) +func (ur *UserRepository) UpdateUser(user *models.User) (*models.User, error) { if user == nil { - return errors.ErrNoData - } - result := database.DB.Session(&gorm.Session{FullSaveAssociations: true}).Updates(&user) - if result.Error != nil { - return result.Error + return nil, errors.ErrNoData } - if result.RowsAffected == 0 { - return errors.ErrNoRowsAffected + err := database.DB.Transaction(func(tx *gorm.DB) error { + if err := tx.First(&models.User{}, user.ID).Error; err != nil { + return err + } + + result := tx.Session(&gorm.Session{FullSaveAssociations: true}).Updates(user) + if result.Error != nil { + return result.Error + } + if result.RowsAffected == 0 { + return errors.ErrNoRowsAffected + } + return nil + }) + + if err != nil { + return nil, err } - return nil + var updatedUser models.User + if err := database.DB.First(&updatedUser, user.ID).Error; err != nil { + return nil, err + } + + return &updatedUser, nil } func (ur *UserRepository) GetUsers(where map[string]interface{}) (*[]models.User, error) { @@ -70,7 +83,7 @@ func (ur *UserRepository) GetUsers(where map[string]interface{}) (*[]models.User return &users, nil } -func (ur *UserRepository) GetUserByID(id int64) (*models.User, error) { +func (ur *UserRepository) GetUserByID(userID *int64) (*models.User, error) { var user models.User result := database.DB. Preload("Consents"). @@ -78,7 +91,7 @@ func (ur *UserRepository) GetUserByID(id int64) (*models.User, error) { Preload("Verification"). Preload("Membership", func(db *gorm.DB) *gorm.DB { return db.Preload("SubscriptionModel") - }).First(&user, id) + }).First(&user, userID) if result.Error != nil { if result.Error == gorm.ErrRecordNotFound { return nil, gorm.ErrRecordNotFound @@ -112,7 +125,8 @@ func (ur *UserRepository) IsVerified(userID *int64) (bool, error) { return user.Status != constants.UnverifiedStatus, nil } -func (ur *UserRepository) VerifyUserOfToken(token *string) (*models.User, error) { +func (ur *UserRepository) GetVerificationOfToken(token *string) (*models.Verification, error) { + var emailVerification models.Verification result := database.DB.Where("verification_token = ?", token).First(&emailVerification) if result.Error != nil { @@ -121,49 +135,10 @@ func (ur *UserRepository) VerifyUserOfToken(token *string) (*models.User, error) } return nil, result.Error } - - // Check if the user is already verified - verified, err := ur.IsVerified(&emailVerification.UserID) - if err != nil { - return nil, err - } - user, err := ur.GetUserByID(emailVerification.UserID) - if err != nil { - return nil, err - } - if verified { - return user, errors.ErrAlreadyVerified - } - // Update user status to active - t := time.Now() - emailVerification.EmailVerifiedAt = &t - user.Status = constants.VerifiedStatus - user.Verification = emailVerification - - err = ur.UpdateUser(emailVerification.UserID, user) - if err != nil { - return nil, err - } - - return user, nil + return &emailVerification, nil } -func (ur *UserRepository) SetVerificationToken(user *models.User, token *string) (int64, error) { - // Check if user is already verified - verified, err := ur.IsVerified(&user.ID) - if err != nil { - return -1, err - } - if verified { - return -1, errors.ErrAlreadyVerified - } - - // Prepare the Verification record - verification := models.Verification{ - UserID: user.ID, - VerificationToken: *token, - } - +func (ur *UserRepository) SetVerificationToken(verification *models.Verification) (int64, error) { // Use GORM to insert or update the Verification record result := database.DB.Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: "user_id"}}, diff --git a/internal/services/user_service.go b/internal/services/user_service.go index 1a25bd3..c348113 100644 --- a/internal/services/user_service.go +++ b/internal/services/user_service.go @@ -8,10 +8,12 @@ import ( "GoMembership/internal/models" "GoMembership/internal/repositories" "GoMembership/internal/utils" + "GoMembership/pkg/errors" "GoMembership/pkg/logger" "github.com/alexedwards/argon2id" "github.com/go-playground/validator/v10" + "gorm.io/gorm" "time" ) @@ -22,50 +24,42 @@ type UserServiceInterface interface { GetUserByID(id int64) (*models.User, error) GetUsers(where map[string]interface{}) (*[]models.User, error) VerifyUser(token *string) (*models.User, error) + UpdateUser(user *models.User) (*models.User, error) } type UserService struct { Repo repositories.UserRepositoryInterface } -func (service *UserService) RegisterUser(user *models.User) (int64, string, error) { - if err := validateRegistrationData(user); err != nil { - return http.StatusNotAcceptable, "", err +func (service *UserService) UpdateUser(user *models.User) (*models.User, error) { + + if err := validateUserData(user); err != nil { + return nil, errors.ErrInvalidUserData } - setPassword(user.Password, user) + if user.Password != "" { + setPassword(user.Password, user) + } - user.Status = constants.UnverifiedStatus - user.CreatedAt = time.Now() user.UpdatedAt = time.Now() - id, err := service.Repo.CreateUser(user) + updatedUser, err := service.Repo.UpdateUser(user) - if err != nil && strings.Contains(err.Error(), "UNIQUE constraint failed") { - return http.StatusConflict, "", err - } else if err != nil { - return http.StatusInternalServerError, "", err - } - - user.ID = id - - token, err := utils.GenerateVerificationToken() if err != nil { - return http.StatusInternalServerError, "", err + if err == gorm.ErrRecordNotFound { + return nil, errors.ErrUserNotFound + } + if strings.Contains(err.Error(), "UNIQUE constraint failed") { + return nil, errors.ErrDuplicateEntry + } + return nil, err } - logger.Info.Printf("TOKEN: %v", token) - - _, err = service.Repo.SetVerificationToken(user, &token) - if err != nil { - return http.StatusInternalServerError, "", err - } - - return id, token, nil + return updatedUser, nil } -func (service *UserService) Update(user *models.User) (int64, string, error) { - if err := validateRegistrationData(user); err != nil { +func (service *UserService) RegisterUser(user *models.User) (int64, string, error) { + if err := validateUserData(user); err != nil { return http.StatusNotAcceptable, "", err } @@ -92,17 +86,31 @@ func (service *UserService) Update(user *models.User) (int64, string, error) { logger.Info.Printf("TOKEN: %v", token) - _, err = service.Repo.SetVerificationToken(user, &token) + // Check if user is already verified + verified, err := service.Repo.IsVerified(&user.ID) if err != nil { return http.StatusInternalServerError, "", err } + if verified { + return http.StatusAlreadyReported, "", errors.ErrAlreadyVerified + } + + // Prepare the Verification record + verification := models.Verification{ + UserID: user.ID, + VerificationToken: token, + } + + if _, err = service.Repo.SetVerificationToken(&verification); err != nil { + return http.StatusInternalServerError, "", err + } return id, token, nil } func (service *UserService) GetUserByID(id int64) (*models.User, error) { - return service.Repo.GetUserByID(id) + return service.Repo.GetUserByID(&id) } func (service *UserService) GetUserByEmail(email string) (*models.User, error) { @@ -114,19 +122,41 @@ func (service *UserService) GetUsers(where map[string]interface{}) (*[]models.Us } func (service *UserService) VerifyUser(token *string) (*models.User, error) { - user, err := service.Repo.VerifyUserOfToken(token) + verification, err := service.Repo.GetVerificationOfToken(token) if err != nil { return nil, err } + // Check if the user is already verified + verified, err := service.Repo.IsVerified(&verification.UserID) + if err != nil { + return nil, err + } + user, err := service.Repo.GetUserByID(&verification.UserID) + if err != nil { + return nil, err + } + if verified { + return user, errors.ErrAlreadyVerified + } + // Update user status to active + t := time.Now() + verification.EmailVerifiedAt = &t + + user.Status = constants.VerifiedStatus + user.Verification = *verification + user.ID = verification.UserID + service.Repo.UpdateUser(user) + return user, nil } -func validateRegistrationData(user *models.User) error { +func validateUserData(user *models.User) error { validate := validator.New() validate.RegisterValidation("age", utils.AgeValidator) validate.RegisterValidation("bic", utils.BICValidator) validate.RegisterValidation("iban", utils.IBANValidator) validate.RegisterValidation("subscriptionModel", utils.SubscriptionModelValidator) + validate.RegisterValidation("safe_content", utils.ValidateSafeContent) validate.RegisterValidation("membershipField", utils.ValidateRequiredMembershipField) return validate.Struct(user)