Files
GoMembership/go-backend/internal/validation/fields.go

110 lines
2.8 KiB
Go

package validation
import (
"errors"
"reflect"
)
// FilterAllowedStructFields filters allowed fields recursively in a struct and modifies structToModify in place.
func FilterAllowedStructFields(input interface{}, existing interface{}, allowedFields map[string]bool, prefix string) error {
v := reflect.ValueOf(input)
origin := reflect.ValueOf(existing)
// Ensure both input and target are pointers to structs
if v.Kind() != reflect.Ptr || origin.Kind() != reflect.Ptr {
return errors.New("both input and existing must be pointers to structs")
}
v = v.Elem()
origin = origin.Elem()
if v.Kind() != reflect.Struct || origin.Kind() != reflect.Struct {
return errors.New("both input and existing must be structs")
}
for i := 0; i < v.NumField(); i++ {
field := v.Type().Field(i)
key := field.Name
// Skip unexported fields
if !field.IsExported() {
continue
}
// Build the full field path
fullKey := key
if prefix != "" {
fullKey = prefix + "." + key
}
fieldValue := v.Field(i)
originField := origin.Field(i)
// Handle nil pointers
if fieldValue.Kind() == reflect.Ptr {
if fieldValue.IsNil() {
// If the field is nil, skip it or initialize it
if !allowedFields[fullKey] {
// If the field is not allowed, set it to the corresponding field from existing
fieldValue.Set(originField)
}
continue
}
// Dereference the pointer for further processing
fieldValue = fieldValue.Elem()
originField = originField.Elem()
}
// Handle slices
if fieldValue.Kind() == reflect.Slice {
if !allowedFields[fullKey] {
// If the slice is not allowed, set it to the corresponding slice from existing
fieldValue.Set(originField)
continue
} else {
originField.Set(fieldValue)
}
continue
}
// Handle nested structs (including pointers to structs)
if fieldValue.Kind() == reflect.Struct || (fieldValue.Kind() == reflect.Ptr && fieldValue.Type().Elem().Kind() == reflect.Struct) {
if fieldValue.Kind() == reflect.Ptr {
if fieldValue.IsNil() {
continue
}
fieldValue = fieldValue.Elem()
originField = originField.Elem() // May result in an invalid originField
}
var originCopy reflect.Value
// Check if originField is valid (non-zero)
if originField.IsValid() {
originCopy = reflect.New(originField.Type()).Elem()
originCopy.Set(originField)
} else {
// If originField is invalid (e.g., existing had a nil pointer),
// create a new instance of the type from fieldValue
originCopy = reflect.New(fieldValue.Type()).Elem()
}
err := FilterAllowedStructFields(
fieldValue.Addr().Interface(),
originCopy.Addr().Interface(),
allowedFields,
fullKey,
)
if err != nil {
return err
}
continue
}
// Only allow whitelisted fields
if !allowedFields[fullKey] {
fieldValue.Set(originField)
}
}
return nil
}