moved field validation to validation package
This commit is contained in:
109
go-backend/internal/validation/fields.go
Normal file
109
go-backend/internal/validation/fields.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package validation
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// FilterAllowedStructFields filters allowed fields recursively in a struct and modifies structToModify in place.
|
||||
func FilterAllowedStructFields(input interface{}, existing interface{}, allowedFields map[string]bool, prefix string) error {
|
||||
v := reflect.ValueOf(input)
|
||||
origin := reflect.ValueOf(existing)
|
||||
|
||||
// Ensure both input and target are pointers to structs
|
||||
if v.Kind() != reflect.Ptr || origin.Kind() != reflect.Ptr {
|
||||
return errors.New("both input and existing must be pointers to structs")
|
||||
}
|
||||
|
||||
v = v.Elem()
|
||||
origin = origin.Elem()
|
||||
|
||||
if v.Kind() != reflect.Struct || origin.Kind() != reflect.Struct {
|
||||
return errors.New("both input and existing must be structs")
|
||||
}
|
||||
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
field := v.Type().Field(i)
|
||||
key := field.Name
|
||||
|
||||
// Skip unexported fields
|
||||
if !field.IsExported() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Build the full field path
|
||||
fullKey := key
|
||||
if prefix != "" {
|
||||
fullKey = prefix + "." + key
|
||||
}
|
||||
fieldValue := v.Field(i)
|
||||
originField := origin.Field(i)
|
||||
|
||||
// Handle nil pointers
|
||||
if fieldValue.Kind() == reflect.Ptr {
|
||||
if fieldValue.IsNil() {
|
||||
// If the field is nil, skip it or initialize it
|
||||
if !allowedFields[fullKey] {
|
||||
// If the field is not allowed, set it to the corresponding field from existing
|
||||
fieldValue.Set(originField)
|
||||
}
|
||||
continue
|
||||
}
|
||||
// Dereference the pointer for further processing
|
||||
fieldValue = fieldValue.Elem()
|
||||
originField = originField.Elem()
|
||||
}
|
||||
|
||||
// Handle slices
|
||||
if fieldValue.Kind() == reflect.Slice {
|
||||
if !allowedFields[fullKey] {
|
||||
// If the slice is not allowed, set it to the corresponding slice from existing
|
||||
fieldValue.Set(originField)
|
||||
continue
|
||||
} else {
|
||||
originField.Set(fieldValue)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle nested structs (including pointers to structs)
|
||||
if fieldValue.Kind() == reflect.Struct || (fieldValue.Kind() == reflect.Ptr && fieldValue.Type().Elem().Kind() == reflect.Struct) {
|
||||
if fieldValue.Kind() == reflect.Ptr {
|
||||
if fieldValue.IsNil() {
|
||||
continue
|
||||
}
|
||||
fieldValue = fieldValue.Elem()
|
||||
originField = originField.Elem() // May result in an invalid originField
|
||||
}
|
||||
|
||||
var originCopy reflect.Value
|
||||
|
||||
// Check if originField is valid (non-zero)
|
||||
if originField.IsValid() {
|
||||
originCopy = reflect.New(originField.Type()).Elem()
|
||||
originCopy.Set(originField)
|
||||
} else {
|
||||
// If originField is invalid (e.g., existing had a nil pointer),
|
||||
// create a new instance of the type from fieldValue
|
||||
originCopy = reflect.New(fieldValue.Type()).Elem()
|
||||
}
|
||||
|
||||
err := FilterAllowedStructFields(
|
||||
fieldValue.Addr().Interface(),
|
||||
originCopy.Addr().Interface(),
|
||||
allowedFields,
|
||||
fullKey,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Only allow whitelisted fields
|
||||
if !allowedFields[fullKey] {
|
||||
fieldValue.Set(originField)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
176
go-backend/internal/validation/fields_test.go
Normal file
176
go-backend/internal/validation/fields_test.go
Normal file
@@ -0,0 +1,176 @@
|
||||
package validation
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
Name string
|
||||
Age int
|
||||
Address *Address
|
||||
Tags []string
|
||||
License License
|
||||
}
|
||||
|
||||
type Address struct {
|
||||
City string
|
||||
Country string
|
||||
}
|
||||
|
||||
type License struct {
|
||||
ID string
|
||||
Categories []string
|
||||
}
|
||||
|
||||
func TestFilterAllowedStructFields(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
existing interface{}
|
||||
allowedFields map[string]bool
|
||||
expectedResult interface{}
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Filter top-level fields",
|
||||
input: &User{
|
||||
Name: "Alice",
|
||||
Age: 30,
|
||||
},
|
||||
existing: &User{
|
||||
Name: "Bob",
|
||||
Age: 25,
|
||||
},
|
||||
allowedFields: map[string]bool{
|
||||
"Name": true,
|
||||
},
|
||||
expectedResult: &User{
|
||||
Name: "Alice", // Allowed field
|
||||
Age: 25, // Kept from existing
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Filter nested struct fields",
|
||||
input: &User{
|
||||
Name: "Alice",
|
||||
Address: &Address{
|
||||
City: "New York",
|
||||
Country: "USA",
|
||||
},
|
||||
},
|
||||
existing: &User{
|
||||
Name: "Bob",
|
||||
Address: &Address{
|
||||
City: "London",
|
||||
Country: "UK",
|
||||
},
|
||||
},
|
||||
allowedFields: map[string]bool{
|
||||
"Address.City": true,
|
||||
},
|
||||
expectedResult: &User{
|
||||
Name: "Bob", // Kept from existing
|
||||
Address: &Address{
|
||||
City: "New York", // Allowed field
|
||||
Country: "UK", // Kept from existing
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Filter slice fields",
|
||||
input: &User{
|
||||
Tags: []string{"admin", "user"},
|
||||
},
|
||||
existing: &User{
|
||||
Tags: []string{"guest"},
|
||||
},
|
||||
allowedFields: map[string]bool{
|
||||
"Tags": true,
|
||||
},
|
||||
expectedResult: &User{
|
||||
Tags: []string{"admin", "user"}, // Allowed slice
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Filter slice of structs",
|
||||
input: &User{
|
||||
License: License{
|
||||
ID: "123",
|
||||
Categories: []string{"A", "B"},
|
||||
},
|
||||
},
|
||||
existing: &User{
|
||||
License: License{
|
||||
ID: "456",
|
||||
Categories: []string{"C"},
|
||||
},
|
||||
},
|
||||
allowedFields: map[string]bool{
|
||||
"License.ID": true,
|
||||
},
|
||||
expectedResult: &User{
|
||||
License: License{
|
||||
ID: "123", // Allowed field
|
||||
Categories: []string{"C"}, // Kept from existing
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Filter pointer fields",
|
||||
input: &User{
|
||||
Address: &Address{
|
||||
City: "Paris",
|
||||
},
|
||||
},
|
||||
existing: &User{
|
||||
Address: &Address{
|
||||
City: "Berlin",
|
||||
Country: "Germany",
|
||||
},
|
||||
},
|
||||
allowedFields: map[string]bool{
|
||||
"Address.City": true,
|
||||
},
|
||||
expectedResult: &User{
|
||||
Address: &Address{
|
||||
City: "Paris", // Allowed field
|
||||
Country: "Germany", // Kept from existing
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid input (non-pointer)",
|
||||
input: User{
|
||||
Name: "Alice",
|
||||
},
|
||||
existing: &User{
|
||||
Name: "Bob",
|
||||
},
|
||||
allowedFields: map[string]bool{
|
||||
"Name": true,
|
||||
},
|
||||
expectedResult: nil,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := FilterAllowedStructFields(tt.input, tt.existing, tt.allowedFields, "")
|
||||
if (err != nil) != tt.expectError {
|
||||
t.Errorf("FilterAllowedStructFields() error = %v, expectError %v", err, tt.expectError)
|
||||
return
|
||||
}
|
||||
|
||||
if !tt.expectError && !reflect.DeepEqual(tt.input, tt.expectedResult) {
|
||||
t.Errorf("FilterAllowedStructFields() = %+v, expected %+v", tt.input, tt.expectedResult)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user