frontend: disabled button while processing password reset

This commit is contained in:
Alex
2025-02-28 08:51:35 +01:00
parent 8137f121ed
commit 9c9430ca9c
92 changed files with 37 additions and 17 deletions

View File

@@ -0,0 +1,161 @@
// Package config provides functionality for loading application configuration from a JSON file and environment variables.
// It defines structs for different configuration sections (database, authentication, SMTP, templates) and functions
// to read and populate these configurations. It also generates secrets for JWT and CSRF tokens.
//
// This package uses the `envconfig` library to map environment variables to struct fields, falls back to variables of a config
// file and provides functions for error handling and logging during the configuration loading process.
package config
import (
"encoding/json"
"os"
"path/filepath"
"strings"
"github.com/kelseyhightower/envconfig"
"GoMembership/internal/utils"
"GoMembership/pkg/logger"
)
type DatabaseConfig struct {
Path string `json:"Path" default:"data/db.sqlite3" envconfig:"DB_PATH"`
}
type SiteConfig struct {
AllowOrigins string `json:"AllowOrigins" envconfig:"ALLOW_ORIGINS"`
WebsiteTitle string `json:"WebsiteTitle" envconfig:"WEBSITE_TITLE"`
BaseURL string `json:"BaseUrl" envconfig:"BASE_URL"`
}
type AuthenticationConfig struct {
JWTSecret string
CSRFSecret string
APIKEY string `json:"APIKey" envconfig:"API_KEY"`
}
type SMTPConfig struct {
Host string `json:"Host" envconfig:"SMTP_HOST"`
User string `json:"User" envconfig:"SMTP_USER"`
Password string `json:"Password" envconfig:"SMTP_PASS"`
Port int `json:"Port" default:"465" envconfig:"SMTP_PORT"`
}
type TemplateConfig struct {
MailPath string `json:"MailPath" default:"templates/email" envconfig:"TEMPLATE_MAIL_PATH"`
HTMLPath string `json:"HTMLPath" default:"templates/html" envconfig:"TEMPLATE_HTML_PATH"`
StaticPath string `json:"StaticPath" default:"templates/css" envconfig:"TEMPLATE_STATIC_PATH"`
LogoURI string `json:"LogoURI" envconfig:"LOGO_URI"`
}
type RecipientsConfig struct {
ContactForm string `json:"ContactForm" envconfig:"RECIPIENT_CONTACT_FORM"`
UserRegistration string `json:"UserRegistration" envconfig:"RECIPIENT_USER_REGISTRATION"`
AdminEmail string `json:"AdminEmail" envconfig:"ADMIN_MAIL"`
}
type SecurityConfig struct {
Ratelimits struct {
Limit int `json:"Limit" default:"1" envconfig:"RATE_LIMIT"`
Burst int `json:"Burst" default:"60" envconfig:"BURST_LIMIT"`
} `json:"RateLimits"`
}
type Config struct {
Auth AuthenticationConfig `json:"auth"`
Site SiteConfig `json:"site"`
Templates TemplateConfig `json:"templates"`
Recipients RecipientsConfig `json:"recipients"`
ConfigFilePath string `json:"config_file_path" envconfig:"CONFIG_FILE_PATH"`
Env string `json:"Environment" default:"development" envconfig:"ENV"`
DB DatabaseConfig `json:"db"`
SMTP SMTPConfig `json:"smtp"`
Security SecurityConfig `json:"security"`
}
var (
Site SiteConfig
CFGPath string
CFG Config
Auth AuthenticationConfig
DB DatabaseConfig
Templates TemplateConfig
SMTP SMTPConfig
Recipients RecipientsConfig
Env string
Security SecurityConfig
)
var environmentOptions map[string]bool = map[string]bool{
"development": true,
"production": true,
"dev": true,
"prod": true,
}
// LoadConfig initializes the configuration by reading from a file and environment variables.
// It also generates JWT and CSRF secrets. Returns a Config pointer or an error if any step fails.
func LoadConfig() {
CFGPath = os.Getenv("CONFIG_FILE_PATH")
logger.Info.Printf("Config file environment: %v", CFGPath)
readFile(&CFG)
readEnv(&CFG)
csrfSecret, err := utils.GenerateRandomString(32)
if err != nil {
logger.Error.Fatalf("could not generate CSRF secret: %v", err)
}
jwtSecret, err := utils.GenerateRandomString(32)
if err != nil {
logger.Error.Fatalf("could not generate JWT secret: %v", err)
}
CFG.Auth.JWTSecret = jwtSecret
CFG.Auth.CSRFSecret = csrfSecret
if environmentOptions[CFG.Env] && strings.Contains("development", CFG.Env) {
CFG.Env = "development"
} else {
CFG.Env = "production"
}
Auth = CFG.Auth
DB = CFG.DB
Templates = CFG.Templates
SMTP = CFG.SMTP
Recipients = CFG.Recipients
Security = CFG.Security
Env = CFG.Env
Site = CFG.Site
logger.Info.Printf("Config loaded: %#v", CFG)
}
// readFile reads the configuration from the specified file path into the provided Config struct.
// If the file path is empty, it defaults to "configs/config.json" in the current working directory.
// Returns an error if the file cannot be opened or the JSON cannot be decoded.
func readFile(cfg *Config) {
if CFGPath == "" {
path, err := os.Getwd()
if err != nil {
logger.Error.Fatalf("could not get working directory: %v", err)
}
CFGPath = filepath.Join(path, "configs", "config.json")
}
configFile, err := os.Open(CFGPath)
// configFile, err := os.Open("config.json")
if err != nil {
logger.Error.Fatalf("could not open config file: %v", err)
}
defer configFile.Close()
decoder := json.NewDecoder(configFile)
err = decoder.Decode(cfg)
if err != nil {
logger.Error.Fatalf("could not decode config file: %v", err)
}
}
// readEnv populates the Config struct with values from environment variables using the envconfig package.
// Returns an error if environment variable decoding fails.
func readEnv(cfg *Config) {
err := envconfig.Process("", cfg)
if err != nil {
logger.Error.Fatalf("could not decode env variables: %#v", err)
}
}

View File

@@ -0,0 +1,101 @@
package constants
const (
UnverifiedStatus = iota + 1
DisabledStatus
VerifiedStatus
ActiveStatus
PassiveStatus
DelayedPaymentStatus
SettledPaymentStatus
AwaitingPaymentStatus
MailVerificationSubject = "Nur noch ein kleiner Schritt!"
MailChangePasswordSubject = "Passwort Änderung angefordert"
MailRegistrationSubject = "Neues Mitglied hat sich registriert"
MailWelcomeSubject = "Willkommen beim Dörpsmobil Hasloh e.V."
MailContactSubject = "Jemand hat das Kontaktformular gefunden"
)
var Roles = struct {
Member int8
Viewer int8
Editor int8
Admin int8
}{
Member: 0,
Viewer: 1,
Editor: 4,
Admin: 8,
}
var Licences = struct {
AM string
A1 string
A2 string
A string
B string
C1 string
C string
D1 string
D string
BE string
C1E string
CE string
D1E string
DE string
L string
T string
}{
AM: "AM",
A1: "A1",
A2: "A2",
A: "A",
B: "B",
C1: "C1",
C: "C",
D1: "D1",
D: "D",
BE: "BE",
C1E: "C1E",
CE: "CE",
D1E: "D1E",
DE: "DE",
L: "L",
T: "T",
}
var VerificationTypes = struct {
Email string
Password string
}{
Email: "email",
Password: "password",
}
var Priviliges = struct {
View int8
Create int8
Update int8
Delete int8
}{
View: 0,
Update: 10,
Create: 20,
Delete: 30,
}
var MemberUpdateFields = map[string]bool{
"Email": true,
"Phone": true,
"Company": true,
"Address": true,
"ZipCode": true,
"City": true,
"Licence.Categories": true,
"BankAccount.Bank": true,
"BankAccount.AccountHolderName": true,
"BankAccount.IBAN": true,
"BankAccount.BIC": true,
}
// "Password": true,

View File

@@ -0,0 +1,71 @@
package controllers
import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
)
type SQLInjectionTest struct {
name string
email string
password string
expectedStatus int
}
func (sit *SQLInjectionTest) SetupContext() (*gin.Context, *httptest.ResponseRecorder, *gin.Engine) {
loginData := loginInput{
Email: sit.email,
Password: sit.password,
}
jsonData, _ := json.Marshal(loginData)
return GetMockedJSONContext(jsonData, "/login")
}
func (sit *SQLInjectionTest) RunHandler(c *gin.Context, router *gin.Engine) {
router.POST("/login", Uc.LoginHandler)
router.ServeHTTP(c.Writer, c.Request)
}
func (sit *SQLInjectionTest) ValidateResponse(w *httptest.ResponseRecorder) error {
if sit.expectedStatus != w.Code {
responseBody, _ := io.ReadAll(w.Body)
return fmt.Errorf("SQL Injection Attempt: Didn't get the expected response code: got: %v; expected: %v. Context: %#v", w.Code, sit.expectedStatus, string(responseBody))
}
return nil
}
func (sit *SQLInjectionTest) ValidateResult() error {
// Add any additional validation if needed
return nil
}
func testSQLInjectionAttempt(t *testing.T) {
tests := []SQLInjectionTest{
{
name: "SQL Injection Attempt in Email",
email: "' OR '1'='1",
password: "password123",
expectedStatus: http.StatusNotFound,
},
{
name: "SQL Injection Attempt in Password",
email: "user@example.com",
password: "' OR '1'='1",
expectedStatus: http.StatusNotFound,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := runSingleTest(&tt); err != nil {
t.Errorf("Test failed: %v", err.Error())
}
})
}
}

View File

@@ -0,0 +1,31 @@
package controllers
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
)
func testXSSAttempt(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.POST("/register", Uc.RegisterUser)
xssPayload := "<script>alert('XSS')</script>"
user := getBaseUser()
user.FirstName = xssPayload
user.Email = "user@xss.hack"
jsonData, _ := json.Marshal(RegistrationData{User: user})
req, _ := http.NewRequest("POST", "/register", bytes.NewBuffer(jsonData))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
assert.Equal(t, http.StatusBadRequest, w.Code)
assert.NotContains(t, w.Body.String(), xssPayload)
}

View File

@@ -0,0 +1,50 @@
package controllers
import (
"net/http"
"github.com/gin-gonic/gin"
"github.com/go-playground/validator/v10"
"GoMembership/internal/services"
"GoMembership/pkg/logger"
)
type ContactController struct {
EmailService *services.EmailService
}
type contactData struct {
Email string `form:"REPLY_TO" validate:"required,email"`
Name string `form:"name"`
Message string `form:"message" validate:"required"`
Honeypot string `form:"username" validate:"eq="`
}
func (cc *ContactController) RelayContactRequest(c *gin.Context) {
var msgData contactData
if err := c.ShouldBind(&msgData); err != nil {
// A bot is talking to us
c.JSON(http.StatusNotAcceptable, gin.H{"error": "Not Acceptable"})
return
}
validate := validator.New()
if err := validate.Struct(msgData); err != nil {
logger.Error.Printf("Couldn't validate contact form data: %#v: %v", msgData, err)
c.HTML(http.StatusNotAcceptable, "contactForm_reply.html", gin.H{"Error": "Form validation failed. Please check again."})
// c.JSON(http.StatusNotAcceptable, gin.H{"error": "Couldn't validate contact form data"})
return
}
if err := cc.EmailService.RelayContactFormMessage(msgData.Email, msgData.Name, msgData.Message); err != nil {
logger.Error.Printf("Couldn't send contact message mail: %v", err)
c.HTML(http.StatusInternalServerError, "contactForm_reply.html", gin.H{"Error": "Email submission failed. Please try again."})
// c.JSON(http.StatusInternalServerError, gin.H{"error": "Couldn't send mail"})
return
}
// c.JSON(http.StatusAccepted, "Your message has been sent")
c.HTML(http.StatusAccepted, "contactForm_reply.html", gin.H{"Success": true})
}

View File

@@ -0,0 +1,159 @@
package controllers
import (
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"GoMembership/internal/config"
"GoMembership/internal/constants"
"GoMembership/internal/utils"
"github.com/gin-gonic/gin"
)
type RelayContactRequestTest struct {
Input url.Values
Name string
WantResponse int
Assert bool
}
func testContactController(t *testing.T) {
tests := getContactData()
for _, tt := range tests {
t.Run(tt.Name, func(t *testing.T) {
if err := runSingleTest(&tt); err != nil {
t.Errorf("Test failed: %v", err.Error())
}
})
}
}
func (rt *RelayContactRequestTest) SetupContext() (*gin.Context, *httptest.ResponseRecorder, *gin.Engine) {
return GetMockedFormContext(rt.Input, "/contact")
}
func (rt *RelayContactRequestTest) RunHandler(c *gin.Context, router *gin.Engine) {
router.POST("/contact", Cc.RelayContactRequest)
router.ServeHTTP(c.Writer, c.Request)
// Cc.RelayContactRequest(c)
}
func (rt *RelayContactRequestTest) ValidateResponse(w *httptest.ResponseRecorder) error {
if w.Code != rt.WantResponse {
return fmt.Errorf("Didn't get the expected response code: got: %v; expected: %v", w.Code, rt.WantResponse)
}
return nil
}
func (rt *RelayContactRequestTest) ValidateResult() error {
messages := utils.SMTPGetMessages()
for _, message := range messages {
mail, err := utils.DecodeMail(message.MsgRequest())
if err != nil {
return err
}
if strings.Contains(mail.Subject, constants.MailContactSubject) {
if err := checkContactRequestMail(mail, rt); err != nil {
return err
}
} else {
return fmt.Errorf("Subject not expected: %v", mail.Subject)
}
}
return nil
}
func checkContactRequestMail(mail *utils.Email, rt *RelayContactRequestTest) error {
if !strings.Contains(mail.To, config.Recipients.ContactForm) {
return fmt.Errorf("Contact Information didn't reach the admin! Recipient was: %v instead of %v", mail.To, config.Recipients.ContactForm)
}
if !strings.Contains(mail.From, config.SMTP.User) {
return fmt.Errorf("Contact Information was sent from unexpected address! Sender was: %v instead of %v", mail.From, config.SMTP.User)
}
//Check if all the relevant data has been passed to the mail.
if !strings.Contains(mail.Body, rt.Input.Get("name")) {
return fmt.Errorf("User name(%v) has not been rendered in contact mail.", rt.Input.Get("name"))
}
if !strings.Contains(mail.Body, rt.Input.Get("message")) {
return fmt.Errorf("User message(%v) has not been rendered in contact mail.", rt.Input.Get("message"))
}
return nil
}
func getBaseRequest() *url.Values {
return &url.Values{
"username": {""},
"name": {"My-First and-Last-Name"},
"REPLY_TO": {"name@domain.de"},
"message": {"My message to the world"},
}
}
func customizeRequest(updates map[string]string) *url.Values {
form := getBaseRequest()
for key, value := range updates {
form.Set(key, value)
}
return form
}
func getContactData() []RelayContactRequestTest {
return []RelayContactRequestTest{
{
Name: "mail empty, should fail",
WantResponse: http.StatusNotAcceptable,
Assert: false,
Input: *customizeRequest(
map[string]string{
"REPLY_TO": "",
}),
},
{
Name: "mail invalid, should fail",
WantResponse: http.StatusNotAcceptable,
Assert: false,
Input: *customizeRequest(
map[string]string{
"REPLY_TO": "novalid#email.de",
}),
},
{
Name: "No message should fail",
WantResponse: http.StatusNotAcceptable,
Assert: true,
Input: *customizeRequest(
map[string]string{
"message": "",
}),
},
{
Name: "Honeypot set, should fail",
WantResponse: http.StatusNotAcceptable,
Assert: true,
Input: *customizeRequest(
map[string]string{
"username": "I'm a bot",
}),
},
{
Name: "Correct message, should pass",
WantResponse: http.StatusAccepted,
Assert: true,
Input: *customizeRequest(
map[string]string{}),
},
}
}

View File

