add & moved to validations folder; del validator/v10

This commit is contained in:
Alex
2024-10-09 18:12:20 +02:00
parent 6aee416b63
commit b2e4947d37
12 changed files with 253 additions and 226 deletions

View File

@@ -11,7 +11,6 @@ import (
type SubscriptionModelsRepositoryInterface interface { type SubscriptionModelsRepositoryInterface interface {
CreateSubscriptionModel(subscriptionModel *models.SubscriptionModel) (uint, error) CreateSubscriptionModel(subscriptionModel *models.SubscriptionModel) (uint, error)
GetMembershipModelNames() ([]string, error) GetMembershipModelNames() ([]string, error)
GetModelByName(modelname *string) (*models.SubscriptionModel, error)
GetSubscriptions(where map[string]interface{}) (*[]models.SubscriptionModel, error) GetSubscriptions(where map[string]interface{}) (*[]models.SubscriptionModel, error)
} }
@@ -26,7 +25,7 @@ func (sr *SubscriptionModelsRepository) CreateSubscriptionModel(subscriptionMode
return subscriptionModel.ID, nil return subscriptionModel.ID, nil
} }
func (sr *SubscriptionModelsRepository) GetModelByName(modelname *string) (*models.SubscriptionModel, error) { func GetModelByName(modelname *string) (*models.SubscriptionModel, error) {
var model models.SubscriptionModel var model models.SubscriptionModel
if err := database.DB.Where("name = ?", modelname).First(&model).Error; err != nil { if err := database.DB.Where("name = ?", modelname).First(&model).Error; err != nil {
return nil, err return nil, err

View File

@@ -13,6 +13,7 @@ import (
"GoMembership/internal/controllers" "GoMembership/internal/controllers"
"GoMembership/internal/middlewares" "GoMembership/internal/middlewares"
"GoMembership/internal/repositories" "GoMembership/internal/repositories"
"GoMembership/internal/validation"
"GoMembership/internal/routes" "GoMembership/internal/routes"
"GoMembership/internal/services" "GoMembership/internal/services"
@@ -63,6 +64,7 @@ func Run() {
router.Use(middlewares.RateLimitMiddleware(limiter)) router.Use(middlewares.RateLimitMiddleware(limiter))
routes.RegisterRoutes(router, userController, membershipController, contactController) routes.RegisterRoutes(router, userController, membershipController, contactController)
validation.SetupValidators()
logger.Info.Println("Starting server on :8080") logger.Info.Println("Starting server on :8080")
srv = &http.Server{ srv = &http.Server{

View File

@@ -1,15 +1,10 @@
package services package services
import ( import (
"slices"
"time" "time"
"github.com/go-playground/validator/v10"
"GoMembership/internal/models" "GoMembership/internal/models"
"GoMembership/internal/repositories" "GoMembership/internal/repositories"
"GoMembership/internal/utils"
"GoMembership/pkg/errors"
) )
type MembershipServiceInterface interface { type MembershipServiceInterface interface {
@@ -37,9 +32,6 @@ func (service *MembershipService) FindMembershipByUserID(userID uint) (*models.M
// Membership_Subscriptions // Membership_Subscriptions
func (service *MembershipService) RegisterSubscription(subscription *models.SubscriptionModel) (uint, error) { func (service *MembershipService) RegisterSubscription(subscription *models.SubscriptionModel) (uint, error) {
if err := validateSubscriptionData(subscription); err != nil {
return 0, err
}
return service.SubscriptionRepo.CreateSubscriptionModel(subscription) return service.SubscriptionRepo.CreateSubscriptionModel(subscription)
} }
@@ -48,15 +40,7 @@ func (service *MembershipService) GetMembershipModelNames() ([]string, error) {
} }
func (service *MembershipService) GetModelByName(modelname *string) (*models.SubscriptionModel, error) { func (service *MembershipService) GetModelByName(modelname *string) (*models.SubscriptionModel, error) {
sModelNames, err := service.SubscriptionRepo.GetMembershipModelNames() return repositories.GetModelByName(modelname)
if err != nil {
return nil, err
}
if !slices.Contains(sModelNames, *modelname) {
return nil, errors.ErrNotFound
}
return service.SubscriptionRepo.GetModelByName(modelname)
} }
func (service *MembershipService) GetSubscriptions(where map[string]interface{}) (*[]models.SubscriptionModel, error) { func (service *MembershipService) GetSubscriptions(where map[string]interface{}) (*[]models.SubscriptionModel, error) {
@@ -65,12 +49,3 @@ func (service *MembershipService) GetSubscriptions(where map[string]interface{})
} }
return service.SubscriptionRepo.GetSubscriptions(where) return service.SubscriptionRepo.GetSubscriptions(where)
} }
func validateSubscriptionData(subscription *models.SubscriptionModel) error {
validate := validator.New()
// subscriptionModel and membershipField don't have to be evaluated if adding a new subscription
validate.RegisterValidation("subscriptionModel", func(fl validator.FieldLevel) bool { return true })
validate.RegisterValidation("membershipField", func(fl validator.FieldLevel) bool { return true })
validate.RegisterValidation("safe_content", utils.ValidateSafeContent)
return validate.Struct(subscription)
}

View File

@@ -12,7 +12,6 @@ import (
"GoMembership/pkg/logger" "GoMembership/pkg/logger"
"github.com/alexedwards/argon2id" "github.com/alexedwards/argon2id"
"github.com/go-playground/validator/v10"
"gorm.io/gorm" "gorm.io/gorm"
"time" "time"
@@ -34,12 +33,6 @@ type UserService struct {
func (service *UserService) UpdateUser(user *models.User, userRole int8) (*models.User, error) { func (service *UserService) UpdateUser(user *models.User, userRole int8) (*models.User, error) {
if err := validateUserData(user, userRole); err != nil {
logger.Info.Printf("UPDATING user: %#v", user)
logger.Error.Printf("Failed to validate user data: %v", err)
return nil, errors.ErrInvalidUserData
}
if user.Password != "" { if user.Password != "" {
setPassword(user.Password, user) setPassword(user.Password, user)
} }
@@ -66,9 +59,6 @@ func (service *UserService) UpdateUser(user *models.User, userRole int8) (*model
} }
func (service *UserService) RegisterUser(user *models.User) (uint, string, error) { func (service *UserService) RegisterUser(user *models.User) (uint, string, error) {
if err := validateUserData(user, user.RoleID); err != nil {
return http.StatusNotAcceptable, "", err
}
setPassword(user.Password, user) setPassword(user.Password, user)
@@ -76,21 +66,19 @@ func (service *UserService) RegisterUser(user *models.User) (uint, string, error
user.CreatedAt = time.Now() user.CreatedAt = time.Now()
user.UpdatedAt = time.Now() user.UpdatedAt = time.Now()
user.PaymentStatus = constants.AwaitingPaymentStatus user.PaymentStatus = constants.AwaitingPaymentStatus
// user.DriversLicence.Status = constants.UnverifiedStatus user.DriversLicence.Status = constants.UnverifiedStatus
user.BankAccount.MandateDateSigned = time.Now() user.BankAccount.MandateDateSigned = time.Now()
id, err := service.Repo.CreateUser(user) id, err := service.Repo.CreateUser(user)
if err != nil && strings.Contains(err.Error(), "UNIQUE constraint failed") { if err != nil {
return http.StatusConflict, "", err return 0, "", err
} else if err != nil {
return http.StatusInternalServerError, "", err
} }
user.ID = id user.ID = id
token, err := utils.GenerateVerificationToken() token, err := utils.GenerateVerificationToken()
if err != nil { if err != nil {
return http.StatusInternalServerError, "", err return 0, "", err
} }
logger.Info.Printf("TOKEN: %v", token) logger.Info.Printf("TOKEN: %v", token)
@@ -98,10 +86,10 @@ func (service *UserService) RegisterUser(user *models.User) (uint, string, error
// Check if user is already verified // Check if user is already verified
verified, err := service.Repo.IsVerified(&user.ID) verified, err := service.Repo.IsVerified(&user.ID)
if err != nil { if err != nil {
return http.StatusInternalServerError, "", err return 0, "", err
} }
if verified { if verified {
return http.StatusAlreadyReported, "", errors.ErrAlreadyVerified return 0, "", errors.ErrAlreadyVerified
} }
// Prepare the Verification record // Prepare the Verification record
@@ -119,7 +107,7 @@ func (service *UserService) RegisterUser(user *models.User) (uint, string, error
func (service *UserService) GetUserByID(id uint) (*models.User, error) { func (service *UserService) GetUserByID(id uint) (*models.User, error) {
return service.Repo.GetUserByID(&id) return repositories.GetUserByID(&id)
} }
func (service *UserService) GetUserByEmail(email string) (*models.User, error) { func (service *UserService) GetUserByEmail(email string) (*models.User, error) {
@@ -140,7 +128,7 @@ func (service *UserService) VerifyUser(token *string) (*models.User, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
user, err := service.Repo.GetUserByID(&verification.UserID) user, err := repositories.GetUserByID(&verification.UserID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -159,27 +147,6 @@ func (service *UserService) VerifyUser(token *string) (*models.User, error) {
return user, nil return user, nil
} }
func validateUserData(user *models.User, userRole int8) error {
validate := validator.New()
validate.RegisterValidation("safe_content", utils.ValidateSafeContent)
if userRole == constants.Roles.Admin {
validate.RegisterValidation("membershipField", utils.ValidateToTrue)
validate.RegisterValidation("age", utils.ValidateToTrue)
validate.RegisterValidation("bic", utils.ValidateToTrue)
validate.RegisterValidation("subscriptionModel", utils.ValidateToTrue)
validate.RegisterValidation("iban", utils.ValidateToTrue)
validate.RegisterValidation("euDriversLicence", utils.ValidateToTrue)
} else {
validate.RegisterValidation("membershipField", utils.ValidateRequiredMembershipField)
validate.RegisterValidation("age", utils.AgeValidator)
validate.RegisterValidation("bic", utils.BICValidator)
validate.RegisterValidation("subscriptionModel", utils.SubscriptionModelValidator)
validate.RegisterValidation("iban", utils.IBANValidator)
validate.RegisterValidation("euDriversLicence", utils.ValidateDriversLicence)
}
return validate.Struct(user)
}
func setPassword(plaintextPassword string, u *models.User) error { func setPassword(plaintextPassword string, u *models.User) error {
hash, err := argon2id.CreateHash(plaintextPassword, argon2id.DefaultParams) hash, err := argon2id.CreateHash(plaintextPassword, argon2id.DefaultParams)
if err != nil { if err != nil {

View File

@@ -1,157 +0,0 @@
package utils
// import "regexp"
import (
"GoMembership/internal/database"
"GoMembership/internal/models"
"GoMembership/pkg/logger"
"reflect"
"regexp"
"slices"
"strconv"
"strings"
"time"
"github.com/go-playground/validator/v10"
"github.com/jbub/banking/iban"
"github.com/jbub/banking/swift"
)
var xssPatterns = []*regexp.Regexp{
regexp.MustCompile(`(?i)<script`),
regexp.MustCompile(`(?i)javascript:`),
regexp.MustCompile(`(?i)on\w+\s*=`),
regexp.MustCompile(`(?i)(vbscript|data):`),
regexp.MustCompile(`(?i)<(iframe|object|embed|applet)`),
regexp.MustCompile(`(?i)expression\s*\(`),
regexp.MustCompile(`(?i)url\s*\(`),
regexp.MustCompile(`(?i)<\?`),
regexp.MustCompile(`(?i)<%`),
regexp.MustCompile(`(?i)<!\[CDATA\[`),
regexp.MustCompile(`(?i)<(svg|animate)`),
regexp.MustCompile(`(?i)<(audio|video|source)`),
regexp.MustCompile(`(?i)base64`),
}
func ValidateToTrue(fl validator.FieldLevel) bool {
return true
}
func AgeValidator(fl validator.FieldLevel) bool {
fieldValue := fl.Field()
dateOfBirth := fieldValue.Interface().(time.Time)
now := time.Now()
age := now.Year() - dateOfBirth.Year()
if now.YearDay() < dateOfBirth.YearDay() {
age-- // if birthday is in the future..
}
return age >= 18
}
func SubscriptionModelValidator(fl validator.FieldLevel) bool {
fieldValue := fl.Field().String()
var names []string
if err := database.DB.Model(&models.SubscriptionModel{}).Pluck("name", &names).Error; err != nil {
logger.Error.Fatalf("Couldn't get SubscriptionModel names: %#v", err)
return false
}
return slices.Contains(names, fieldValue)
}
func IBANValidator(fl validator.FieldLevel) bool {
fieldValue := fl.Field().String()
return iban.Validate(fieldValue) == nil
}
func ValidateRequiredMembershipField(fl validator.FieldLevel) bool {
user := fl.Top().Interface().(*models.User)
membership := user.Membership
subModel := membership.SubscriptionModel
// Get the field name specified in RequiredMembershipField
fieldName := subModel.RequiredMembershipField
if fieldName == "" {
return true
}
// Get the value of the field specified by RequiredMembershipField
fieldValue := reflect.ValueOf(membership).FieldByName(fieldName)
// Check if the fieldValue is valid
if !fieldValue.IsValid() {
return false
}
// Check if the fieldValue is a nil pointer
if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() {
return false
}
// Ensure that the fieldValue is an uint
var fieldUint uint
if fieldValue.Kind() == reflect.Uint {
fieldUint = uint(fieldValue.Uint())
} else {
return false
}
var membershipIDs []uint
if err := database.DB.Model(&models.Membership{}).Pluck("id", &membershipIDs).Error; err != nil {
logger.Error.Fatalf("Couldn't get SubscriptionModel names: %#v", err)
return false
}
// Check if the field value is zero (empty)
return slices.Contains(membershipIDs, fieldUint)
}
func BICValidator(fl validator.FieldLevel) bool {
fieldValue := fl.Field().String()
return swift.Validate(fieldValue) == nil
}
func ValidateSafeContent(fl validator.FieldLevel) bool {
input := strings.ToLower(fl.Field().String())
for _, pattern := range xssPatterns {
if pattern.MatchString(input) {
return false
}
}
return true
}
func ValidateDriversLicence(fl validator.FieldLevel) bool {
fieldValue := fl.Field().String()
if len(fieldValue) != 11 {
return false
}
id, tenthChar := string(fieldValue[:9]), string(fieldValue[9])
if tenthChar == "X" {
tenthChar = "10"
}
tenthValue, _ := strconv.ParseInt(tenthChar, 10, 8)
// for readability
weights := []int{9, 8, 7, 6, 5, 4, 3, 2, 1}
sum := 0
for i := 0; i < 9; i++ {
char := string(id[i])
value, _ := strconv.ParseInt(char, 36, 64)
sum += int(value) * weights[i]
}
calcCheckDigit := sum % 11
if calcCheckDigit != int(tenthValue) {
return false
}
return true
}

View File

@@ -0,0 +1,38 @@
package validation
import (
"strconv"
"github.com/go-playground/validator/v10"
)
func ValidateDriversLicence(fl validator.FieldLevel) bool {
fieldValue := fl.Field().String()
if len(fieldValue) != 11 {
return false
}
id, tenthChar := string(fieldValue[:9]), string(fieldValue[9])
if tenthChar == "X" {
tenthChar = "10"
}
tenthValue, _ := strconv.ParseInt(tenthChar, 10, 8)
// for readability
weights := []int{9, 8, 7, 6, 5, 4, 3, 2, 1}
sum := 0
for i := 0; i < 9; i++ {
char := string(id[i])
value, _ := strconv.ParseInt(char, 36, 64)
sum += int(value) * weights[i]
}
calcCheckDigit := sum % 11
if calcCheckDigit != int(tenthValue) {
return false
}
return true
}

View File

@@ -0,0 +1,19 @@
package validation
import (
"github.com/go-playground/validator/v10"
"github.com/jbub/banking/iban"
"github.com/jbub/banking/swift"
)
func IBANValidator(fl validator.FieldLevel) bool {
fieldValue := fl.Field().String()
return iban.Validate(fieldValue) == nil
}
func BICValidator(fl validator.FieldLevel) bool {
fieldValue := fl.Field().String()
return swift.Validate(fieldValue) == nil
}

View File

@@ -0,0 +1,34 @@
package validation
import (
"regexp"
"strings"
"github.com/go-playground/validator/v10"
)
var xssPatterns = []*regexp.Regexp{
regexp.MustCompile(`(?i)<script`),
regexp.MustCompile(`(?i)javascript:`),
regexp.MustCompile(`(?i)on\w+\s*=`),
regexp.MustCompile(`(?i)(vbscript|data):`),
regexp.MustCompile(`(?i)<(iframe|object|embed|applet)`),
regexp.MustCompile(`(?i)expression\s*\(`),
regexp.MustCompile(`(?i)url\s*\(`),
regexp.MustCompile(`(?i)<\?`),
regexp.MustCompile(`(?i)<%`),
regexp.MustCompile(`(?i)<!\[CDATA\[`),
regexp.MustCompile(`(?i)<(svg|animate)`),
regexp.MustCompile(`(?i)<(audio|video|source)`),
regexp.MustCompile(`(?i)base64`),
}
func ValidateSafeContent(fl validator.FieldLevel) bool {
input := strings.ToLower(fl.Field().String())
for _, pattern := range xssPatterns {
if pattern.MatchString(input) {
return false
}
}
return true
}

View File

@@ -0,0 +1,30 @@
package validation
import (
"GoMembership/internal/models"
"GoMembership/internal/repositories"
"github.com/go-playground/validator/v10"
)
func validateMembership(sl validator.StructLevel, membership models.Membership) {
if membership.SubscriptionModel.RequiredMembershipField != "" {
switch membership.SubscriptionModel.RequiredMembershipField {
case "ParentMembershipID":
if membership.ParentMembershipID == 0 {
sl.ReportError(membership.ParentMembershipID, membership.SubscriptionModel.RequiredMembershipField,
"RequiredMembershipField", "required", "")
} else {
_, err := repositories.GetUserByID(&membership.ParentMembershipID)
if err != nil {
sl.ReportError(membership.ParentMembershipID, membership.SubscriptionModel.RequiredMembershipField,
"RequiredMembershipField", "user_id_not_found", "")
}
}
default:
sl.ReportError(membership.ParentMembershipID, membership.SubscriptionModel.RequiredMembershipField,
"RequiredMembershipField", "not_implemented", "")
}
}
}

View File

@@ -0,0 +1,23 @@
package validation
import (
"GoMembership/internal/models"
"github.com/gin-gonic/gin/binding"
"github.com/go-playground/validator/v10"
)
func SetupValidators() {
if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
// Register custom validators
v.RegisterValidation("safe_content", ValidateSafeContent)
v.RegisterValidation("iban", IBANValidator)
v.RegisterValidation("bic", BICValidator)
v.RegisterValidation("euDriversLicence", ValidateDriversLicence)
// Register struct-level validations
v.RegisterStructValidation(validateUser, models.User{})
v.RegisterStructValidation(ValidateSubscription, models.SubscriptionModel{})
}
}

View File

@@ -0,0 +1,46 @@
package validation
import (
"GoMembership/internal/models"
"GoMembership/internal/repositories"
"github.com/go-playground/validator/v10"
)
// ValidateNewSubscription validates a new subscription model being created
func ValidateSubscription(sl validator.StructLevel) {
subscription := sl.Current().Interface().(models.SubscriptionModel)
if subscription.Name == "" {
sl.ReportError(subscription.Name, "Name", "name", "required", "")
}
if sl.Parent().Type().Name() == "MembershipData" {
// This is subscription only operation
if subscription.Details == "" {
sl.ReportError(subscription.Details, "Details", "details", "required", "")
}
if subscription.MonthlyFee < 0 {
sl.ReportError(subscription.MonthlyFee, "MonthlyFee", "monthly_fee", "gte", "0")
}
if subscription.HourlyRate < 0 {
sl.ReportError(subscription.HourlyRate, "HourlyRate", "hourly_rate", "gte", "0")
}
if subscription.IncludedPerYear < 0 {
sl.ReportError(subscription.IncludedPerYear, "IncludedPerYear", "included_hours_per_year", "gte", "0")
}
if subscription.IncludedPerMonth < 0 {
sl.ReportError(subscription.IncludedPerMonth, "IncludedPerMonth", "included_hours_per_month", "gte", "0")
}
} else {
// This is a nested probably user struct. We are only checking if the model exists
existingSubscription, err := repositories.GetModelByName(&subscription.Name)
if err != nil || existingSubscription == nil {
sl.ReportError(subscription.Name, "Name", "name", "exists", "")
}
}
}

View File

@@ -0,0 +1,51 @@
package validation
import (
"GoMembership/internal/models"
"GoMembership/internal/repositories"
"GoMembership/pkg/logger"
"time"
"github.com/go-playground/validator/v10"
)
func validateUser(sl validator.StructLevel) {
user := sl.Current().Interface().(models.User)
if user.DateOfBirth.After(time.Now().AddDate(-18, 0, 0)) {
sl.ReportError(user.DateOfBirth, "DateOfBirth", "date_of_birth", "age", "")
}
if user.Membership.SubscriptionModel.Name == "" {
sl.ReportError(user.Membership.SubscriptionModel.Name, "SubscriptionModel.Name", "name", "required", "")
} else {
selectedModel, err := repositories.GetModelByName(&user.Membership.SubscriptionModel.Name)
if err != nil {
logger.Error.Printf("Error finding subscription model for user %v: %v", user.Email, err)
sl.ReportError(user.Membership.SubscriptionModel.Name, "SubscriptionModel.Name", "name", "invalid", "")
} else {
user.Membership.SubscriptionModel = *selectedModel
}
}
validateMembership(sl, user.Membership)
}
// func RequiredIfNotAdmin(fl validator.FieldLevel) bool {
// // Traverse up the struct hierarchy to find the IsAdmin field
// current := fl.Parent()
// // Check multiple levels of nesting to find userRole
// for current.IsValid() {
// if isRoleIDField := current.FieldByName("RoleID"); isRoleIDField.IsValid() {
// // If IsAdmin is found and is true, skip validation
// if isRoleIDField.Interface().(int8) == constants.Roles.Admin{
// return true
// }
// break
// }
// current = current.Parent() // Move to the next parent level
// }
// If not an admin, enforce that the field must have a non-zero value
// return !fl.Field().IsZero()
// }