add: update handling

This commit is contained in:
$(pass /github/name)
2024-09-20 08:29:00 +02:00
parent 62624cd0f8
commit 00facf8758
4 changed files with 406 additions and 157 deletions

View File

@@ -1,17 +1,18 @@
package controllers package controllers
import ( import (
"fmt" "GoMembership/internal/config"
"GoMembership/internal/constants" "GoMembership/internal/constants"
"GoMembership/internal/middlewares" "GoMembership/internal/middlewares"
"GoMembership/internal/models" "GoMembership/internal/models"
"GoMembership/internal/services" "GoMembership/internal/services"
"GoMembership/internal/utils"
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"GoMembership/pkg/errors"
"GoMembership/pkg/logger" "GoMembership/pkg/logger"
) )
@@ -27,12 +28,80 @@ type RegistrationData struct {
User models.User `json:"user"` User models.User `json:"user"`
} }
func (uc *UserController) CurrentUserHandler(c *gin.Context) { func (uc *UserController) UpdateHandler(c *gin.Context) {
userIDString, ok := c.Get("user_id") var user models.User
if !ok || userIDString == nil { if err := c.ShouldBindJSON(&user); err != nil {
logger.Error.Printf("Error getting user_id from header") logger.Error.Printf("Couldn't decode input: %v", err)
c.JSON(http.StatusBadRequest, gin.H{"error": "Couldn't decode request data"})
return
} }
userID := userIDString.(float64) tokenString, err := c.Cookie("jwt")
if err != nil {
logger.Error.Printf("No Auth token: %v\n", err)
c.JSON(http.StatusUnauthorized, gin.H{"error": "No Auth token"})
c.Abort()
return
}
_, claims, err := middlewares.ExtractContentFrom(tokenString)
if err != nil {
logger.Error.Printf("Error retrieving token and claims from JWT")
c.JSON(http.StatusInternalServerError, gin.H{"error": "JWT parsing error"})
return
}
jwtUserID := int64((*claims)["user_id"].(float64))
userRole := int8((*claims)["role_id"].(float64))
if user.ID == 0 {
logger.Error.Printf("No User.ID in request from user with id: %v, aborting", jwtUserID)
c.JSON(http.StatusBadRequest, gin.H{"error": "No user id provided"})
return
}
if user.ID != jwtUserID && userRole < constants.Roles.Editor {
c.JSON(http.StatusForbidden, gin.H{"error": "You are not authorized to update this user"})
return
}
// TODO: If it's not an admin, prevent changes to critical fields
// if userRole != constants.Roles.Admin {
// existingUser, err := uc.Service.GetUserByID(jwtUserID)
// if err != nil {
// c.JSON(http.StatusInternalServerError, gin.H{"error": "Error retrieving user data"})
// return
// }
// user.Email = existingUser.Email
// user.RoleID = existingUser.RoleID
// }
updatedUser, err := uc.Service.UpdateUser(&user)
if err != nil {
switch err {
case errors.ErrUserNotFound:
c.JSON(http.StatusNotFound, gin.H{"error": "User not found"})
case errors.ErrInvalidUserData:
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user data"})
default:
logger.Error.Printf("Failed to update user: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Internal Server error"})
}
return
}
c.JSON(http.StatusAccepted, gin.H{"message": "User updated successfully", "user": updatedUser})
}
func (uc *UserController) CurrentUserHandler(c *gin.Context) {
userIDInterface, ok := c.Get("user_id")
if !ok || userIDInterface == nil {
logger.Error.Printf("Error getting user_id from header")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Missing or invalid user ID type"})
return
}
userID, ok := userIDInterface.(int64)
if !ok {
logger.Error.Printf("Error: user_id is not of type int64")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user ID type"})
return
}
user, err := uc.Service.GetUserByID(int64(userID)) user, err := uc.Service.GetUserByID(int64(userID))
if err != nil { if err != nil {
logger.Error.Printf("Error retrieving valid user: %v", err) logger.Error.Printf("Error retrieving valid user: %v", err)
@@ -44,7 +113,13 @@ func (uc *UserController) CurrentUserHandler(c *gin.Context) {
} }
func (uc *UserController) LogoutHandler(c *gin.Context) { func (uc *UserController) LogoutHandler(c *gin.Context) {
// just clear the JWT cookie tokenString, err := c.Cookie("jwt")
if err != nil {
logger.Error.Printf("unable to get token from cookie: %#v", err)
}
middlewares.InvalidateSession(tokenString)
c.SetCookie("jwt", "", -1, "/", "", true, true) c.SetCookie("jwt", "", -1, "/", "", true, true)
c.JSON(http.StatusOK, gin.H{"message": "Logged out successfully"}) c.JSON(http.StatusOK, gin.H{"message": "Logged out successfully"})
} }
@@ -82,25 +157,17 @@ func (uc *UserController) LoginHandler(c *gin.Context) {
return return
} }
token, err := middlewares.GenerateToken(user.ID) logger.Error.Printf("jwtsevret: %v", config.Auth.JWTSecret)
token, err := middlewares.GenerateToken(config.Auth.JWTSecret, user, "")
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate JWT token"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate JWT token"})
return return
} }
c.SetCookie( utils.SetCookie(c, token)
"jwt",
token,
10*60, // 10 minutes
"/",
"",
true,
true,
)
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "Login successful", "message": "Login successful",
"set-token": token,
}) })
} }
@@ -118,7 +185,7 @@ func (uc *UserController) RegisterUser(c *gin.Context) {
c.JSON(http.StatusNotAcceptable, gin.H{"error": "No subscription model provided"}) c.JSON(http.StatusNotAcceptable, gin.H{"error": "No subscription model provided"})
return return
} }
logger.Error.Printf("user.membership: %#v", regData.User.Membership)
selectedModel, err := uc.MembershipService.GetModelByName(&regData.User.Membership.SubscriptionModel.Name) selectedModel, err := uc.MembershipService.GetModelByName(&regData.User.Membership.SubscriptionModel.Name)
if err != nil { if err != nil {
logger.Error.Printf("%v:No subscription model found: %#v", regData.User.Email, err) logger.Error.Printf("%v:No subscription model found: %#v", regData.User.Email, err)
@@ -134,7 +201,7 @@ func (uc *UserController) RegisterUser(c *gin.Context) {
id, token, err := uc.Service.RegisterUser(&regData.User) id, token, err := uc.Service.RegisterUser(&regData.User)
if err != nil { if err != nil {
logger.Error.Printf("Couldn't register User(%v): %v", regData.User.Email, err) logger.Error.Printf("Couldn't register User(%v): %v", regData.User.Email, err)
c.JSON(int(id), gin.H{"error": fmt.Sprintf("Couldn't register User: %v", err)}) c.JSON(int(id), gin.H{"error": "Couldn't register User"})
return return
} }
regData.User.ID = id regData.User.ID = id
@@ -194,7 +261,7 @@ func (uc *UserController) VerifyMailHandler(c *gin.Context) {
c.HTML(http.StatusUnauthorized, "verification_error.html", gin.H{"ErrorMessage": "Emailadresse wurde schon bestätigt. Sollte dies nicht der Fall sein, wende Dich bitte an info@carsharing-hasloh.de."}) c.HTML(http.StatusUnauthorized, "verification_error.html", gin.H{"ErrorMessage": "Emailadresse wurde schon bestätigt. Sollte dies nicht der Fall sein, wende Dich bitte an info@carsharing-hasloh.de."})
return return
} }
logger.Info.Printf("User: %#v", user) logger.Info.Printf("VerificationMailHandler User: %#v", user.Email)
uc.EmailService.SendWelcomeEmail(user) uc.EmailService.SendWelcomeEmail(user)
c.HTML(http.StatusOK, "verification_success.html", gin.H{"FirstName": user.FirstName}) c.HTML(http.StatusOK, "verification_success.html", gin.H{"FirstName": user.FirstName})

View File

@@ -1,6 +1,7 @@
package controllers package controllers
import ( import (
"bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@@ -9,6 +10,7 @@ import (
"net/url" "net/url"
"path/filepath" "path/filepath"
"regexp" "regexp"
"strconv"
"strings" "strings"
"testing" "testing"
"time" "time"
@@ -22,12 +24,9 @@ import (
"GoMembership/internal/models" "GoMembership/internal/models"
"GoMembership/internal/utils" "GoMembership/internal/utils"
"GoMembership/pkg/logger" "GoMembership/pkg/logger"
)
type loginInput struct { "github.com/golang-jwt/jwt/v5"
Email string `json:"email"` )
Password string `json:"password"`
}
type RegisterUserTest struct { type RegisterUserTest struct {
WantDBData map[string]interface{} WantDBData map[string]interface{}
@@ -37,6 +36,8 @@ type RegisterUserTest struct {
Assert bool Assert bool
} }
var jwtSigningMethod = jwt.SigningMethodHS256
func (rt *RegisterUserTest) SetupContext() (*gin.Context, *httptest.ResponseRecorder, *gin.Engine) { func (rt *RegisterUserTest) SetupContext() (*gin.Context, *httptest.ResponseRecorder, *gin.Engine) {
return GetMockedJSONContext([]byte(rt.Input), "register") return GetMockedJSONContext([]byte(rt.Input), "register")
} }
@@ -61,17 +62,23 @@ func testUserController(t *testing.T) {
tests := getTestUsers() tests := getTestUsers()
for _, tt := range tests { for _, tt := range tests {
logger.Error.Print("==============================================================")
logger.Error.Printf("Register User Testing : %v", tt.Name)
logger.Error.Print("==============================================================")
t.Run(tt.Name, func(t *testing.T) { t.Run(tt.Name, func(t *testing.T) {
if err := runSingleTest(&tt); err != nil { if err := runSingleTest(&tt); err != nil {
t.Fatalf("Test failed: %v", err.Error()) t.Fatalf("Test failed: %v", err.Error())
} }
}) })
} }
testCurrentUserHandler(t)
loginEmail, loginCookie := testLoginHandler(t)
logoutCookie := testCurrentUserHandler(t, loginEmail, loginCookie)
testUpdateUser(t, loginEmail, loginCookie)
testLogoutHandler(t, logoutCookie)
} }
func testLogoutHandler(t *testing.T) { func testLogoutHandler(t *testing.T, loginCookie http.Cookie) {
loginCookie := testCurrentUserHandler(t)
tests := []struct { tests := []struct {
name string name string
@@ -93,6 +100,9 @@ func testLogoutHandler(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
logger.Error.Print("==============================================================")
logger.Error.Printf("Logout User Testing : %v", tt.name)
logger.Error.Print("==============================================================")
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
router := gin.New() router := gin.New()
@@ -125,11 +135,11 @@ func testLogoutHandler(t *testing.T) {
// Verify that the user can no longer access protected routes // Verify that the user can no longer access protected routes
w = httptest.NewRecorder() w = httptest.NewRecorder()
req, _ = http.NewRequest("GET", "/current-user", nil) req, _ = http.NewRequest("GET", "/current", nil)
if logoutCookie != nil { if logoutCookie != nil {
req.AddCookie(logoutCookie) req.AddCookie(logoutCookie)
} }
router.GET("/current-user", middlewares.AuthMiddleware(), Uc.CurrentUserHandler) router.GET("/current", middlewares.AuthMiddleware(), Uc.CurrentUserHandler)
router.ServeHTTP(w, req) router.ServeHTTP(w, req)
assert.Equal(t, http.StatusUnauthorized, w.Code, "User should not be able to access protected routes after logout") assert.Equal(t, http.StatusUnauthorized, w.Code, "User should not be able to access protected routes after logout")
}) })
@@ -196,9 +206,8 @@ func testLoginHandler(t *testing.T) (string, http.Cookie) {
assert.NoError(t, err) assert.NoError(t, err)
if tt.wantToken { if tt.wantToken {
logger.Info.Printf("Response: %#v", response) assert.Contains(t, response, "message")
assert.Contains(t, response, "set-token") assert.Equal(t, "Login successful", response["message"])
assert.NotEmpty(t, response["set-token"])
for _, cookie := range w.Result().Cookies() { for _, cookie := range w.Result().Cookies() {
if cookie.Name == "jwt" { if cookie.Name == "jwt" {
loginCookie = *cookie loginCookie = *cookie
@@ -211,7 +220,8 @@ func testLoginHandler(t *testing.T) (string, http.Cookie) {
} }
assert.NotEmpty(t, loginCookie) assert.NotEmpty(t, loginCookie)
} else { } else {
assert.NotContains(t, response, "set-token") assert.Contains(t, response, "error")
assert.NotEmpty(t, response["error"])
} }
}) })
@@ -220,8 +230,7 @@ func testLoginHandler(t *testing.T) (string, http.Cookie) {
return loginInput.Email, loginCookie return loginInput.Email, loginCookie
} }
func testCurrentUserHandler(t *testing.T) http.Cookie { func testCurrentUserHandler(t *testing.T, loginEmail string, loginCookie http.Cookie) http.Cookie {
loginEmail, loginCookie := testLoginHandler(t)
// This test should run after the user login test // This test should run after the user login test
invalidCookie := http.Cookie{ invalidCookie := http.Cookie{
Name: "jwt", Name: "jwt",
@@ -232,6 +241,7 @@ func testCurrentUserHandler(t *testing.T) http.Cookie {
setupCookie func(*http.Request) setupCookie func(*http.Request)
expectedUserMail string expectedUserMail string
expectedStatus int expectedStatus int
expectNewCookie bool
}{ }{
{ {
name: "With valid cookie", name: "With valid cookie",
@@ -241,6 +251,24 @@ func testCurrentUserHandler(t *testing.T) http.Cookie {
expectedUserMail: loginEmail, expectedUserMail: loginEmail,
expectedStatus: http.StatusOK, expectedStatus: http.StatusOK,
}, },
{
name: "With valid expired cookie",
setupCookie: func(req *http.Request) {
sessionID := "test-session"
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{
"user_id": 1,
"role_id": 0,
"session_id": sessionID,
"exp": time.Now().Add(-time.Hour).Unix(), // Expired 1 hour ago
})
tokenString, _ := token.SignedString([]byte(config.Auth.JWTSecret))
req.AddCookie(&http.Cookie{Name: "jwt", Value: tokenString})
middlewares.UpdateSession(sessionID, 1) // Add a valid session
},
expectedUserMail: config.Recipients.AdminEmail,
expectedStatus: http.StatusOK,
expectNewCookie: true,
},
{ {
name: "Without cookie", name: "Without cookie",
setupCookie: func(req *http.Request) {}, setupCookie: func(req *http.Request) {},
@@ -259,18 +287,15 @@ func testCurrentUserHandler(t *testing.T) http.Cookie {
logger.Error.Print("==============================================================") logger.Error.Print("==============================================================")
logger.Error.Printf("Testing : %v", tt.name) logger.Error.Printf("Testing : %v", tt.name)
logger.Error.Print("==============================================================") logger.Error.Print("==============================================================")
if tt.expectedStatus == http.StatusOK {
time.Sleep(time.Second) // Small delay to ensure different timestamps to get a different JWT token
}
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
router := gin.New() router := gin.New()
router.Use(middlewares.AuthMiddleware()) router.Use(middlewares.AuthMiddleware())
router.GET("/current-user", Uc.CurrentUserHandler) router.GET("/current", Uc.CurrentUserHandler)
w := httptest.NewRecorder() w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/current-user", nil) req, _ := http.NewRequest("GET", "/current", nil)
tt.setupCookie(req) tt.setupCookie(req)
router.ServeHTTP(w, req) router.ServeHTTP(w, req)
@@ -290,9 +315,13 @@ func testCurrentUserHandler(t *testing.T) http.Cookie {
break break
} }
} }
assert.NotNil(t, newCookie, "Cookie should be renewed") if tt.expectNewCookie {
assert.NotNil(t, newCookie, "New cookie should be set for expired token")
assert.NotEqual(t, loginCookie.Value, newCookie.Value, "Cookie value should be different") assert.NotEqual(t, loginCookie.Value, newCookie.Value, "Cookie value should be different")
assert.True(t, newCookie.MaxAge > 0, "New cookie should not be expired") assert.True(t, newCookie.MaxAge > 0, "New cookie should not be expired")
} else {
assert.Nil(t, newCookie, "No new cookie should be set for non-expired token")
}
} else { } else {
// For unauthorized requests, check for an error message // For unauthorized requests, check for an error message
var errorResponse map[string]string var errorResponse map[string]string
@@ -316,29 +345,32 @@ func validateUser(assert bool, wantDBData map[string]interface{}) error {
if assert != (len(*users) != 0) { if assert != (len(*users) != 0) {
return fmt.Errorf("User entry query didn't met expectation: %v != %#v", assert, *users) return fmt.Errorf("User entry query didn't met expectation: %v != %#v", assert, *users)
} }
if assert { if assert {
//check for email delivery //check for email delivery
messages := utils.SMTPGetMessages() messages := utils.SMTPGetMessages()
for _, message := range messages { for _, message := range messages {
mail, err := utils.DecodeMail(message.MsgRequest()) mail, err := utils.DecodeMail(message.MsgRequest())
if err != nil { if err != nil {
logger.Error.Printf("Error in validateUser: %#v", err)
return err return err
} }
if strings.Contains(mail.Subject, constants.MailRegistrationSubject) { if strings.Contains(mail.Subject, constants.MailRegistrationSubject) {
if err := checkRegistrationMail(mail, &(*users)[0]); err != nil { if err := checkRegistrationMail(mail, &(*users)[0]); err != nil {
logger.Error.Printf("Error in checkRegistrationMail: %#v", err)
return err return err
} }
} else if strings.Contains(mail.Subject, constants.MailVerificationSubject) { } else if strings.Contains(mail.Subject, constants.MailVerificationSubject) {
if err := checkVerificationMail(mail, &(*users)[0]); err != nil { if err := checkVerificationMail(mail, &(*users)[0]); err != nil {
logger.Error.Printf("Error in checkVerificationMail: %#v", err)
return err return err
} }
verifiedUsers, err := Uc.Service.GetUsers(wantDBData) verifiedUsers, err := Uc.Service.GetUsers(wantDBData)
if err != nil { if err != nil {
logger.Error.Printf("Error in GetUsers: %#v", err)
return err return err
} }
if (*verifiedUsers)[0].Status != constants.VerifiedStatus { if (*verifiedUsers)[0].Status != constants.VerifiedStatus {
return fmt.Errorf("Users status isn't verified after email verification. Status is: %#v", (*verifiedUsers)[0].Status) return fmt.Errorf("Users(%v) status isn't verified after email verification. Status is: %v", (*verifiedUsers)[0].Email, (*verifiedUsers)[0].Status)
} }
} else { } else {
return fmt.Errorf("Subject not expected: %v", mail.Subject) return fmt.Errorf("Subject not expected: %v", mail.Subject)
@@ -348,6 +380,168 @@ func validateUser(assert bool, wantDBData map[string]interface{}) error {
return nil return nil
} }
func testUpdateUser(t *testing.T, loginEmail string, loginCookie http.Cookie) {
invalidCookie := http.Cookie{
Name: "jwt",
Value: "invalid.token.here",
}
// Get the user we just created
users, err := Uc.Service.GetUsers(map[string]interface{}{"email": "john.doe@example.com"})
if err != nil || len(*users) == 0 {
t.Fatalf("Failed to get test user: %v", err)
}
user := (*users)[0]
tests := []struct {
name string
setupCookie func(*http.Request)
updateFunc func(*models.User)
expectedStatus int
expectedError string
}{
{
name: "Valid Update",
setupCookie: func(req *http.Request) {
req.AddCookie(&loginCookie)
},
updateFunc: func(u *models.User) {
u.Password = ""
u.FirstName = "John Updated"
u.LastName = "Doe Updated"
u.Phone = "01738484994"
},
expectedStatus: http.StatusAccepted,
},
{
name: "Password Update",
setupCookie: func(req *http.Request) {
req.AddCookie(&loginCookie)
},
updateFunc: func(u *models.User) {
u.Password = "NewPassword"
},
expectedStatus: http.StatusAccepted,
},
{
name: "Valid Update, invalid cookie",
setupCookie: func(req *http.Request) {
req.AddCookie(&invalidCookie)
},
updateFunc: func(u *models.User) {
u.Password = ""
u.FirstName = "John Updated"
u.LastName = "Doe Updated"
u.Phone = "01738484994"
},
expectedStatus: http.StatusUnauthorized,
expectedError: "Auth token invalid",
},
{
name: "Invalid Email Update",
setupCookie: func(req *http.Request) {
req.AddCookie(&loginCookie)
},
updateFunc: func(u *models.User) {
u.Password = ""
u.Email = "invalid-email"
},
expectedStatus: http.StatusBadRequest,
expectedError: "Invalid user data",
},
{
name: "User ID mismatch while not admin",
setupCookie: func(req *http.Request) {
req.AddCookie(&loginCookie)
},
updateFunc: func(u *models.User) {
u.Password = ""
u.ID = 1
u.FirstName = "John Missing ID"
},
expectedStatus: http.StatusForbidden,
expectedError: "You are not authorized to update this user",
},
// {
// name: "Non-existent User",
// setupCookie: func(req *http.Request) {
// req.AddCookie(&loginCookie)
// },
// updateFunc: func(u *models.User) {
// u.Password = ""
// u.ID = 99999
// u.FirstName = "Non-existent"
// },
// expectedStatus: http.StatusNotFound,
// expectedError: "User not found",
// },
}
for _, tt := range tests {
logger.Error.Print("==============================================================")
logger.Error.Printf("Update Testing : %v", tt.name)
logger.Error.Print("==============================================================")
t.Run(tt.name, func(t *testing.T) {
// Create a copy of the user and apply the updates
updatedUser := user
tt.updateFunc(&updatedUser)
// Convert user to JSON
jsonData, err := json.Marshal(updatedUser)
if err != nil {
t.Fatalf("Failed to marshal user data: %v", err)
}
// Create request
req, _ := http.NewRequest("PUT", "/users/"+strconv.FormatInt(user.ID, 10), bytes.NewBuffer(jsonData))
req.Header.Set("Content-Type", "application/json")
tt.setupCookie(req)
// Create response recorder
w := httptest.NewRecorder()
// Set up router and add middleware
router := gin.New()
router.Use(middlewares.AuthMiddleware())
router.PUT("/users/:id", Uc.UpdateHandler)
// Perform request
router.ServeHTTP(w, req)
// Check status code
assert.Equal(t, tt.expectedStatus, w.Code)
// Parse response
var response map[string]interface{}
err = json.Unmarshal(w.Body.Bytes(), &response)
assert.NoError(t, err)
if tt.expectedError != "" {
assert.Equal(t, tt.expectedError, response["error"])
} else {
assert.Equal(t, "User updated successfully", response["message"])
// Verify the update in the database
updatedUserFromDB, err := Uc.Service.GetUserByID(user.ID)
updatedUserFromDB.UpdatedAt = updatedUser.UpdatedAt
updatedUserFromDB.Membership.UpdatedAt = updatedUser.Membership.UpdatedAt
updatedUserFromDB.BankAccount.UpdatedAt = updatedUser.BankAccount.UpdatedAt
updatedUserFromDB.Verification.UpdatedAt = updatedUser.Verification.UpdatedAt
updatedUserFromDB.Membership.SubscriptionModel.UpdatedAt = updatedUser.Membership.SubscriptionModel.UpdatedAt
if updatedUser.Password == "" {
assert.Equal(t, user.Password, (*updatedUserFromDB).Password)
} else {
assert.NotEqual(t, user.Password, (*updatedUserFromDB).Password)
updatedUser.Password = ""
}
updatedUserFromDB.Password = ""
assert.NoError(t, err)
assert.Equal(t, updatedUser, *updatedUserFromDB, "Updated user in DB does not match expected user")
}
})
}
}
func checkWelcomeMail(message *utils.Email, user *models.User) error { func checkWelcomeMail(message *utils.Email, user *models.User) error {
if !strings.Contains(message.To, user.Email) { if !strings.Contains(message.To, user.Email) {
@@ -506,23 +700,6 @@ func getVerificationURL(mailBody string) (string, error) {
} }
// TEST DATA: // TEST DATA:
func getBaseUser() models.User {
return models.User{
DateOfBirth: time.Date(2000, time.January, 1, 0, 0, 0, 0, time.UTC),
FirstName: "John",
LastName: "Doe",
Email: "john.doe@example.com",
Address: "Pablo Escobar Str. 4",
ZipCode: "25474",
City: "Hasloh",
Phone: "01738484993",
BankAccount: models.BankAccount{IBAN: "DE89370400440532013000"},
Membership: models.Membership{SubscriptionModel: models.SubscriptionModel{Name: "Basic"}},
ProfilePicture: "",
Password: "password123",
Company: "",
}
}
func customizeInput(customize func(models.User) models.User) *RegistrationData { func customizeInput(customize func(models.User) models.User) *RegistrationData {
user := getBaseUser() user := getBaseUser()

View File

@@ -1,8 +1,6 @@
package repositories package repositories
import ( import (
"time"
"gorm.io/gorm" "gorm.io/gorm"
"GoMembership/internal/constants" "GoMembership/internal/constants"
@@ -16,13 +14,13 @@ import (
type UserRepositoryInterface interface { type UserRepositoryInterface interface {
CreateUser(user *models.User) (int64, error) CreateUser(user *models.User) (int64, error)
UpdateUser(userID int64, user *models.User) error UpdateUser(user *models.User) (*models.User, error)
GetUsers(where map[string]interface{}) (*[]models.User, error) GetUsers(where map[string]interface{}) (*[]models.User, error)
GetUserByID(id int64) (*models.User, error) GetUserByID(userID *int64) (*models.User, error)
GetUserByEmail(email string) (*models.User, error) GetUserByEmail(email string) (*models.User, error)
SetVerificationToken(user *models.User, token *string) (int64, error) SetVerificationToken(verification *models.Verification) (int64, error)
IsVerified(userID *int64) (bool, error) IsVerified(userID *int64) (bool, error)
VerifyUserOfToken(token *string) (*models.User, error) GetVerificationOfToken(token *string) (*models.Verification, error)
} }
type UserRepository struct{} type UserRepository struct{}
@@ -35,21 +33,36 @@ func (ur *UserRepository) CreateUser(user *models.User) (int64, error) {
return user.ID, nil return user.ID, nil
} }
func (ur *UserRepository) UpdateUser(userID int64, user *models.User) error { func (ur *UserRepository) UpdateUser(user *models.User) (*models.User, error) {
// logger.Info.Printf("Updating User: %#v\n", user)
if user == nil { if user == nil {
return errors.ErrNoData return nil, errors.ErrNoData
} }
result := database.DB.Session(&gorm.Session{FullSaveAssociations: true}).Updates(&user)
err := database.DB.Transaction(func(tx *gorm.DB) error {
if err := tx.First(&models.User{}, user.ID).Error; err != nil {
return err
}
result := tx.Session(&gorm.Session{FullSaveAssociations: true}).Updates(user)
if result.Error != nil { if result.Error != nil {
return result.Error return result.Error
} }
if result.RowsAffected == 0 { if result.RowsAffected == 0 {
return errors.ErrNoRowsAffected return errors.ErrNoRowsAffected
} }
return nil return nil
})
if err != nil {
return nil, err
}
var updatedUser models.User
if err := database.DB.First(&updatedUser, user.ID).Error; err != nil {
return nil, err
}
return &updatedUser, nil
} }
func (ur *UserRepository) GetUsers(where map[string]interface{}) (*[]models.User, error) { func (ur *UserRepository) GetUsers(where map[string]interface{}) (*[]models.User, error) {
@@ -70,7 +83,7 @@ func (ur *UserRepository) GetUsers(where map[string]interface{}) (*[]models.User
return &users, nil return &users, nil
} }
func (ur *UserRepository) GetUserByID(id int64) (*models.User, error) { func (ur *UserRepository) GetUserByID(userID *int64) (*models.User, error) {
var user models.User var user models.User
result := database.DB. result := database.DB.
Preload("Consents"). Preload("Consents").
@@ -78,7 +91,7 @@ func (ur *UserRepository) GetUserByID(id int64) (*models.User, error) {
Preload("Verification"). Preload("Verification").
Preload("Membership", func(db *gorm.DB) *gorm.DB { Preload("Membership", func(db *gorm.DB) *gorm.DB {
return db.Preload("SubscriptionModel") return db.Preload("SubscriptionModel")
}).First(&user, id) }).First(&user, userID)
if result.Error != nil { if result.Error != nil {
if result.Error == gorm.ErrRecordNotFound { if result.Error == gorm.ErrRecordNotFound {
return nil, gorm.ErrRecordNotFound return nil, gorm.ErrRecordNotFound
@@ -112,7 +125,8 @@ func (ur *UserRepository) IsVerified(userID *int64) (bool, error) {
return user.Status != constants.UnverifiedStatus, nil return user.Status != constants.UnverifiedStatus, nil
} }
func (ur *UserRepository) VerifyUserOfToken(token *string) (*models.User, error) { func (ur *UserRepository) GetVerificationOfToken(token *string) (*models.Verification, error) {
var emailVerification models.Verification var emailVerification models.Verification
result := database.DB.Where("verification_token = ?", token).First(&emailVerification) result := database.DB.Where("verification_token = ?", token).First(&emailVerification)
if result.Error != nil { if result.Error != nil {
@@ -121,49 +135,10 @@ func (ur *UserRepository) VerifyUserOfToken(token *string) (*models.User, error)
} }
return nil, result.Error return nil, result.Error
} }
return &emailVerification, nil
// Check if the user is already verified
verified, err := ur.IsVerified(&emailVerification.UserID)
if err != nil {
return nil, err
}
user, err := ur.GetUserByID(emailVerification.UserID)
if err != nil {
return nil, err
}
if verified {
return user, errors.ErrAlreadyVerified
}
// Update user status to active
t := time.Now()
emailVerification.EmailVerifiedAt = &t
user.Status = constants.VerifiedStatus
user.Verification = emailVerification
err = ur.UpdateUser(emailVerification.UserID, user)
if err != nil {
return nil, err
}
return user, nil
}
func (ur *UserRepository) SetVerificationToken(user *models.User, token *string) (int64, error) {
// Check if user is already verified
verified, err := ur.IsVerified(&user.ID)
if err != nil {
return -1, err
}
if verified {
return -1, errors.ErrAlreadyVerified
}
// Prepare the Verification record
verification := models.Verification{
UserID: user.ID,
VerificationToken: *token,
} }
func (ur *UserRepository) SetVerificationToken(verification *models.Verification) (int64, error) {
// Use GORM to insert or update the Verification record // Use GORM to insert or update the Verification record
result := database.DB.Clauses(clause.OnConflict{ result := database.DB.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "user_id"}}, Columns: []clause.Column{{Name: "user_id"}},

View File

@@ -8,10 +8,12 @@ import (
"GoMembership/internal/models" "GoMembership/internal/models"
"GoMembership/internal/repositories" "GoMembership/internal/repositories"
"GoMembership/internal/utils" "GoMembership/internal/utils"
"GoMembership/pkg/errors"
"GoMembership/pkg/logger" "GoMembership/pkg/logger"
"github.com/alexedwards/argon2id" "github.com/alexedwards/argon2id"
"github.com/go-playground/validator/v10" "github.com/go-playground/validator/v10"
"gorm.io/gorm"
"time" "time"
) )
@@ -22,14 +24,42 @@ type UserServiceInterface interface {
GetUserByID(id int64) (*models.User, error) GetUserByID(id int64) (*models.User, error)
GetUsers(where map[string]interface{}) (*[]models.User, error) GetUsers(where map[string]interface{}) (*[]models.User, error)
VerifyUser(token *string) (*models.User, error) VerifyUser(token *string) (*models.User, error)
UpdateUser(user *models.User) (*models.User, error)
} }
type UserService struct { type UserService struct {
Repo repositories.UserRepositoryInterface Repo repositories.UserRepositoryInterface
} }
func (service *UserService) UpdateUser(user *models.User) (*models.User, error) {
if err := validateUserData(user); err != nil {
return nil, errors.ErrInvalidUserData
}
if user.Password != "" {
setPassword(user.Password, user)
}
user.UpdatedAt = time.Now()
updatedUser, err := service.Repo.UpdateUser(user)
if err != nil {
if err == gorm.ErrRecordNotFound {
return nil, errors.ErrUserNotFound
}
if strings.Contains(err.Error(), "UNIQUE constraint failed") {
return nil, errors.ErrDuplicateEntry
}
return nil, err
}
return updatedUser, nil
}
func (service *UserService) RegisterUser(user *models.User) (int64, string, error) { func (service *UserService) RegisterUser(user *models.User) (int64, string, error) {
if err := validateRegistrationData(user); err != nil { if err := validateUserData(user); err != nil {
return http.StatusNotAcceptable, "", err return http.StatusNotAcceptable, "", err
} }
@@ -56,44 +86,22 @@ func (service *UserService) RegisterUser(user *models.User) (int64, string, erro
logger.Info.Printf("TOKEN: %v", token) logger.Info.Printf("TOKEN: %v", token)
_, err = service.Repo.SetVerificationToken(user, &token) // Check if user is already verified
verified, err := service.Repo.IsVerified(&user.ID)
if err != nil { if err != nil {
return http.StatusInternalServerError, "", err return http.StatusInternalServerError, "", err
} }
if verified {
return id, token, nil return http.StatusAlreadyReported, "", errors.ErrAlreadyVerified
} }
func (service *UserService) Update(user *models.User) (int64, string, error) { // Prepare the Verification record
if err := validateRegistrationData(user); err != nil { verification := models.Verification{
return http.StatusNotAcceptable, "", err UserID: user.ID,
VerificationToken: token,
} }
setPassword(user.Password, user) if _, err = service.Repo.SetVerificationToken(&verification); err != nil {
user.Status = constants.UnverifiedStatus
user.CreatedAt = time.Now()
user.UpdatedAt = time.Now()
id, err := service.Repo.CreateUser(user)
if err != nil && strings.Contains(err.Error(), "UNIQUE constraint failed") {
return http.StatusConflict, "", err
} else if err != nil {
return http.StatusInternalServerError, "", err
}
user.ID = id
token, err := utils.GenerateVerificationToken()
if err != nil {
return http.StatusInternalServerError, "", err
}
logger.Info.Printf("TOKEN: %v", token)
_, err = service.Repo.SetVerificationToken(user, &token)
if err != nil {
return http.StatusInternalServerError, "", err return http.StatusInternalServerError, "", err
} }
@@ -102,7 +110,7 @@ func (service *UserService) Update(user *models.User) (int64, string, error) {
func (service *UserService) GetUserByID(id int64) (*models.User, error) { func (service *UserService) GetUserByID(id int64) (*models.User, error) {
return service.Repo.GetUserByID(id) return service.Repo.GetUserByID(&id)
} }
func (service *UserService) GetUserByEmail(email string) (*models.User, error) { func (service *UserService) GetUserByEmail(email string) (*models.User, error) {
@@ -114,19 +122,41 @@ func (service *UserService) GetUsers(where map[string]interface{}) (*[]models.Us
} }
func (service *UserService) VerifyUser(token *string) (*models.User, error) { func (service *UserService) VerifyUser(token *string) (*models.User, error) {
user, err := service.Repo.VerifyUserOfToken(token) verification, err := service.Repo.GetVerificationOfToken(token)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Check if the user is already verified
verified, err := service.Repo.IsVerified(&verification.UserID)
if err != nil {
return nil, err
}
user, err := service.Repo.GetUserByID(&verification.UserID)
if err != nil {
return nil, err
}
if verified {
return user, errors.ErrAlreadyVerified
}
// Update user status to active
t := time.Now()
verification.EmailVerifiedAt = &t
user.Status = constants.VerifiedStatus
user.Verification = *verification
user.ID = verification.UserID
service.Repo.UpdateUser(user)
return user, nil return user, nil
} }
func validateRegistrationData(user *models.User) error { func validateUserData(user *models.User) error {
validate := validator.New() validate := validator.New()
validate.RegisterValidation("age", utils.AgeValidator) validate.RegisterValidation("age", utils.AgeValidator)
validate.RegisterValidation("bic", utils.BICValidator) validate.RegisterValidation("bic", utils.BICValidator)
validate.RegisterValidation("iban", utils.IBANValidator) validate.RegisterValidation("iban", utils.IBANValidator)
validate.RegisterValidation("subscriptionModel", utils.SubscriptionModelValidator) validate.RegisterValidation("subscriptionModel", utils.SubscriptionModelValidator)
validate.RegisterValidation("safe_content", utils.ValidateSafeContent)
validate.RegisterValidation("membershipField", utils.ValidateRequiredMembershipField) validate.RegisterValidation("membershipField", utils.ValidateRequiredMembershipField)
return validate.Struct(user) return validate.Struct(user)