@@ -0,0 +1,313 @@
package controllers
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"strconv"
"testing"
"time"
"log"
"github.com/gin-gonic/gin"
"GoMembership/internal/config"
"GoMembership/internal/constants"
"GoMembership/internal/database"
"GoMembership/internal/models"
"GoMembership/internal/repositories"
"GoMembership/internal/services"
"GoMembership/internal/utils"
"GoMembership/internal/validation"
"GoMembership/pkg/logger"
)
type TestCase interface {
SetupContext() (*gin.Context, *httptest.ResponseRecorder, *gin.Engine)
RunHandler(*gin.Context, *gin.Engine)
ValidateResponse(*httptest.ResponseRecorder) error
ValidateResult() error
}
const (
Host = "127.0.0.1"
Port int = 2525
)
type loginInput struct {
Email string `json:"email"`
Password string `json:"password"`
}
var (
Uc *UserController
Mc *MembershipController
Cc *ContactController
)
func TestSuite(t *testing.T) {
_ = deleteTestDB("test.db")
cwd, err := os.Getwd()
if err != nil {
log.Fatalf("Failed to get current working directory: %v", err)
}
// Build paths relative to the current working directory
configFilePath := filepath.Join(cwd, "..", "..", "configs", "config.json")
templateHTMLPath := filepath.Join(cwd, "..", "..", "templates", "html")
templateMailPath := filepath.Join(cwd, "..", "..", "templates", "email")
if err := os.Setenv("TEMPLATE_MAIL_PATH", templateMailPath); err != nil {
log.Fatalf("Error setting environment variable: %v", err)
}
if err := os.Setenv("TEMPLATE_HTML_PATH", templateHTMLPath); err != nil {
log.Fatalf("Error setting environment variable: %v", err)
}
if err := os.Setenv("CONFIG_FILE_PATH", configFilePath); err != nil {
log.Fatalf("Error setting environment variable: %v", err)
}
if err := os.Setenv("SMTP_HOST", Host); err != nil {
log.Fatalf("Error setting environment variable: %v", err)
}
if err := os.Setenv("SMTP_PORT", strconv.Itoa(Port)); err != nil {
log.Fatalf("Error setting environment variable: %v", err)
}
if err := os.Setenv("BASE_URL", "http://"+Host+":2525"); err != nil {
log.Fatalf("Error setting environment variable: %v", err)
}
if err := os.Setenv("DB_PATH", "test.db"); err != nil {
log.Fatalf("Error setting environment variable: %v", err)
}
config.LoadConfig()
if err := database.Open("test.db", config.Recipients.AdminEmail); err != nil {
log.Fatalf("Failed to create DB: %#v", err)
}
utils.SMTPStart(Host, Port)
emailService := services.NewEmailService(config.SMTP.Host, config.SMTP.Port, config.SMTP.User, config.SMTP.Password)
var consentRepo repositories.ConsentRepositoryInterface = &repositories.ConsentRepository{}
consentService := &services.ConsentService{Repo: consentRepo}
var bankAccountRepo repositories.BankAccountRepositoryInterface = &repositories.BankAccountRepository{}
bankAccountService := &services.BankAccountService{Repo: bankAccountRepo}
var membershipRepo repositories.MembershipRepositoryInterface = &repositories.MembershipRepository{}
var subscriptionRepo repositories.SubscriptionModelsRepositoryInterface = &repositories.SubscriptionModelsRepository{}
membershipService := &services.MembershipService{Repo: membershipRepo, SubscriptionRepo: subscriptionRepo}
var licenceRepo repositories.LicenceInterface = &repositories.LicenceRepository{}
var userRepo repositories.UserRepositoryInterface = &repositories.UserRepository{}
userService := &services.UserService{Repo: userRepo, Licences: licenceRepo}
licenceService := &services.LicenceService{Repo: licenceRepo}
Uc = &UserController{Service: userService, LicenceService: licenceService, EmailService: emailService, ConsentService: consentService, BankAccountService: bankAccountService, MembershipService: membershipService}
Mc = &MembershipController{UserController: &MockUserController{}, Service: *membershipService}
Cc = &ContactController{EmailService: emailService}
if err := initSubscriptionPlans(); err != nil {
log.Fatalf("Failed to init Subscription plans: %#v", err)
}
if err := initLicenceCategories(); err != nil {
log.Fatalf("Failed to init Categories: %v", err)
}
admin := models.User{
FirstName: "Ad",
LastName: "min",
Email: "admin@example.com",
DateOfBirth: time.Date(1990, 1, 1, 0, 0, 0, 0, time.UTC),
Company: "SampleCorp",
Phone: "+123456789",
Address: "123 Main Street",
ZipCode: "12345",
City: "SampleCity",
Status: constants.ActiveStatus,
RoleID: 8,
}
admin.SetPassword("securepassword")
database.DB.Create(&admin)
validation.SetupValidators()
t.Run("userController", func(t *testing.T) {
testUserController(t)
})
t.Run("SQL_Injection", func(t *testing.T) {
testSQLInjectionAttempt(t)
})
t.Run("contactController", func(t *testing.T) {
testContactController(t)
})
t.Run("membershipController", func(t *testing.T) {
testMembershipController(t)
})
t.Run("XSSAttempt", func(t *testing.T) {
testXSSAttempt(t)
})
if err := utils.SMTPStop(); err != nil {
log.Fatalf("Failed to stop SMTP Mockup Server: %#v", err)
}
// if err := deleteTestDB("test.db"); err != nil {
// log.Fatalf("Failed to tear down DB: %#v", err)
// }
}
func initLicenceCategories() error {
categories := []models.Category{
{Name: "AM"},
{Name: "A1"},
{Name: "A2"},
{Name: "A"},
{Name: "B"},
{Name: "C1"},
{Name: "C"},
{Name: "D1"},
{Name: "D"},
{Name: "BE"},
{Name: "C1E"},
{Name: "CE"},
{Name: "D1E"},
{Name: "DE"},
{Name: "T"},
{Name: "L"},
}
for _, category := range categories {
result := database.DB.Create(&category)
if result.Error != nil {
return result.Error
}
}
return nil
}
func initSubscriptionPlans() error {
subscriptions := []models.SubscriptionModel{
{
Name: "Basic",
Details: "Test Plan",
MonthlyFee: 2,
HourlyRate: 3,
},
{
Name: "additional",
Details: "This plan needs another membership id to validate",
RequiredMembershipField: "ParentMembershipID",
MonthlyFee: 2,
HourlyRate: 3,
},
}
for _, subscription := range subscriptions {
result := database.DB.Create(&subscription)
if result.Error != nil {
return result.Error
}
}
return nil
}
func GetMockedJSONContext(jsonStr []byte, url string) (*gin.Context, *httptest.ResponseRecorder, *gin.Engine) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
router := gin.New()
// Load HTML templates
router.LoadHTMLGlob(config.Templates.HTMLPath + "/*")
var err error
c.Request, err = http.NewRequest("POST", url, bytes.NewBuffer(jsonStr))
if err != nil {
log.Fatalf("Failed to create new Request: %#v", err)
}
c.Request.Header.Set("Content-Type", "application/json")
return c, w, router
}
func GetMockedFormContext(formData url.Values, url string) (*gin.Context, *httptest.ResponseRecorder, *gin.Engine) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
router := gin.New()
// Load HTML templates
router.LoadHTMLGlob(config.Templates.HTMLPath + "/*")
req, err := http.NewRequest("POST",
url,
bytes.NewBufferString(formData.Encode()))
if err != nil {
log.Fatalf("Failed to create new Request: %#v", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
c.Request = req
return c, w, router
}
func getBaseUser() models.User {
return models.User{
DateOfBirth: time.Date(2000, time.January, 1, 0, 0, 0, 0, time.UTC),
FirstName: "John",
LastName: "Doe",
Email: "john.doe@example.com",
Address: "Pablo Escobar Str. 4",
ZipCode: "25474",
City: "Hasloh",
Phone: "01738484993",
BankAccount: models.BankAccount{IBAN: "DE89370400440532013000"},
Membership: models.Membership{SubscriptionModel: models.SubscriptionModel{Name: "Basic"}},
Licence: nil,
ProfilePicture: "",
Password: "password123",
Company: "",
RoleID: 8,
}
}
func deleteTestDB(dbPath string) error {
err := os.Remove(dbPath)
if err != nil {
return err
}
return nil
}
func runSingleTest(tc TestCase) error {
c, w, router := tc.SetupContext()
tc.RunHandler(c, router)
if err := tc.ValidateResponse(w); err != nil {
return err
}
return tc.ValidateResult()
}
func GenerateInputJSON(aStruct interface{}) string {
// Marshal the object into JSON
jsonBytes, err := json.Marshal(aStruct)
if err != nil {
logger.Error.Fatalf("Couldn't generate JSON: %#v\nERROR: %#v", aStruct, err)
return ""
}
return string(jsonBytes)
}

View File

@@ -0,0 +1,26 @@
package controllers
import (
"GoMembership/internal/services"
"GoMembership/internal/utils"
"net/http"
"github.com/gin-gonic/gin"
)
type LicenceController struct {
Service services.LicenceService
}
func (lc *LicenceController) GetAllCategories(c *gin.Context) {
categories, err := lc.Service.GetAllCategories()
if err != nil {
utils.RespondWithError(c, err, "Error retrieving licence categories", http.StatusInternalServerError, "general", "server.error.internal_server_error")
return
}
c.JSON(http.StatusOK, gin.H{
"licence_categories": categories,
})
}

View File

@@ -0,0 +1,153 @@
package controllers
import (
"GoMembership/internal/constants"
"GoMembership/internal/models"
"GoMembership/internal/services"
"GoMembership/internal/utils"
"strings"
"net/http"
"github.com/gin-gonic/gin"
"GoMembership/pkg/errors"
"GoMembership/pkg/logger"
)
type MembershipController struct {
Service services.MembershipService
UserController interface {
ExtractUserFromContext(*gin.Context) (*models.User, error)
}
}
type MembershipData struct {
// APIKey string `json:"api_key"`
Subscription models.SubscriptionModel `json:"subscription"`
}
func (mc *MembershipController) RegisterSubscription(c *gin.Context) {
var regData MembershipData
requestUser, err := mc.UserController.ExtractUserFromContext(c)
if err != nil {
utils.RespondWithError(c, err, "Error extracting user from context in subscription registrationHandler", http.StatusBadRequest, "general", "server.validation.invalid_user_data")
return
}
if !utils.HasPrivilige(requestUser, constants.Priviliges.Create) {
utils.RespondWithError(c, errors.ErrNotAuthorized, "Not allowed to register subscription", http.StatusForbidden, "user.user", "server.error.unauthorized")
return
}
if err := c.ShouldBindJSON(&regData); err != nil {
utils.HandleValidationError(c, err)
return
}
// Register Subscription
logger.Info.Printf("Registering subscription %v", regData.Subscription.Name)
id, err := mc.Service.RegisterSubscription(&regData.Subscription)
if err != nil {
logger.Error.Printf("Couldn't register Membershipmodel: %v", err)
if strings.Contains(err.Error(), "UNIQUE constraint failed") {
c.JSON(http.StatusConflict, "Duplicate subscription name")
return
}
c.JSON(http.StatusNotAcceptable, "Couldn't register Membershipmodel")
return
}
logger.Info.Printf("registering subscription: %+v", regData)
c.JSON(http.StatusCreated, gin.H{
"status": "success",
"id": id,
})
}
func (mc *MembershipController) UpdateHandler(c *gin.Context) {
var regData MembershipData
requestUser, err := mc.UserController.ExtractUserFromContext(c)
if err != nil {
utils.RespondWithError(c, err, "Error extracting user from context in subscription UpdateHandler", http.StatusBadRequest, "general", "server.validation.no_auth_tokenw")
return
}
if !utils.HasPrivilige(requestUser, constants.Priviliges.Update) {
utils.RespondWithError(c, errors.ErrNotAuthorized, "Not allowed to update subscription", http.StatusForbidden, "user.user", "server.error.unauthorized")
return
}
if err := c.ShouldBindJSON(&regData); err != nil {
utils.HandleValidationError(c, err)
return
}
// update Subscription
logger.Info.Printf("Updating subscription %v", regData.Subscription.Name)
id, err := mc.Service.UpdateSubscription(&regData.Subscription)
if err != nil {
logger.Error.Printf("Couldn't update Membershipmodel: %v", err)
if strings.Contains(err.Error(), "UNIQUE constraint failed") {
c.JSON(http.StatusConflict, "Duplicate subscription name")
return
}
c.JSON(http.StatusNotAcceptable, "Couldn't update Membershipmodel")
return
}
logger.Info.Printf("updating subscription: %+v", regData)
c.JSON(http.StatusAccepted, gin.H{
"status": "success",
"id": id,
})
}
func (mc *MembershipController) DeleteSubscription(c *gin.Context) {
type deleteData struct {
Subscription struct {
ID uint `json:"id"`
Name string `json:"name"`
} `json:"subscription"`
}
var data deleteData
requestUser, err := mc.UserController.ExtractUserFromContext(c)
if err != nil {
utils.RespondWithError(c, err, "Error extracting user from context in subscription UpdateHandler", http.StatusBadRequest, "general", "server.validation.no_auth_tokenw")
return
}
if !utils.HasPrivilige(requestUser, constants.Priviliges.Delete) {
utils.RespondWithError(c, errors.ErrNotAuthorized, "Not allowed to update subscription", http.StatusForbidden, "user.user", "server.error.unauthorized")
return
}
if err := c.ShouldBindJSON(&data); err != nil {
utils.HandleValidationError(c, err)
return
}
if err := mc.Service.DeleteSubscription(&data.Subscription.ID, &data.Subscription.Name); err != nil {
utils.RespondWithError(c, err, "Error during subscription Deletion", http.StatusExpectationFailed, "subscription", "server.error.not_possible")
return
}
c.JSON(http.StatusOK, gin.H{"message": "Subscription deleted successfully"})
}
func (mc *MembershipController) GetSubscriptions(c *gin.Context) {
subscriptions, err := mc.Service.GetSubscriptions(nil)
if err != nil {
logger.Error.Printf("Error retrieving subscriptions: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"errors": []gin.H{{
"field": "general",
"key": "validation.internal_server_error",
}}})
return
}
c.JSON(http.StatusOK, gin.H{
"subscriptions": subscriptions,
})
}

View File

@@ -0,0 +1,393 @@
package controllers
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
"GoMembership/internal/constants"
"GoMembership/internal/database"
"GoMembership/internal/models"
"GoMembership/pkg/logger"
"github.com/gin-gonic/gin"
)
type RegisterSubscriptionTest struct {
WantDBData map[string]interface{}
Input string
Name string
WantResponse int
Assert bool
}
type UpdateSubscriptionTest struct {
WantDBData map[string]interface{}
Input string
Name string
WantResponse int
Assert bool
}
type DeleteSubscriptionTest struct {
WantDBData map[string]interface{}
Input string
Name string
WantResponse int
Assert bool
}
type MockUserController struct {
UserController // Embed the UserController
}
func (m *MockUserController) ExtractUserFromContext(c *gin.Context) (*models.User, error) {
return &models.User{
ID: 1,
FirstName: "Admin",
LastName: "User",
Email: "admin@test.com",
RoleID: constants.Roles.Admin,
}, nil
}
func setupMockAuth() {
// Create and assign the mock controller
mockController := &MockUserController{}
Mc.UserController = mockController
}
func testMembershipController(t *testing.T) {
setupMockAuth()
tests := getSubscriptionRegistrationData()
for _, tt := range tests {
logger.Error.Print("==============================================================")
logger.Error.Printf("MembershipController : %v", tt.Name)
logger.Error.Print("==============================================================")
t.Run(tt.Name, func(t *testing.T) {
if err := runSingleTest(&tt); err != nil {
t.Errorf("Test failed: %v", err.Error())
}
})
}
updateTests := getSubscriptionUpdateData()
for _, tt := range updateTests {
logger.Error.Print("==============================================================")
logger.Error.Printf("Update SubscriptionData : %v", tt.Name)
logger.Error.Print("==============================================================")
t.Run(tt.Name, func(t *testing.T) {
if err := runSingleTest(&tt); err != nil {
t.Errorf("Test failed: %v", err.Error())
}
})
}
deleteTests := getSubscriptionDeleteData()
for _, tt := range deleteTests {
logger.Error.Print("==============================================================")
logger.Error.Printf("Delete SubscriptionData : %v", tt.Name)
logger.Error.Print("==============================================================")
t.Run(tt.Name, func(t *testing.T) {
if err := runSingleTest(&tt); err != nil {
t.Errorf("Test failed: %v", err.Error())
}
})
}
}
func (rt *RegisterSubscriptionTest) SetupContext() (*gin.Context, *httptest.ResponseRecorder, *gin.Engine) {
return GetMockedJSONContext([]byte(rt.Input), "api/subscription")
}
func (rt *RegisterSubscriptionTest) RunHandler(c *gin.Context, router *gin.Engine) {
Mc.RegisterSubscription(c)
}
func (rt *RegisterSubscriptionTest) ValidateResponse(w *httptest.ResponseRecorder) error {
if w.Code != rt.WantResponse {
return fmt.Errorf("Didn't get the expected response code: got: %v; expected: %v", w.Code, rt.WantResponse)
}
return nil
}
func (rt *RegisterSubscriptionTest) ValidateResult() error {
return validateSubscription(rt.Assert, rt.WantDBData)
}
func validateSubscription(assert bool, wantDBData map[string]interface{}) error {
subscriptions, err := Mc.Service.GetSubscriptions(wantDBData)
if err != nil {
return fmt.Errorf("Error in database ops: %#v", err)
}
if assert != (len(*subscriptions) != 0) {
return fmt.Errorf("Subscription entry query didn't met expectation: %v != %#v", assert, *subscriptions)
}
return nil
}
func (ut *UpdateSubscriptionTest) SetupContext() (*gin.Context, *httptest.ResponseRecorder, *gin.Engine) {
return GetMockedJSONContext([]byte(ut.Input), "api/subscription/upsert")
}
func (ut *UpdateSubscriptionTest) RunHandler(c *gin.Context, router *gin.Engine) {
Mc.UpdateHandler(c)
}
func (ut *UpdateSubscriptionTest) ValidateResponse(w *httptest.ResponseRecorder) error {
if w.Code != ut.WantResponse {
return fmt.Errorf("Didn't get the expected response code: got: %v; expected: %v", w.Code, ut.WantResponse)
}
return nil
}
func (ut *UpdateSubscriptionTest) ValidateResult() error {
return validateSubscription(ut.Assert, ut.WantDBData)
}
func (dt *DeleteSubscriptionTest) SetupContext() (*gin.Context, *httptest.ResponseRecorder, *gin.Engine) {
return GetMockedJSONContext([]byte(dt.Input), "api/subscription/delete")
}
func (dt *DeleteSubscriptionTest) RunHandler(c *gin.Context, router *gin.Engine) {
Mc.DeleteSubscription(c)
}
func (dt *DeleteSubscriptionTest) ValidateResponse(w *httptest.ResponseRecorder) error {
if w.Code != dt.WantResponse {
return fmt.Errorf("Didn't get the expected response code: got: %v; expected: %v", w.Code, dt.WantResponse)
}
return nil
}
func (dt *DeleteSubscriptionTest) ValidateResult() error {
return validateSubscription(dt.Assert, dt.WantDBData)
}
func getBaseSubscription() MembershipData {
return MembershipData{
// APIKey: config.Auth.APIKEY,
Subscription: models.SubscriptionModel{
Name: "Premium",
Details: "A subscription detail",
MonthlyFee: 12.0,
HourlyRate: 14.0,
},
}
}
func customizeSubscription(customize func(MembershipData) MembershipData) MembershipData {
subscription := getBaseSubscription()
return customize(subscription)
}
func getSubscriptionRegistrationData() []RegisterSubscriptionTest {
return []RegisterSubscriptionTest{
{
Name: "Missing details should fail",
WantResponse: http.StatusBadRequest,
WantDBData: map[string]interface{}{"name": "Just a Subscription"},
Assert: false,
Input: GenerateInputJSON(
customizeSubscription(func(subscription MembershipData) MembershipData {
subscription.Subscription.Details = ""
return subscription
})),
},
{
Name: "Missing model name should fail",
WantResponse: http.StatusBadRequest,
WantDBData: map[string]interface{}{"name": ""},
Assert: false,
Input: GenerateInputJSON(
customizeSubscription(func(subscription MembershipData) MembershipData {
subscription.Subscription.Name = ""
return subscription
})),
},
{
Name: "Negative monthly fee should fail",
WantResponse: http.StatusBadRequest,
WantDBData: map[string]interface{}{"name": "Premium"},
Assert: false,
Input: GenerateInputJSON(customizeSubscription(func(sub MembershipData) MembershipData {
sub.Subscription.MonthlyFee = -10.0
return sub
})),
},
{
Name: "Negative hourly rate should fail",
WantResponse: http.StatusBadRequest,
WantDBData: map[string]interface{}{"name": "Premium"},
Assert: false,
Input: GenerateInputJSON(customizeSubscription(func(sub MembershipData) MembershipData {
sub.Subscription.HourlyRate = -1.0
return sub
})),
},
{
Name: "correct entry should pass",
WantResponse: http.StatusCreated,
WantDBData: map[string]interface{}{"name": "Premium"},
Assert: true,
Input: GenerateInputJSON(
customizeSubscription(func(subscription MembershipData) MembershipData {
subscription.Subscription.Conditions = "Some Condition"
subscription.Subscription.IncludedPerYear = 0
subscription.Subscription.IncludedPerMonth = 1
return subscription
})),
},
{
Name: "Duplicate subscription name should fail",
WantResponse: http.StatusConflict,
WantDBData: map[string]interface{}{"name": "Premium"},
Assert: true, // The original subscription should still exist
Input: GenerateInputJSON(getBaseSubscription()),
},
}
}
func getSubscriptionUpdateData() []UpdateSubscriptionTest {
return []UpdateSubscriptionTest{
{
Name: "Modified Monthly Fee, should fail",
WantResponse: http.StatusNotAcceptable,
WantDBData: map[string]interface{}{"name": "Premium", "monthly_fee": "12"},
Assert: true,
Input: GenerateInputJSON(
customizeSubscription(func(subscription MembershipData) MembershipData {
subscription.Subscription.MonthlyFee = 123.0
return subscription
})),
},
{
Name: "Missing ID, should fail",
WantResponse: http.StatusNotAcceptable,
WantDBData: map[string]interface{}{"name": "Premium"},
Assert: true,
Input: GenerateInputJSON(
customizeSubscription(func(subscription MembershipData) MembershipData {
subscription.Subscription.ID = 0
return subscription
})),
},
{
Name: "Modified Hourly Rate, should fail",
WantResponse: http.StatusNotAcceptable,
WantDBData: map[string]interface{}{"name": "Premium", "hourly_rate": "14"},
Assert: true,
Input: GenerateInputJSON(
customizeSubscription(func(subscription MembershipData) MembershipData {
subscription.Subscription.HourlyRate = 3254.0
return subscription
})),
},
{
Name: "IncludedPerYear changed, should fail",
WantResponse: http.StatusNotAcceptable,
WantDBData: map[string]interface{}{"name": "Premium", "included_per_year": "0"},
Assert: true,
Input: GenerateInputJSON(
customizeSubscription(func(subscription MembershipData) MembershipData {
subscription.Subscription.IncludedPerYear = 9873.0
return subscription
})),
},
{
Name: "IncludedPerMonth changed, should fail",
WantResponse: http.StatusNotAcceptable,
WantDBData: map[string]interface{}{"name": "Premium", "included_per_month": "1"},
Assert: true,
Input: GenerateInputJSON(
customizeSubscription(func(subscription MembershipData) MembershipData {
subscription.Subscription.IncludedPerMonth = 23415.0
return subscription
})),
},
{
Name: "Update non-existent subscription should fail",
WantResponse: http.StatusNotAcceptable,
WantDBData: map[string]interface{}{"name": "NonExistentSubscription"},
Assert: false,
Input: GenerateInputJSON(
customizeSubscription(func(subscription MembershipData) MembershipData {
subscription.Subscription.Name = "NonExistentSubscription"
return subscription
})),
},
{
Name: "Correct Update should pass",
WantResponse: http.StatusAccepted,
WantDBData: map[string]interface{}{"name": "Premium", "details": "Altered Details"},
Assert: true,
Input: GenerateInputJSON(
customizeSubscription(func(subscription MembershipData) MembershipData {
subscription.Subscription.Details = "Altered Details"
subscription.Subscription.Conditions = "Some Condition"
subscription.Subscription.IncludedPerYear = 0
subscription.Subscription.IncludedPerMonth = 1
return subscription
})),
},
}
}
func getSubscriptionDeleteData() []DeleteSubscriptionTest {
var premiumSub, basicSub models.SubscriptionModel
database.DB.Where("name = ?", "Premium").First(&premiumSub)
database.DB.Where("name = ?", "Basic").First(&basicSub)
logger.Error.Printf("premiumSub.ID: %v", premiumSub.ID)
logger.Error.Printf("basicSub.ID: %v", basicSub.ID)
return []DeleteSubscriptionTest{
{
Name: "Delete non-existent subscription should fail",
WantResponse: http.StatusExpectationFailed,
WantDBData: map[string]interface{}{"name": "NonExistentSubscription"},
Assert: false,
Input: GenerateInputJSON(
customizeSubscription(func(subscription MembershipData) MembershipData {
subscription.Subscription.Name = "NonExistentSubscription"
subscription.Subscription.ID = basicSub.ID
return subscription
})),
},
{
Name: "Delete subscription without name should fail",
WantResponse: http.StatusExpectationFailed,
WantDBData: map[string]interface{}{"name": ""},
Assert: false,
Input: GenerateInputJSON(
customizeSubscription(func(subscription MembershipData) MembershipData {
subscription.Subscription.Name = ""
subscription.Subscription.ID = basicSub.ID
return subscription
})),
},
{
Name: "Delete subscription with users should fail",
WantResponse: http.StatusExpectationFailed,
WantDBData: map[string]interface{}{"name": "Basic"},
Assert: true,
Input: GenerateInputJSON(
customizeSubscription(func(subscription MembershipData) MembershipData {
subscription.Subscription.Name = "Basic"
subscription.Subscription.ID = basicSub.ID
return subscription
})),
},
{
Name: "Delete valid subscription should succeed",
WantResponse: http.StatusOK,
WantDBData: map[string]interface{}{"name": "Premium"},
Assert: false,
Input: GenerateInputJSON(
customizeSubscription(func(subscription MembershipData) MembershipData {
subscription.Subscription.Name = "Premium"
subscription.Subscription.ID = premiumSub.ID
return subscription
})),
},
}
}

