moved field validation to validation package

This commit is contained in:
Alex
2025-03-11 20:42:45 +01:00
parent 0d6013d566
commit c6ea179eca
2 changed files with 2 additions and 7 deletions

View File

@@ -0,0 +1,109 @@
package validation
import (
"errors"
"reflect"
)
// FilterAllowedStructFields filters allowed fields recursively in a struct and modifies structToModify in place.
func FilterAllowedStructFields(input interface{}, existing interface{}, allowedFields map[string]bool, prefix string) error {
v := reflect.ValueOf(input)
origin := reflect.ValueOf(existing)
// Ensure both input and target are pointers to structs
if v.Kind() != reflect.Ptr || origin.Kind() != reflect.Ptr {
return errors.New("both input and existing must be pointers to structs")
}
v = v.Elem()
origin = origin.Elem()
if v.Kind() != reflect.Struct || origin.Kind() != reflect.Struct {
return errors.New("both input and existing must be structs")
}
for i := 0; i < v.NumField(); i++ {
field := v.Type().Field(i)
key := field.Name
// Skip unexported fields
if !field.IsExported() {
continue
}
// Build the full field path
fullKey := key
if prefix != "" {
fullKey = prefix + "." + key
}
fieldValue := v.Field(i)
originField := origin.Field(i)
// Handle nil pointers
if fieldValue.Kind() == reflect.Ptr {
if fieldValue.IsNil() {
// If the field is nil, skip it or initialize it
if !allowedFields[fullKey] {
// If the field is not allowed, set it to the corresponding field from existing
fieldValue.Set(originField)
}
continue
}
// Dereference the pointer for further processing
fieldValue = fieldValue.Elem()
originField = originField.Elem()
}
// Handle slices
if fieldValue.Kind() == reflect.Slice {
if !allowedFields[fullKey] {
// If the slice is not allowed, set it to the corresponding slice from existing
fieldValue.Set(originField)
continue
} else {
originField.Set(fieldValue)
}
continue
}
// Handle nested structs (including pointers to structs)
if fieldValue.Kind() == reflect.Struct || (fieldValue.Kind() == reflect.Ptr && fieldValue.Type().Elem().Kind() == reflect.Struct) {
if fieldValue.Kind() == reflect.Ptr {
if fieldValue.IsNil() {
continue
}
fieldValue = fieldValue.Elem()
originField = originField.Elem() // May result in an invalid originField
}
var originCopy reflect.Value
// Check if originField is valid (non-zero)
if originField.IsValid() {
originCopy = reflect.New(originField.Type()).Elem()
originCopy.Set(originField)
} else {
// If originField is invalid (e.g., existing had a nil pointer),
// create a new instance of the type from fieldValue
originCopy = reflect.New(fieldValue.Type()).Elem()
}
err := FilterAllowedStructFields(
fieldValue.Addr().Interface(),
originCopy.Addr().Interface(),
allowedFields,
fullKey,
)
if err != nil {
return err
}
continue
}
// Only allow whitelisted fields
if !allowedFields[fullKey] {
fieldValue.Set(originField)
}
}
return nil
}

View File

@@ -0,0 +1,176 @@
package validation
import (
"reflect"
"testing"
)
type User struct {
Name string
Age int
Address *Address
Tags []string
License License
}
type Address struct {
City string
Country string
}
type License struct {
ID string
Categories []string
}
func TestFilterAllowedStructFields(t *testing.T) {
tests := []struct {
name string
input interface{}
existing interface{}
allowedFields map[string]bool
expectedResult interface{}
expectError bool
}{
{
name: "Filter top-level fields",
input: &User{
Name: "Alice",
Age: 30,
},
existing: &User{
Name: "Bob",
Age: 25,
},
allowedFields: map[string]bool{
"Name": true,
},
expectedResult: &User{
Name: "Alice", // Allowed field
Age: 25, // Kept from existing
},
expectError: false,
},
{
name: "Filter nested struct fields",
input: &User{
Name: "Alice",
Address: &Address{
City: "New York",
Country: "USA",
},
},
existing: &User{
Name: "Bob",
Address: &Address{
City: "London",
Country: "UK",
},
},
allowedFields: map[string]bool{
"Address.City": true,
},
expectedResult: &User{
Name: "Bob", // Kept from existing
Address: &Address{
City: "New York", // Allowed field
Country: "UK", // Kept from existing
},
},
expectError: false,
},
{
name: "Filter slice fields",
input: &User{
Tags: []string{"admin", "user"},
},
existing: &User{
Tags: []string{"guest"},
},
allowedFields: map[string]bool{
"Tags": true,
},
expectedResult: &User{
Tags: []string{"admin", "user"}, // Allowed slice
},
expectError: false,
},
{
name: "Filter slice of structs",
input: &User{
License: License{
ID: "123",
Categories: []string{"A", "B"},
},
},
existing: &User{
License: License{
ID: "456",
Categories: []string{"C"},
},
},
allowedFields: map[string]bool{
"License.ID": true,
},
expectedResult: &User{
License: License{
ID: "123", // Allowed field
Categories: []string{"C"}, // Kept from existing
},
},
expectError: false,
},
{
name: "Filter pointer fields",
input: &User{
Address: &Address{
City: "Paris",
},
},
existing: &User{
Address: &Address{
City: "Berlin",
Country: "Germany",
},
},
allowedFields: map[string]bool{
"Address.City": true,
},
expectedResult: &User{
Address: &Address{
City: "Paris", // Allowed field
Country: "Germany", // Kept from existing
},
},
expectError: false,
},
{
name: "Invalid input (non-pointer)",
input: User{
Name: "Alice",
},
existing: &User{
Name: "Bob",
},
allowedFields: map[string]bool{
"Name": true,
},
expectedResult: nil,
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := FilterAllowedStructFields(tt.input, tt.existing, tt.allowedFields, "")
if (err != nil) != tt.expectError {
t.Errorf("FilterAllowedStructFields() error = %v, expectError %v", err, tt.expectError)
return
}
if !tt.expectError && !reflect.DeepEqual(tt.input, tt.expectedResult) {
t.Errorf("FilterAllowedStructFields() = %+v, expected %+v", tt.input, tt.expectedResult)
}
})
}
}