diff --git a/internal/repositories/user_repository.go b/internal/repositories/user_repository.go index 6bacc8a..73e059a 100644 --- a/internal/repositories/user_repository.go +++ b/internal/repositories/user_repository.go @@ -10,13 +10,13 @@ import ( "GoMembership/internal/models" "GoMembership/pkg/errors" + "GoMembership/pkg/logger" ) type UserRepositoryInterface interface { CreateUser(user *models.User) (uint, error) UpdateUser(user *models.User) (*models.User, error) GetUsers(where map[string]interface{}) (*[]models.User, error) - GetUserByID(userID *uint) (*models.User, error) GetUserByEmail(email string) (*models.User, error) SetVerificationToken(verification *models.Verification) (uint, error) IsVerified(userID *uint) (bool, error) @@ -28,6 +28,7 @@ type UserRepository struct{} func (ur *UserRepository) CreateUser(user *models.User) (uint, error) { result := database.DB.Create(user) if result.Error != nil { + logger.Error.Printf("Create User error: %#v", result.Error) return 0, result.Error } return user.ID, nil @@ -39,10 +40,14 @@ func (ur *UserRepository) UpdateUser(user *models.User) (*models.User, error) { } err := database.DB.Transaction(func(tx *gorm.DB) error { - if err := tx.First(&models.User{}, user.ID).Error; err != nil { + // Check if the user exists in the database + var existingUser models.User + if err := tx.Preload("DriversLicence.LicenceCategories"). + Preload("Membership"). + First(&existingUser, user.ID).Error; err != nil { return err } - + // Update the user's main fields result := tx.Session(&gorm.Session{FullSaveAssociations: true}).Updates(user) if result.Error != nil { return result.Error @@ -50,6 +55,29 @@ func (ur *UserRepository) UpdateUser(user *models.User) (*models.User, error) { if result.RowsAffected == 0 { return errors.ErrNoRowsAffected } + + // Handle the update of the LicenceCategories explicitly + if user.DriversLicence.ID != 0 { + // Replace the LicenceCategories with the new list + if err := tx.Model(&existingUser.DriversLicence).Association("LicenceCategories").Replace(user.DriversLicence.LicenceCategories); err != nil { + return err + } + } + + // Update the Membership if provided + if user.Membership.ID != 0 { + if err := tx.Model(&existingUser.Membership).Updates(user.Membership).Error; err != nil { + return err + } + } + + // Update the DriversLicence fields if provided + if user.DriversLicence.ID != 0 { + if err := tx.Model(&existingUser.DriversLicence).Updates(user.DriversLicence).Error; err != nil { + return err + } + } + return nil }) @@ -58,10 +86,11 @@ func (ur *UserRepository) UpdateUser(user *models.User) (*models.User, error) { } var updatedUser models.User - if err := database.DB.First(&updatedUser, user.ID).Error; err != nil { + if err := database.DB.Preload("DriversLicence.LicenceCategories"). + Preload("Membership"). + First(&updatedUser, user.ID).Error; err != nil { return nil, err } - return &updatedUser, nil } @@ -81,7 +110,7 @@ func (ur *UserRepository) GetUsers(where map[string]interface{}) (*[]models.User return &users, nil } -func (ur *UserRepository) GetUserByID(userID *uint) (*models.User, error) { +func GetUserByID(userID *uint) (*models.User, error) { var user models.User result := database.DB. Preload(clause.Associations).