View File

@@ -0,0 +1,105 @@
package controllers
import (
"GoMembership/internal/constants"
"GoMembership/internal/utils"
"GoMembership/pkg/errors"
"net/http"
"strconv"
"github.com/gin-gonic/gin"
)
func (uc *UserController) RequestPasswordChangeHandler(c *gin.Context) {
// Expected data from the user
var input struct {
Email string `json:"email" binding:"required,email"`
}
if err := c.ShouldBindJSON(&input); err != nil {
utils.HandleValidationError(c, err)
return
}
// find user
db_user, err := uc.Service.GetUserByEmail(input.Email)
if err != nil {
utils.RespondWithError(c, err, "couldn't get user by email", http.StatusNotFound, "user.user", "user.email")
return
}
// check if user may change the password
if db_user.Status <= constants.DisabledStatus {
utils.RespondWithError(c, errors.ErrNotAuthorized, "User password change request denied, user is disabled", http.StatusForbidden, errors.Responses.Fields.Login, errors.Responses.Keys.UserDisabled)
return
}
// create token
token, err := uc.Service.HandlePasswordChangeRequest(db_user)
if err != nil {
utils.RespondWithError(c, err, "couldn't handle password change request", http.StatusInternalServerError, errors.Responses.Fields.General, errors.Responses.Keys.InternalServerError)
return
}
// send email
if err := uc.EmailService.SendChangePasswordEmail(db_user, &token); err != nil {
utils.RespondWithError(c, err, "Couldn't send change password email", http.StatusInternalServerError, errors.Responses.Fields.General, errors.Responses.Keys.InternalServerError)
return
}
c.JSON(http.StatusAccepted, gin.H{
"message": "password_change_requested",
})
}
func (uc *UserController) ChangePassword(c *gin.Context) {
// Expected data from the user
var input struct {
Password string `json:"password" binding:"required"`
Token string `json:"token" binding:"required"`
}
userIDint, err := strconv.Atoi(c.Param("id"))
if err != nil {
utils.RespondWithError(c, err, "Invalid user ID", http.StatusBadRequest, errors.Responses.Fields.User, errors.Responses.Keys.InvalidUserID)
return
}
if err := c.ShouldBindJSON(&input); err != nil {
utils.HandleValidationError(c, err)
return
}
verification, err := uc.Service.VerifyUser(&input.Token, &constants.VerificationTypes.Password)
if err != nil || uint(userIDint) != verification.UserID {
if err == errors.ErrAlreadyVerified {
utils.RespondWithError(c, err, "User already changed password", http.StatusConflict, errors.Responses.Fields.User, errors.Responses.Keys.PasswordAlreadyChanged)
} else if err.Error() == "record not found" {
utils.RespondWithError(c, err, "Couldn't find verification. This is most probably a outdated token.", http.StatusGone, errors.Responses.Fields.User, errors.Responses.Keys.NoAuthToken)
} else {
utils.RespondWithError(c, err, "Couldn't verify user", http.StatusInternalServerError, errors.Responses.Fields.General, errors.Responses.Keys.InternalServerError)
}
return
}
user, err := uc.Service.GetUserByID(verification.UserID)
if err != nil {
utils.RespondWithError(c, err, "Couldn't find user", http.StatusNotFound, errors.Responses.Fields.User, errors.Responses.Keys.UserNotFoundWrongPassword)
return
}
user.Status = constants.ActiveStatus
user.Verification = *verification
user.ID = verification.UserID
user.Password = input.Password
_, err = uc.Service.UpdateUser(user)
if err != nil {
utils.RespondWithError(c, err, "Couldn't update user", http.StatusInternalServerError, errors.Responses.Fields.User, errors.Responses.Keys.InternalServerError)
return
}
c.JSON(http.StatusOK, gin.H{
"message": "password_changed",
})
}

View File

@@ -0,0 +1,366 @@
package controllers
import (
"GoMembership/internal/config"
"GoMembership/internal/constants"
"GoMembership/internal/middlewares"
"GoMembership/internal/models"
"GoMembership/internal/services"
"GoMembership/internal/utils"
"GoMembership/internal/validation"
"fmt"
"strings"
"net/http"
"github.com/gin-gonic/gin"
"GoMembership/pkg/errors"
"GoMembership/pkg/logger"
)
type UserController struct {
Service services.UserServiceInterface
EmailService *services.EmailService
ConsentService services.ConsentServiceInterface
BankAccountService services.BankAccountServiceInterface
MembershipService services.MembershipServiceInterface
LicenceService services.LicenceInterface
}
type RegistrationData struct {
User models.User `json:"user"`
}
func (uc *UserController) CurrentUserHandler(c *gin.Context) {
requestUser, err := uc.ExtractUserFromContext(c)
if err != nil {
utils.RespondWithError(c, err, "Error extracting user from context in CurrentUserHandler", http.StatusBadRequest, "general", "server.error.internal_server_error")
return
}
c.JSON(http.StatusOK, gin.H{
"user": requestUser.Safe(),
})
}
func (uc *UserController) GetAllUsers(c *gin.Context) {
requestUser, err := uc.ExtractUserFromContext(c)
if err != nil {
utils.RespondWithError(c, err, "Error extracting user from context in UpdateHandler", http.StatusBadRequest, "general", "server.validation.no_auth_tokenw")
return
}
if requestUser.RoleID == constants.Roles.Member {
utils.RespondWithError(c, errors.ErrNotAuthorized, "Not allowed to update user", http.StatusForbidden, "user.user", "server.error.unauthorized")
return
}
users, err := uc.Service.GetUsers(nil)
if err != nil {
utils.RespondWithError(c, err, "Error getting users in GetAllUsers", http.StatusInternalServerError, "user.user", "server.error.internal_server_error")
return
}
// Create a slice to hold the safe user representations
safeUsers := make([]map[string]interface{}, len(*users))
// Convert each user to its safe representation
for i, user := range *users {
safeUsers[i] = user.Safe()
}
c.JSON(http.StatusOK, gin.H{
"users": users,
})
}
func (uc *UserController) UpdateHandler(c *gin.Context) {
// 1. Extract and validate the user ID from the route
requestUser, err := uc.ExtractUserFromContext(c)
if err != nil {
utils.RespondWithError(c, err, "Error extracting user from context in UpdateHandler", http.StatusBadRequest, "general", "server.validation.no_auth_tokenw")
return
}
var user models.User
var updateData RegistrationData
if err := c.ShouldBindJSON(&updateData); err != nil {
utils.HandleValidationError(c, err)
return
}
user = updateData.User
if !utils.HasPrivilige(requestUser, constants.Priviliges.Update) && user.ID != requestUser.ID {
utils.RespondWithError(c, errors.ErrNotAuthorized, "Not allowed to update user", http.StatusUnauthorized, "user.user", "server.error.unauthorized")
return
}
existingUser, err := uc.Service.GetUserByID(user.ID)
if err != nil {
utils.RespondWithError(c, err, "Error finding an existing user", http.StatusNotFound, "user.user", "server.error.not_found")
return
}
// user.Membership.ID = existingUser.Membership.ID
// user.MembershipID = existingUser.MembershipID
// if existingUser.Licence != nil {
// user.Licence.ID = existingUser.Licence.ID
// }
// user.LicenceID = existingUser.LicenceID
// user.BankAccount.ID = existingUser.BankAccount.ID
// user.BankAccountID = existingUser.BankAccountID
if requestUser.RoleID <= constants.Priviliges.View {
existingUser.Password = ""
if err := utils.FilterAllowedStructFields(&user, existingUser, constants.MemberUpdateFields, ""); err != nil {
if err.Error() == "Not authorized" {
utils.RespondWithError(c, errors.ErrNotAuthorized, "Trying to update unauthorized fields", http.StatusUnauthorized, "user.user", "server.error.unauthorized")
return
}
utils.RespondWithError(c, err, "Error filtering users input fields", http.StatusInternalServerError, "user.user", "server.error.internal_server_error")
return
}
}
updatedUser, err := uc.Service.UpdateUser(&user)
if err != nil {
utils.HandleUserUpdateError(c, err)
return
}
logger.Info.Printf("User %d updated successfully by user %d", updatedUser.ID, requestUser.ID)
c.JSON(http.StatusAccepted, gin.H{"message": "User updated successfully", "user": updatedUser.Safe()})
}
func (uc *UserController) DeleteUser(c *gin.Context) {
requestUser, err := uc.ExtractUserFromContext(c)
if err != nil {
utils.RespondWithError(c, err, "Error extracting user from context in DeleteUser", http.StatusBadRequest, "general", "server.validation.no_auth_tokenw")
return
}
type deleteData struct {
User struct {
ID uint `json:"id"`
LastName string `json:"last_name"`
} `json:"user"`
}
var data deleteData
if err := c.ShouldBindJSON(&data); err != nil {
utils.HandleValidationError(c, err)
return
}
if !utils.HasPrivilige(requestUser, constants.Priviliges.Delete) && data.User.ID != requestUser.ID {
utils.RespondWithError(c, errors.ErrNotAuthorized, "Not allowed to delete user", http.StatusForbidden, "user.user", "server.error.unauthorized")
return
}
logger.Error.Printf("Deleting user: %v", data.User)
if err := uc.Service.DeleteUser(data.User.LastName, data.User.ID); err != nil {
utils.RespondWithError(c, err, "Error during user deletion", http.StatusInternalServerError, "user.user", "server.error.internal_server_error")
return
}
c.JSON(http.StatusOK, gin.H{"message": "User deleted successfully"})
}
func (uc *UserController) ExtractUserFromContext(c *gin.Context) (*models.User, error) {
tokenString, err := c.Cookie("jwt")
if err != nil {
return nil, err
}
_, claims, err := middlewares.ExtractContentFrom(tokenString)
if err != nil {
return nil, err
}
jwtUserID := uint((*claims)["user_id"].(float64))
user, err := uc.Service.GetUserByID(jwtUserID)
if err != nil {
return nil, err
}
return user, nil
}
func (uc *UserController) LogoutHandler(c *gin.Context) {
tokenString, err := c.Cookie("jwt")
if err != nil {
logger.Error.Printf("unable to get token from cookie: %#v", err)
}
middlewares.InvalidateSession(tokenString)
c.SetCookie("jwt", "", -1, "/", "", true, true)
c.JSON(http.StatusOK, gin.H{"message": "Logged out successfully"})
}
func (uc *UserController) LoginHandler(c *gin.Context) {
var input struct {
Email string `json:"email"`
Password string `json:"password"`
}
if err := c.ShouldBindJSON(&input); err != nil {
utils.RespondWithError(c, err, "Invalid JSON or malformed request", http.StatusBadRequest, errors.Responses.Fields.General, errors.Responses.Keys.Invalid)
return
}
user, err := uc.Service.GetUserByEmail(input.Email)
if err != nil {
utils.RespondWithError(c, err, "Login Error; user not found", http.StatusNotFound,
errors.Responses.Fields.Login,
errors.Responses.Keys.UserNotFoundWrongPassword)
return
}
if user.Status <= constants.DisabledStatus {
utils.RespondWithError(c, fmt.Errorf("User banned from login %v %v", user.FirstName, user.LastName),
"Login Error; user is disabled",
http.StatusNotAcceptable,
errors.Responses.Fields.Login,
errors.Responses.Keys.UserDisabled)
return
}
ok, err := user.PasswordMatches(input.Password)
if err != nil {
utils.RespondWithError(c, err, "Login Error; password comparisson failed", http.StatusInternalServerError, errors.Responses.Fields.Login, errors.Responses.Keys.InternalServerError)
return
}
if !ok {
utils.RespondWithError(c, fmt.Errorf("%v %v(%v)", user.FirstName, user.LastName, user.Email),
"Login Error; wrong password",
http.StatusNotAcceptable,
errors.Responses.Fields.Login,
errors.Responses.Keys.UserNotFoundWrongPassword)
return
}
logger.Error.Printf("jwtsecret: %v", config.Auth.JWTSecret)
token, err := middlewares.GenerateToken(config.Auth.JWTSecret, user, "")
if err != nil {
utils.RespondWithError(c, err, "Error generating token in LoginHandler", http.StatusInternalServerError, errors.Responses.Fields.Login, errors.Responses.Keys.JwtGenerationFailed)
return
}
utils.SetCookie(c, token)
c.JSON(http.StatusOK, gin.H{
"message": "Login successful",
})
}
func (uc *UserController) RegisterUser(c *gin.Context) {
var regData RegistrationData
logger.Error.Printf("registering user...")
if err := c.ShouldBindJSON(&regData); err != nil {
utils.HandleValidationError(c, err)
return
}
logger.Info.Printf("Registering user %v", regData.User.Email)
selectedModel, err := uc.MembershipService.GetSubscriptionByName(&regData.User.Membership.SubscriptionModel.Name)
if err != nil {
utils.RespondWithError(c, err, "Error in Registeruser, couldn't get selected model", http.StatusNotFound, "subscription_model", "server.validation.subscription_model_not_found")
return
}
regData.User.Membership.SubscriptionModel = *selectedModel
if selectedModel.RequiredMembershipField != "" {
if err := validation.CheckParentMembershipID(regData.User.Membership); err != nil {
utils.RespondWithError(c, err, "Error in RegisterUser, couldn't check parent membership id", http.StatusBadRequest, "parent_membership_id", "server.validation.parent_membership_id_not_found")
return
}
}
regData.User.RoleID = constants.Roles.Member
// Register User
id, token, err := uc.Service.RegisterUser(&regData.User)
if err != nil {
logger.Error.Printf("Couldn't register User(%v): %v", regData.User.Email, err)
if strings.Contains(err.Error(), "UNIQUE constraint failed: users.email") {
utils.RespondWithError(c, err, "Error in RegisterUser, couldn't register user", http.StatusConflict, "email", "server.validation.email_already_exists")
} else {
utils.RespondWithError(c, err, "Error in RegisterUser, couldn't register user", http.StatusConflict, "general", "server.error.internal_server_error")
}
return
}
regData.User.ID = id
// Register Consents
var consents = [2]models.Consent{
{
FirstName: regData.User.FirstName,
LastName: regData.User.LastName,
Email: regData.User.Email,
ConsentType: "TermsOfService",
},
{
FirstName: regData.User.FirstName,
LastName: regData.User.LastName,
Email: regData.User.Email,
ConsentType: "Privacy",
},
}
for _, consent := range consents {
_, err = uc.ConsentService.RegisterConsent(&consent)
if err != nil {
utils.RespondWithError(c, err, "Error in RegisterUser, couldn't register consent", http.StatusInternalServerError, "general", "server.error.internal_server_error")
return
}
}
// Send notifications
if err := uc.EmailService.SendVerificationEmail(&regData.User, &token); err != nil {
logger.Error.Printf("Failed to send email verification email to user(%v): %v", regData.User.Email, err)
// Proceed without returning error since user registration is successful
// TODO Notify Admin
}
// Notify admin of new user registration
if err := uc.EmailService.SendRegistrationNotification(&regData.User); err != nil {
logger.Error.Printf("Failed to notify admin of new user(%v) registration: %v", regData.User.Email, err)
// Proceed without returning error since user registration is successful
// TODO Notify Admin
}
c.JSON(http.StatusCreated, gin.H{
"message": "Registration successuful",
"id": regData.User.ID,
})
}
func (uc *UserController) VerifyMailHandler(c *gin.Context) {
token := c.Query("token")
if token == "" {
logger.Error.Println("Missing token to verify mail")
c.HTML(http.StatusBadRequest, "verification_error.html", gin.H{"ErrorMessage": "Missing token"})
return
}
verification, err := uc.Service.VerifyUser(&token, &constants.VerificationTypes.Email)
if err != nil {
logger.Error.Printf("Cannot verify user: %v", err)
c.HTML(http.StatusUnauthorized, "verification_error.html", gin.H{"ErrorMessage": "Emailadresse wurde schon bestätigt. Sollte dies nicht der Fall sein, wende Dich bitte an info@carsharing-hasloh.de."})
return
}
user, err := uc.Service.GetUserByID(verification.UserID)
if err != nil {
utils.RespondWithError(c, err, "Couldn't find user", http.StatusNotFound, errors.Responses.Fields.User, errors.Responses.Keys.UserNotFoundWrongPassword)
return
}
user.Status = constants.VerifiedStatus
user.Verification = *verification
user.ID = verification.UserID
user.Password = ""
uc.Service.UpdateUser(user)
logger.Info.Printf("Verified User: %#v", user.Email)
uc.EmailService.SendWelcomeEmail(user)
c.HTML(http.StatusOK, "verification_success.html", gin.H{"FirstName": user.FirstName})
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,177 @@
package database
import (
"GoMembership/internal/constants"
"GoMembership/internal/models"
"GoMembership/pkg/logger"
"crypto/rand"
"encoding/base64"
"time"
"github.com/alexedwards/argon2id"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
var DB *gorm.DB
func Open(dbPath string, adminMail string) error {
db, err := gorm.Open(sqlite.Open(dbPath), &gorm.Config{})
if err != nil {
return err
}
if err := db.AutoMigrate(
&models.User{},
&models.SubscriptionModel{},
&models.Membership{},
&models.Consent{},
&models.Verification{},
&models.Licence{},
&models.Category{},
&models.BankAccount{}); err != nil {
logger.Error.Fatalf("Couldn't create database: %v", err)
return err
}
DB = db
logger.Info.Print("Opened DB")
var categoriesCount int64
db.Model(&models.Category{}).Count(&categoriesCount)
if categoriesCount == 0 {
categories := createLicenceCategories()
for _, model := range categories {
result := db.Create(&model)
if result.Error != nil {
return result.Error
}
}
}
var subscriptionsCount int64
db.Model(&models.SubscriptionModel{}).Count(&subscriptionsCount)
subscriptionModels := createSubscriptionModels()
for _, model := range subscriptionModels {
var exists int64
db.
Model(&models.SubscriptionModel{}).
Where("name = ?", model.Name).
Count(&exists)
logger.Error.Printf("looked for model.name %v and found %v", model.Name, exists)
if exists == 0 {
result := db.Create(&model)
if result.Error != nil {
return result.Error
}
}
}
var userCount int64
db.Model(&models.User{}).Count(&userCount)
if userCount == 0 {
var createdModel models.SubscriptionModel
if err := db.First(&createdModel).Error; err != nil {
return err
}
admin, err := createAdmin(adminMail, createdModel.ID)
if err != nil {
return err
}
result := db.Session(&gorm.Session{FullSaveAssociations: true}).Create(&admin)
if result.Error != nil {
return result.Error
}
}
return nil
}
func createSubscriptionModels() []models.SubscriptionModel {
return []models.SubscriptionModel{
{
Name: "Keins",
Details: "Dieses Modell ist für Vereinsmitglieder, die keinen Wunsch haben, an dem Carhsharing teilzunehmen.",
HourlyRate: 999,
MonthlyFee: 0,
},
}
}
func createLicenceCategories() []models.Category {
return []models.Category{
{Name: "AM"},
{Name: "A1"},
{Name: "A2"},
{Name: "A"},
{Name: "B"},
{Name: "C1"},
{Name: "C"},
{Name: "D1"},
{Name: "D"},
{Name: "BE"},
{Name: "C1E"},
{Name: "CE"},
{Name: "D1E"},
{Name: "DE"},
{Name: "T"},
{Name: "L"},
}
}
// TODO: Landing page to create an admin
func createAdmin(userMail string, subscriptionModelID uint) (*models.User, error) {
passwordBytes := make([]byte, 12)
_, err := rand.Read(passwordBytes)
if err != nil {
return nil, err
}
// Encode into a URL-safe base64 string
password := base64.URLEncoding.EncodeToString(passwordBytes)[:12]
hash, err := argon2id.CreateHash(password, argon2id.DefaultParams)
if err != nil {
return nil, err
}
logger.Error.Print("==============================================================")
logger.Error.Printf("Admin Email: %v", userMail)
logger.Error.Printf("Admin Password: %v", password)
logger.Error.Print("==============================================================")
return &models.User{
FirstName: "ad",
LastName: "min",
DateOfBirth: time.Now().AddDate(-20, 0, 0),
Password: hash,
Address: "Downhill 4",
ZipCode: "99999",
City: "TechTown",
Phone: "0123455678",
Email: userMail,
Status: constants.ActiveStatus,
RoleID: constants.Roles.Admin,
Membership: models.Membership{
Status: constants.DisabledStatus,
StartDate: time.Now(),
SubscriptionModelID: subscriptionModelID,
},
BankAccount: models.BankAccount{},
Licence: &models.Licence{
Status: constants.UnverifiedStatus,
},
}, nil
//"DE49700500000008447644", //fake
}
func Close() error {
logger.Info.Print("Closing DB")
db, err := DB.DB()
if err != nil {
return err
}
return db.Close()
}

View File

@@ -0,0 +1,31 @@
package middlewares
import (
"crypto/subtle"
"net/http"
"github.com/gin-gonic/gin"
"GoMembership/internal/config"
)
func APIKeyMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
clientAPIKey := c.GetHeader("X-API-Key")
if clientAPIKey == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "API key is missing"})
c.Abort()
return
}
// Using subtle.ConstantTimeCompare to mitigate timing attacks
if subtle.ConstantTimeCompare([]byte(clientAPIKey), []byte(config.Auth.APIKEY)) != 1 {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid API key"})
c.Abort()
return
}
c.Next()
}
}

