backend added struct merging and FieldPermissionsOnRoleId
This commit is contained in:
@@ -79,3 +79,19 @@ const PRIV_VIEW = 1
|
||||
const PRIV_ADD = 2
|
||||
const PRIV_EDIT = 4
|
||||
const PRIV_DELETE = 8
|
||||
|
||||
var MemberUpdateFields = map[string]bool{
|
||||
"Email": true,
|
||||
"Phone": true,
|
||||
"Company": true,
|
||||
"Address": true,
|
||||
"ZipCode": true,
|
||||
"City": true,
|
||||
"Licence.Categories": true,
|
||||
"BankAccount.Bank": true,
|
||||
"BankAccount.AccountHolderName": true,
|
||||
"BankAccount.IBAN": true,
|
||||
"BankAccount.BIC": true,
|
||||
}
|
||||
|
||||
// "Password": true,
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
|
||||
"log"
|
||||
|
||||
"github.com/alexedwards/argon2id"
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"GoMembership/internal/config"
|
||||
@@ -116,6 +117,22 @@ func TestSuite(t *testing.T) {
|
||||
if err := initLicenceCategories(); err != nil {
|
||||
log.Fatalf("Failed to init Categories: %v", err)
|
||||
}
|
||||
hash, err := argon2id.CreateHash("securepassword", argon2id.DefaultParams)
|
||||
admin := models.User{
|
||||
FirstName: "Ad",
|
||||
LastName: "min",
|
||||
Email: "admin@example.com",
|
||||
Password: hash,
|
||||
DateOfBirth: time.Date(1990, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
Company: "SampleCorp",
|
||||
Phone: "+123456789",
|
||||
Address: "123 Main Street",
|
||||
ZipCode: "12345",
|
||||
City: "SampleCity",
|
||||
Status: 1,
|
||||
RoleID: 8,
|
||||
}
|
||||
database.DB.Create(&admin)
|
||||
validation.SetupValidators()
|
||||
t.Run("userController", func(t *testing.T) {
|
||||
testUserController(t)
|
||||
@@ -262,6 +279,7 @@ func getBaseUser() models.User {
|
||||
ProfilePicture: "",
|
||||
Password: "password123",
|
||||
Company: "",
|
||||
RoleID: 8,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -91,9 +91,35 @@ func (uc *UserController) UpdateHandler(c *gin.Context) {
|
||||
user = updateData.User
|
||||
|
||||
if !utils.HasPrivilige(requestUser, constants.Priviliges.Update) && user.ID != requestUser.ID {
|
||||
utils.RespondWithError(c, errors.ErrNotAuthorized, "Not allowed to update user", http.StatusForbidden, "user.user", "server.error.unauthorized")
|
||||
utils.RespondWithError(c, errors.ErrNotAuthorized, "Not allowed to update user", http.StatusUnauthorized, "user.user", "server.error.unauthorized")
|
||||
return
|
||||
}
|
||||
existingUser, err := uc.Service.GetUserByID(user.ID)
|
||||
if err != nil {
|
||||
utils.RespondWithError(c, err, "Error finding an existing user", http.StatusNotFound, "user.user", "server.error.not_found")
|
||||
return
|
||||
}
|
||||
// user.Membership.ID = existingUser.Membership.ID
|
||||
|
||||
// user.MembershipID = existingUser.MembershipID
|
||||
// if existingUser.Licence != nil {
|
||||
// user.Licence.ID = existingUser.Licence.ID
|
||||
// }
|
||||
// user.LicenceID = existingUser.LicenceID
|
||||
// user.BankAccount.ID = existingUser.BankAccount.ID
|
||||
// user.BankAccountID = existingUser.BankAccountID
|
||||
|
||||
if requestUser.RoleID <= constants.Priviliges.View {
|
||||
existingUser.Password = ""
|
||||
if err := utils.FilterAllowedStructFields(&user, existingUser, constants.MemberUpdateFields, ""); err != nil {
|
||||
if err.Error() == "Not authorized" {
|
||||
utils.RespondWithError(c, errors.ErrNotAuthorized, "Trying to update unauthorized fields", http.StatusUnauthorized, "user.user", "server.error.unauthorized")
|
||||
return
|
||||
}
|
||||
utils.RespondWithError(c, err, "Error filtering users input fields", http.StatusInternalServerError, "user.user", "server.error.internal_server_error")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
updatedUser, err := uc.Service.UpdateUser(&user)
|
||||
if err != nil {
|
||||
|
||||
@@ -75,7 +75,37 @@ func testUserController(t *testing.T) {
|
||||
|
||||
loginEmail, loginCookie := testLoginHandler(t)
|
||||
logoutCookie := testCurrentUserHandler(t, loginEmail, loginCookie)
|
||||
testUpdateUser(t, loginCookie)
|
||||
|
||||
// creating a admin cookie
|
||||
c, w, _ := GetMockedJSONContext([]byte(`{
|
||||
"email": "admin@example.com",
|
||||
"password": "securepassword"
|
||||
}`), "/login")
|
||||
|
||||
Uc.LoginHandler(c)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "Login successful", response["message"])
|
||||
var adminCookie http.Cookie
|
||||
for _, cookie := range w.Result().Cookies() {
|
||||
if cookie.Name == "jwt" {
|
||||
adminCookie = *cookie
|
||||
|
||||
tokenString := adminCookie.Value
|
||||
_, claims, err := middlewares.ExtractContentFrom(tokenString)
|
||||
assert.NoError(t, err, "FAiled getting cookie string")
|
||||
jwtUserID := uint((*claims)["user_id"].(float64))
|
||||
user, err := Uc.Service.GetUserByID(jwtUserID)
|
||||
assert.NoError(t, err, "FAiled getting cookie string")
|
||||
logger.Error.Printf("ADMIN USER: %#v", user)
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.NotEmpty(t, adminCookie)
|
||||
testUpdateUser(t, loginCookie, adminCookie)
|
||||
testLogoutHandler(t, logoutCookie)
|
||||
}
|
||||
|
||||
@@ -190,7 +220,7 @@ func testLoginHandler(t *testing.T) (string, http.Cookie) {
|
||||
|
||||
for _, tt := range tests {
|
||||
logger.Error.Print("==============================================================")
|
||||
logger.Error.Printf("Testing : %v", tt.name)
|
||||
logger.Error.Printf("Login Testing : %v", tt.name)
|
||||
logger.Error.Print("==============================================================")
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Setup
|
||||
@@ -213,6 +243,14 @@ func testLoginHandler(t *testing.T) (string, http.Cookie) {
|
||||
if cookie.Name == "jwt" {
|
||||
loginCookie = *cookie
|
||||
|
||||
// tokenString := loginCookie.Value
|
||||
// _, claims, err := middlewares.ExtractContentFrom(tokenString)
|
||||
// assert.NoError(t, err, "FAiled getting cookie string")
|
||||
// jwtUserID := uint((*claims)["user_id"].(float64))
|
||||
// user, err := Uc.Service.GetUserByID(jwtUserID)
|
||||
// assert.NoError(t, err, "FAiled getting cookie string")
|
||||
|
||||
// logger.Error.Printf("cookie user: %#v", user)
|
||||
err = json.Unmarshal([]byte(tt.input), &loginInput)
|
||||
assert.NoError(t, err, "Failed to unmarshal input JSON")
|
||||
|
||||
@@ -413,7 +451,7 @@ func validateUser(assert bool, wantDBData map[string]interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func testUpdateUser(t *testing.T, loginCookie http.Cookie) {
|
||||
func testUpdateUser(t *testing.T, loginCookie http.Cookie, adminCookie http.Cookie) {
|
||||
|
||||
invalidCookie := http.Cookie{
|
||||
Name: "jwt",
|
||||
@@ -437,13 +475,14 @@ func testUpdateUser(t *testing.T, loginCookie http.Cookie) {
|
||||
name string
|
||||
setupCookie func(*http.Request)
|
||||
updateFunc func(*models.User)
|
||||
expectedReturn func(*models.User)
|
||||
expectedStatus int
|
||||
expectedErrors []map[string]string
|
||||
}{
|
||||
{
|
||||
name: "Valid Update",
|
||||
name: "Valid Admin Update",
|
||||
setupCookie: func(req *http.Request) {
|
||||
req.AddCookie(&loginCookie)
|
||||
req.AddCookie(&adminCookie)
|
||||
},
|
||||
updateFunc: func(u *models.User) {
|
||||
u.Password = ""
|
||||
@@ -486,8 +525,23 @@ func testUpdateUser(t *testing.T, loginCookie http.Cookie) {
|
||||
{"field": "Email", "key": "server.validation.email"},
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
name: "Change Number",
|
||||
name: "admin may change licence number",
|
||||
setupCookie: func(req *http.Request) {
|
||||
req.AddCookie(&adminCookie)
|
||||
},
|
||||
updateFunc: func(u *models.User) {
|
||||
u.Password = ""
|
||||
u.FirstName = "John Updated"
|
||||
u.LastName = "Doe Updated"
|
||||
u.Phone = "01738484994"
|
||||
u.Licence.Number = "B072RRE2I50"
|
||||
},
|
||||
expectedStatus: http.StatusAccepted,
|
||||
},
|
||||
{
|
||||
name: "Change phone number",
|
||||
setupCookie: func(req *http.Request) {
|
||||
req.AddCookie(&loginCookie)
|
||||
},
|
||||
@@ -578,12 +632,13 @@ func testUpdateUser(t *testing.T, loginCookie http.Cookie) {
|
||||
updateFunc: func(u *models.User) {
|
||||
u.Password = ""
|
||||
u.ID = 1
|
||||
u.FirstName = "John Updated"
|
||||
u.LastName = "Doe Updated"
|
||||
u.Phone = "01738484994"
|
||||
u.Licence.Number = "B072RRE2I50"
|
||||
u.FirstName = "John Missing ID"
|
||||
},
|
||||
expectedStatus: http.StatusForbidden,
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedErrors: []map[string]string{
|
||||
{"field": "user.user", "key": "server.error.unauthorized"},
|
||||
},
|
||||
@@ -594,7 +649,27 @@ func testUpdateUser(t *testing.T, loginCookie http.Cookie) {
|
||||
req.AddCookie(&loginCookie)
|
||||
},
|
||||
updateFunc: func(u *models.User) {
|
||||
u.FirstName = "John Updated"
|
||||
u.LastName = "Doe Updated"
|
||||
u.Phone = "01738484994"
|
||||
u.Licence.Number = "B072RRE2I50"
|
||||
u.Password = "NewPassword"
|
||||
},
|
||||
expectedReturn: func(u *models.User) {
|
||||
u.Password = ""
|
||||
u.FirstName = "John Updated"
|
||||
u.LastName = "Doe Updated"
|
||||
u.Phone = "01738484994"
|
||||
u.Licence.Number = "B072RRE2I50"
|
||||
},
|
||||
expectedStatus: http.StatusAccepted,
|
||||
},
|
||||
{
|
||||
name: "Admin Password Update",
|
||||
setupCookie: func(req *http.Request) {
|
||||
req.AddCookie(&adminCookie)
|
||||
},
|
||||
updateFunc: func(u *models.User) {
|
||||
u.LastName = "Doe Updated"
|
||||
u.Phone = "01738484994"
|
||||
u.Licence.Number = "B072RRE2I50"
|
||||
@@ -602,19 +677,21 @@ func testUpdateUser(t *testing.T, loginCookie http.Cookie) {
|
||||
},
|
||||
expectedStatus: http.StatusAccepted,
|
||||
},
|
||||
// {
|
||||
// name: "Non-existent User",
|
||||
// setupCookie: func(req *http.Request) {
|
||||
// req.AddCookie(&loginCookie)
|
||||
// },
|
||||
// updateFunc: func(u *models.User) {
|
||||
// u.Password = ""
|
||||
// u.ID = 99999
|
||||
// u.FirstName = "Non-existent"
|
||||
// },
|
||||
// expectedStatus: http.StatusNotFound,
|
||||
// expectedError: "User not found",
|
||||
// },
|
||||
{
|
||||
name: "Non-existent User",
|
||||
setupCookie: func(req *http.Request) {
|
||||
req.AddCookie(&loginCookie)
|
||||
},
|
||||
updateFunc: func(u *models.User) {
|
||||
u.Password = ""
|
||||
u.ID = 99999
|
||||
u.FirstName = "Non-existent"
|
||||
},
|
||||
expectedErrors: []map[string]string{
|
||||
{"field": "user.user", "key": "server.error.unauthorized"},
|
||||
},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
logger.Error.Print("==============================================================")
|
||||
@@ -623,9 +700,8 @@ func testUpdateUser(t *testing.T, loginCookie http.Cookie) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create a copy of the user and apply the updates
|
||||
updatedUser := user
|
||||
logger.Error.Printf("user to be updated: %+v", user.Licence)
|
||||
// logger.Error.Printf("users licence to be updated: %+v", user.Licence)
|
||||
tt.updateFunc(&updatedUser)
|
||||
// Convert user to JSON
|
||||
|
||||
updateData := &RegistrationData{User: updatedUser}
|
||||
jsonData, err := json.Marshal(updateData)
|
||||
@@ -633,7 +709,11 @@ func testUpdateUser(t *testing.T, loginCookie http.Cookie) {
|
||||
t.Fatalf("Failed to marshal user data: %v", err)
|
||||
}
|
||||
|
||||
// logger.Error.Printf("Updated User: %#v", updatedUser)
|
||||
logger.Error.Printf("Updated User: %#v", updatedUser.Safe())
|
||||
if tt.expectedReturn != nil {
|
||||
tt.expectedReturn(&updatedUser)
|
||||
}
|
||||
|
||||
// Create request
|
||||
req, _ := http.NewRequest("PUT", "/users/"+strconv.FormatUint(uint64(user.ID), 10), bytes.NewBuffer(jsonData))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
@@ -987,6 +1067,7 @@ func getTestUsers() []RegisterUserTest {
|
||||
Assert: false,
|
||||
Input: GenerateInputJSON(customizeInput(func(user models.User) models.User {
|
||||
user.BankAccount.IBAN = ""
|
||||
user.RoleID = 0
|
||||
return user
|
||||
})),
|
||||
},
|
||||
@@ -997,6 +1078,7 @@ func getTestUsers() []RegisterUserTest {
|
||||
Assert: false,
|
||||
Input: GenerateInputJSON(customizeInput(func(user models.User) models.User {
|
||||
user.BankAccount.IBAN = "DE1234234123134"
|
||||
user.RoleID = 0
|
||||
return user
|
||||
})),
|
||||
},
|
||||
@@ -1110,35 +1192,35 @@ func getTestUsers() []RegisterUserTest {
|
||||
// return user
|
||||
// })),
|
||||
// },
|
||||
{
|
||||
Name: "empty driverslicence number, should fail",
|
||||
WantResponse: http.StatusBadRequest,
|
||||
WantDBData: map[string]interface{}{"email": "john.wronglicence.doe@example.com"},
|
||||
Assert: false,
|
||||
Input: GenerateInputJSON(customizeInput(func(user models.User) models.User {
|
||||
user.Email = "john.wronglicence.doe@example.com"
|
||||
user.Licence = &models.Licence{
|
||||
Number: "",
|
||||
ExpirationDate: time.Now().AddDate(1, 0, 0),
|
||||
IssuedDate: time.Now().AddDate(-1, 0, 0),
|
||||
}
|
||||
return user
|
||||
})),
|
||||
},
|
||||
{
|
||||
Name: "Correct Licence number, should pass",
|
||||
WantResponse: http.StatusCreated,
|
||||
WantDBData: map[string]interface{}{"email": "john.correctLicenceNumber@example.com"},
|
||||
Assert: true,
|
||||
Input: GenerateInputJSON(customizeInput(func(user models.User) models.User {
|
||||
user.Email = "john.correctLicenceNumber@example.com"
|
||||
user.Licence = &models.Licence{
|
||||
Number: "B072RRE2I55",
|
||||
ExpirationDate: time.Now().AddDate(1, 0, 0),
|
||||
IssuedDate: time.Now().AddDate(-1, 0, 0),
|
||||
}
|
||||
return user
|
||||
})),
|
||||
},
|
||||
// {
|
||||
// Name: "empty driverslicence number, should fail",
|
||||
// WantResponse: http.StatusBadRequest,
|
||||
// WantDBData: map[string]interface{}{"email": "john.wronglicence.doe@example.com"},
|
||||
// Assert: false,
|
||||
// Input: GenerateInputJSON(customizeInput(func(user models.User) models.User {
|
||||
// user.Email = "john.wronglicence.doe@example.com"
|
||||
// user.Licence = &models.Licence{
|
||||
// Number: "",
|
||||
// ExpirationDate: time.Now().AddDate(1, 0, 0),
|
||||
// IssuedDate: time.Now().AddDate(-1, 0, 0),
|
||||
// }
|
||||
// return user
|
||||
// })),
|
||||
// },
|
||||
// {
|
||||
// Name: "Correct Licence number, should pass",
|
||||
// WantResponse: http.StatusCreated,
|
||||
// WantDBData: map[string]interface{}{"email": "john.correctLicenceNumber@example.com"},
|
||||
// Assert: true,
|
||||
// Input: GenerateInputJSON(customizeInput(func(user models.User) models.User {
|
||||
// user.Email = "john.correctLicenceNumber@example.com"
|
||||
// user.Licence = &models.Licence{
|
||||
// Number: "B072RRE2I55",
|
||||
// ExpirationDate: time.Now().AddDate(1, 0, 0),
|
||||
// IssuedDate: time.Now().AddDate(-1, 0, 0),
|
||||
// }
|
||||
// return user
|
||||
// })),
|
||||
// },
|
||||
}
|
||||
}
|
||||
|
||||
@@ -66,26 +66,6 @@ func (service *UserService) UpdateUser(user *models.User) (*models.User, error)
|
||||
user.Membership.SubscriptionModel = *selectedModel
|
||||
user.Membership.SubscriptionModelID = selectedModel.ID
|
||||
|
||||
existingUser, err := service.GetUserByID(user.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user.Membership.ID = existingUser.Membership.ID
|
||||
|
||||
user.MembershipID = existingUser.MembershipID
|
||||
if existingUser.Licence != nil {
|
||||
user.Licence.ID = existingUser.Licence.ID
|
||||
}
|
||||
user.LicenceID = existingUser.LicenceID
|
||||
user.BankAccount.ID = existingUser.BankAccount.ID
|
||||
user.BankAccountID = existingUser.BankAccountID
|
||||
|
||||
// if user.Licence.Status == 0 {
|
||||
// // This is a new drivers licence
|
||||
// user.Licence.Status = constants.UnverifiedStatus
|
||||
// }
|
||||
|
||||
updatedUser, err := service.Repo.UpdateUser(user)
|
||||
|
||||
if err != nil {
|
||||
|
||||
@@ -3,6 +3,9 @@ package utils
|
||||
import (
|
||||
"GoMembership/internal/constants"
|
||||
"GoMembership/internal/models"
|
||||
"GoMembership/pkg/logger"
|
||||
"errors"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
func HasPrivilige(user *models.User, privilige int8) bool {
|
||||
@@ -18,5 +21,126 @@ func HasPrivilige(user *models.User, privilige int8) bool {
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// FilterAllowedStructFields filters allowed fields recursively in a struct and modifies structToModify in place.
|
||||
func FilterAllowedStructFields(input interface{}, existing interface{}, allowedFields map[string]bool, prefix string) error {
|
||||
v := reflect.ValueOf(input)
|
||||
origin := reflect.ValueOf(existing)
|
||||
|
||||
// Ensure both input and target are pointers to structs
|
||||
if v.Kind() != reflect.Ptr || origin.Kind() != reflect.Ptr {
|
||||
return errors.New("both input and existing must be pointers to structs")
|
||||
}
|
||||
|
||||
v = v.Elem()
|
||||
origin = origin.Elem()
|
||||
|
||||
if v.Kind() != reflect.Struct || origin.Kind() != reflect.Struct {
|
||||
return errors.New("both input and existing must be structs")
|
||||
}
|
||||
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
field := v.Type().Field(i)
|
||||
key := field.Name
|
||||
|
||||
// Skip unexported fields
|
||||
if !field.IsExported() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Build the full field path
|
||||
fullKey := key
|
||||
if prefix != "" {
|
||||
fullKey = prefix + "." + key
|
||||
}
|
||||
fieldValue := v.Field(i)
|
||||
originField := origin.Field(i)
|
||||
|
||||
// Handle nil pointers
|
||||
if fieldValue.Kind() == reflect.Ptr {
|
||||
if fieldValue.IsNil() {
|
||||
// If the field is nil, skip it or initialize it
|
||||
if !allowedFields[fullKey] {
|
||||
// If the field is not allowed, set it to the corresponding field from existing
|
||||
fieldValue.Set(originField)
|
||||
}
|
||||
continue
|
||||
}
|
||||
// Dereference the pointer for further processing
|
||||
fieldValue = fieldValue.Elem()
|
||||
originField = originField.Elem()
|
||||
}
|
||||
|
||||
// Handle slices
|
||||
if fieldValue.Kind() == reflect.Slice {
|
||||
if !allowedFields[fullKey] {
|
||||
// If the slice is not allowed, set it to the corresponding slice from existing
|
||||
fieldValue.Set(originField)
|
||||
continue
|
||||
} else {
|
||||
originField.Set(fieldValue)
|
||||
}
|
||||
|
||||
// If the slice contains structs, recursively filter each element
|
||||
// if fieldValue.Type().Elem().Kind() == reflect.Struct {
|
||||
// for j := 0; j < fieldValue.Len(); j++ {
|
||||
// err := FilterAllowedStructFields(
|
||||
// fieldValue.Index(j).Addr().Interface(),
|
||||
// originField.Index(j).Addr().Interface(),
|
||||
// allowedFields,
|
||||
// fullKey,
|
||||
// )
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle nested structs (including pointers to structs)
|
||||
if fieldValue.Kind() == reflect.Struct || (fieldValue.Kind() == reflect.Ptr && fieldValue.Type().Elem().Kind() == reflect.Struct) {
|
||||
if fieldValue.Kind() == reflect.Ptr {
|
||||
if fieldValue.IsNil() {
|
||||
continue
|
||||
}
|
||||
fieldValue = fieldValue.Elem()
|
||||
originField = originField.Elem() // May result in an invalid originField
|
||||
}
|
||||
|
||||
var originCopy reflect.Value
|
||||
|
||||
// Check if originField is valid (non-zero)
|
||||
if originField.IsValid() {
|
||||
originCopy = reflect.New(originField.Type()).Elem()
|
||||
originCopy.Set(originField)
|
||||
} else {
|
||||
// If originField is invalid (e.g., existing had a nil pointer),
|
||||
// create a new instance of the type from fieldValue
|
||||
originCopy = reflect.New(fieldValue.Type()).Elem()
|
||||
}
|
||||
|
||||
err := FilterAllowedStructFields(
|
||||
fieldValue.Addr().Interface(),
|
||||
originCopy.Addr().Interface(),
|
||||
allowedFields,
|
||||
fullKey,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Only allow whitelisted fields
|
||||
if !allowedFields[fullKey] {
|
||||
logger.Error.Printf("denying update of field: %#v", fullKey)
|
||||
fieldValue.Set(originField)
|
||||
} else {
|
||||
logger.Error.Printf("updating whitelisted field: %#v", fullKey)
|
||||
}
|
||||
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
176
internal/utils/priviliges_test.go
Normal file
176
internal/utils/priviliges_test.go
Normal file
@@ -0,0 +1,176 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
Name string
|
||||
Age int
|
||||
Address *Address
|
||||
Tags []string
|
||||
License License
|
||||
}
|
||||
|
||||
type Address struct {
|
||||
City string
|
||||
Country string
|
||||
}
|
||||
|
||||
type License struct {
|
||||
ID string
|
||||
Categories []string
|
||||
}
|
||||
|
||||
func TestFilterAllowedStructFields(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
existing interface{}
|
||||
allowedFields map[string]bool
|
||||
expectedResult interface{}
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Filter top-level fields",
|
||||
input: &User{
|
||||
Name: "Alice",
|
||||
Age: 30,
|
||||
},
|
||||
existing: &User{
|
||||
Name: "Bob",
|
||||
Age: 25,
|
||||
},
|
||||
allowedFields: map[string]bool{
|
||||
"Name": true,
|
||||
},
|
||||
expectedResult: &User{
|
||||
Name: "Alice", // Allowed field
|
||||
Age: 25, // Kept from existing
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Filter nested struct fields",
|
||||
input: &User{
|
||||
Name: "Alice",
|
||||
Address: &Address{
|
||||
City: "New York",
|
||||
Country: "USA",
|
||||
},
|
||||
},
|
||||
existing: &User{
|
||||
Name: "Bob",
|
||||
Address: &Address{
|
||||
City: "London",
|
||||
Country: "UK",
|
||||
},
|
||||
},
|
||||
allowedFields: map[string]bool{
|
||||
"Address.City": true,
|
||||
},
|
||||
expectedResult: &User{
|
||||
Name: "Bob", // Kept from existing
|
||||
Address: &Address{
|
||||
City: "New York", // Allowed field
|
||||
Country: "UK", // Kept from existing
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Filter slice fields",
|
||||
input: &User{
|
||||
Tags: []string{"admin", "user"},
|
||||
},
|
||||
existing: &User{
|
||||
Tags: []string{"guest"},
|
||||
},
|
||||
allowedFields: map[string]bool{
|
||||
"Tags": true,
|
||||
},
|
||||
expectedResult: &User{
|
||||
Tags: []string{"admin", "user"}, // Allowed slice
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Filter slice of structs",
|
||||
input: &User{
|
||||
License: License{
|
||||
ID: "123",
|
||||
Categories: []string{"A", "B"},
|
||||
},
|
||||
},
|
||||
existing: &User{
|
||||
License: License{
|
||||
ID: "456",
|
||||
Categories: []string{"C"},
|
||||
},
|
||||
},
|
||||
allowedFields: map[string]bool{
|
||||
"License.ID": true,
|
||||
},
|
||||
expectedResult: &User{
|
||||
License: License{
|
||||
ID: "123", // Allowed field
|
||||
Categories: []string{"C"}, // Kept from existing
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Filter pointer fields",
|
||||
input: &User{
|
||||
Address: &Address{
|
||||
City: "Paris",
|
||||
},
|
||||
},
|
||||
existing: &User{
|
||||
Address: &Address{
|
||||
City: "Berlin",
|
||||
Country: "Germany",
|
||||
},
|
||||
},
|
||||
allowedFields: map[string]bool{
|
||||
"Address.City": true,
|
||||
},
|
||||
expectedResult: &User{
|
||||
Address: &Address{
|
||||
City: "Paris", // Allowed field
|
||||
Country: "Germany", // Kept from existing
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid input (non-pointer)",
|
||||
input: User{
|
||||
Name: "Alice",
|
||||
},
|
||||
existing: &User{
|
||||
Name: "Bob",
|
||||
},
|
||||
allowedFields: map[string]bool{
|
||||
"Name": true,
|
||||
},
|
||||
expectedResult: nil,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := FilterAllowedStructFields(tt.input, tt.existing, tt.allowedFields, "")
|
||||
if (err != nil) != tt.expectError {
|
||||
t.Errorf("FilterAllowedStructFields() error = %v, expectError %v", err, tt.expectError)
|
||||
return
|
||||
}
|
||||
|
||||
if !tt.expectError && !reflect.DeepEqual(tt.input, tt.expectedResult) {
|
||||
t.Errorf("FilterAllowedStructFields() = %+v, expected %+v", tt.input, tt.expectedResult)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -23,7 +23,7 @@ func validateUser(sl validator.StructLevel) {
|
||||
}
|
||||
}
|
||||
// Validate User > 18 years old
|
||||
if !isSuper && user.DateOfBirth.After(time.Now().AddDate(-18, 0, 0)) {
|
||||
if user.DateOfBirth.After(time.Now().AddDate(-18, 0, 0)) {
|
||||
sl.ReportError(user.DateOfBirth, "DateOfBirth", "dateofbirth", "age", "")
|
||||
}
|
||||
// validate subscriptionModel
|
||||
|
||||
Reference in New Issue
Block a user