From f55ef5cf7093eb68e6fabc499230a76cdc39211b Mon Sep 17 00:00:00 2001 From: Alex <$(pass /github/email)> Date: Sun, 23 Feb 2025 12:29:12 +0100 Subject: [PATCH] backend added struct merging and FieldPermissionsOnRoleId --- internal/constants/constants.go | 16 ++ internal/controllers/controllers_test.go | 18 ++ internal/controllers/user_controller.go | 28 ++- internal/controllers/user_controller_test.go | 188 +++++++++++++------ internal/services/user_service.go | 20 -- internal/utils/priviliges.go | 126 ++++++++++++- internal/utils/priviliges_test.go | 176 +++++++++++++++++ internal/validation/user_validation.go | 2 +- 8 files changed, 498 insertions(+), 76 deletions(-) create mode 100644 internal/utils/priviliges_test.go diff --git a/internal/constants/constants.go b/internal/constants/constants.go index 1c39882..e1e4f54 100644 --- a/internal/constants/constants.go +++ b/internal/constants/constants.go @@ -79,3 +79,19 @@ const PRIV_VIEW = 1 const PRIV_ADD = 2 const PRIV_EDIT = 4 const PRIV_DELETE = 8 + +var MemberUpdateFields = map[string]bool{ + "Email": true, + "Phone": true, + "Company": true, + "Address": true, + "ZipCode": true, + "City": true, + "Licence.Categories": true, + "BankAccount.Bank": true, + "BankAccount.AccountHolderName": true, + "BankAccount.IBAN": true, + "BankAccount.BIC": true, +} + +// "Password": true, diff --git a/internal/controllers/controllers_test.go b/internal/controllers/controllers_test.go index 84401ff..17bbd6c 100644 --- a/internal/controllers/controllers_test.go +++ b/internal/controllers/controllers_test.go @@ -14,6 +14,7 @@ import ( "log" + "github.com/alexedwards/argon2id" "github.com/gin-gonic/gin" "GoMembership/internal/config" @@ -116,6 +117,22 @@ func TestSuite(t *testing.T) { if err := initLicenceCategories(); err != nil { log.Fatalf("Failed to init Categories: %v", err) } + hash, err := argon2id.CreateHash("securepassword", argon2id.DefaultParams) + admin := models.User{ + FirstName: "Ad", + LastName: "min", + Email: "admin@example.com", + Password: hash, + DateOfBirth: time.Date(1990, 1, 1, 0, 0, 0, 0, time.UTC), + Company: "SampleCorp", + Phone: "+123456789", + Address: "123 Main Street", + ZipCode: "12345", + City: "SampleCity", + Status: 1, + RoleID: 8, + } + database.DB.Create(&admin) validation.SetupValidators() t.Run("userController", func(t *testing.T) { testUserController(t) @@ -262,6 +279,7 @@ func getBaseUser() models.User { ProfilePicture: "", Password: "password123", Company: "", + RoleID: 8, } } diff --git a/internal/controllers/user_controller.go b/internal/controllers/user_controller.go index 3f2a39c..77da7ca 100644 --- a/internal/controllers/user_controller.go +++ b/internal/controllers/user_controller.go @@ -91,9 +91,35 @@ func (uc *UserController) UpdateHandler(c *gin.Context) { user = updateData.User if !utils.HasPrivilige(requestUser, constants.Priviliges.Update) && user.ID != requestUser.ID { - utils.RespondWithError(c, errors.ErrNotAuthorized, "Not allowed to update user", http.StatusForbidden, "user.user", "server.error.unauthorized") + utils.RespondWithError(c, errors.ErrNotAuthorized, "Not allowed to update user", http.StatusUnauthorized, "user.user", "server.error.unauthorized") return } + existingUser, err := uc.Service.GetUserByID(user.ID) + if err != nil { + utils.RespondWithError(c, err, "Error finding an existing user", http.StatusNotFound, "user.user", "server.error.not_found") + return + } + // user.Membership.ID = existingUser.Membership.ID + + // user.MembershipID = existingUser.MembershipID + // if existingUser.Licence != nil { + // user.Licence.ID = existingUser.Licence.ID + // } + // user.LicenceID = existingUser.LicenceID + // user.BankAccount.ID = existingUser.BankAccount.ID + // user.BankAccountID = existingUser.BankAccountID + + if requestUser.RoleID <= constants.Priviliges.View { + existingUser.Password = "" + if err := utils.FilterAllowedStructFields(&user, existingUser, constants.MemberUpdateFields, ""); err != nil { + if err.Error() == "Not authorized" { + utils.RespondWithError(c, errors.ErrNotAuthorized, "Trying to update unauthorized fields", http.StatusUnauthorized, "user.user", "server.error.unauthorized") + return + } + utils.RespondWithError(c, err, "Error filtering users input fields", http.StatusInternalServerError, "user.user", "server.error.internal_server_error") + return + } + } updatedUser, err := uc.Service.UpdateUser(&user) if err != nil { diff --git a/internal/controllers/user_controller_test.go b/internal/controllers/user_controller_test.go index 703b335..a0a9b22 100644 --- a/internal/controllers/user_controller_test.go +++ b/internal/controllers/user_controller_test.go @@ -75,7 +75,37 @@ func testUserController(t *testing.T) { loginEmail, loginCookie := testLoginHandler(t) logoutCookie := testCurrentUserHandler(t, loginEmail, loginCookie) - testUpdateUser(t, loginCookie) + + // creating a admin cookie + c, w, _ := GetMockedJSONContext([]byte(`{ + "email": "admin@example.com", + "password": "securepassword" + }`), "/login") + + Uc.LoginHandler(c) + + var response map[string]interface{} + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + + assert.Equal(t, "Login successful", response["message"]) + var adminCookie http.Cookie + for _, cookie := range w.Result().Cookies() { + if cookie.Name == "jwt" { + adminCookie = *cookie + + tokenString := adminCookie.Value + _, claims, err := middlewares.ExtractContentFrom(tokenString) + assert.NoError(t, err, "FAiled getting cookie string") + jwtUserID := uint((*claims)["user_id"].(float64)) + user, err := Uc.Service.GetUserByID(jwtUserID) + assert.NoError(t, err, "FAiled getting cookie string") + logger.Error.Printf("ADMIN USER: %#v", user) + break + } + } + assert.NotEmpty(t, adminCookie) + testUpdateUser(t, loginCookie, adminCookie) testLogoutHandler(t, logoutCookie) } @@ -190,7 +220,7 @@ func testLoginHandler(t *testing.T) (string, http.Cookie) { for _, tt := range tests { logger.Error.Print("==============================================================") - logger.Error.Printf("Testing : %v", tt.name) + logger.Error.Printf("Login Testing : %v", tt.name) logger.Error.Print("==============================================================") t.Run(tt.name, func(t *testing.T) { // Setup @@ -213,6 +243,14 @@ func testLoginHandler(t *testing.T) (string, http.Cookie) { if cookie.Name == "jwt" { loginCookie = *cookie + // tokenString := loginCookie.Value + // _, claims, err := middlewares.ExtractContentFrom(tokenString) + // assert.NoError(t, err, "FAiled getting cookie string") + // jwtUserID := uint((*claims)["user_id"].(float64)) + // user, err := Uc.Service.GetUserByID(jwtUserID) + // assert.NoError(t, err, "FAiled getting cookie string") + + // logger.Error.Printf("cookie user: %#v", user) err = json.Unmarshal([]byte(tt.input), &loginInput) assert.NoError(t, err, "Failed to unmarshal input JSON") @@ -413,7 +451,7 @@ func validateUser(assert bool, wantDBData map[string]interface{}) error { return nil } -func testUpdateUser(t *testing.T, loginCookie http.Cookie) { +func testUpdateUser(t *testing.T, loginCookie http.Cookie, adminCookie http.Cookie) { invalidCookie := http.Cookie{ Name: "jwt", @@ -437,13 +475,14 @@ func testUpdateUser(t *testing.T, loginCookie http.Cookie) { name string setupCookie func(*http.Request) updateFunc func(*models.User) + expectedReturn func(*models.User) expectedStatus int expectedErrors []map[string]string }{ { - name: "Valid Update", + name: "Valid Admin Update", setupCookie: func(req *http.Request) { - req.AddCookie(&loginCookie) + req.AddCookie(&adminCookie) }, updateFunc: func(u *models.User) { u.Password = "" @@ -486,8 +525,23 @@ func testUpdateUser(t *testing.T, loginCookie http.Cookie) { {"field": "Email", "key": "server.validation.email"}, }, }, + { - name: "Change Number", + name: "admin may change licence number", + setupCookie: func(req *http.Request) { + req.AddCookie(&adminCookie) + }, + updateFunc: func(u *models.User) { + u.Password = "" + u.FirstName = "John Updated" + u.LastName = "Doe Updated" + u.Phone = "01738484994" + u.Licence.Number = "B072RRE2I50" + }, + expectedStatus: http.StatusAccepted, + }, + { + name: "Change phone number", setupCookie: func(req *http.Request) { req.AddCookie(&loginCookie) }, @@ -578,12 +632,13 @@ func testUpdateUser(t *testing.T, loginCookie http.Cookie) { updateFunc: func(u *models.User) { u.Password = "" u.ID = 1 + u.FirstName = "John Updated" u.LastName = "Doe Updated" u.Phone = "01738484994" u.Licence.Number = "B072RRE2I50" u.FirstName = "John Missing ID" }, - expectedStatus: http.StatusForbidden, + expectedStatus: http.StatusUnauthorized, expectedErrors: []map[string]string{ {"field": "user.user", "key": "server.error.unauthorized"}, }, @@ -594,7 +649,27 @@ func testUpdateUser(t *testing.T, loginCookie http.Cookie) { req.AddCookie(&loginCookie) }, updateFunc: func(u *models.User) { + u.FirstName = "John Updated" + u.LastName = "Doe Updated" + u.Phone = "01738484994" + u.Licence.Number = "B072RRE2I50" + u.Password = "NewPassword" + }, + expectedReturn: func(u *models.User) { u.Password = "" + u.FirstName = "John Updated" + u.LastName = "Doe Updated" + u.Phone = "01738484994" + u.Licence.Number = "B072RRE2I50" + }, + expectedStatus: http.StatusAccepted, + }, + { + name: "Admin Password Update", + setupCookie: func(req *http.Request) { + req.AddCookie(&adminCookie) + }, + updateFunc: func(u *models.User) { u.LastName = "Doe Updated" u.Phone = "01738484994" u.Licence.Number = "B072RRE2I50" @@ -602,19 +677,21 @@ func testUpdateUser(t *testing.T, loginCookie http.Cookie) { }, expectedStatus: http.StatusAccepted, }, - // { - // name: "Non-existent User", - // setupCookie: func(req *http.Request) { - // req.AddCookie(&loginCookie) - // }, - // updateFunc: func(u *models.User) { - // u.Password = "" - // u.ID = 99999 - // u.FirstName = "Non-existent" - // }, - // expectedStatus: http.StatusNotFound, - // expectedError: "User not found", - // }, + { + name: "Non-existent User", + setupCookie: func(req *http.Request) { + req.AddCookie(&loginCookie) + }, + updateFunc: func(u *models.User) { + u.Password = "" + u.ID = 99999 + u.FirstName = "Non-existent" + }, + expectedErrors: []map[string]string{ + {"field": "user.user", "key": "server.error.unauthorized"}, + }, + expectedStatus: http.StatusUnauthorized, + }, } for _, tt := range tests { logger.Error.Print("==============================================================") @@ -623,9 +700,8 @@ func testUpdateUser(t *testing.T, loginCookie http.Cookie) { t.Run(tt.name, func(t *testing.T) { // Create a copy of the user and apply the updates updatedUser := user - logger.Error.Printf("user to be updated: %+v", user.Licence) + // logger.Error.Printf("users licence to be updated: %+v", user.Licence) tt.updateFunc(&updatedUser) - // Convert user to JSON updateData := &RegistrationData{User: updatedUser} jsonData, err := json.Marshal(updateData) @@ -633,7 +709,11 @@ func testUpdateUser(t *testing.T, loginCookie http.Cookie) { t.Fatalf("Failed to marshal user data: %v", err) } - // logger.Error.Printf("Updated User: %#v", updatedUser) + logger.Error.Printf("Updated User: %#v", updatedUser.Safe()) + if tt.expectedReturn != nil { + tt.expectedReturn(&updatedUser) + } + // Create request req, _ := http.NewRequest("PUT", "/users/"+strconv.FormatUint(uint64(user.ID), 10), bytes.NewBuffer(jsonData)) req.Header.Set("Content-Type", "application/json") @@ -987,6 +1067,7 @@ func getTestUsers() []RegisterUserTest { Assert: false, Input: GenerateInputJSON(customizeInput(func(user models.User) models.User { user.BankAccount.IBAN = "" + user.RoleID = 0 return user })), }, @@ -997,6 +1078,7 @@ func getTestUsers() []RegisterUserTest { Assert: false, Input: GenerateInputJSON(customizeInput(func(user models.User) models.User { user.BankAccount.IBAN = "DE1234234123134" + user.RoleID = 0 return user })), }, @@ -1110,35 +1192,35 @@ func getTestUsers() []RegisterUserTest { // return user // })), // }, - { - Name: "empty driverslicence number, should fail", - WantResponse: http.StatusBadRequest, - WantDBData: map[string]interface{}{"email": "john.wronglicence.doe@example.com"}, - Assert: false, - Input: GenerateInputJSON(customizeInput(func(user models.User) models.User { - user.Email = "john.wronglicence.doe@example.com" - user.Licence = &models.Licence{ - Number: "", - ExpirationDate: time.Now().AddDate(1, 0, 0), - IssuedDate: time.Now().AddDate(-1, 0, 0), - } - return user - })), - }, - { - Name: "Correct Licence number, should pass", - WantResponse: http.StatusCreated, - WantDBData: map[string]interface{}{"email": "john.correctLicenceNumber@example.com"}, - Assert: true, - Input: GenerateInputJSON(customizeInput(func(user models.User) models.User { - user.Email = "john.correctLicenceNumber@example.com" - user.Licence = &models.Licence{ - Number: "B072RRE2I55", - ExpirationDate: time.Now().AddDate(1, 0, 0), - IssuedDate: time.Now().AddDate(-1, 0, 0), - } - return user - })), - }, + // { + // Name: "empty driverslicence number, should fail", + // WantResponse: http.StatusBadRequest, + // WantDBData: map[string]interface{}{"email": "john.wronglicence.doe@example.com"}, + // Assert: false, + // Input: GenerateInputJSON(customizeInput(func(user models.User) models.User { + // user.Email = "john.wronglicence.doe@example.com" + // user.Licence = &models.Licence{ + // Number: "", + // ExpirationDate: time.Now().AddDate(1, 0, 0), + // IssuedDate: time.Now().AddDate(-1, 0, 0), + // } + // return user + // })), + // }, + // { + // Name: "Correct Licence number, should pass", + // WantResponse: http.StatusCreated, + // WantDBData: map[string]interface{}{"email": "john.correctLicenceNumber@example.com"}, + // Assert: true, + // Input: GenerateInputJSON(customizeInput(func(user models.User) models.User { + // user.Email = "john.correctLicenceNumber@example.com" + // user.Licence = &models.Licence{ + // Number: "B072RRE2I55", + // ExpirationDate: time.Now().AddDate(1, 0, 0), + // IssuedDate: time.Now().AddDate(-1, 0, 0), + // } + // return user + // })), + // }, } } diff --git a/internal/services/user_service.go b/internal/services/user_service.go index b5a648f..53e9a30 100644 --- a/internal/services/user_service.go +++ b/internal/services/user_service.go @@ -66,26 +66,6 @@ func (service *UserService) UpdateUser(user *models.User) (*models.User, error) user.Membership.SubscriptionModel = *selectedModel user.Membership.SubscriptionModelID = selectedModel.ID - existingUser, err := service.GetUserByID(user.ID) - if err != nil { - return nil, err - } - - user.Membership.ID = existingUser.Membership.ID - - user.MembershipID = existingUser.MembershipID - if existingUser.Licence != nil { - user.Licence.ID = existingUser.Licence.ID - } - user.LicenceID = existingUser.LicenceID - user.BankAccount.ID = existingUser.BankAccount.ID - user.BankAccountID = existingUser.BankAccountID - - // if user.Licence.Status == 0 { - // // This is a new drivers licence - // user.Licence.Status = constants.UnverifiedStatus - // } - updatedUser, err := service.Repo.UpdateUser(user) if err != nil { diff --git a/internal/utils/priviliges.go b/internal/utils/priviliges.go index a817aef..b463dfd 100644 --- a/internal/utils/priviliges.go +++ b/internal/utils/priviliges.go @@ -3,6 +3,9 @@ package utils import ( "GoMembership/internal/constants" "GoMembership/internal/models" + "GoMembership/pkg/logger" + "errors" + "reflect" ) func HasPrivilige(user *models.User, privilige int8) bool { @@ -18,5 +21,126 @@ func HasPrivilige(user *models.User, privilige int8) bool { default: return false } - +} + +// FilterAllowedStructFields filters allowed fields recursively in a struct and modifies structToModify in place. +func FilterAllowedStructFields(input interface{}, existing interface{}, allowedFields map[string]bool, prefix string) error { + v := reflect.ValueOf(input) + origin := reflect.ValueOf(existing) + + // Ensure both input and target are pointers to structs + if v.Kind() != reflect.Ptr || origin.Kind() != reflect.Ptr { + return errors.New("both input and existing must be pointers to structs") + } + + v = v.Elem() + origin = origin.Elem() + + if v.Kind() != reflect.Struct || origin.Kind() != reflect.Struct { + return errors.New("both input and existing must be structs") + } + + for i := 0; i < v.NumField(); i++ { + field := v.Type().Field(i) + key := field.Name + + // Skip unexported fields + if !field.IsExported() { + continue + } + + // Build the full field path + fullKey := key + if prefix != "" { + fullKey = prefix + "." + key + } + fieldValue := v.Field(i) + originField := origin.Field(i) + + // Handle nil pointers + if fieldValue.Kind() == reflect.Ptr { + if fieldValue.IsNil() { + // If the field is nil, skip it or initialize it + if !allowedFields[fullKey] { + // If the field is not allowed, set it to the corresponding field from existing + fieldValue.Set(originField) + } + continue + } + // Dereference the pointer for further processing + fieldValue = fieldValue.Elem() + originField = originField.Elem() + } + + // Handle slices + if fieldValue.Kind() == reflect.Slice { + if !allowedFields[fullKey] { + // If the slice is not allowed, set it to the corresponding slice from existing + fieldValue.Set(originField) + continue + } else { + originField.Set(fieldValue) + } + + // If the slice contains structs, recursively filter each element + // if fieldValue.Type().Elem().Kind() == reflect.Struct { + // for j := 0; j < fieldValue.Len(); j++ { + // err := FilterAllowedStructFields( + // fieldValue.Index(j).Addr().Interface(), + // originField.Index(j).Addr().Interface(), + // allowedFields, + // fullKey, + // ) + // if err != nil { + // return err + // } + // } + // } + continue + } + + // Handle nested structs (including pointers to structs) + if fieldValue.Kind() == reflect.Struct || (fieldValue.Kind() == reflect.Ptr && fieldValue.Type().Elem().Kind() == reflect.Struct) { + if fieldValue.Kind() == reflect.Ptr { + if fieldValue.IsNil() { + continue + } + fieldValue = fieldValue.Elem() + originField = originField.Elem() // May result in an invalid originField + } + + var originCopy reflect.Value + + // Check if originField is valid (non-zero) + if originField.IsValid() { + originCopy = reflect.New(originField.Type()).Elem() + originCopy.Set(originField) + } else { + // If originField is invalid (e.g., existing had a nil pointer), + // create a new instance of the type from fieldValue + originCopy = reflect.New(fieldValue.Type()).Elem() + } + + err := FilterAllowedStructFields( + fieldValue.Addr().Interface(), + originCopy.Addr().Interface(), + allowedFields, + fullKey, + ) + if err != nil { + return err + } + continue + } + + // Only allow whitelisted fields + if !allowedFields[fullKey] { + logger.Error.Printf("denying update of field: %#v", fullKey) + fieldValue.Set(originField) + } else { + logger.Error.Printf("updating whitelisted field: %#v", fullKey) + } + + } + return nil } diff --git a/internal/utils/priviliges_test.go b/internal/utils/priviliges_test.go new file mode 100644 index 0000000..5633168 --- /dev/null +++ b/internal/utils/priviliges_test.go @@ -0,0 +1,176 @@ +package utils + +import ( + "reflect" + "testing" +) + +type User struct { + Name string + Age int + Address *Address + Tags []string + License License +} + +type Address struct { + City string + Country string +} + +type License struct { + ID string + Categories []string +} + +func TestFilterAllowedStructFields(t *testing.T) { + tests := []struct { + name string + input interface{} + existing interface{} + allowedFields map[string]bool + expectedResult interface{} + expectError bool + }{ + { + name: "Filter top-level fields", + input: &User{ + Name: "Alice", + Age: 30, + }, + existing: &User{ + Name: "Bob", + Age: 25, + }, + allowedFields: map[string]bool{ + "Name": true, + }, + expectedResult: &User{ + Name: "Alice", // Allowed field + Age: 25, // Kept from existing + }, + expectError: false, + }, + { + name: "Filter nested struct fields", + input: &User{ + Name: "Alice", + Address: &Address{ + City: "New York", + Country: "USA", + }, + }, + existing: &User{ + Name: "Bob", + Address: &Address{ + City: "London", + Country: "UK", + }, + }, + allowedFields: map[string]bool{ + "Address.City": true, + }, + expectedResult: &User{ + Name: "Bob", // Kept from existing + Address: &Address{ + City: "New York", // Allowed field + Country: "UK", // Kept from existing + }, + }, + expectError: false, + }, + { + name: "Filter slice fields", + input: &User{ + Tags: []string{"admin", "user"}, + }, + existing: &User{ + Tags: []string{"guest"}, + }, + allowedFields: map[string]bool{ + "Tags": true, + }, + expectedResult: &User{ + Tags: []string{"admin", "user"}, // Allowed slice + }, + expectError: false, + }, + { + name: "Filter slice of structs", + input: &User{ + License: License{ + ID: "123", + Categories: []string{"A", "B"}, + }, + }, + existing: &User{ + License: License{ + ID: "456", + Categories: []string{"C"}, + }, + }, + allowedFields: map[string]bool{ + "License.ID": true, + }, + expectedResult: &User{ + License: License{ + ID: "123", // Allowed field + Categories: []string{"C"}, // Kept from existing + }, + }, + expectError: false, + }, + { + name: "Filter pointer fields", + input: &User{ + Address: &Address{ + City: "Paris", + }, + }, + existing: &User{ + Address: &Address{ + City: "Berlin", + Country: "Germany", + }, + }, + allowedFields: map[string]bool{ + "Address.City": true, + }, + expectedResult: &User{ + Address: &Address{ + City: "Paris", // Allowed field + Country: "Germany", // Kept from existing + }, + }, + expectError: false, + }, + { + name: "Invalid input (non-pointer)", + input: User{ + Name: "Alice", + }, + existing: &User{ + Name: "Bob", + }, + allowedFields: map[string]bool{ + "Name": true, + }, + expectedResult: nil, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := FilterAllowedStructFields(tt.input, tt.existing, tt.allowedFields, "") + if (err != nil) != tt.expectError { + t.Errorf("FilterAllowedStructFields() error = %v, expectError %v", err, tt.expectError) + return + } + + if !tt.expectError && !reflect.DeepEqual(tt.input, tt.expectedResult) { + t.Errorf("FilterAllowedStructFields() = %+v, expected %+v", tt.input, tt.expectedResult) + } + }) + } +} diff --git a/internal/validation/user_validation.go b/internal/validation/user_validation.go index 7c927e4..a5927df 100644 --- a/internal/validation/user_validation.go +++ b/internal/validation/user_validation.go @@ -23,7 +23,7 @@ func validateUser(sl validator.StructLevel) { } } // Validate User > 18 years old - if !isSuper && user.DateOfBirth.After(time.Now().AddDate(-18, 0, 0)) { + if user.DateOfBirth.After(time.Now().AddDate(-18, 0, 0)) { sl.ReportError(user.DateOfBirth, "DateOfBirth", "dateofbirth", "age", "") } // validate subscriptionModel