View File

@@ -0,0 +1,61 @@
package middlewares
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"GoMembership/internal/config"
)
func TestAPIKeyMiddleware(t *testing.T) {
// Set up a test API key
testAPIKey := "test-api-key-12345"
config.Auth.APIKEY = testAPIKey
// Set Gin to Test Mode
gin.SetMode(gin.TestMode)
// Tests table
tests := []struct {
name string
apiKey string
wantStatus int
}{
{"Valid API Key", testAPIKey, http.StatusOK},
{"Missing API Key", "", http.StatusUnauthorized},
{"Invalid API Key", "wrong-key", http.StatusUnauthorized},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Set up a new test router and handler
router := gin.New()
router.Use(APIKeyMiddleware())
router.GET("/test", func(c *gin.Context) {
c.Status(http.StatusOK)
})
// Create a test request
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
if tt.apiKey != "" {
req.Header.Set("X-API-Key", tt.apiKey)
}
// Serve the request
router.ServeHTTP(w, req)
// Assert the response
assert.Equal(t, tt.wantStatus, w.Code)
// Additional assertions for specific cases
if tt.wantStatus == http.StatusUnauthorized {
assert.Contains(t, w.Body.String(), "API key")
}
})
}
}

View File

@@ -0,0 +1,179 @@
package middlewares
import (
"GoMembership/internal/config"
"GoMembership/internal/models"
"GoMembership/internal/utils"
customerrors "GoMembership/pkg/errors"
"GoMembership/pkg/logger"
"errors"
"fmt"
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
)
type Session struct {
UserID uint
ExpiresAt time.Time
}
var (
sessionDuration = 5 * 24 * time.Hour
jwtSigningMethod = jwt.SigningMethodHS256
jwtParser = jwt.NewParser(jwt.WithValidMethods([]string{jwtSigningMethod.Alg()}))
sessions = make(map[string]*Session)
)
func verifyAndRenewToken(tokenString string) (string, uint, error) {
if tokenString == "" {
logger.Error.Printf("empty tokenstring")
return "", 0, fmt.Errorf("Authorization token is required")
}
token, claims, err := ExtractContentFrom(tokenString)
if err != nil {
logger.Error.Printf("Couldn't parse JWT token String: %v", err)
return "", 0, err
}
sessionID := (*claims)["session_id"].(string)
userID := uint((*claims)["user_id"].(float64))
roleID := int8((*claims)["role_id"].(float64))
session, ok := sessions[sessionID]
if !ok {
logger.Error.Printf("session not found")
return "", 0, fmt.Errorf("session not found")
}
if userID != session.UserID {
return "", 0, fmt.Errorf("Cookie has been altered, aborting..")
}
if token.Valid {
// token is valid, so we can return the old tokenString
return tokenString, session.UserID, customerrors.ErrValidToken
}
if time.Now().After(sessions[sessionID].ExpiresAt) {
delete(sessions, sessionID)
logger.Error.Printf("session expired")
return "", 0, fmt.Errorf("session expired")
}
session.ExpiresAt = time.Now().Add(sessionDuration)
logger.Error.Printf("Session still valid generating new token")
// Session is still valid, generate a new token
user := models.User{ID: userID, RoleID: roleID}
newTokenString, err := GenerateToken(config.Auth.JWTSecret, &user, sessionID)
if err != nil {
return "", 0, err
}
return newTokenString, session.UserID, nil
}
func AuthMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
tokenString, err := c.Cookie("jwt")
if err != nil {
logger.Error.Printf("No Auth token: %v\n", err)
c.JSON(http.StatusUnauthorized,
gin.H{"errors": []gin.H{{
"field": "general",
"key": "server.error.no_auth_token",
}}})
c.Abort()
return
}
newToken, userID, err := verifyAndRenewToken(tokenString)
if err != nil {
if err == customerrors.ErrValidToken {
c.Set("user_id", uint(userID))
c.Next()
return
}
logger.Error.Printf("Token(%v) is invalid: %v\n", tokenString, err)
c.JSON(http.StatusUnauthorized,
gin.H{"errors": []gin.H{{
"field": "general",
"key": "server.error.no_auth_token",
}}})
c.Abort()
return
}
utils.SetCookie(c, newToken)
c.Set("user_id", uint(userID))
c.Next()
}
}
func GenerateToken(jwtKey string, user *models.User, sessionID string) (string, error) {
if sessionID == "" {
sessionID = uuid.New().String()
}
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{
"user_id": user.ID,
"role_id": user.RoleID,
"session_id": sessionID,
"exp": time.Now().Add(time.Minute * 1).Unix(), // Token expires in 10 Minutes
})
UpdateSession(sessionID, user.ID)
return token.SignedString([]byte(jwtKey))
}
func ExtractContentFrom(tokenString string) (*jwt.Token, *jwt.MapClaims, error) {
token, err := jwtParser.Parse(tokenString, func(_ *jwt.Token) (interface{}, error) {
return []byte(config.Auth.JWTSecret), nil
})
if !errors.Is(err, jwt.ErrTokenExpired) && err != nil {
logger.Error.Printf("Error during token(%v) parsing: %#v", tokenString, err)
return nil, nil, err
}
// Token is expired, check if session is still valid
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
logger.Error.Printf("Invalid Token Claims")
return nil, nil, fmt.Errorf("invalid token claims")
}
if !ok {
logger.Error.Printf("invalid session_id in token")
return nil, nil, fmt.Errorf("invalid session_id in token")
}
return token, &claims, nil
}
func UpdateSession(sessionID string, userID uint) {
sessions[sessionID] = &Session{
UserID: userID,
ExpiresAt: time.Now().Add(sessionDuration),
}
}
func InvalidateSession(token string) (bool, error) {
claims := jwt.MapClaims{}
_, err := jwt.ParseWithClaims(
token,
claims,
func(token *jwt.Token) (interface{}, error) {
return config.Auth.JWTSecret, nil
},
)
if err != nil {
return false, fmt.Errorf("Couldn't get JWT claims: %#v", err)
}
sessionID, ok := claims["session_id"].(string)
if !ok {
return false, fmt.Errorf("No SessionID found")
}
delete(sessions, sessionID)
return true, nil
}

View File

@@ -0,0 +1,191 @@
package middlewares
import (
"GoMembership/internal/config"
"GoMembership/internal/constants"
"GoMembership/internal/models"
"GoMembership/pkg/logger"
"encoding/json"
"log"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert"
)
func setupTestEnvironment() {
cwd, err := os.Getwd()
if err != nil {
log.Fatalf("Failed to get current working directory: %v", err)
}
configFilePath := filepath.Join(cwd, "..", "..", "configs", "config.json")
templateHTMLPath := filepath.Join(cwd, "..", "..", "templates", "html")
templateMailPath := filepath.Join(cwd, "..", "..", "templates", "email")
if err := os.Setenv("TEMPLATE_MAIL_PATH", templateMailPath); err != nil {
log.Fatalf("Error setting environment variable: %v", err)
}
if err := os.Setenv("TEMPLATE_HTML_PATH", templateHTMLPath); err != nil {
log.Fatalf("Error setting environment variable: %v", err)
}
if err := os.Setenv("CONFIG_FILE_PATH", configFilePath); err != nil {
log.Fatalf("Error setting environment variable: %v", err)
}
config.LoadConfig()
logger.Info.Printf("Config: %#v", config.CFG)
}
func TestAuthMiddleware(t *testing.T) {
gin.SetMode(gin.TestMode)
setupTestEnvironment()
tests := []struct {
name string
setupAuth func(r *http.Request)
expectedStatus int
expectNewCookie bool
expectedUserID uint
}{
{
name: "Valid Token",
setupAuth: func(r *http.Request) {
user := models.User{ID: 123, RoleID: constants.Roles.Member}
token, _ := GenerateToken(config.Auth.JWTSecret, &user, "")
r.AddCookie(&http.Cookie{Name: "jwt", Value: token})
},
expectedStatus: http.StatusOK,
expectedUserID: 123,
},
{
name: "Missing Cookie",
setupAuth: func(r *http.Request) {},
expectedStatus: http.StatusUnauthorized,
expectedUserID: 0,
},
{
name: "Invalid Token",
setupAuth: func(r *http.Request) {
r.AddCookie(&http.Cookie{Name: "jwt", Value: "InvalidToken"})
},
expectedStatus: http.StatusUnauthorized,
expectedUserID: 0,
},
{
name: "Expired Token with Valid Session",
setupAuth: func(r *http.Request) {
sessionID := "test-session"
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{
"user_id": 123,
"role_id": constants.Roles.Member,
"session_id": sessionID,
"exp": time.Now().Add(-time.Hour).Unix(), // Expired 1 hour ago
})
tokenString, _ := token.SignedString([]byte(config.Auth.JWTSecret))
r.AddCookie(&http.Cookie{Name: "jwt", Value: tokenString})
UpdateSession(sessionID, 123) // Add a valid session
},
expectedStatus: http.StatusOK,
expectNewCookie: true,
expectedUserID: 123,
},
{
name: "Expired Token with Expired Session",
setupAuth: func(r *http.Request) {
sessionID := "expired-session"
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{
"user_id": 123,
"role_id": constants.Roles.Member,
"session_id": sessionID,
"exp": time.Now().Add(-time.Hour).Unix(), // Expired 1 hour ago
})
tokenString, _ := token.SignedString([]byte(config.Auth.JWTSecret))
r.AddCookie(&http.Cookie{Name: "jwt", Value: tokenString})
// Don't add a session, simulating an expired session
},
expectedStatus: http.StatusUnauthorized,
expectedUserID: 0,
},
{
name: "Invalid Signature",
setupAuth: func(r *http.Request) {
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{
"user_id": 123,
"session_id": "some-session",
"exp": time.Now().Add(time.Hour).Unix(),
})
tokenString, _ := token.SignedString([]byte("wrong_secret"))
r.AddCookie(&http.Cookie{Name: "jwt", Value: tokenString})
},
expectedStatus: http.StatusUnauthorized,
expectedUserID: 0,
},
{
name: "Invalid Signing Method",
setupAuth: func(r *http.Request) {
token := jwt.NewWithClaims(jwt.SigningMethodES256, jwt.MapClaims{
"user_id": 123,
"session_id": "some-session",
"role_id": constants.Roles.Member,
"exp": time.Now().Add(time.Hour).Unix(),
})
tokenString, _ := token.SignedString([]byte(config.Auth.JWTSecret))
r.AddCookie(&http.Cookie{Name: "jwt", Value: tokenString})
},
expectedStatus: http.StatusUnauthorized,
expectedUserID: 0,
},
}
for _, tt := range tests {
logger.Error.Print("==============================================================")
logger.Error.Printf("Testing : %v", tt.name)
logger.Error.Print("==============================================================")
t.Run(tt.name, func(t *testing.T) {
// Setup
r := gin.New()
r.Use(AuthMiddleware())
r.GET("/test", func(c *gin.Context) {
userID, exists := c.Get("user_id")
if exists {
c.JSON(http.StatusOK, gin.H{"user_id": userID})
} else {
c.JSON(http.StatusUnauthorized, gin.H{"user_id": 0})
}
})
req, _ := http.NewRequest(http.MethodGet, "/test", nil)
tt.setupAuth(req)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, tt.expectedStatus, w.Code)
if tt.expectedStatus == http.StatusOK {
var response map[string]uint
err := json.Unmarshal(w.Body.Bytes(), &response)
assert.NoError(t, err)
assert.Equal(t, tt.expectedUserID, response["user_id"])
// Check if a new cookie was set
cookies := w.Result().Cookies()
if tt.expectNewCookie {
assert.GreaterOrEqual(t, len(cookies), 1)
assert.Equal(t, "jwt", cookies[0].Name)
assert.NotEmpty(t, cookies[0].Value)
} else {
assert.Equal(t, 0, len(cookies), "Unexpected cookie set")
}
} else {
assert.Equal(t, 0, len(w.Result().Cookies()))
}
})
}
}

View File

@@ -0,0 +1,22 @@
package middlewares
import (
"GoMembership/internal/config"
"GoMembership/pkg/logger"
"strings"
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
)
func CORSMiddleware() gin.HandlerFunc {
logger.Info.Print("Applying CORS")
return cors.New(cors.Config{
AllowOrigins: strings.Split(config.Site.AllowOrigins, ","),
AllowMethods: []string{"GET", "POST", "PATCH", "PUT", "OPTIONS"},
AllowHeaders: []string{"Origin", "Content-Type", "Accept", "Authorization", "X-Requested-With", "X-CSRF-Token"},
ExposeHeaders: []string{"Content-Length"},
AllowCredentials: true,
MaxAge: 12 * 60 * 60, // 12 hours
})
}

View File

@@ -0,0 +1,104 @@
package middlewares
import (
"log"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strconv"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"GoMembership/internal/config"
"GoMembership/pkg/logger"
)
const (
Host = "127.0.0.1"
Port int = 2525
)
func TestCORSMiddleware(t *testing.T) {
cwd, err := os.Getwd()
if err != nil {
log.Fatalf("Failed to get current working directory: %v", err)
}
configFilePath := filepath.Join(cwd, "..", "..", "configs", "config.json")
templateHTMLPath := filepath.Join(cwd, "..", "..", "templates", "html")
templateMailPath := filepath.Join(cwd, "..", "..", "templates", "email")
if err := os.Setenv("TEMPLATE_MAIL_PATH", templateMailPath); err != nil {
log.Fatalf("Error setting environment variable: %v", err)
}
if err := os.Setenv("TEMPLATE_HTML_PATH", templateHTMLPath); err != nil {
log.Fatalf("Error setting environment variable: %v", err)
}
if err := os.Setenv("CONFIG_FILE_PATH", configFilePath); err != nil {
log.Fatalf("Error setting environment variable: %v", err)
}
if err := os.Setenv("SMTP_HOST", Host); err != nil {
log.Fatalf("Error setting environment variable: %v", err)
}
if err := os.Setenv("SMTP_PORT", strconv.Itoa(Port)); err != nil {
log.Fatalf("Error setting environment variable: %v", err)
}
if err := os.Setenv("BASE_URL", "http://"+Host+":2525"); err != nil {
log.Fatalf("Error setting environment variable: %v", err)
}
// Load your configuration
config.LoadConfig()
// Create a gin router with the CORS middleware
router := gin.New()
router.Use(CORSMiddleware())
// Add a simple handler
router.GET("/test", func(c *gin.Context) {
c.String(200, "test")
})
tests := []struct {
name string
origin string
expectedStatus int
expectedHeaders map[string]string
}{
{
name: "Allowed origin",
origin: config.Site.AllowOrigins,
expectedStatus: http.StatusOK,
expectedHeaders: map[string]string{
"Access-Control-Allow-Origin": config.Site.AllowOrigins,
"Content-Type": "text/plain; charset=utf-8",
"Access-Control-Allow-Credentials": "true",
},
},
{
name: "Disallowed origin",
origin: "http://evil.com",
expectedStatus: http.StatusForbidden,
expectedHeaders: map[string]string{
"Access-Control-Allow-Origin": "",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
req.Header.Set("Origin", tt.origin)
router.ServeHTTP(w, req)
assert.Equal(t, tt.expectedStatus, w.Code)
logger.Info.Printf("Recieved Headers: %#v", w.Header())
for key, value := range tt.expectedHeaders {
assert.Equal(t, value, w.Header().Get(key))
}
})
}
}

View File

@@ -0,0 +1,44 @@
package middlewares
import (
"GoMembership/internal/config"
"GoMembership/pkg/logger"
"net/http"
"github.com/gin-gonic/gin"
)
func CSPMiddleware() gin.HandlerFunc {
logger.Error.Printf("applying CSP")
return func(c *gin.Context) {
policy := "default-src 'self'; " +
"script-src 'self' 'unsafe-inline'" +
"style-src 'self' 'unsafe-inline'" +
"img-src 'self'" +
"font-src 'self'" +
"connect-src 'self'; " +
"frame-ancestors 'none'; " +
"form-action 'self'; " +
"base-uri 'self'; " +
"upgrade-insecure-requests;"
if config.Env == "development" {
policy += " report-uri /csp-report;"
c.Header("Content-Security-Policy-Report-Only", policy)
} else {
c.Header("Content-Security-Policy", policy)
}
c.Next()
}
}
func CSPReportHandling(c *gin.Context) {
var report map[string]interface{}
if err := c.BindJSON(&report); err != nil {
logger.Error.Printf("Couldn't Bind JSON: %#v", err)
return
}
logger.Info.Printf("CSP Violation: %+v", report)
c.Status(http.StatusNoContent)
}

View File

@@ -0,0 +1,81 @@
package middlewares
import (
"GoMembership/internal/config"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
)
func TestCSPMiddleware(t *testing.T) {
// Save the current environment and restore it after the test
originalEnv := config.Env
tests := []struct {
name string
environment string
expectedHeader string
expectedPolicy string
}{
{
name: "Development Environment",
environment: "development",
expectedHeader: "Content-Security-Policy-Report-Only",
expectedPolicy: "default-src 'self'; " +
"script-src 'self' 'unsafe-inline'" +
"style-src 'self' 'unsafe-inline'" +
"img-src 'self'" +
"font-src 'self'" +
"connect-src 'self'; " +
"frame-ancestors 'none'; " +
"form-action 'self'; " +
"base-uri 'self'; " +
"upgrade-insecure-requests; report-uri /csp-report;",
},
{
name: "Production Environment",
environment: "production",
expectedHeader: "Content-Security-Policy",
expectedPolicy: "default-src 'self'; " +
"script-src 'self' 'unsafe-inline'" +
"style-src 'self' 'unsafe-inline'" +
"img-src 'self'" +
"font-src 'self'" +
"connect-src 'self'; " +
"frame-ancestors 'none'; " +
"form-action 'self'; " +
"base-uri 'self'; " +
"upgrade-insecure-requests;",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Set up the test environment
config.Env = tt.environment
// Create a new Gin router with the middleware
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(CSPMiddleware())
router.GET("/test", func(c *gin.Context) {
c.String(http.StatusOK, "test")
})
// Create a test request
req, _ := http.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
// Serve the request
router.ServeHTTP(w, req)
// Check the response
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, tt.expectedPolicy, w.Header().Get(tt.expectedHeader))
})
}
config.Env = originalEnv
}

View File

@@ -0,0 +1,21 @@
package middlewares
import (
"GoMembership/pkg/logger"
"github.com/gin-gonic/gin"
)
func SecurityHeadersMiddleware() gin.HandlerFunc {
logger.Error.Printf("applying headers")
return func(c *gin.Context) {
c.Header("X-Frame-Options", "DENY")
c.Header("X-Content-Type-Options", "nosniff")
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
c.Header("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
c.Header("X-XSS-Protection", "1; mode=block")
c.Header("Feature-Policy", "geolocation 'none'; midi 'none'; sync-xhr 'none'; microphone 'none'; camera 'none'; magnetometer 'none'; gyroscope 'none'; speaker 'none'; fullscreen 'self'; payment 'none'")
c.Header("Permissions-Policy", "geolocation=(), midi=(), sync-xhr=(), microphone=(), camera=(), magnetometer=(), gyroscope=(), fullscreen=(self), payment=()")
c.Next()
}
}

View File

@@ -0,0 +1,83 @@
package middlewares
import (
"GoMembership/pkg/logger"
"net/http"
"sync"
"github.com/gin-gonic/gin"
"golang.org/x/time/rate"
)
type IPRateLimiter struct {
ips map[string]*rate.Limiter
mu *sync.RWMutex
r rate.Limit
b int
}
func NewIPRateLimiter(r int, b int) *IPRateLimiter {
return &IPRateLimiter{
ips: make(map[string]*rate.Limiter),
mu: &sync.RWMutex{},
r: rate.Limit(r),
b: b,
}
}
func (i *IPRateLimiter) GetLimiter(ip string) *rate.Limiter {
i.mu.Lock()
defer i.mu.Unlock()
limiter, exists := i.ips[ip]
if !exists {
limiter = rate.NewLimiter(i.r, i.b)
i.ips[ip] = limiter
}
return limiter
}
// func RateLimitMiddleware() gin.HandlerFunc {
// if iPLimiter == nil {
// iPLimiter := NewIPRateLimiter(
// rate.Limit(config.Security.Ratelimits.Limit),
// config.Security.Ratelimits.Burst)
// }
// return func(c *gin.Context) {
// ip := c.ClientIP()
// l := iPLimiter.GetLimiter(ip)
// if !l.Allow() {
// c.JSON(http.StatusTooManyRequests, gin.H{
// "error": "Too many requests",
// })
// c.Abort()
// return
// }
// c.Next()
// }
// }
func RateLimitMiddleware(limiter *IPRateLimiter) gin.HandlerFunc {
logger.Info.Printf("Limiter with Limit: %v, Burst: %v", limiter.r, limiter.b)
return func(c *gin.Context) {
if limiter == nil {
logger.Error.Println("Limiter missing")
c.AbortWithStatus(http.StatusInternalServerError)
return
}
ip := c.ClientIP()
l := limiter.GetLimiter(ip)
if !l.Allow() {
c.JSON(http.StatusTooManyRequests, gin.H{
"error": "Too many requests",
})
c.Abort()
return
}
c.Next()
}
}

View File

@@ -0,0 +1,143 @@
package middlewares
import (
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/gin-gonic/gin"
)
var (
limit = 1
burst = 60
)
func TestRateLimitMiddleware(t *testing.T) {
gin.SetMode(gin.TestMode)
// Create a new rate limiter that allows 2 requests per second with a burst of 4
// Create a new Gin router with the rate limit middleware
router := setupRouter()
// Test cases
tests := []struct {
name string
requests int
expectedStatus int
sleep time.Duration
}{
{"Allow first request", 1, http.StatusOK, 0},
{"Allow up to burst limit", burst - 1, http.StatusOK, 0},
{"Block after burst limit", burst + 20, http.StatusTooManyRequests, 0},
{"Allow after rate limit replenishes", 1, http.StatusOK, time.Second},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
time.Sleep(tt.sleep)
var status int
for i := 0; i < tt.requests; i++ {
status = makeRequest(router, "192.168.0.2")
}
if status != tt.expectedStatus {
t.Errorf("Expected status %d, got %d", tt.expectedStatus, status)
}
})
}
}
func TestIPRateLimiter(t *testing.T) {
limiter := NewIPRateLimiter(limit, burst)
limiter1 := limiter.GetLimiter("127.0.0.1")
limiter2 := limiter.GetLimiter("192.168.0.1")
if limiter1 == limiter2 {
t.Error("Expected different limiters for different IPs")
}
limiter3 := limiter.GetLimiter("127.0.0.1")
if limiter1 != limiter3 {
t.Error("Expected the same limiter for the same IP")
}
}
func TestDifferentRateLimits(t *testing.T) {
testCases := []struct {
name string
duration time.Duration
requests int
expectedRequests int
ip string
}{
{"Low rate", 5 * time.Second, burst + 5*limit, burst + 5*limit, "192.168.23.3"},
{"Low rate with limiting", 4 * time.Second, burst + 4*limit + 4, burst + 4*limit, "192.168.23.4"},
{"High rate", 1 * time.Second, burst + 5*limit, burst + limit, "192.168.23.5"},
{"Fractional rate", 10 * time.Second, (burst + limit) / 2, (burst + limit) / 2, "192.168.23.6"},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
router := setupRouter()
testRateLimit(t, router, tc.duration, tc.requests, tc.expectedRequests, tc.ip)
})
}
}
func testRateLimit(t *testing.T, router *gin.Engine, duration time.Duration, requests int, expectedRequests int, ip string) {
start := time.Now()
successCount := 0
totalRequests := 0
t.Logf("Sleeping for: %v", time.Duration(duration.Nanoseconds()/int64(requests)))
for time.Since(start) < duration {
status := makeRequest(router, ip)
totalRequests++
if status == http.StatusOK {
successCount++
}
time.Sleep(time.Duration(duration.Nanoseconds() / int64(requests)))
}
actualDuration := time.Since(start)
t.Logf("limit: %v, burst: %v", limit, burst)
t.Logf("Test duration: %v", actualDuration)
t.Logf("Successful requests: %d", successCount)
t.Logf("Expected successful requests: %d", expectedRequests)
t.Logf("Total requests: %d", totalRequests)
if successCount < int(expectedRequests)-4 || successCount > int(expectedRequests)+4 {
t.Errorf("Expected around %d successful requests, got %d", expectedRequests, successCount)
}
if requests-expectedRequests != 0 && totalRequests <= successCount {
t.Errorf("Expected some requests to be rate limited")
}
}
func setupRouter() *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
limiter := NewIPRateLimiter(limit, burst)
router.Use(RateLimitMiddleware(limiter))
router.GET("/test", func(c *gin.Context) {
c.String(http.StatusOK, "success")
})
return router
}
func makeRequest(router *gin.Engine, ip string) int {
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
req.Header.Set("X-Forwarded-For", ip) // Set a consistent IP
router.ServeHTTP(w, req)
return w.Code
}

View File

@@ -0,0 +1,15 @@
package models
import "time"
type BankAccount struct {
CreatedAt time.Time
UpdatedAt time.Time
MandateDateSigned time.Time `gorm:"not null" json:"mandate_date_signed"`
Bank string `json:"bank_name" binding:"safe_content"`
AccountHolderName string `json:"account_holder_name" binding:"safe_content"`
IBAN string `json:"iban"`
BIC string `json:"bic"`
MandateReference string `gorm:"not null" json:"mandate_reference"`
ID uint `gorm:"primaryKey"`
}

View File

@@ -0,0 +1,17 @@
package models
import (
"time"
)
type Consent struct {
CreatedAt time.Time
UpdatedAt time.Time
FirstName string `gorm:"not null" json:"first_name" binding:"safe_content"`
LastName string `gorm:"not null" json:"last_name" binding:"safe_content"`
Email string `json:"email" binding:"email,safe_content"`
ConsentType string `gorm:"not null" json:"consent_type" binding:"safe_content"`
ID uint `gorm:"primaryKey"`
User User
UserID uint
}

View File

@@ -0,0 +1,22 @@
package models
import (
"time"
)
type Licence struct {
ID uint `json:"id"`
CreatedAt time.Time
UpdatedAt time.Time
Status int8 `json:"status" binding:"omitempty,number"`
Number string `json:"number" binding:"omitempty,safe_content"`
IssuedDate time.Time `json:"issued_date" binding:"omitempty"`
ExpirationDate time.Time `json:"expiration_date" binding:"omitempty"`
IssuingCountry string `json:"country" binding:"safe_content"`
Categories []Category `json:"categories" gorm:"many2many:licence_2_categories"`
}
type Category struct {
ID uint `json:"id" gorm:"primaryKey"`
Name string `json:"category" binding:"safe_content"`
}

View File

@@ -0,0 +1,15 @@
package models
import "time"
type Membership struct {
CreatedAt time.Time
UpdatedAt time.Time
StartDate time.Time `json:"start_date"`
EndDate time.Time `json:"end_date"`
Status int8 `json:"status" binding:"number,safe_content"`
SubscriptionModel SubscriptionModel `gorm:"foreignKey:SubscriptionModelID" json:"subscription_model"`
SubscriptionModelID uint `json:"subsription_model_id"`
ParentMembershipID uint `json:"parent_member_id" binding:"omitempty,omitnil,number"`
ID uint `json:"id"`
}

View File

@@ -0,0 +1,19 @@
package models
import (
"time"
)
type SubscriptionModel struct {
CreatedAt time.Time
UpdatedAt time.Time
Name string `gorm:"unique" json:"name" binding:"required"`
Details string `json:"details"`
Conditions string `json:"conditions"`
RequiredMembershipField string `json:"required_membership_field"`
ID uint `json:"id" gorm:"primaryKey"`
MonthlyFee float32 `json:"monthly_fee"`
HourlyRate float32 `json:"hourly_rate"`
IncludedPerYear int16 `json:"included_hours_per_year"`
IncludedPerMonth int16 `json:"included_hours_per_month"`
}

View File

@@ -0,0 +1,132 @@
package models
import (
"GoMembership/pkg/logger"
"fmt"
"time"
"github.com/alexedwards/argon2id"
"gorm.io/gorm"
)
type User struct {
ID uint `gorm:"primarykey" json:"id"`
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt *time.Time `gorm:"index"`
DateOfBirth time.Time `gorm:"not null" json:"dateofbirth" binding:"required,safe_content"`
Company string `json:"company" binding:"omitempty,omitnil,safe_content"`
Phone string `json:"phone" binding:"omitempty,omitnil,safe_content"`
Notes string `json:"notes" binding:"safe_content"`
FirstName string `gorm:"not null" json:"first_name" binding:"required,safe_content"`
Password string `json:"password" binding:"safe_content"`
Email string `gorm:"unique;not null" json:"email" binding:"required,email,safe_content"`
LastName string `gorm:"not null" json:"last_name" binding:"required,safe_content"`
ProfilePicture string `json:"profile_picture" binding:"omitempty,omitnil,image,safe_content"`
Address string `gorm:"not null" json:"address" binding:"required,safe_content"`
ZipCode string `gorm:"not null" json:"zip_code" binding:"required,alphanum,safe_content"`
City string `form:"not null" json:"city" binding:"required,alphaunicode,safe_content"`
Consents []Consent `gorm:"constraint:OnUpdate:CASCADE"`
BankAccount BankAccount `gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE;" json:"bank_account"`
BankAccountID uint
Verification Verification `gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE;"`
VerificationID uint
Membership Membership `gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE;" json:"membership"`
MembershipID uint
Licence *Licence `gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE;" json:"licence"`
LicenceID uint
PaymentStatus int8 `json:"payment_status"`
Status int8 `json:"status"`
RoleID int8 `json:"role_id"`
}
func (u *User) AfterCreate(tx *gorm.DB) (err error) {
if u.BankAccount.ID != 0 && u.BankAccount.MandateReference == "" {
mandateReference := u.GenerateMandateReference()
return tx.Model(&u.BankAccount).Update("MandateReference", mandateReference).Error
}
return nil
}
func (u *User) GenerateMandateReference() string {
return fmt.Sprintf("%s%d%s", time.Now().Format("20060102"), u.ID, u.BankAccount.IBAN)
}
func (u *User) SetPassword(plaintextPassword string) error {
if plaintextPassword == "" {
return nil
}
hash, err := argon2id.CreateHash(plaintextPassword, argon2id.DefaultParams)
if err != nil {
return err
}
u.Password = hash
return nil
}
func (u *User) PasswordMatches(plaintextPassword string) (bool, error) {
logger.Error.Printf("plaintext: %v user password: %v", plaintextPassword, u.Password)
return argon2id.ComparePasswordAndHash(plaintextPassword, u.Password)
}
func (u *User) Safe() map[string]interface{} {
result := map[string]interface{}{
"email": u.Email,
"first_name": u.FirstName,
"last_name": u.LastName,
"phone": u.Phone,
"notes": u.Notes,
"address": u.Address,
"zip_code": u.ZipCode,
"city": u.City,
"status": u.Status,
"id": u.ID,
"role_id": u.RoleID,
"company": u.Company,
"dateofbirth": u.DateOfBirth,
"membership": map[string]interface{}{
"id": u.Membership.ID,
"start_date": u.Membership.StartDate,
"end_date": u.Membership.EndDate,
"status": u.Membership.Status,
"subscription_model": map[string]interface{}{
"id": u.Membership.SubscriptionModel.ID,
"name": u.Membership.SubscriptionModel.Name,
"details": u.Membership.SubscriptionModel.Details,
"conditions": u.Membership.SubscriptionModel.Conditions,
"monthly_fee": u.Membership.SubscriptionModel.MonthlyFee,
"hourly_rate": u.Membership.SubscriptionModel.HourlyRate,
"included_per_year": u.Membership.SubscriptionModel.IncludedPerYear,
"included_per_month": u.Membership.SubscriptionModel.IncludedPerMonth,
},
},
"licence": map[string]interface{}{
"id": 0,
},
"bank_account": map[string]interface{}{
"id": u.BankAccount.ID,
"mandate_date_signed": u.BankAccount.MandateDateSigned,
"bank": u.BankAccount.Bank,
"account_holder_name": u.BankAccount.AccountHolderName,
"iban": u.BankAccount.IBAN,
"bic": u.BankAccount.BIC,
"mandate_reference": u.BankAccount.MandateReference,
},
}
if u.Licence != nil {
result["licence"] = map[string]interface{}{
"id": u.Licence.ID,
"number": u.Licence.Number,
"status": u.Licence.Status,
"issued_date": u.Licence.IssuedDate,
"expiration_date": u.Licence.ExpirationDate,
"country": u.Licence.IssuingCountry,
"categories": u.Licence.Categories,
}
}
return result
}

View File

@@ -0,0 +1,13 @@
package models
import "time"
type Verification struct {
UpdatedAt time.Time
CreatedAt time.Time
VerifiedAt *time.Time `gorm:"Default:NULL" json:"verified_at"`
VerificationToken string `json:"token"`
ID uint `gorm:"primaryKey"`
UserID uint `gorm:"unique;" json:"user_id"`
Type string
}

View File

@@ -0,0 +1,20 @@
package repositories
import (
"GoMembership/internal/database"
"GoMembership/internal/models"
)
type BankAccountRepositoryInterface interface {
CreateBankAccount(account *models.BankAccount) (uint, error)
}
type BankAccountRepository struct{}
func (repo *BankAccountRepository) CreateBankAccount(account *models.BankAccount) (uint, error) {
result := database.DB.Create(account)
if result.Error != nil {
return 0, result.Error
}
return account.ID, nil
}

View File

@@ -0,0 +1,21 @@
package repositories
import (
"GoMembership/internal/database"
"GoMembership/internal/models"
)
type ConsentRepositoryInterface interface {
CreateConsent(consent *models.Consent) (uint, error)
}
type ConsentRepository struct{}
func (repo *ConsentRepository) CreateConsent(consent *models.Consent) (uint, error) {
result := database.DB.Create(consent)
if result.Error != nil {
return 0, result.Error
}
return consent.ID, nil
}

View File

@@ -0,0 +1,31 @@
package repositories
import (
"GoMembership/internal/database"
"GoMembership/internal/models"
)
type LicenceInterface interface {
FindCategoryByName(categoryName string) (models.Category, error)
FindCategoriesByIDs(ids []uint) ([]models.Category, error)
GetAllCategories() ([]models.Category, error)
}
type LicenceRepository struct{}
func (r *LicenceRepository) GetAllCategories() ([]models.Category, error) {
var categories []models.Category
err := database.DB.Find(&categories).Error
return categories, err
}
func (r *LicenceRepository) FindCategoriesByIDs(ids []uint) ([]models.Category, error) {
var categories []models.Category
err := database.DB.Where("id IN ?", ids).Find(&categories).Error
return categories, err
}
func (r *LicenceRepository) FindCategoryByName(categoryName string) (models.Category, error) {
var category models.Category
err := database.DB.Where("name = ?", categoryName).First(&category).Error
return category, err
}

View File

@@ -0,0 +1,37 @@
package repositories
import (
"GoMembership/internal/database"
"gorm.io/gorm"
"GoMembership/internal/models"
)
type MembershipRepositoryInterface interface {
CreateMembership(membership *models.Membership) (uint, error)
FindMembershipByUserID(userID uint) (*models.Membership, error)
}
type MembershipRepository struct{}
func (repo *MembershipRepository) CreateMembership(membership *models.Membership) (uint, error) {
result := database.DB.Create(membership)
if result.Error != nil {
return 0, result.Error
}
return membership.ID, nil
}
func (repo *MembershipRepository) FindMembershipByUserID(userID uint) (*models.Membership, error) {
var membership models.Membership
result := database.DB.First(&membership, userID)
if result.Error != nil {
if result.Error == gorm.ErrRecordNotFound {
return nil, gorm.ErrRecordNotFound
}
return nil, result.Error
}
return &membership, nil
}

View File

@@ -0,0 +1,97 @@
package repositories
import (
"GoMembership/internal/database"
"gorm.io/gorm"
"GoMembership/internal/models"
)
type SubscriptionModelsRepositoryInterface interface {
CreateSubscriptionModel(subscriptionModel *models.SubscriptionModel) (uint, error)
UpdateSubscription(subscription *models.SubscriptionModel) (*models.SubscriptionModel, error)
GetSubscriptionModelNames() ([]string, error)
GetSubscriptions(where map[string]interface{}) (*[]models.SubscriptionModel, error)
// GetUsersBySubscription(id uint) (*[]models.SubscriptionModel, error)
DeleteSubscription(id *uint) error
}
type SubscriptionModelsRepository struct{}
func (sr *SubscriptionModelsRepository) CreateSubscriptionModel(subscriptionModel *models.SubscriptionModel) (uint, error) {
result := database.DB.Create(subscriptionModel)
if result.Error != nil {
return 0, result.Error
}
return subscriptionModel.ID, nil
}
func (sr *SubscriptionModelsRepository) UpdateSubscription(subscription *models.SubscriptionModel) (*models.SubscriptionModel, error) {
result := database.DB.Model(&models.SubscriptionModel{ID: subscription.ID}).Updates(subscription)
if result.Error != nil {
return nil, result.Error
}
return subscription, nil
}
func (sr *SubscriptionModelsRepository) DeleteSubscription(id *uint) error {
result := database.DB.Delete(&models.SubscriptionModel{}, id)
if result.Error != nil {
return result.Error
}
return nil
}
func GetSubscriptionByName(modelname *string) (*models.SubscriptionModel, error) {
var model models.SubscriptionModel
result := database.DB.Where("name = ?", modelname).First(&model)
if result.Error != nil {
return nil, result.Error
}
return &model, nil
}
func (sr *SubscriptionModelsRepository) GetSubscriptionModelNames() ([]string, error) {
var names []string
if err := database.DB.Model(&models.SubscriptionModel{}).Pluck("name", &names).Error; err != nil {
return []string{}, err
}
return names, nil
}
func (sr *SubscriptionModelsRepository) GetSubscriptions(where map[string]interface{}) (*[]models.SubscriptionModel, error) {
var subscriptions []models.SubscriptionModel
result := database.DB.Where(where).Find(&subscriptions)
if result.Error != nil {
if result.Error == gorm.ErrRecordNotFound {
return nil, gorm.ErrRecordNotFound
}
return nil, result.Error
}
return &subscriptions, nil
}
func GetUsersBySubscription(subscriptionID uint) (*[]models.User, error) {
var users []models.User
err := database.DB.Preload("Membership").
Preload("Membership.SubscriptionModel").
Preload("BankAccount").
Preload("Licence").
Preload("Licence.Categories").
Joins("JOIN memberships ON users.membership_id = memberships.id").
Joins("JOIN subscription_models ON memberships.subscription_model_id = subscription_models.id").
Where("subscription_models.id = ?", subscriptionID).
Find(&users).Error
if err != nil {
return nil, err
}
return &users, nil
}

View File

@@ -0,0 +1,10 @@
package repositories
import (
"GoMembership/internal/database"
"GoMembership/internal/models"
)
func (r *UserRepository) SetUserStatus(id uint, status uint) error {
return database.DB.Model(&models.User{}).Where("id = ?", id).Update("status", status).Error
}

View File

@@ -0,0 +1,159 @@
package repositories
import (
"gorm.io/gorm"
"GoMembership/internal/database"
"gorm.io/gorm/clause"
"GoMembership/internal/models"
"GoMembership/pkg/errors"
"GoMembership/pkg/logger"
)
type UserRepositoryInterface interface {
CreateUser(user *models.User) (uint, error)
UpdateUser(user *models.User) (*models.User, error)
GetUsers(where map[string]interface{}) (*[]models.User, error)
GetUserByEmail(email string) (*models.User, error)
IsVerified(userID *uint) (bool, error)
GetVerificationOfToken(token *string, verificationType *string) (*models.Verification, error)
SetVerificationToken(verification *models.Verification) (token string, err error)
DeleteVerification(id uint, verificationType string) error
DeleteUser(id uint) error
SetUserStatus(id uint, status uint) error
}
type UserRepository struct{}
func (ur *UserRepository) DeleteUser(id uint) error {
return database.DB.Delete(&models.User{}, "id = ?", id).Error
}
func PasswordExists(userID *uint) (bool, error) {
var user models.User
result := database.DB.Select("password").First(&user, userID)
if result.Error != nil {
return false, result.Error
}
return user.Password != "", nil
}
func (ur *UserRepository) CreateUser(user *models.User) (uint, error) {
result := database.DB.Create(user)
if result.Error != nil {
logger.Error.Printf("Create User error: %#v", result.Error)
return 0, result.Error
}
return user.ID, nil
}
func (ur *UserRepository) UpdateUser(user *models.User) (*models.User, error) {
if user == nil {
return nil, errors.ErrNoData
}
err := database.DB.Transaction(func(tx *gorm.DB) error {
// Check if the user exists in the database
var existingUser models.User
if err := tx.Preload(clause.Associations).
Preload("Membership").
Preload("Membership.SubscriptionModel").
Preload("Licence").
Preload("Licence.Categories").
First(&existingUser, user.ID).Error; err != nil {
return err
}
// Update the user's main fields
result := tx.Session(&gorm.Session{FullSaveAssociations: true}).Omit("Password").Updates(user)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return errors.ErrNoRowsAffected
}
if user.Password != "" {
if err := tx.Model(&models.User{}).
Where("id = ?", user.ID).
Update("Password", user.Password).Error; err != nil {
return err
}
}
// Update the Membership if provided
if user.Membership.ID != 0 {
if err := tx.Model(&existingUser.Membership).Updates(user.Membership).Error; err != nil {
return err
}
}
// Replace categories if Licence and Categories are provided
if user.Licence != nil {
if err := tx.Model(&user.Licence).Association("Categories").Replace(user.Licence.Categories); err != nil {
return err
}
}
return nil
})
if err != nil {
return nil, err
}
var updatedUser models.User
if err := database.DB.Preload("Licence.Categories").
Preload("Membership").
First(&updatedUser, user.ID).Error; err != nil {
return nil, err
}
return &updatedUser, nil
}
func (ur *UserRepository) GetUsers(where map[string]interface{}) (*[]models.User, error) {
var users []models.User
result := database.DB.
Preload(clause.Associations).
Preload("Membership.SubscriptionModel").
Preload("Licence.Categories").
Where(where).Find(&users)
if result.Error != nil {
if result.Error == gorm.ErrRecordNotFound {
return nil, gorm.ErrRecordNotFound
}
return nil, result.Error
}
return &users, nil
}
func GetUserByID(userID *uint) (*models.User, error) {
var user models.User
result := database.DB.
Preload(clause.Associations).
Preload("Membership").
Preload("Membership.SubscriptionModel").
Preload("Licence.Categories").
First(&user, userID)
if result.Error != nil {
if result.Error == gorm.ErrRecordNotFound {
return nil, gorm.ErrRecordNotFound
}
return nil, result.Error
}
return &user, nil
}
func (ur *UserRepository) GetUserByEmail(email string) (*models.User, error) {
var user models.User
result := database.DB.Where("email = ?", email).First(&user)
if result.Error != nil {
if result.Error == gorm.ErrRecordNotFound {
return nil, gorm.ErrRecordNotFound
}
return nil, result.Error
}
return &user, nil
}

View File

@@ -0,0 +1,57 @@
package repositories
import (
"GoMembership/internal/constants"
"GoMembership/internal/database"
"GoMembership/internal/models"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
func (ur *UserRepository) IsVerified(userID *uint) (bool, error) {
var user models.User
result := database.DB.Select("status").First(&user, userID)
if result.Error != nil {
if result.Error == gorm.ErrRecordNotFound {
return false, gorm.ErrRecordNotFound
}
return false, result.Error
}
return user.Status > constants.DisabledStatus, nil
}
func (ur *UserRepository) GetVerificationOfToken(token *string, verificationType *string) (*models.Verification, error) {
var emailVerification models.Verification
result := database.DB.Where("verification_token = ? AND type = ?", token, verificationType).First(&emailVerification)
if result.Error != nil {
if result.Error == gorm.ErrRecordNotFound {
return nil, gorm.ErrRecordNotFound
}
return nil, result.Error
}
return &emailVerification, nil
}
func (ur *UserRepository) SetVerificationToken(verification *models.Verification) (token string, err error) {
result := database.DB.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "user_id"}},
DoUpdates: clause.AssignmentColumns([]string{"verification_token", "created_at", "type"}),
}).Create(&verification)
if result.Error != nil {
return "", result.Error
}
return verification.VerificationToken, nil
}
func (ur *UserRepository) DeleteVerification(id uint, verificationType string) error {
result := database.DB.Where("user_id = ? AND type = ?", id, verificationType).Delete(&models.Verification{})
if result.Error != nil {
return result.Error
}
return nil
}

View File

@@ -0,0 +1,50 @@
package routes
import (
"GoMembership/internal/controllers"
"GoMembership/internal/middlewares"
"github.com/gin-gonic/gin"
)
func RegisterRoutes(router *gin.Engine, userController *controllers.UserController, membershipcontroller *controllers.MembershipController, contactController *controllers.ContactController, licenceController *controllers.LicenceController) {
router.GET("/users/verify", userController.VerifyMailHandler)
router.POST("/users/register", userController.RegisterUser)
router.POST("/users/contact", contactController.RelayContactRequest)
router.POST("/users/password/request-change", userController.RequestPasswordChangeHandler)
router.PATCH("/users/password/change/:id", userController.ChangePassword)
router.POST("/users/login", userController.LoginHandler)
router.POST("/csp-report", middlewares.CSPReportHandling)
// apiRouter := router.Group("/api")
// apiRouter.Use(middlewares.APIKeyMiddleware())
// {
// apiRouter.POST("/v1/subscription", membershipcontroller.RegisterSubscription)
// }
userRouter := router.Group("/backend")
userRouter.Use(middlewares.AuthMiddleware())
{
userRouter.GET("/users/current", userController.CurrentUserHandler)
userRouter.POST("/logout", userController.LogoutHandler)
userRouter.PUT("/users", userController.UpdateHandler)
userRouter.POST("/users", userController.RegisterUser)
userRouter.GET("/users/all", userController.GetAllUsers)
userRouter.DELETE("/users", userController.DeleteUser)
}
membershipRouter := router.Group("/backend/membership")
membershipRouter.Use(middlewares.AuthMiddleware())
{
membershipRouter.GET("/subscriptions", membershipcontroller.GetSubscriptions)
membershipRouter.PUT("/subscriptions", membershipcontroller.UpdateHandler)
membershipRouter.POST("/subscriptions", membershipcontroller.RegisterSubscription)
membershipRouter.DELETE("/subscriptions", membershipcontroller.DeleteSubscription)
}
licenceRouter := router.Group("/backend/licence")
licenceRouter.Use(middlewares.AuthMiddleware())
{
licenceRouter.GET("/categories", licenceController.GetAllCategories)
}
}

View File

@@ -0,0 +1,95 @@
// Package server initializes and runs the application server.
// It sets up configurations, initializes the database, services, and controllers,
// loads HTML templates, and starts the HTTP server.
package server
import (
"context"
"net/http"
"path/filepath"
"time"
"GoMembership/internal/config"
"GoMembership/internal/controllers"
"GoMembership/internal/middlewares"
"GoMembership/internal/repositories"
"GoMembership/internal/validation"
"GoMembership/internal/routes"
"GoMembership/internal/services"
"GoMembership/pkg/logger"
"github.com/gin-gonic/gin"
)
var shutdownChannel = make(chan struct{})
var srv *http.Server
// Run initializes the server configuration, sets up services and controllers, and starts the HTTP server.
func Run() {
emailService := services.NewEmailService(config.SMTP.Host, config.SMTP.Port, config.SMTP.User, config.SMTP.Password)
var consentRepo repositories.ConsentRepositoryInterface = &repositories.ConsentRepository{}
consentService := &services.ConsentService{Repo: consentRepo}
var bankAccountRepo repositories.BankAccountRepositoryInterface = &repositories.BankAccountRepository{}
bankAccountService := &services.BankAccountService{Repo: bankAccountRepo}
var membershipRepo repositories.MembershipRepositoryInterface = &repositories.MembershipRepository{}
var subscriptionRepo repositories.SubscriptionModelsRepositoryInterface = &repositories.SubscriptionModelsRepository{}
membershipService := &services.MembershipService{Repo: membershipRepo, SubscriptionRepo: subscriptionRepo}
var licenceRepo repositories.LicenceInterface = &repositories.LicenceRepository{}
licenceService := &services.LicenceService{Repo: licenceRepo}
var userRepo repositories.UserRepositoryInterface = &repositories.UserRepository{}
userService := &services.UserService{Repo: userRepo, Licences: licenceRepo}
userController := &controllers.UserController{Service: userService, EmailService: emailService, ConsentService: consentService, LicenceService: licenceService, BankAccountService: bankAccountService, MembershipService: membershipService}
membershipController := &controllers.MembershipController{Service: *membershipService, UserController: userController}
licenceController := &controllers.LicenceController{Service: *licenceService}
contactController := &controllers.ContactController{EmailService: emailService}
router := gin.Default()
// gin.SetMode(gin.ReleaseMode)
router.Static(config.Templates.StaticPath, "./style")
// Load HTML templates
router.LoadHTMLGlob(filepath.Join(config.Templates.HTMLPath, "*"))
router.Use(gin.Logger())
router.Use(middlewares.CORSMiddleware())
router.Use(middlewares.CSPMiddleware())
router.Use(middlewares.SecurityHeadersMiddleware())
limiter := middlewares.NewIPRateLimiter(config.Security.Ratelimits.Limit, config.Security.Ratelimits.Burst)
router.Use(middlewares.RateLimitMiddleware(limiter))
routes.RegisterRoutes(router, userController, membershipController, contactController, licenceController)
validation.SetupValidators()
logger.Info.Println("Starting server on :8080")
srv = &http.Server{
Addr: ":8080",
Handler: router,
}
go func() {
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
logger.Error.Fatalf("could not start server: %v", err)
}
}()
// Wait for the shutdown signal
<-shutdownChannel
}
func Shutdown(ctx context.Context) error {
if srv == nil {
return nil
}
// Graceful shutdown with a timeout
shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
// Attempt to shutdown the server
return srv.Shutdown(shutdownCtx)
}

View File

@@ -0,0 +1,12 @@
package services
import (
"GoMembership/internal/repositories"
)
type BankAccountServiceInterface interface {
}
type BankAccountService struct {
Repo repositories.BankAccountRepositoryInterface
}

View File

@@ -0,0 +1,22 @@
package services
import (
"time"
"GoMembership/internal/models"
"GoMembership/internal/repositories"
)
type ConsentServiceInterface interface {
RegisterConsent(consent *models.Consent) (uint, error)
}
type ConsentService struct {
Repo repositories.ConsentRepositoryInterface
}
func (service *ConsentService) RegisterConsent(consent *models.Consent) (uint, error) {
consent.CreatedAt = time.Now()
consent.UpdatedAt = time.Now()
return service.Repo.CreateConsent(consent)
}

View File

@@ -0,0 +1,240 @@
package services
import (
"bytes"
"html/template"
"gopkg.in/gomail.v2"
"GoMembership/internal/config"
"GoMembership/internal/constants"
"GoMembership/internal/models"
"GoMembership/pkg/logger"
)
type EmailService struct {
dialer *gomail.Dialer
}
func NewEmailService(host string, port int, username string, password string) *EmailService {
dialer := gomail.NewDialer(host, port, username, password)
return &EmailService{dialer: dialer}
}
func (s *EmailService) SendEmail(to string, subject string, body string, bodyTXT string, replyTo string) error {
msg := gomail.NewMessage()
msg.SetHeader("From", s.dialer.Username)
msg.SetHeader("To", to)
msg.SetHeader("Subject", subject)
if replyTo != "" {
msg.SetHeader("REPLY_TO", replyTo)
}
if bodyTXT != "" {
msg.SetBody("text/plain", bodyTXT)
}
msg.AddAlternative("text/html", body)
// msg.WriteTo(os.Stdout)
if err := s.dialer.DialAndSend(msg); err != nil {
logger.Error.Printf("Could not send email to %s: %v", to, err)
return err
}
logger.Info.Printf("Email sent to %s", to)
return nil
}
func ParseTemplate(filename string, data interface{}) (string, error) {
// Read the email template file
templateDir := config.Templates.MailPath
tpl, err := template.ParseFiles(templateDir + "/" + filename)
if err != nil {
logger.Error.Printf("Failed to parse email template: %v", err)
return "", err
}
// Buffer to hold the rendered template
var tplBuffer bytes.Buffer
if err := tpl.Execute(&tplBuffer, data); err != nil {
logger.Error.Printf("Failed to execute email template: %v", err)
return "", err
}
return tplBuffer.String(), nil
}
func (s *EmailService) SendVerificationEmail(user *models.User, token *string) error {
// Prepare data to be injected into the template
data := struct {
FirstName string
LastName string
Token string
BASEURL string
}{
FirstName: user.FirstName,
LastName: user.LastName,
Token: *token,
BASEURL: config.Site.BaseURL,
}
subject := constants.MailVerificationSubject
body, err := ParseTemplate("mail_verification.tmpl", data)
if err != nil {
logger.Error.Print("Couldn't send verification mail")
return err
}
return s.SendEmail(user.Email, subject, body, "", "")
}
func (s *EmailService) SendChangePasswordEmail(user *models.User, token *string) error {
// Prepare data to be injected into the template
data := struct {
FirstName string
LastName string
Token string
BASEURL string
UserID uint
}{
FirstName: user.FirstName,
LastName: user.LastName,
Token: *token,
BASEURL: config.Site.BaseURL,
UserID: user.ID,
}
subject := constants.MailChangePasswordSubject
htmlBody, err := ParseTemplate("mail_change_password.tmpl", data)
if err != nil {
logger.Error.Print("Couldn't parse password mail")
return err
}
plainBody, err := ParseTemplate("mail_change_password.txt.tmpl", data)
if err != nil {
logger.Error.Print("Couldn't parse password mail")
return err
}
return s.SendEmail(user.Email, subject, htmlBody, plainBody, "")
}
func (s *EmailService) SendWelcomeEmail(user *models.User) error {
// Prepare data to be injected into the template
data := struct {
Company string
FirstName string
MembershipModel string
BASEURL string
MembershipID uint
MembershipFee float32
Logo string
WebsiteTitle string
RentalFee float32
}{
Company: user.Company,
FirstName: user.FirstName,
MembershipModel: user.Membership.SubscriptionModel.Name,
MembershipID: user.Membership.ID,
MembershipFee: float32(user.Membership.SubscriptionModel.MonthlyFee),
RentalFee: float32(user.Membership.SubscriptionModel.HourlyRate),
BASEURL: config.Site.BaseURL,
WebsiteTitle: config.Site.WebsiteTitle,
Logo: config.Templates.LogoURI,
}
subject := constants.MailWelcomeSubject
htmlBody, err := ParseTemplate("mail_welcome.tmpl", data)
if err != nil {
logger.Error.Print("Couldn't send welcome mail")
return err
}
plainBody, err := ParseTemplate("mail_welcome.txt.tmpl", data)
if err != nil {
logger.Error.Print("Couldn't parse password mail")
return err
}
return s.SendEmail(user.Email, subject, htmlBody, plainBody, "")
}
func (s *EmailService) SendRegistrationNotification(user *models.User) error {
// Prepare data to be injected into the template
data := struct {
FirstName string
DateOfBirth string
LastName string
MembershipModel string
Address string
IBAN string
Email string
Phone string
City string
Company string
ZipCode string
BASEURL string
MembershipID uint
RentalFee float32
MembershipFee float32
Logo string
WebsiteTitle string
}{
Company: user.Company,
FirstName: user.FirstName,
LastName: user.LastName,
MembershipModel: user.Membership.SubscriptionModel.Name,
MembershipID: user.Membership.ID,
MembershipFee: float32(user.Membership.SubscriptionModel.MonthlyFee),
RentalFee: float32(user.Membership.SubscriptionModel.HourlyRate),
Address: user.Address,
ZipCode: user.ZipCode,
City: user.City,
DateOfBirth: user.DateOfBirth.Format("20060102"),
Email: user.Email,
Phone: user.Phone,
IBAN: user.BankAccount.IBAN,
BASEURL: config.Site.BaseURL,
Logo: config.Templates.LogoURI,
WebsiteTitle: config.Site.WebsiteTitle,
}
subject := constants.MailRegistrationSubject
htmlBody, err := ParseTemplate("mail_registration.tmpl", data)
if err != nil {
logger.Error.Print("Couldn't send admin notification mail")
return err
}
plainBody, err := ParseTemplate("mail_registration.txt.tmpl", data)
if err != nil {
logger.Error.Print("Couldn't parse password mail")
return err
}
return s.SendEmail(config.Recipients.UserRegistration, subject, htmlBody, plainBody, "")
}
func (s *EmailService) RelayContactFormMessage(sender string, name string, message string) error {
data := struct {
Message string
Name string
BASEURL string
Logo string
WebsiteTitle string
}{
Message: message,
Name: name,
BASEURL: config.Site.BaseURL,
Logo: config.Templates.LogoURI,
WebsiteTitle: config.Site.WebsiteTitle,
}
subject := constants.MailContactSubject
htmlBody, err := ParseTemplate("mail_contact_form.tmpl", data)
if err != nil {
logger.Error.Print("Couldn't send contact form message mail")
return err
}
plainBody, err := ParseTemplate("mail_contact_form.txt.tmpl", data)
if err != nil {
logger.Error.Print("Couldn't parse password mail")
return err
}
return s.SendEmail(config.Recipients.ContactForm, subject, htmlBody, plainBody, sender)
}

View File

@@ -0,0 +1,18 @@
package services
import (
"GoMembership/internal/models"
"GoMembership/internal/repositories"
)
type LicenceInterface interface {
GetAllCategories() ([]models.Category, error)
}
type LicenceService struct {
Repo repositories.LicenceInterface
}
func (s *LicenceService) GetAllCategories() ([]models.Category, error) {
return s.Repo.GetAllCategories()
}

View File

@@ -0,0 +1,100 @@
package services
import (
"time"
"GoMembership/internal/models"
"GoMembership/internal/repositories"
"GoMembership/pkg/errors"
)
type MembershipServiceInterface interface {
RegisterMembership(membership *models.Membership) (uint, error)
FindMembershipByUserID(userID uint) (*models.Membership, error)
RegisterSubscription(subscription *models.SubscriptionModel) (uint, error)
UpdateSubscription(subscription *models.SubscriptionModel) (*models.SubscriptionModel, error)
DeleteSubscription(id *uint, name *string) error
GetSubscriptionModelNames() ([]string, error)
GetSubscriptionByName(modelname *string) (*models.SubscriptionModel, error)
GetSubscriptions(where map[string]interface{}) (*[]models.SubscriptionModel, error)
}
type MembershipService struct {
Repo repositories.MembershipRepositoryInterface
SubscriptionRepo repositories.SubscriptionModelsRepositoryInterface
}
func (service *MembershipService) RegisterMembership(membership *models.Membership) (uint, error) {
membership.StartDate = time.Now()
return service.Repo.CreateMembership(membership)
}
func (service *MembershipService) UpdateSubscription(subscription *models.SubscriptionModel) (*models.SubscriptionModel, error) {
existingSubscription, err := repositories.GetSubscriptionByName(&subscription.Name)
if err != nil {
return nil, err
}
if existingSubscription == nil {
return nil, errors.ErrSubscriptionNotFound
}
if existingSubscription.MonthlyFee != subscription.MonthlyFee ||
existingSubscription.HourlyRate != subscription.HourlyRate ||
existingSubscription.Conditions != subscription.Conditions ||
existingSubscription.IncludedPerYear != subscription.IncludedPerYear ||
existingSubscription.IncludedPerMonth != subscription.IncludedPerMonth {
return nil, errors.ErrInvalidSubscriptionData
}
subscription.ID = existingSubscription.ID
return service.SubscriptionRepo.UpdateSubscription(subscription)
}
func (service *MembershipService) DeleteSubscription(id *uint, name *string) error {
if *name == "" {
return errors.ErrNoData
}
exists, err := repositories.GetSubscriptionByName(name)
if err != nil {
return err
}
if exists == nil {
return errors.ErrNotFound
}
if *id != exists.ID {
return errors.ErrInvalidSubscriptionData
}
usersInSubscription, err := repositories.GetUsersBySubscription(*id)
if err != nil {
return err
}
if len(*usersInSubscription) > 0 {
return errors.ErrSubscriptionInUse
}
return service.SubscriptionRepo.DeleteSubscription(id)
}
func (service *MembershipService) FindMembershipByUserID(userID uint) (*models.Membership, error) {
return service.Repo.FindMembershipByUserID(userID)
}
// Membership_Subscriptions
func (service *MembershipService) RegisterSubscription(subscription *models.SubscriptionModel) (uint, error) {
return service.SubscriptionRepo.CreateSubscriptionModel(subscription)
}
func (service *MembershipService) GetSubscriptionModelNames() ([]string, error) {
return service.SubscriptionRepo.GetSubscriptionModelNames()
}
func (service *MembershipService) GetSubscriptionByName(modelname *string) (*models.SubscriptionModel, error) {
return repositories.GetSubscriptionByName(modelname)
}
func (service *MembershipService) GetSubscriptions(where map[string]interface{}) (*[]models.SubscriptionModel, error) {
if where == nil {
where = map[string]interface{}{}
}
return service.SubscriptionRepo.GetSubscriptions(where)
}

View File

@@ -0,0 +1,21 @@
package services
import (
"GoMembership/internal/constants"
"GoMembership/internal/models"
)
func (s *UserService) HandlePasswordChangeRequest(user *models.User) (token string, err error) {
// Deactivate user and reset Verification
if err := s.SetUserStatus(user.ID, constants.DisabledStatus); err != nil {
return "", err
}
if err := s.RevokeVerification(&user.ID, constants.VerificationTypes.Password); err != nil {
return "", err
}
// Generate a token
return s.SetVerificationToken(&user.ID, &constants.VerificationTypes.Password)
}

View File

@@ -0,0 +1,5 @@
package services
func (s *UserService) SetUserStatus(id uint, status uint) error {
return s.Repo.SetUserStatus(id, status)
}

View File

@@ -0,0 +1,116 @@
package services
import (
"strings"
"GoMembership/internal/constants"
"GoMembership/internal/models"
"GoMembership/internal/repositories"
"GoMembership/pkg/errors"
"gorm.io/gorm"
"time"
)
type UserServiceInterface interface {
RegisterUser(user *models.User) (id uint, token string, err error)
GetUserByEmail(email string) (*models.User, error)
GetUserByID(id uint) (*models.User, error)
GetUsers(where map[string]interface{}) (*[]models.User, error)
UpdateUser(user *models.User) (*models.User, error)
DeleteUser(lastname string, id uint) error
SetUserStatus(id uint, status uint) error
VerifyUser(token *string, verificationType *string) (*models.Verification, error)
SetVerificationToken(id *uint, verificationType *string) (string, error)
RevokeVerification(id *uint, verificationType string) error
HandlePasswordChangeRequest(user *models.User) (token string, err error)
}
type UserService struct {
Repo repositories.UserRepositoryInterface
Licences repositories.LicenceInterface
}
func (service *UserService) DeleteUser(lastname string, id uint) error {
if id == 0 || lastname == "" {
return errors.ErrNoData
}
user, err := service.GetUserByID(id)
if err != nil {
return err
}
if user == nil {
return errors.ErrUserNotFound
}
return service.Repo.DeleteUser(id)
}
func (service *UserService) UpdateUser(user *models.User) (*models.User, error) {
if user.ID == 0 {
return nil, errors.ErrUserNotFound
}
user.SetPassword(user.Password)
// Validate subscription model
selectedModel, err := repositories.GetSubscriptionByName(&user.Membership.SubscriptionModel.Name)
if err != nil {
return nil, errors.ErrSubscriptionNotFound
}
user.Membership.SubscriptionModel = *selectedModel
user.Membership.SubscriptionModelID = selectedModel.ID
updatedUser, err := service.Repo.UpdateUser(user)
if err != nil {
if err == gorm.ErrRecordNotFound {
return nil, errors.ErrUserNotFound
}
if strings.Contains(err.Error(), "UNIQUE constraint failed") {
return nil, errors.ErrDuplicateEntry
}
return nil, err
}
return updatedUser, nil
}
func (service *UserService) RegisterUser(user *models.User) (id uint, token string, err error) {
user.SetPassword(user.Password)
user.Status = constants.UnverifiedStatus
user.CreatedAt = time.Now()
user.UpdatedAt = time.Now()
user.PaymentStatus = constants.AwaitingPaymentStatus
user.BankAccount.MandateDateSigned = time.Now()
id, err = service.Repo.CreateUser(user)
if err != nil {
return 0, "", err
}
token, err = service.SetVerificationToken(&id, &constants.VerificationTypes.Email)
if err != nil {
return 0, "", err
}
return id, token, nil
}
func (service *UserService) GetUserByID(id uint) (*models.User, error) {
return repositories.GetUserByID(&id)
}
func (service *UserService) GetUserByEmail(email string) (*models.User, error) {
return service.Repo.GetUserByEmail(email)
}
func (service *UserService) GetUsers(where map[string]interface{}) (*[]models.User, error) {
if where == nil {
where = map[string]interface{}{}
}
return service.Repo.GetUsers(where)
}

View File

@@ -0,0 +1,59 @@
package services
import (
"GoMembership/internal/models"
"GoMembership/internal/utils"
"GoMembership/pkg/errors"
"time"
)
func (s *UserService) SetVerificationToken(id *uint, verificationType *string) (string, error) {
token, err := utils.GenerateVerificationToken()
if err != nil {
return "", err
}
// Check if user is already verified
verified, err := s.Repo.IsVerified(id)
if err != nil {
return "", err
}
if verified {
return "", errors.ErrAlreadyVerified
}
// Prepare the Verification record
verification := models.Verification{
UserID: *id,
VerificationToken: token,
Type: *verificationType,
}
return s.Repo.SetVerificationToken(&verification)
}
func (s *UserService) RevokeVerification(id *uint, verificationType string) error {
return s.Repo.DeleteVerification(*id, verificationType)
}
func (service *UserService) VerifyUser(token *string, verificationType *string) (*models.Verification, error) {
verification, err := service.Repo.GetVerificationOfToken(token, verificationType)
if err != nil {
return nil, err
}
// Check if the user is already verified
verified, err := service.Repo.IsVerified(&verification.UserID)
if err != nil {
return nil, err
}
if verified {
return nil, errors.ErrAlreadyVerified
}
t := time.Now()
verification.VerifiedAt = &t
return verification, nil
}

View File

@@ -0,0 +1,20 @@
package utils
import (
"net/http"
"github.com/gin-gonic/gin"
)
func SetCookie(c *gin.Context, token string) {
c.SetSameSite(http.SameSiteLaxMode)
c.SetCookie(
"jwt",
token,
5*24*60*60, // 5 days
"/",
"",
true,
true,
)
}

View File

@@ -0,0 +1,101 @@
package utils
import (
"bytes"
"crypto/rand"
"encoding/base64"
"io"
"mime"
"mime/quotedprintable"
"net/mail"
"strings"
)
type Email struct {
MimeVersion string
Date string
From string
To string
Subject string
ContentType string
Body string
}
func GenerateRandomString(length int) (string, error) {
bytes := make([]byte, length)
_, err := rand.Read(bytes)
if err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(bytes), nil
}
func GenerateVerificationToken() (string, error) {
return GenerateRandomString(32)
}
func DecodeMail(message string) (*Email, error) {
msg, err := mail.ReadMessage(strings.NewReader(message))
if err != nil {
return nil, err
}
decodedBody, err := io.ReadAll(msg.Body)
if err != nil {
return nil, err
}
decodedBodyString, err := DecodeQuotedPrintable(string(decodedBody))
if err != nil {
return nil, err
}
decodedSubject, err := DecodeRFC2047(msg.Header.Get("Subject"))
if err != nil {
return nil, err
}
email := &Email{}
// Populate the headers
email.MimeVersion = msg.Header.Get("Mime-Version")
email.Date = msg.Header.Get("Date")
email.From = msg.Header.Get("From")
email.To = msg.Header.Get("To")
email.Subject = decodedSubject
email.Body = decodedBodyString
email.ContentType = msg.Header.Get("Content-Type")
return email, nil
}
func DecodeRFC2047(encoded string) (string, error) {
decoder := new(mime.WordDecoder)
decoded, err := decoder.DecodeHeader(encoded)
if err != nil {
return "", err
}
return decoded, nil
}
func DecodeQuotedPrintable(encodedString string) (string, error) {
// Decode quoted-printable encoding
reader := quotedprintable.NewReader(strings.NewReader(encodedString))
decodedBytes := new(bytes.Buffer)
_, err := decodedBytes.ReadFrom(reader)
if err != nil {
return "", err
}
return decodedBytes.String(), nil
}
func EncodeQuotedPrintable(s string) string {
var buf bytes.Buffer
// Use Quoted-Printable encoder
qp := quotedprintable.NewWriter(&buf)
// Write the UTF-8 encoded string to the Quoted-Printable encoder
qp.Write([]byte(s))
qp.Close()
// Encode the result into a MIME header
return mime.QEncoding.Encode("UTF-8", buf.String())
}

View File

@@ -0,0 +1,33 @@
package utils
import (
smtpmock "github.com/mocktools/go-smtp-mock/v2"
)
var Server smtpmock.Server
// StartMockSMTPServer starts a mock SMTP server for testing
func SMTPStart(host string, port int) error {
Server = *smtpmock.New(smtpmock.ConfigurationAttr{
HostAddress: host,
PortNumber: port,
LogToStdout: false,
LogServerActivity: false,
})
if err := Server.Start(); err != nil {
return err
}
return nil
}
func SMTPGetMessages() []smtpmock.Message {
return Server.MessagesAndPurge()
}
func SMTPStop() error {
if err := Server.Stop(); err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,131 @@
package utils
import (
"GoMembership/internal/constants"
"GoMembership/internal/models"
"GoMembership/pkg/logger"
"errors"
"reflect"
)
func HasPrivilige(user *models.User, privilige int8) bool {
switch privilige {
case constants.Priviliges.View:
return user.RoleID >= constants.Roles.Viewer
case constants.Priviliges.Update:
return user.RoleID >= constants.Roles.Editor
case constants.Priviliges.Create:
return user.RoleID >= constants.Roles.Editor
case constants.Priviliges.Delete:
return user.RoleID >= constants.Roles.Editor
default:
return false
}
}
// FilterAllowedStructFields filters allowed fields recursively in a struct and modifies structToModify in place.
func FilterAllowedStructFields(input interface{}, existing interface{}, allowedFields map[string]bool, prefix string) error {
v := reflect.ValueOf(input)
origin := reflect.ValueOf(existing)
// Ensure both input and target are pointers to structs
if v.Kind() != reflect.Ptr || origin.Kind() != reflect.Ptr {
return errors.New("both input and existing must be pointers to structs")
}
v = v.Elem()
origin = origin.Elem()
if v.Kind() != reflect.Struct || origin.Kind() != reflect.Struct {
return errors.New("both input and existing must be structs")
}
for i := 0; i < v.NumField(); i++ {
field := v.Type().Field(i)
key := field.Name
// Skip unexported fields
if !field.IsExported() {
continue
}
// Build the full field path
fullKey := key
if prefix != "" {
fullKey = prefix + "." + key
}
fieldValue := v.Field(i)
originField := origin.Field(i)
// Handle nil pointers
if fieldValue.Kind() == reflect.Ptr {
if fieldValue.IsNil() {
// If the field is nil, skip it or initialize it
if !allowedFields[fullKey] {
// If the field is not allowed, set it to the corresponding field from existing
fieldValue.Set(originField)
}
continue
}
// Dereference the pointer for further processing
fieldValue = fieldValue.Elem()
originField = originField.Elem()
}
// Handle slices
if fieldValue.Kind() == reflect.Slice {
if !allowedFields[fullKey] {
// If the slice is not allowed, set it to the corresponding slice from existing
fieldValue.Set(originField)
continue
} else {
originField.Set(fieldValue)
}
continue
}
// Handle nested structs (including pointers to structs)
if fieldValue.Kind() == reflect.Struct || (fieldValue.Kind() == reflect.Ptr && fieldValue.Type().Elem().Kind() == reflect.Struct) {
if fieldValue.Kind() == reflect.Ptr {
if fieldValue.IsNil() {
continue
}
fieldValue = fieldValue.Elem()
originField = originField.Elem() // May result in an invalid originField
}
var originCopy reflect.Value
// Check if originField is valid (non-zero)
if originField.IsValid() {
originCopy = reflect.New(originField.Type()).Elem()
originCopy.Set(originField)
} else {
// If originField is invalid (e.g., existing had a nil pointer),
// create a new instance of the type from fieldValue
originCopy = reflect.New(fieldValue.Type()).Elem()
}
err := FilterAllowedStructFields(
fieldValue.Addr().Interface(),
originCopy.Addr().Interface(),
allowedFields,
fullKey,
)
if err != nil {
return err
}
continue
}
// Only allow whitelisted fields
if !allowedFields[fullKey] {
logger.Error.Printf("denying update of field: %#v", fullKey)
fieldValue.Set(originField)
} else {
logger.Error.Printf("updating whitelisted field: %#v", fullKey)
}
}
return nil
}

View File

@@ -0,0 +1,176 @@
package utils
import (
"reflect"
"testing"
)
type User struct {
Name string
Age int
Address *Address
Tags []string
License License
}
type Address struct {
City string
Country string
}
type License struct {
ID string
Categories []string
}
func TestFilterAllowedStructFields(t *testing.T) {
tests := []struct {
name string
input interface{}
existing interface{}
allowedFields map[string]bool
expectedResult interface{}
expectError bool
}{
{
name: "Filter top-level fields",
input: &User{
Name: "Alice",
Age: 30,
},
existing: &User{
Name: "Bob",
Age: 25,
},
allowedFields: map[string]bool{
"Name": true,
},
expectedResult: &User{
Name: "Alice", // Allowed field
Age: 25, // Kept from existing
},
expectError: false,
},
{
name: "Filter nested struct fields",
input: &User{
Name: "Alice",
Address: &Address{
City: "New York",
Country: "USA",
},
},
existing: &User{
Name: "Bob",
Address: &Address{
City: "London",
Country: "UK",
},
},
allowedFields: map[string]bool{
"Address.City": true,
},
expectedResult: &User{
Name: "Bob", // Kept from existing
Address: &Address{
City: "New York", // Allowed field
Country: "UK", // Kept from existing
},
},
expectError: false,
},
{
name: "Filter slice fields",
input: &User{
Tags: []string{"admin", "user"},
},
existing: &User{
Tags: []string{"guest"},
},
allowedFields: map[string]bool{
"Tags": true,
},
expectedResult: &User{
Tags: []string{"admin", "user"}, // Allowed slice
},
expectError: false,
},
{
name: "Filter slice of structs",
input: &User{
License: License{
ID: "123",
Categories: []string{"A", "B"},
},
},
existing: &User{
License: License{
ID: "456",
Categories: []string{"C"},
},
},
allowedFields: map[string]bool{
"License.ID": true,
},
expectedResult: &User{
License: License{
ID: "123", // Allowed field
Categories: []string{"C"}, // Kept from existing
},
},
expectError: false,
},
{
name: "Filter pointer fields",
input: &User{
Address: &Address{
City: "Paris",
},
},
existing: &User{
Address: &Address{
City: "Berlin",
Country: "Germany",
},
},
allowedFields: map[string]bool{
"Address.City": true,
},
expectedResult: &User{
Address: &Address{
City: "Paris", // Allowed field
Country: "Germany", // Kept from existing
},
},
expectError: false,
},
{
name: "Invalid input (non-pointer)",
input: User{
Name: "Alice",
},
existing: &User{
Name: "Bob",
},
allowedFields: map[string]bool{
"Name": true,
},
expectedResult: nil,
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := FilterAllowedStructFields(tt.input, tt.existing, tt.allowedFields, "")
if (err != nil) != tt.expectError {
t.Errorf("FilterAllowedStructFields() error = %v, expectError %v", err, tt.expectError)
return
}
if !tt.expectError && !reflect.DeepEqual(tt.input, tt.expectedResult) {
t.Errorf("FilterAllowedStructFields() = %+v, expected %+v", tt.input, tt.expectedResult)
}
})
}
}

View File

@@ -0,0 +1,50 @@
package utils
import (
"GoMembership/pkg/errors"
"GoMembership/pkg/logger"
"net/http"
"github.com/gin-gonic/gin"
"github.com/go-playground/validator/v10"
)
func RespondWithError(c *gin.Context, err error, context string, code int, field string, key string) {
logger.Error.Printf("Sending %v Error Response(Field: %v Key: %v) %v: %v", code, field, key, context, err.Error())
c.JSON(code, gin.H{"errors": []gin.H{{
"field": field,
"key": key,
}}})
}
func HandleValidationError(c *gin.Context, err error) {
var validationErrors []gin.H
logger.Error.Printf("Sending validation error response Error %v", err.Error())
if ve, ok := err.(validator.ValidationErrors); ok {
for _, e := range ve {
validationErrors = append(validationErrors, gin.H{
"field": e.Field(),
"key": "server.validation." + e.Tag(),
})
}
} else {
validationErrors = append(validationErrors, gin.H{
"field": "general",
"key": "server.error.invalid_json",
})
}
c.JSON(http.StatusBadRequest, gin.H{"errors": validationErrors})
}
func HandleUserUpdateError(c *gin.Context, err error) {
switch err {
case errors.ErrUserNotFound:
RespondWithError(c, err, "Error while updating user", http.StatusNotFound, "user.user", "server.validation.user_not_found")
case errors.ErrInvalidUserData:
RespondWithError(c, err, "Error while updating user", http.StatusBadRequest, "user.user", "server.validation.invalid_user_data")
case errors.ErrSubscriptionNotFound:
RespondWithError(c, err, "Error while updating user", http.StatusBadRequest, "subscription", "server.validation.subscription_data")
default:
RespondWithError(c, err, "Error while updating user", http.StatusInternalServerError, "user.user", "server.error.internal_server_error")
}
}

View File

@@ -0,0 +1,53 @@
package validation
import (
"GoMembership/internal/models"
"time"
"github.com/go-playground/validator/v10"
)
func validateDriverslicence(sl validator.StructLevel) {
dl := sl.Current().Interface().(models.User).Licence
// if !vValidateLicence(dl.Number) {
if dl.Number == "" {
sl.ReportError(dl.Number, "licence_number", "", "invalid", "")
}
if dl.IssuedDate.After(time.Now()) {
sl.ReportError(dl.IssuedDate, "issued_date", "", "invalid", "")
}
if dl.ExpirationDate.Before(time.Now().AddDate(0, 0, 3)) {
sl.ReportError(dl.ExpirationDate, "expiration_date", "", "too_soon", "")
}
}
// seems like not every country has to have an licence id and it seems that germany changed their id generation type..
// func validateLicence(fieldValue string) bool {
// if len(fieldValue) != 11 {
// return false
// }
// id, tenthChar := string(fieldValue[:9]), string(fieldValue[9])
// if tenthChar == "X" {
// tenthChar = "10"
// }
// tenthValue, _ := strconv.ParseInt(tenthChar, 10, 8)
// // for readability
// weights := []int{9, 8, 7, 6, 5, 4, 3, 2, 1}
// sum := 0
// for i := 0; i < 9; i++ {
// char := string(id[i])
// value, _ := strconv.ParseInt(char, 36, 64)
// sum += int(value) * weights[i]
// }
// calcCheckDigit := sum % 11
// if calcCheckDigit != int(tenthValue) {
// return false
// }
// return true
// }

View File

@@ -0,0 +1,27 @@
package validation
import (
"GoMembership/internal/models"
"github.com/go-playground/validator/v10"
"github.com/jbub/banking/iban"
"github.com/jbub/banking/swift"
)
func validateBankAccount(sl validator.StructLevel) {
ba := sl.Current().Interface().(models.User).BankAccount
if !ibanValidator(ba.IBAN) {
sl.ReportError(ba.IBAN, "IBAN", "BankAccount.IBAN", "required", "")
}
if ba.BIC != "" && !bicValidator(ba.BIC) {
sl.ReportError(ba.IBAN, "IBAN", "BankAccount.IBAN", "required", "")
}
}
func ibanValidator(fieldValue string) bool {
return iban.Validate(fieldValue) == nil
}
func bicValidator(fieldValue string) bool {
return swift.Validate(fieldValue) == nil
}

View File

@@ -0,0 +1,34 @@
package validation
import (
"regexp"
"strings"
"github.com/go-playground/validator/v10"
)
var xssPatterns = []*regexp.Regexp{
regexp.MustCompile(`(?i)<script`),
regexp.MustCompile(`(?i)javascript:`),
regexp.MustCompile(`(?i)on\w+\s*=`),
regexp.MustCompile(`(?i)(vbscript|data):`),
regexp.MustCompile(`(?i)<(iframe|object|embed|applet)`),
regexp.MustCompile(`(?i)expression\s*\(`),
regexp.MustCompile(`(?i)url\s*\(`),
regexp.MustCompile(`(?i)<\?`),
regexp.MustCompile(`(?i)<%`),
regexp.MustCompile(`(?i)<!\[CDATA\[`),
regexp.MustCompile(`(?i)<(svg|animate)`),
regexp.MustCompile(`(?i)<(audio|video|source)`),
regexp.MustCompile(`(?i)base64`),
}
func ValidateSafeContent(fl validator.FieldLevel) bool {
input := strings.ToLower(fl.Field().String())
for _, pattern := range xssPatterns {
if pattern.MatchString(input) {
return false
}
}
return true
}

View File

@@ -0,0 +1,42 @@
package validation
import (
"GoMembership/internal/models"
"GoMembership/internal/repositories"
"GoMembership/pkg/errors"
"GoMembership/pkg/logger"
"github.com/go-playground/validator/v10"
)
func validateMembership(sl validator.StructLevel) {
membership := sl.Current().Interface().(models.User).Membership
if membership.SubscriptionModel.RequiredMembershipField != "" {
switch membership.SubscriptionModel.RequiredMembershipField {
case "ParentMembershipID":
if err := CheckParentMembershipID(membership); err != nil {
logger.Error.Printf("Error ParentMembershipValidation: %v", err.Error())
sl.ReportError(membership.ParentMembershipID, membership.SubscriptionModel.RequiredMembershipField,
"RequiredMembershipField", "invalid", "")
}
default:
logger.Error.Printf("Error no matching RequiredMembershipField: %v", errors.ErrInvalidValue.Error())
sl.ReportError(membership.ParentMembershipID, membership.SubscriptionModel.RequiredMembershipField,
"RequiredMembershipField", "not_implemented", "")
}
}
}
func CheckParentMembershipID(membership models.Membership) error {
if membership.ParentMembershipID == 0 {
return errors.ValErrParentIDNotSet
} else {
_, err := repositories.GetUserByID(&membership.ParentMembershipID)
if err != nil {
return errors.ValErrParentIDNotFound
}
}
return nil
}

View File

@@ -0,0 +1,20 @@
package validation
import (
"GoMembership/internal/models"
"github.com/gin-gonic/gin/binding"
"github.com/go-playground/validator/v10"
)
func SetupValidators() {
if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
// Register custom validators
v.RegisterValidation("safe_content", ValidateSafeContent)
// Register struct-level validations
v.RegisterStructValidation(validateUser, models.User{})
v.RegisterStructValidation(ValidateSubscription, models.SubscriptionModel{})
}
}

View File

@@ -0,0 +1,46 @@
package validation
import (
"GoMembership/internal/models"
"GoMembership/internal/repositories"
"github.com/go-playground/validator/v10"
)
// ValidateNewSubscription validates a new subscription model being created
func ValidateSubscription(sl validator.StructLevel) {
subscription := sl.Current().Interface().(models.SubscriptionModel)
if subscription.Name == "" {
sl.ReportError(subscription.Name, "Name", "name", "required", "")
}
if sl.Parent().Type().Name() == "MembershipData" {
// This is modifying a subscription directly
if subscription.Details == "" {
sl.ReportError(subscription.Details, "Details", "details", "required", "")
}
if subscription.MonthlyFee < 0 {
sl.ReportError(subscription.MonthlyFee, "MonthlyFee", "monthly_fee", "gte", "0")
}
if subscription.HourlyRate < 0 {
sl.ReportError(subscription.HourlyRate, "HourlyRate", "hourly_rate", "gte", "0")
}
if subscription.IncludedPerYear < 0 {
sl.ReportError(subscription.IncludedPerYear, "IncludedPerYear", "included_hours_per_year", "gte", "0")
}
if subscription.IncludedPerMonth < 0 {
sl.ReportError(subscription.IncludedPerMonth, "IncludedPerMonth", "included_hours_per_month", "gte", "0")
}
} else {
// This is a nested probably user struct. We are only checking if the model exists
existingSubscription, err := repositories.GetSubscriptionByName(&subscription.Name)
if err != nil || existingSubscription == nil {
sl.ReportError(subscription.Name, "Subscription_Name", "name", "exists", "")
}
}
}

View File

@@ -0,0 +1,49 @@
package validation
import (
"GoMembership/internal/constants"
"GoMembership/internal/models"
"GoMembership/internal/repositories"
"GoMembership/pkg/logger"
"time"
"github.com/go-playground/validator/v10"
)
func validateUser(sl validator.StructLevel) {
user := sl.Current().Interface().(models.User)
isSuper := user.RoleID >= constants.Roles.Admin
if user.RoleID > constants.Roles.Member && user.Password == "" {
passwordExists, err := repositories.PasswordExists(&user.ID)
if err != nil || !passwordExists {
logger.Error.Printf("Error checking password exists for user %v: %v", user.Email, err)
sl.ReportError(user.Password, "Password", "password", "required", "")
}
}
// Validate User > 18 years old
if user.DateOfBirth.After(time.Now().AddDate(-18, 0, 0)) {
sl.ReportError(user.DateOfBirth, "DateOfBirth", "dateofbirth", "age", "")
}
// validate subscriptionModel
if user.Membership.SubscriptionModel.Name == "" {
sl.ReportError(user.Membership.SubscriptionModel.Name, "SubscriptionModel.Name", "name", "required", "")
} else {
selectedModel, err := repositories.GetSubscriptionByName(&user.Membership.SubscriptionModel.Name)
if err != nil {
logger.Error.Printf("Error finding subscription model for user %v: %v", user.Email, err)
sl.ReportError(user.Membership.SubscriptionModel.Name, "SubscriptionModel.Name", "name", "invalid", "")
} else {
user.Membership.SubscriptionModel = *selectedModel
}
}
validateMembership(sl)
if !isSuper {
validateBankAccount(sl)
if user.Licence != nil {
validateDriverslicence(sl)
}
}
}