frontend: disabled button while processing password reset
This commit is contained in:
161
go-backend/internal/config/config.go
Normal file
161
go-backend/internal/config/config.go
Normal file
@@ -0,0 +1,161 @@
|
||||
// Package config provides functionality for loading application configuration from a JSON file and environment variables.
|
||||
// It defines structs for different configuration sections (database, authentication, SMTP, templates) and functions
|
||||
// to read and populate these configurations. It also generates secrets for JWT and CSRF tokens.
|
||||
//
|
||||
// This package uses the `envconfig` library to map environment variables to struct fields, falls back to variables of a config
|
||||
// file and provides functions for error handling and logging during the configuration loading process.
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/kelseyhightower/envconfig"
|
||||
|
||||
"GoMembership/internal/utils"
|
||||
"GoMembership/pkg/logger"
|
||||
)
|
||||
|
||||
type DatabaseConfig struct {
|
||||
Path string `json:"Path" default:"data/db.sqlite3" envconfig:"DB_PATH"`
|
||||
}
|
||||
|
||||
type SiteConfig struct {
|
||||
AllowOrigins string `json:"AllowOrigins" envconfig:"ALLOW_ORIGINS"`
|
||||
WebsiteTitle string `json:"WebsiteTitle" envconfig:"WEBSITE_TITLE"`
|
||||
BaseURL string `json:"BaseUrl" envconfig:"BASE_URL"`
|
||||
}
|
||||
type AuthenticationConfig struct {
|
||||
JWTSecret string
|
||||
CSRFSecret string
|
||||
APIKEY string `json:"APIKey" envconfig:"API_KEY"`
|
||||
}
|
||||
|
||||
type SMTPConfig struct {
|
||||
Host string `json:"Host" envconfig:"SMTP_HOST"`
|
||||
User string `json:"User" envconfig:"SMTP_USER"`
|
||||
Password string `json:"Password" envconfig:"SMTP_PASS"`
|
||||
Port int `json:"Port" default:"465" envconfig:"SMTP_PORT"`
|
||||
}
|
||||
|
||||
type TemplateConfig struct {
|
||||
MailPath string `json:"MailPath" default:"templates/email" envconfig:"TEMPLATE_MAIL_PATH"`
|
||||
HTMLPath string `json:"HTMLPath" default:"templates/html" envconfig:"TEMPLATE_HTML_PATH"`
|
||||
StaticPath string `json:"StaticPath" default:"templates/css" envconfig:"TEMPLATE_STATIC_PATH"`
|
||||
LogoURI string `json:"LogoURI" envconfig:"LOGO_URI"`
|
||||
}
|
||||
|
||||
type RecipientsConfig struct {
|
||||
ContactForm string `json:"ContactForm" envconfig:"RECIPIENT_CONTACT_FORM"`
|
||||
UserRegistration string `json:"UserRegistration" envconfig:"RECIPIENT_USER_REGISTRATION"`
|
||||
AdminEmail string `json:"AdminEmail" envconfig:"ADMIN_MAIL"`
|
||||
}
|
||||
|
||||
type SecurityConfig struct {
|
||||
Ratelimits struct {
|
||||
Limit int `json:"Limit" default:"1" envconfig:"RATE_LIMIT"`
|
||||
Burst int `json:"Burst" default:"60" envconfig:"BURST_LIMIT"`
|
||||
} `json:"RateLimits"`
|
||||
}
|
||||
type Config struct {
|
||||
Auth AuthenticationConfig `json:"auth"`
|
||||
Site SiteConfig `json:"site"`
|
||||
Templates TemplateConfig `json:"templates"`
|
||||
Recipients RecipientsConfig `json:"recipients"`
|
||||
ConfigFilePath string `json:"config_file_path" envconfig:"CONFIG_FILE_PATH"`
|
||||
Env string `json:"Environment" default:"development" envconfig:"ENV"`
|
||||
DB DatabaseConfig `json:"db"`
|
||||
SMTP SMTPConfig `json:"smtp"`
|
||||
Security SecurityConfig `json:"security"`
|
||||
}
|
||||
|
||||
var (
|
||||
Site SiteConfig
|
||||
CFGPath string
|
||||
CFG Config
|
||||
Auth AuthenticationConfig
|
||||
DB DatabaseConfig
|
||||
Templates TemplateConfig
|
||||
SMTP SMTPConfig
|
||||
Recipients RecipientsConfig
|
||||
Env string
|
||||
Security SecurityConfig
|
||||
)
|
||||
var environmentOptions map[string]bool = map[string]bool{
|
||||
"development": true,
|
||||
"production": true,
|
||||
"dev": true,
|
||||
"prod": true,
|
||||
}
|
||||
|
||||
// LoadConfig initializes the configuration by reading from a file and environment variables.
|
||||
// It also generates JWT and CSRF secrets. Returns a Config pointer or an error if any step fails.
|
||||
func LoadConfig() {
|
||||
CFGPath = os.Getenv("CONFIG_FILE_PATH")
|
||||
logger.Info.Printf("Config file environment: %v", CFGPath)
|
||||
readFile(&CFG)
|
||||
readEnv(&CFG)
|
||||
csrfSecret, err := utils.GenerateRandomString(32)
|
||||
if err != nil {
|
||||
logger.Error.Fatalf("could not generate CSRF secret: %v", err)
|
||||
}
|
||||
|
||||
jwtSecret, err := utils.GenerateRandomString(32)
|
||||
if err != nil {
|
||||
logger.Error.Fatalf("could not generate JWT secret: %v", err)
|
||||
}
|
||||
CFG.Auth.JWTSecret = jwtSecret
|
||||
CFG.Auth.CSRFSecret = csrfSecret
|
||||
if environmentOptions[CFG.Env] && strings.Contains("development", CFG.Env) {
|
||||
CFG.Env = "development"
|
||||
} else {
|
||||
CFG.Env = "production"
|
||||
}
|
||||
Auth = CFG.Auth
|
||||
DB = CFG.DB
|
||||
Templates = CFG.Templates
|
||||
SMTP = CFG.SMTP
|
||||
Recipients = CFG.Recipients
|
||||
Security = CFG.Security
|
||||
Env = CFG.Env
|
||||
Site = CFG.Site
|
||||
logger.Info.Printf("Config loaded: %#v", CFG)
|
||||
}
|
||||
|
||||
// readFile reads the configuration from the specified file path into the provided Config struct.
|
||||
// If the file path is empty, it defaults to "configs/config.json" in the current working directory.
|
||||
// Returns an error if the file cannot be opened or the JSON cannot be decoded.
|
||||
func readFile(cfg *Config) {
|
||||
if CFGPath == "" {
|
||||
path, err := os.Getwd()
|
||||
if err != nil {
|
||||
logger.Error.Fatalf("could not get working directory: %v", err)
|
||||
}
|
||||
CFGPath = filepath.Join(path, "configs", "config.json")
|
||||
}
|
||||
|
||||
configFile, err := os.Open(CFGPath)
|
||||
// configFile, err := os.Open("config.json")
|
||||
if err != nil {
|
||||
logger.Error.Fatalf("could not open config file: %v", err)
|
||||
}
|
||||
defer configFile.Close()
|
||||
|
||||
decoder := json.NewDecoder(configFile)
|
||||
err = decoder.Decode(cfg)
|
||||
if err != nil {
|
||||
logger.Error.Fatalf("could not decode config file: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// readEnv populates the Config struct with values from environment variables using the envconfig package.
|
||||
// Returns an error if environment variable decoding fails.
|
||||
func readEnv(cfg *Config) {
|
||||
err := envconfig.Process("", cfg)
|
||||
if err != nil {
|
||||
logger.Error.Fatalf("could not decode env variables: %#v", err)
|
||||
}
|
||||
}
|
||||
101
go-backend/internal/constants/constants.go
Normal file
101
go-backend/internal/constants/constants.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package constants
|
||||
|
||||
const (
|
||||
UnverifiedStatus = iota + 1
|
||||
DisabledStatus
|
||||
VerifiedStatus
|
||||
ActiveStatus
|
||||
PassiveStatus
|
||||
DelayedPaymentStatus
|
||||
SettledPaymentStatus
|
||||
AwaitingPaymentStatus
|
||||
MailVerificationSubject = "Nur noch ein kleiner Schritt!"
|
||||
MailChangePasswordSubject = "Passwort Änderung angefordert"
|
||||
MailRegistrationSubject = "Neues Mitglied hat sich registriert"
|
||||
MailWelcomeSubject = "Willkommen beim Dörpsmobil Hasloh e.V."
|
||||
MailContactSubject = "Jemand hat das Kontaktformular gefunden"
|
||||
)
|
||||
|
||||
var Roles = struct {
|
||||
Member int8
|
||||
Viewer int8
|
||||
Editor int8
|
||||
Admin int8
|
||||
}{
|
||||
Member: 0,
|
||||
Viewer: 1,
|
||||
Editor: 4,
|
||||
Admin: 8,
|
||||
}
|
||||
|
||||
var Licences = struct {
|
||||
AM string
|
||||
A1 string
|
||||
A2 string
|
||||
A string
|
||||
B string
|
||||
C1 string
|
||||
C string
|
||||
D1 string
|
||||
D string
|
||||
BE string
|
||||
C1E string
|
||||
CE string
|
||||
D1E string
|
||||
DE string
|
||||
L string
|
||||
T string
|
||||
}{
|
||||
AM: "AM",
|
||||
A1: "A1",
|
||||
A2: "A2",
|
||||
A: "A",
|
||||
B: "B",
|
||||
C1: "C1",
|
||||
C: "C",
|
||||
D1: "D1",
|
||||
D: "D",
|
||||
BE: "BE",
|
||||
C1E: "C1E",
|
||||
CE: "CE",
|
||||
D1E: "D1E",
|
||||
DE: "DE",
|
||||
L: "L",
|
||||
T: "T",
|
||||
}
|
||||
|
||||
var VerificationTypes = struct {
|
||||
Email string
|
||||
Password string
|
||||
}{
|
||||
Email: "email",
|
||||
Password: "password",
|
||||
}
|
||||
|
||||
var Priviliges = struct {
|
||||
View int8
|
||||
Create int8
|
||||
Update int8
|
||||
Delete int8
|
||||
}{
|
||||
View: 0,
|
||||
Update: 10,
|
||||
Create: 20,
|
||||
Delete: 30,
|
||||
}
|
||||
|
||||
var MemberUpdateFields = map[string]bool{
|
||||
"Email": true,
|
||||
"Phone": true,
|
||||
"Company": true,
|
||||
"Address": true,
|
||||
"ZipCode": true,
|
||||
"City": true,
|
||||
"Licence.Categories": true,
|
||||
"BankAccount.Bank": true,
|
||||
"BankAccount.AccountHolderName": true,
|
||||
"BankAccount.IBAN": true,
|
||||
"BankAccount.BIC": true,
|
||||
}
|
||||
|
||||
// "Password": true,
|
||||
71
go-backend/internal/controllers/SQLInjection_test.go
Normal file
71
go-backend/internal/controllers/SQLInjection_test.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package controllers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type SQLInjectionTest struct {
|
||||
name string
|
||||
email string
|
||||
password string
|
||||
expectedStatus int
|
||||
}
|
||||
|
||||
func (sit *SQLInjectionTest) SetupContext() (*gin.Context, *httptest.ResponseRecorder, *gin.Engine) {
|
||||
loginData := loginInput{
|
||||
Email: sit.email,
|
||||
Password: sit.password,
|
||||
}
|
||||
jsonData, _ := json.Marshal(loginData)
|
||||
return GetMockedJSONContext(jsonData, "/login")
|
||||
}
|
||||
|
||||
func (sit *SQLInjectionTest) RunHandler(c *gin.Context, router *gin.Engine) {
|
||||
router.POST("/login", Uc.LoginHandler)
|
||||
router.ServeHTTP(c.Writer, c.Request)
|
||||
}
|
||||
|
||||
func (sit *SQLInjectionTest) ValidateResponse(w *httptest.ResponseRecorder) error {
|
||||
if sit.expectedStatus != w.Code {
|
||||
responseBody, _ := io.ReadAll(w.Body)
|
||||
return fmt.Errorf("SQL Injection Attempt: Didn't get the expected response code: got: %v; expected: %v. Context: %#v", w.Code, sit.expectedStatus, string(responseBody))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sit *SQLInjectionTest) ValidateResult() error {
|
||||
// Add any additional validation if needed
|
||||
return nil
|
||||
}
|
||||
|
||||
func testSQLInjectionAttempt(t *testing.T) {
|
||||
tests := []SQLInjectionTest{
|
||||
{
|
||||
name: "SQL Injection Attempt in Email",
|
||||
email: "' OR '1'='1",
|
||||
password: "password123",
|
||||
expectedStatus: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
name: "SQL Injection Attempt in Password",
|
||||
email: "user@example.com",
|
||||
password: "' OR '1'='1",
|
||||
expectedStatus: http.StatusNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := runSingleTest(&tt); err != nil {
|
||||
t.Errorf("Test failed: %v", err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
31
go-backend/internal/controllers/XSS_test.go
Normal file
31
go-backend/internal/controllers/XSS_test.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package controllers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func testXSSAttempt(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.POST("/register", Uc.RegisterUser)
|
||||
|
||||
xssPayload := "<script>alert('XSS')</script>"
|
||||
user := getBaseUser()
|
||||
user.FirstName = xssPayload
|
||||
user.Email = "user@xss.hack"
|
||||
jsonData, _ := json.Marshal(RegistrationData{User: user})
|
||||
req, _ := http.NewRequest("POST", "/register", bytes.NewBuffer(jsonData))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
assert.NotContains(t, w.Body.String(), xssPayload)
|
||||
}
|
||||
50
go-backend/internal/controllers/contactController.go
Normal file
50
go-backend/internal/controllers/contactController.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package controllers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
|
||||
"GoMembership/internal/services"
|
||||
"GoMembership/pkg/logger"
|
||||
)
|
||||
|
||||
type ContactController struct {
|
||||
EmailService *services.EmailService
|
||||
}
|
||||
type contactData struct {
|
||||
Email string `form:"REPLY_TO" validate:"required,email"`
|
||||
Name string `form:"name"`
|
||||
Message string `form:"message" validate:"required"`
|
||||
Honeypot string `form:"username" validate:"eq="`
|
||||
}
|
||||
|
||||
func (cc *ContactController) RelayContactRequest(c *gin.Context) {
|
||||
var msgData contactData
|
||||
|
||||
if err := c.ShouldBind(&msgData); err != nil {
|
||||
// A bot is talking to us
|
||||
c.JSON(http.StatusNotAcceptable, gin.H{"error": "Not Acceptable"})
|
||||
return
|
||||
}
|
||||
|
||||
validate := validator.New()
|
||||
if err := validate.Struct(msgData); err != nil {
|
||||
logger.Error.Printf("Couldn't validate contact form data: %#v: %v", msgData, err)
|
||||
c.HTML(http.StatusNotAcceptable, "contactForm_reply.html", gin.H{"Error": "Form validation failed. Please check again."})
|
||||
// c.JSON(http.StatusNotAcceptable, gin.H{"error": "Couldn't validate contact form data"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := cc.EmailService.RelayContactFormMessage(msgData.Email, msgData.Name, msgData.Message); err != nil {
|
||||
logger.Error.Printf("Couldn't send contact message mail: %v", err)
|
||||
c.HTML(http.StatusInternalServerError, "contactForm_reply.html", gin.H{"Error": "Email submission failed. Please try again."})
|
||||
// c.JSON(http.StatusInternalServerError, gin.H{"error": "Couldn't send mail"})
|
||||
return
|
||||
}
|
||||
|
||||
// c.JSON(http.StatusAccepted, "Your message has been sent")
|
||||
c.HTML(http.StatusAccepted, "contactForm_reply.html", gin.H{"Success": true})
|
||||
}
|
||||
159
go-backend/internal/controllers/contactController_test.go
Normal file
159
go-backend/internal/controllers/contactController_test.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package controllers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"GoMembership/internal/config"
|
||||
"GoMembership/internal/constants"
|
||||
"GoMembership/internal/utils"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type RelayContactRequestTest struct {
|
||||
Input url.Values
|
||||
Name string
|
||||
WantResponse int
|
||||
Assert bool
|
||||
}
|
||||
|
||||
func testContactController(t *testing.T) {
|
||||
|
||||
tests := getContactData()
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.Name, func(t *testing.T) {
|
||||
if err := runSingleTest(&tt); err != nil {
|
||||
t.Errorf("Test failed: %v", err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (rt *RelayContactRequestTest) SetupContext() (*gin.Context, *httptest.ResponseRecorder, *gin.Engine) {
|
||||
return GetMockedFormContext(rt.Input, "/contact")
|
||||
}
|
||||
|
||||
func (rt *RelayContactRequestTest) RunHandler(c *gin.Context, router *gin.Engine) {
|
||||
router.POST("/contact", Cc.RelayContactRequest)
|
||||
router.ServeHTTP(c.Writer, c.Request)
|
||||
// Cc.RelayContactRequest(c)
|
||||
}
|
||||
|
||||
func (rt *RelayContactRequestTest) ValidateResponse(w *httptest.ResponseRecorder) error {
|
||||
if w.Code != rt.WantResponse {
|
||||
return fmt.Errorf("Didn't get the expected response code: got: %v; expected: %v", w.Code, rt.WantResponse)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rt *RelayContactRequestTest) ValidateResult() error {
|
||||
|
||||
messages := utils.SMTPGetMessages()
|
||||
|
||||
for _, message := range messages {
|
||||
|
||||
mail, err := utils.DecodeMail(message.MsgRequest())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if strings.Contains(mail.Subject, constants.MailContactSubject) {
|
||||
|
||||
if err := checkContactRequestMail(mail, rt); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("Subject not expected: %v", mail.Subject)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkContactRequestMail(mail *utils.Email, rt *RelayContactRequestTest) error {
|
||||
|
||||
if !strings.Contains(mail.To, config.Recipients.ContactForm) {
|
||||
return fmt.Errorf("Contact Information didn't reach the admin! Recipient was: %v instead of %v", mail.To, config.Recipients.ContactForm)
|
||||
}
|
||||
if !strings.Contains(mail.From, config.SMTP.User) {
|
||||
return fmt.Errorf("Contact Information was sent from unexpected address! Sender was: %v instead of %v", mail.From, config.SMTP.User)
|
||||
}
|
||||
|
||||
//Check if all the relevant data has been passed to the mail.
|
||||
if !strings.Contains(mail.Body, rt.Input.Get("name")) {
|
||||
return fmt.Errorf("User name(%v) has not been rendered in contact mail.", rt.Input.Get("name"))
|
||||
}
|
||||
|
||||
if !strings.Contains(mail.Body, rt.Input.Get("message")) {
|
||||
return fmt.Errorf("User message(%v) has not been rendered in contact mail.", rt.Input.Get("message"))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getBaseRequest() *url.Values {
|
||||
return &url.Values{
|
||||
"username": {""},
|
||||
"name": {"My-First and-Last-Name"},
|
||||
"REPLY_TO": {"name@domain.de"},
|
||||
"message": {"My message to the world"},
|
||||
}
|
||||
}
|
||||
|
||||
func customizeRequest(updates map[string]string) *url.Values {
|
||||
form := getBaseRequest()
|
||||
for key, value := range updates {
|
||||
form.Set(key, value)
|
||||
}
|
||||
return form
|
||||
}
|
||||
|
||||
func getContactData() []RelayContactRequestTest {
|
||||
return []RelayContactRequestTest{
|
||||
{
|
||||
Name: "mail empty, should fail",
|
||||
WantResponse: http.StatusNotAcceptable,
|
||||
Assert: false,
|
||||
Input: *customizeRequest(
|
||||
map[string]string{
|
||||
"REPLY_TO": "",
|
||||
}),
|
||||
},
|
||||
{
|
||||
Name: "mail invalid, should fail",
|
||||
WantResponse: http.StatusNotAcceptable,
|
||||
Assert: false,
|
||||
Input: *customizeRequest(
|
||||
map[string]string{
|
||||
"REPLY_TO": "novalid#email.de",
|
||||
}),
|
||||
},
|
||||
{
|
||||
Name: "No message should fail",
|
||||
WantResponse: http.StatusNotAcceptable,
|
||||
Assert: true,
|
||||
Input: *customizeRequest(
|
||||
map[string]string{
|
||||
"message": "",
|
||||
}),
|
||||
},
|
||||
{
|
||||
Name: "Honeypot set, should fail",
|
||||
WantResponse: http.StatusNotAcceptable,
|
||||
Assert: true,
|
||||
Input: *customizeRequest(
|
||||
map[string]string{
|
||||
"username": "I'm a bot",
|
||||
}),
|
||||
},
|
||||
{
|
||||
Name: "Correct message, should pass",
|
||||
WantResponse: http.StatusAccepted,
|
||||
Assert: true,
|
||||
Input: *customizeRequest(
|
||||
map[string]string{}),
|
||||
},
|
||||
}
|
||||
}
|
||||
313
go-backend/internal/controllers/controllers_test.go
Normal file
313
go-backend/internal/controllers/controllers_test.go
Normal file
@@ -0,0 +1,313 @@
|
||||
package controllers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"log"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"GoMembership/internal/config"
|
||||
"GoMembership/internal/constants"
|
||||
"GoMembership/internal/database"
|
||||
"GoMembership/internal/models"
|
||||
"GoMembership/internal/repositories"
|
||||
"GoMembership/internal/services"
|
||||
"GoMembership/internal/utils"
|
||||
"GoMembership/internal/validation"
|
||||
"GoMembership/pkg/logger"
|
||||
)
|
||||
|
||||
type TestCase interface {
|
||||
SetupContext() (*gin.Context, *httptest.ResponseRecorder, *gin.Engine)
|
||||
RunHandler(*gin.Context, *gin.Engine)
|
||||
ValidateResponse(*httptest.ResponseRecorder) error
|
||||
ValidateResult() error
|
||||
}
|
||||
|
||||
const (
|
||||
Host = "127.0.0.1"
|
||||
Port int = 2525
|
||||
)
|
||||
|
||||
type loginInput struct {
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
var (
|
||||
Uc *UserController
|
||||
Mc *MembershipController
|
||||
Cc *ContactController
|
||||
)
|
||||
|
||||
func TestSuite(t *testing.T) {
|
||||
_ = deleteTestDB("test.db")
|
||||
|
||||
cwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to get current working directory: %v", err)
|
||||
}
|
||||
|
||||
// Build paths relative to the current working directory
|
||||
configFilePath := filepath.Join(cwd, "..", "..", "configs", "config.json")
|
||||
templateHTMLPath := filepath.Join(cwd, "..", "..", "templates", "html")
|
||||
templateMailPath := filepath.Join(cwd, "..", "..", "templates", "email")
|
||||
|
||||
if err := os.Setenv("TEMPLATE_MAIL_PATH", templateMailPath); err != nil {
|
||||
log.Fatalf("Error setting environment variable: %v", err)
|
||||
}
|
||||
if err := os.Setenv("TEMPLATE_HTML_PATH", templateHTMLPath); err != nil {
|
||||
log.Fatalf("Error setting environment variable: %v", err)
|
||||
}
|
||||
if err := os.Setenv("CONFIG_FILE_PATH", configFilePath); err != nil {
|
||||
log.Fatalf("Error setting environment variable: %v", err)
|
||||
}
|
||||
if err := os.Setenv("SMTP_HOST", Host); err != nil {
|
||||
log.Fatalf("Error setting environment variable: %v", err)
|
||||
}
|
||||
if err := os.Setenv("SMTP_PORT", strconv.Itoa(Port)); err != nil {
|
||||
log.Fatalf("Error setting environment variable: %v", err)
|
||||
}
|
||||
if err := os.Setenv("BASE_URL", "http://"+Host+":2525"); err != nil {
|
||||
log.Fatalf("Error setting environment variable: %v", err)
|
||||
}
|
||||
if err := os.Setenv("DB_PATH", "test.db"); err != nil {
|
||||
log.Fatalf("Error setting environment variable: %v", err)
|
||||
}
|
||||
config.LoadConfig()
|
||||
if err := database.Open("test.db", config.Recipients.AdminEmail); err != nil {
|
||||
log.Fatalf("Failed to create DB: %#v", err)
|
||||
}
|
||||
utils.SMTPStart(Host, Port)
|
||||
emailService := services.NewEmailService(config.SMTP.Host, config.SMTP.Port, config.SMTP.User, config.SMTP.Password)
|
||||
var consentRepo repositories.ConsentRepositoryInterface = &repositories.ConsentRepository{}
|
||||
consentService := &services.ConsentService{Repo: consentRepo}
|
||||
|
||||
var bankAccountRepo repositories.BankAccountRepositoryInterface = &repositories.BankAccountRepository{}
|
||||
bankAccountService := &services.BankAccountService{Repo: bankAccountRepo}
|
||||
|
||||
var membershipRepo repositories.MembershipRepositoryInterface = &repositories.MembershipRepository{}
|
||||
var subscriptionRepo repositories.SubscriptionModelsRepositoryInterface = &repositories.SubscriptionModelsRepository{}
|
||||
membershipService := &services.MembershipService{Repo: membershipRepo, SubscriptionRepo: subscriptionRepo}
|
||||
|
||||
var licenceRepo repositories.LicenceInterface = &repositories.LicenceRepository{}
|
||||
var userRepo repositories.UserRepositoryInterface = &repositories.UserRepository{}
|
||||
userService := &services.UserService{Repo: userRepo, Licences: licenceRepo}
|
||||
|
||||
licenceService := &services.LicenceService{Repo: licenceRepo}
|
||||
|
||||
Uc = &UserController{Service: userService, LicenceService: licenceService, EmailService: emailService, ConsentService: consentService, BankAccountService: bankAccountService, MembershipService: membershipService}
|
||||
Mc = &MembershipController{UserController: &MockUserController{}, Service: *membershipService}
|
||||
Cc = &ContactController{EmailService: emailService}
|
||||
|
||||
if err := initSubscriptionPlans(); err != nil {
|
||||
log.Fatalf("Failed to init Subscription plans: %#v", err)
|
||||
}
|
||||
|
||||
if err := initLicenceCategories(); err != nil {
|
||||
log.Fatalf("Failed to init Categories: %v", err)
|
||||
}
|
||||
admin := models.User{
|
||||
FirstName: "Ad",
|
||||
LastName: "min",
|
||||
Email: "admin@example.com",
|
||||
DateOfBirth: time.Date(1990, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
Company: "SampleCorp",
|
||||
Phone: "+123456789",
|
||||
Address: "123 Main Street",
|
||||
ZipCode: "12345",
|
||||
City: "SampleCity",
|
||||
Status: constants.ActiveStatus,
|
||||
RoleID: 8,
|
||||
}
|
||||
admin.SetPassword("securepassword")
|
||||
database.DB.Create(&admin)
|
||||
validation.SetupValidators()
|
||||
t.Run("userController", func(t *testing.T) {
|
||||
testUserController(t)
|
||||
})
|
||||
|
||||
t.Run("SQL_Injection", func(t *testing.T) {
|
||||
testSQLInjectionAttempt(t)
|
||||
})
|
||||
|
||||
t.Run("contactController", func(t *testing.T) {
|
||||
testContactController(t)
|
||||
})
|
||||
|
||||
t.Run("membershipController", func(t *testing.T) {
|
||||
testMembershipController(t)
|
||||
})
|
||||
|
||||
t.Run("XSSAttempt", func(t *testing.T) {
|
||||
testXSSAttempt(t)
|
||||
})
|
||||
|
||||
if err := utils.SMTPStop(); err != nil {
|
||||
log.Fatalf("Failed to stop SMTP Mockup Server: %#v", err)
|
||||
}
|
||||
|
||||
// if err := deleteTestDB("test.db"); err != nil {
|
||||
// log.Fatalf("Failed to tear down DB: %#v", err)
|
||||
// }
|
||||
}
|
||||
|
||||
func initLicenceCategories() error {
|
||||
categories := []models.Category{
|
||||
{Name: "AM"},
|
||||
{Name: "A1"},
|
||||
{Name: "A2"},
|
||||
{Name: "A"},
|
||||
{Name: "B"},
|
||||
{Name: "C1"},
|
||||
{Name: "C"},
|
||||
{Name: "D1"},
|
||||
{Name: "D"},
|
||||
{Name: "BE"},
|
||||
{Name: "C1E"},
|
||||
{Name: "CE"},
|
||||
{Name: "D1E"},
|
||||
{Name: "DE"},
|
||||
{Name: "T"},
|
||||
{Name: "L"},
|
||||
}
|
||||
for _, category := range categories {
|
||||
result := database.DB.Create(&category)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func initSubscriptionPlans() error {
|
||||
subscriptions := []models.SubscriptionModel{
|
||||
{
|
||||
Name: "Basic",
|
||||
Details: "Test Plan",
|
||||
MonthlyFee: 2,
|
||||
HourlyRate: 3,
|
||||
},
|
||||
{
|
||||
Name: "additional",
|
||||
Details: "This plan needs another membership id to validate",
|
||||
RequiredMembershipField: "ParentMembershipID",
|
||||
MonthlyFee: 2,
|
||||
HourlyRate: 3,
|
||||
},
|
||||
}
|
||||
for _, subscription := range subscriptions {
|
||||
|
||||
result := database.DB.Create(&subscription)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetMockedJSONContext(jsonStr []byte, url string) (*gin.Context, *httptest.ResponseRecorder, *gin.Engine) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
router := gin.New()
|
||||
|
||||
// Load HTML templates
|
||||
router.LoadHTMLGlob(config.Templates.HTMLPath + "/*")
|
||||
|
||||
var err error
|
||||
c.Request, err = http.NewRequest("POST", url, bytes.NewBuffer(jsonStr))
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to create new Request: %#v", err)
|
||||
}
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
return c, w, router
|
||||
}
|
||||
|
||||
func GetMockedFormContext(formData url.Values, url string) (*gin.Context, *httptest.ResponseRecorder, *gin.Engine) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
router := gin.New()
|
||||
|
||||
// Load HTML templates
|
||||
router.LoadHTMLGlob(config.Templates.HTMLPath + "/*")
|
||||
|
||||
req, err := http.NewRequest("POST",
|
||||
url,
|
||||
bytes.NewBufferString(formData.Encode()))
|
||||
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to create new Request: %#v", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
c.Request = req
|
||||
|
||||
return c, w, router
|
||||
}
|
||||
|
||||
func getBaseUser() models.User {
|
||||
return models.User{
|
||||
DateOfBirth: time.Date(2000, time.January, 1, 0, 0, 0, 0, time.UTC),
|
||||
FirstName: "John",
|
||||
LastName: "Doe",
|
||||
Email: "john.doe@example.com",
|
||||
Address: "Pablo Escobar Str. 4",
|
||||
ZipCode: "25474",
|
||||
City: "Hasloh",
|
||||
Phone: "01738484993",
|
||||
BankAccount: models.BankAccount{IBAN: "DE89370400440532013000"},
|
||||
Membership: models.Membership{SubscriptionModel: models.SubscriptionModel{Name: "Basic"}},
|
||||
Licence: nil,
|
||||
ProfilePicture: "",
|
||||
Password: "password123",
|
||||
Company: "",
|
||||
RoleID: 8,
|
||||
}
|
||||
}
|
||||
|
||||
func deleteTestDB(dbPath string) error {
|
||||
err := os.Remove(dbPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func runSingleTest(tc TestCase) error {
|
||||
c, w, router := tc.SetupContext()
|
||||
tc.RunHandler(c, router)
|
||||
|
||||
if err := tc.ValidateResponse(w); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return tc.ValidateResult()
|
||||
}
|
||||
|
||||
func GenerateInputJSON(aStruct interface{}) string {
|
||||
|
||||
// Marshal the object into JSON
|
||||
jsonBytes, err := json.Marshal(aStruct)
|
||||
if err != nil {
|
||||
logger.Error.Fatalf("Couldn't generate JSON: %#v\nERROR: %#v", aStruct, err)
|
||||
return ""
|
||||
}
|
||||
return string(jsonBytes)
|
||||
}
|
||||
26
go-backend/internal/controllers/licenceController.go
Normal file
26
go-backend/internal/controllers/licenceController.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package controllers
|
||||
|
||||
import (
|
||||
"GoMembership/internal/services"
|
||||
"GoMembership/internal/utils"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type LicenceController struct {
|
||||
Service services.LicenceService
|
||||
}
|
||||
|
||||
func (lc *LicenceController) GetAllCategories(c *gin.Context) {
|
||||
|
||||
categories, err := lc.Service.GetAllCategories()
|
||||
|
||||
if err != nil {
|
||||
utils.RespondWithError(c, err, "Error retrieving licence categories", http.StatusInternalServerError, "general", "server.error.internal_server_error")
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"licence_categories": categories,
|
||||
})
|
||||
}
|
||||
153
go-backend/internal/controllers/membershipController.go
Normal file
153
go-backend/internal/controllers/membershipController.go
Normal file
@@ -0,0 +1,153 @@
|
||||
package controllers
|
||||
|
||||
import (
|
||||
"GoMembership/internal/constants"
|
||||
"GoMembership/internal/models"
|
||||
"GoMembership/internal/services"
|
||||
"GoMembership/internal/utils"
|
||||
"strings"
|
||||
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"GoMembership/pkg/errors"
|
||||
"GoMembership/pkg/logger"
|
||||
)
|
||||
|
||||
type MembershipController struct {
|
||||
Service services.MembershipService
|
||||
UserController interface {
|
||||
ExtractUserFromContext(*gin.Context) (*models.User, error)
|
||||
}
|
||||
}
|
||||
|
||||
type MembershipData struct {
|
||||
// APIKey string `json:"api_key"`
|
||||
Subscription models.SubscriptionModel `json:"subscription"`
|
||||
}
|
||||
|
||||
func (mc *MembershipController) RegisterSubscription(c *gin.Context) {
|
||||
var regData MembershipData
|
||||
|
||||
requestUser, err := mc.UserController.ExtractUserFromContext(c)
|
||||
if err != nil {
|
||||
utils.RespondWithError(c, err, "Error extracting user from context in subscription registrationHandler", http.StatusBadRequest, "general", "server.validation.invalid_user_data")
|
||||
return
|
||||
}
|
||||
|
||||
if !utils.HasPrivilige(requestUser, constants.Priviliges.Create) {
|
||||
utils.RespondWithError(c, errors.ErrNotAuthorized, "Not allowed to register subscription", http.StatusForbidden, "user.user", "server.error.unauthorized")
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(®Data); err != nil {
|
||||
utils.HandleValidationError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Register Subscription
|
||||
logger.Info.Printf("Registering subscription %v", regData.Subscription.Name)
|
||||
id, err := mc.Service.RegisterSubscription(®Data.Subscription)
|
||||
if err != nil {
|
||||
logger.Error.Printf("Couldn't register Membershipmodel: %v", err)
|
||||
if strings.Contains(err.Error(), "UNIQUE constraint failed") {
|
||||
c.JSON(http.StatusConflict, "Duplicate subscription name")
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusNotAcceptable, "Couldn't register Membershipmodel")
|
||||
return
|
||||
}
|
||||
logger.Info.Printf("registering subscription: %+v", regData)
|
||||
c.JSON(http.StatusCreated, gin.H{
|
||||
"status": "success",
|
||||
"id": id,
|
||||
})
|
||||
}
|
||||
|
||||
func (mc *MembershipController) UpdateHandler(c *gin.Context) {
|
||||
var regData MembershipData
|
||||
|
||||
requestUser, err := mc.UserController.ExtractUserFromContext(c)
|
||||
if err != nil {
|
||||
utils.RespondWithError(c, err, "Error extracting user from context in subscription UpdateHandler", http.StatusBadRequest, "general", "server.validation.no_auth_tokenw")
|
||||
return
|
||||
}
|
||||
|
||||
if !utils.HasPrivilige(requestUser, constants.Priviliges.Update) {
|
||||
utils.RespondWithError(c, errors.ErrNotAuthorized, "Not allowed to update subscription", http.StatusForbidden, "user.user", "server.error.unauthorized")
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(®Data); err != nil {
|
||||
utils.HandleValidationError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// update Subscription
|
||||
logger.Info.Printf("Updating subscription %v", regData.Subscription.Name)
|
||||
id, err := mc.Service.UpdateSubscription(®Data.Subscription)
|
||||
if err != nil {
|
||||
logger.Error.Printf("Couldn't update Membershipmodel: %v", err)
|
||||
if strings.Contains(err.Error(), "UNIQUE constraint failed") {
|
||||
c.JSON(http.StatusConflict, "Duplicate subscription name")
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusNotAcceptable, "Couldn't update Membershipmodel")
|
||||
return
|
||||
}
|
||||
logger.Info.Printf("updating subscription: %+v", regData)
|
||||
c.JSON(http.StatusAccepted, gin.H{
|
||||
"status": "success",
|
||||
"id": id,
|
||||
})
|
||||
}
|
||||
|
||||
func (mc *MembershipController) DeleteSubscription(c *gin.Context) {
|
||||
type deleteData struct {
|
||||
Subscription struct {
|
||||
ID uint `json:"id"`
|
||||
Name string `json:"name"`
|
||||
} `json:"subscription"`
|
||||
}
|
||||
|
||||
var data deleteData
|
||||
requestUser, err := mc.UserController.ExtractUserFromContext(c)
|
||||
if err != nil {
|
||||
utils.RespondWithError(c, err, "Error extracting user from context in subscription UpdateHandler", http.StatusBadRequest, "general", "server.validation.no_auth_tokenw")
|
||||
return
|
||||
}
|
||||
|
||||
if !utils.HasPrivilige(requestUser, constants.Priviliges.Delete) {
|
||||
utils.RespondWithError(c, errors.ErrNotAuthorized, "Not allowed to update subscription", http.StatusForbidden, "user.user", "server.error.unauthorized")
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
utils.HandleValidationError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := mc.Service.DeleteSubscription(&data.Subscription.ID, &data.Subscription.Name); err != nil {
|
||||
utils.RespondWithError(c, err, "Error during subscription Deletion", http.StatusExpectationFailed, "subscription", "server.error.not_possible")
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "Subscription deleted successfully"})
|
||||
}
|
||||
|
||||
func (mc *MembershipController) GetSubscriptions(c *gin.Context) {
|
||||
subscriptions, err := mc.Service.GetSubscriptions(nil)
|
||||
if err != nil {
|
||||
logger.Error.Printf("Error retrieving subscriptions: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"errors": []gin.H{{
|
||||
"field": "general",
|
||||
"key": "validation.internal_server_error",
|
||||
}}})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"subscriptions": subscriptions,
|
||||
})
|
||||
}
|
||||
393
go-backend/internal/controllers/membershipController_test.go
Normal file
393
go-backend/internal/controllers/membershipController_test.go
Normal file
@@ -0,0 +1,393 @@
|
||||
package controllers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"GoMembership/internal/constants"
|
||||
"GoMembership/internal/database"
|
||||
"GoMembership/internal/models"
|
||||
"GoMembership/pkg/logger"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type RegisterSubscriptionTest struct {
|
||||
WantDBData map[string]interface{}
|
||||
Input string
|
||||
Name string
|
||||
WantResponse int
|
||||
Assert bool
|
||||
}
|
||||
|
||||
type UpdateSubscriptionTest struct {
|
||||
WantDBData map[string]interface{}
|
||||
Input string
|
||||
Name string
|
||||
WantResponse int
|
||||
Assert bool
|
||||
}
|
||||
|
||||
type DeleteSubscriptionTest struct {
|
||||
WantDBData map[string]interface{}
|
||||
Input string
|
||||
Name string
|
||||
WantResponse int
|
||||
Assert bool
|
||||
}
|
||||
|
||||
type MockUserController struct {
|
||||
UserController // Embed the UserController
|
||||
}
|
||||
|
||||
func (m *MockUserController) ExtractUserFromContext(c *gin.Context) (*models.User, error) {
|
||||
return &models.User{
|
||||
ID: 1,
|
||||
FirstName: "Admin",
|
||||
LastName: "User",
|
||||
Email: "admin@test.com",
|
||||
RoleID: constants.Roles.Admin,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func setupMockAuth() {
|
||||
// Create and assign the mock controller
|
||||
mockController := &MockUserController{}
|
||||
Mc.UserController = mockController
|
||||
}
|
||||
|
||||
func testMembershipController(t *testing.T) {
|
||||
|
||||
setupMockAuth()
|
||||
tests := getSubscriptionRegistrationData()
|
||||
for _, tt := range tests {
|
||||
logger.Error.Print("==============================================================")
|
||||
logger.Error.Printf("MembershipController : %v", tt.Name)
|
||||
logger.Error.Print("==============================================================")
|
||||
t.Run(tt.Name, func(t *testing.T) {
|
||||
if err := runSingleTest(&tt); err != nil {
|
||||
t.Errorf("Test failed: %v", err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
updateTests := getSubscriptionUpdateData()
|
||||
for _, tt := range updateTests {
|
||||
logger.Error.Print("==============================================================")
|
||||
logger.Error.Printf("Update SubscriptionData : %v", tt.Name)
|
||||
logger.Error.Print("==============================================================")
|
||||
t.Run(tt.Name, func(t *testing.T) {
|
||||
if err := runSingleTest(&tt); err != nil {
|
||||
t.Errorf("Test failed: %v", err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
deleteTests := getSubscriptionDeleteData()
|
||||
for _, tt := range deleteTests {
|
||||
logger.Error.Print("==============================================================")
|
||||
logger.Error.Printf("Delete SubscriptionData : %v", tt.Name)
|
||||
logger.Error.Print("==============================================================")
|
||||
t.Run(tt.Name, func(t *testing.T) {
|
||||
if err := runSingleTest(&tt); err != nil {
|
||||
t.Errorf("Test failed: %v", err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (rt *RegisterSubscriptionTest) SetupContext() (*gin.Context, *httptest.ResponseRecorder, *gin.Engine) {
|
||||
return GetMockedJSONContext([]byte(rt.Input), "api/subscription")
|
||||
}
|
||||
|
||||
func (rt *RegisterSubscriptionTest) RunHandler(c *gin.Context, router *gin.Engine) {
|
||||
Mc.RegisterSubscription(c)
|
||||
}
|
||||
|
||||
func (rt *RegisterSubscriptionTest) ValidateResponse(w *httptest.ResponseRecorder) error {
|
||||
if w.Code != rt.WantResponse {
|
||||
return fmt.Errorf("Didn't get the expected response code: got: %v; expected: %v", w.Code, rt.WantResponse)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rt *RegisterSubscriptionTest) ValidateResult() error {
|
||||
return validateSubscription(rt.Assert, rt.WantDBData)
|
||||
}
|
||||
|
||||
func validateSubscription(assert bool, wantDBData map[string]interface{}) error {
|
||||
subscriptions, err := Mc.Service.GetSubscriptions(wantDBData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error in database ops: %#v", err)
|
||||
}
|
||||
if assert != (len(*subscriptions) != 0) {
|
||||
return fmt.Errorf("Subscription entry query didn't met expectation: %v != %#v", assert, *subscriptions)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ut *UpdateSubscriptionTest) SetupContext() (*gin.Context, *httptest.ResponseRecorder, *gin.Engine) {
|
||||
return GetMockedJSONContext([]byte(ut.Input), "api/subscription/upsert")
|
||||
}
|
||||
|
||||
func (ut *UpdateSubscriptionTest) RunHandler(c *gin.Context, router *gin.Engine) {
|
||||
Mc.UpdateHandler(c)
|
||||
}
|
||||
|
||||
func (ut *UpdateSubscriptionTest) ValidateResponse(w *httptest.ResponseRecorder) error {
|
||||
if w.Code != ut.WantResponse {
|
||||
return fmt.Errorf("Didn't get the expected response code: got: %v; expected: %v", w.Code, ut.WantResponse)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ut *UpdateSubscriptionTest) ValidateResult() error {
|
||||
return validateSubscription(ut.Assert, ut.WantDBData)
|
||||
}
|
||||
|
||||
func (dt *DeleteSubscriptionTest) SetupContext() (*gin.Context, *httptest.ResponseRecorder, *gin.Engine) {
|
||||
return GetMockedJSONContext([]byte(dt.Input), "api/subscription/delete")
|
||||
}
|
||||
|
||||
func (dt *DeleteSubscriptionTest) RunHandler(c *gin.Context, router *gin.Engine) {
|
||||
Mc.DeleteSubscription(c)
|
||||
}
|
||||
|
||||
func (dt *DeleteSubscriptionTest) ValidateResponse(w *httptest.ResponseRecorder) error {
|
||||
if w.Code != dt.WantResponse {
|
||||
return fmt.Errorf("Didn't get the expected response code: got: %v; expected: %v", w.Code, dt.WantResponse)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dt *DeleteSubscriptionTest) ValidateResult() error {
|
||||
return validateSubscription(dt.Assert, dt.WantDBData)
|
||||
}
|
||||
|
||||
func getBaseSubscription() MembershipData {
|
||||
return MembershipData{
|
||||
// APIKey: config.Auth.APIKEY,
|
||||
Subscription: models.SubscriptionModel{
|
||||
Name: "Premium",
|
||||
Details: "A subscription detail",
|
||||
MonthlyFee: 12.0,
|
||||
HourlyRate: 14.0,
|
||||
},
|
||||
}
|
||||
}
|
||||
func customizeSubscription(customize func(MembershipData) MembershipData) MembershipData {
|
||||
subscription := getBaseSubscription()
|
||||
return customize(subscription)
|
||||
}
|
||||
|
||||
func getSubscriptionRegistrationData() []RegisterSubscriptionTest {
|
||||
return []RegisterSubscriptionTest{
|
||||
{
|
||||
Name: "Missing details should fail",
|
||||
WantResponse: http.StatusBadRequest,
|
||||
WantDBData: map[string]interface{}{"name": "Just a Subscription"},
|
||||
Assert: false,
|
||||
Input: GenerateInputJSON(
|
||||
customizeSubscription(func(subscription MembershipData) MembershipData {
|
||||
subscription.Subscription.Details = ""
|
||||
return subscription
|
||||
})),
|
||||
},
|
||||
{
|
||||
Name: "Missing model name should fail",
|
||||
WantResponse: http.StatusBadRequest,
|
||||
WantDBData: map[string]interface{}{"name": ""},
|
||||
Assert: false,
|
||||
Input: GenerateInputJSON(
|
||||
customizeSubscription(func(subscription MembershipData) MembershipData {
|
||||
subscription.Subscription.Name = ""
|
||||
return subscription
|
||||
})),
|
||||
},
|
||||
{
|
||||
Name: "Negative monthly fee should fail",
|
||||
WantResponse: http.StatusBadRequest,
|
||||
WantDBData: map[string]interface{}{"name": "Premium"},
|
||||
Assert: false,
|
||||
Input: GenerateInputJSON(customizeSubscription(func(sub MembershipData) MembershipData {
|
||||
sub.Subscription.MonthlyFee = -10.0
|
||||
return sub
|
||||
})),
|
||||
},
|
||||
{
|
||||
Name: "Negative hourly rate should fail",
|
||||
WantResponse: http.StatusBadRequest,
|
||||
WantDBData: map[string]interface{}{"name": "Premium"},
|
||||
Assert: false,
|
||||
Input: GenerateInputJSON(customizeSubscription(func(sub MembershipData) MembershipData {
|
||||
sub.Subscription.HourlyRate = -1.0
|
||||
return sub
|
||||
})),
|
||||
},
|
||||
{
|
||||
Name: "correct entry should pass",
|
||||
WantResponse: http.StatusCreated,
|
||||
WantDBData: map[string]interface{}{"name": "Premium"},
|
||||
Assert: true,
|
||||
Input: GenerateInputJSON(
|
||||
customizeSubscription(func(subscription MembershipData) MembershipData {
|
||||
subscription.Subscription.Conditions = "Some Condition"
|
||||
subscription.Subscription.IncludedPerYear = 0
|
||||
subscription.Subscription.IncludedPerMonth = 1
|
||||
return subscription
|
||||
})),
|
||||
},
|
||||
{
|
||||
Name: "Duplicate subscription name should fail",
|
||||
WantResponse: http.StatusConflict,
|
||||
WantDBData: map[string]interface{}{"name": "Premium"},
|
||||
Assert: true, // The original subscription should still exist
|
||||
Input: GenerateInputJSON(getBaseSubscription()),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func getSubscriptionUpdateData() []UpdateSubscriptionTest {
|
||||
return []UpdateSubscriptionTest{
|
||||
{
|
||||
Name: "Modified Monthly Fee, should fail",
|
||||
WantResponse: http.StatusNotAcceptable,
|
||||
WantDBData: map[string]interface{}{"name": "Premium", "monthly_fee": "12"},
|
||||
Assert: true,
|
||||
Input: GenerateInputJSON(
|
||||
customizeSubscription(func(subscription MembershipData) MembershipData {
|
||||
subscription.Subscription.MonthlyFee = 123.0
|
||||
return subscription
|
||||
})),
|
||||
},
|
||||
{
|
||||
Name: "Missing ID, should fail",
|
||||
WantResponse: http.StatusNotAcceptable,
|
||||
WantDBData: map[string]interface{}{"name": "Premium"},
|
||||
Assert: true,
|
||||
Input: GenerateInputJSON(
|
||||
customizeSubscription(func(subscription MembershipData) MembershipData {
|
||||
subscription.Subscription.ID = 0
|
||||
return subscription
|
||||
})),
|
||||
},
|
||||
{
|
||||
Name: "Modified Hourly Rate, should fail",
|
||||
WantResponse: http.StatusNotAcceptable,
|
||||
WantDBData: map[string]interface{}{"name": "Premium", "hourly_rate": "14"},
|
||||
Assert: true,
|
||||
Input: GenerateInputJSON(
|
||||
customizeSubscription(func(subscription MembershipData) MembershipData {
|
||||
subscription.Subscription.HourlyRate = 3254.0
|
||||
return subscription
|
||||
})),
|
||||
},
|
||||
{
|
||||
Name: "IncludedPerYear changed, should fail",
|
||||
WantResponse: http.StatusNotAcceptable,
|
||||
WantDBData: map[string]interface{}{"name": "Premium", "included_per_year": "0"},
|
||||
Assert: true,
|
||||
Input: GenerateInputJSON(
|
||||
customizeSubscription(func(subscription MembershipData) MembershipData {
|
||||
subscription.Subscription.IncludedPerYear = 9873.0
|
||||
return subscription
|
||||
})),
|
||||
},
|
||||
{
|
||||
Name: "IncludedPerMonth changed, should fail",
|
||||
WantResponse: http.StatusNotAcceptable,
|
||||
WantDBData: map[string]interface{}{"name": "Premium", "included_per_month": "1"},
|
||||
Assert: true,
|
||||
Input: GenerateInputJSON(
|
||||
customizeSubscription(func(subscription MembershipData) MembershipData {
|
||||
subscription.Subscription.IncludedPerMonth = 23415.0
|
||||
return subscription
|
||||
})),
|
||||
},
|
||||
{
|
||||
Name: "Update non-existent subscription should fail",
|
||||
WantResponse: http.StatusNotAcceptable,
|
||||
WantDBData: map[string]interface{}{"name": "NonExistentSubscription"},
|
||||
Assert: false,
|
||||
Input: GenerateInputJSON(
|
||||
customizeSubscription(func(subscription MembershipData) MembershipData {
|
||||
subscription.Subscription.Name = "NonExistentSubscription"
|
||||
return subscription
|
||||
})),
|
||||
},
|
||||
{
|
||||
Name: "Correct Update should pass",
|
||||
WantResponse: http.StatusAccepted,
|
||||
WantDBData: map[string]interface{}{"name": "Premium", "details": "Altered Details"},
|
||||
Assert: true,
|
||||
Input: GenerateInputJSON(
|
||||
customizeSubscription(func(subscription MembershipData) MembershipData {
|
||||
subscription.Subscription.Details = "Altered Details"
|
||||
subscription.Subscription.Conditions = "Some Condition"
|
||||
subscription.Subscription.IncludedPerYear = 0
|
||||
subscription.Subscription.IncludedPerMonth = 1
|
||||
return subscription
|
||||
})),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func getSubscriptionDeleteData() []DeleteSubscriptionTest {
|
||||
|
||||
var premiumSub, basicSub models.SubscriptionModel
|
||||
database.DB.Where("name = ?", "Premium").First(&premiumSub)
|
||||
database.DB.Where("name = ?", "Basic").First(&basicSub)
|
||||
|
||||
logger.Error.Printf("premiumSub.ID: %v", premiumSub.ID)
|
||||
logger.Error.Printf("basicSub.ID: %v", basicSub.ID)
|
||||
return []DeleteSubscriptionTest{
|
||||
{
|
||||
Name: "Delete non-existent subscription should fail",
|
||||
WantResponse: http.StatusExpectationFailed,
|
||||
WantDBData: map[string]interface{}{"name": "NonExistentSubscription"},
|
||||
Assert: false,
|
||||
Input: GenerateInputJSON(
|
||||
customizeSubscription(func(subscription MembershipData) MembershipData {
|
||||
subscription.Subscription.Name = "NonExistentSubscription"
|
||||
subscription.Subscription.ID = basicSub.ID
|
||||
return subscription
|
||||
})),
|
||||
},
|
||||
{
|
||||
Name: "Delete subscription without name should fail",
|
||||
WantResponse: http.StatusExpectationFailed,
|
||||
WantDBData: map[string]interface{}{"name": ""},
|
||||
Assert: false,
|
||||
Input: GenerateInputJSON(
|
||||
customizeSubscription(func(subscription MembershipData) MembershipData {
|
||||
subscription.Subscription.Name = ""
|
||||
subscription.Subscription.ID = basicSub.ID
|
||||
return subscription
|
||||
})),
|
||||
},
|
||||
{
|
||||
Name: "Delete subscription with users should fail",
|
||||
WantResponse: http.StatusExpectationFailed,
|
||||
WantDBData: map[string]interface{}{"name": "Basic"},
|
||||
Assert: true,
|
||||
Input: GenerateInputJSON(
|
||||
customizeSubscription(func(subscription MembershipData) MembershipData {
|
||||
subscription.Subscription.Name = "Basic"
|
||||
subscription.Subscription.ID = basicSub.ID
|
||||
return subscription
|
||||
})),
|
||||
},
|
||||
{
|
||||
Name: "Delete valid subscription should succeed",
|
||||
WantResponse: http.StatusOK,
|
||||
WantDBData: map[string]interface{}{"name": "Premium"},
|
||||
Assert: false,
|
||||
Input: GenerateInputJSON(
|
||||
customizeSubscription(func(subscription MembershipData) MembershipData {
|
||||
subscription.Subscription.Name = "Premium"
|
||||
subscription.Subscription.ID = premiumSub.ID
|
||||
return subscription
|
||||
})),
|
||||
},
|
||||
}
|
||||
}
|
||||
105
go-backend/internal/controllers/user_Password.go
Normal file
105
go-backend/internal/controllers/user_Password.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package controllers
|
||||
|
||||
import (
|
||||
"GoMembership/internal/constants"
|
||||
"GoMembership/internal/utils"
|
||||
"GoMembership/pkg/errors"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func (uc *UserController) RequestPasswordChangeHandler(c *gin.Context) {
|
||||
|
||||
// Expected data from the user
|
||||
var input struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
utils.HandleValidationError(c, err)
|
||||
return
|
||||
}
|
||||
// find user
|
||||
db_user, err := uc.Service.GetUserByEmail(input.Email)
|
||||
if err != nil {
|
||||
utils.RespondWithError(c, err, "couldn't get user by email", http.StatusNotFound, "user.user", "user.email")
|
||||
return
|
||||
}
|
||||
|
||||
// check if user may change the password
|
||||
if db_user.Status <= constants.DisabledStatus {
|
||||
utils.RespondWithError(c, errors.ErrNotAuthorized, "User password change request denied, user is disabled", http.StatusForbidden, errors.Responses.Fields.Login, errors.Responses.Keys.UserDisabled)
|
||||
return
|
||||
}
|
||||
|
||||
// create token
|
||||
token, err := uc.Service.HandlePasswordChangeRequest(db_user)
|
||||
if err != nil {
|
||||
utils.RespondWithError(c, err, "couldn't handle password change request", http.StatusInternalServerError, errors.Responses.Fields.General, errors.Responses.Keys.InternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// send email
|
||||
if err := uc.EmailService.SendChangePasswordEmail(db_user, &token); err != nil {
|
||||
utils.RespondWithError(c, err, "Couldn't send change password email", http.StatusInternalServerError, errors.Responses.Fields.General, errors.Responses.Keys.InternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusAccepted, gin.H{
|
||||
"message": "password_change_requested",
|
||||
})
|
||||
}
|
||||
|
||||
func (uc *UserController) ChangePassword(c *gin.Context) {
|
||||
// Expected data from the user
|
||||
var input struct {
|
||||
Password string `json:"password" binding:"required"`
|
||||
Token string `json:"token" binding:"required"`
|
||||
}
|
||||
userIDint, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
utils.RespondWithError(c, err, "Invalid user ID", http.StatusBadRequest, errors.Responses.Fields.User, errors.Responses.Keys.InvalidUserID)
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
utils.HandleValidationError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
verification, err := uc.Service.VerifyUser(&input.Token, &constants.VerificationTypes.Password)
|
||||
if err != nil || uint(userIDint) != verification.UserID {
|
||||
if err == errors.ErrAlreadyVerified {
|
||||
utils.RespondWithError(c, err, "User already changed password", http.StatusConflict, errors.Responses.Fields.User, errors.Responses.Keys.PasswordAlreadyChanged)
|
||||
} else if err.Error() == "record not found" {
|
||||
utils.RespondWithError(c, err, "Couldn't find verification. This is most probably a outdated token.", http.StatusGone, errors.Responses.Fields.User, errors.Responses.Keys.NoAuthToken)
|
||||
} else {
|
||||
utils.RespondWithError(c, err, "Couldn't verify user", http.StatusInternalServerError, errors.Responses.Fields.General, errors.Responses.Keys.InternalServerError)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
user, err := uc.Service.GetUserByID(verification.UserID)
|
||||
if err != nil {
|
||||
utils.RespondWithError(c, err, "Couldn't find user", http.StatusNotFound, errors.Responses.Fields.User, errors.Responses.Keys.UserNotFoundWrongPassword)
|
||||
return
|
||||
}
|
||||
|
||||
user.Status = constants.ActiveStatus
|
||||
user.Verification = *verification
|
||||
user.ID = verification.UserID
|
||||
user.Password = input.Password
|
||||
|
||||
_, err = uc.Service.UpdateUser(user)
|
||||
if err != nil {
|
||||
utils.RespondWithError(c, err, "Couldn't update user", http.StatusInternalServerError, errors.Responses.Fields.User, errors.Responses.Keys.InternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "password_changed",
|
||||
})
|
||||
}
|
||||
366
go-backend/internal/controllers/user_controller.go
Normal file
366
go-backend/internal/controllers/user_controller.go
Normal file
@@ -0,0 +1,366 @@
|
||||
package controllers
|
||||
|
||||
import (
|
||||
"GoMembership/internal/config"
|
||||
"GoMembership/internal/constants"
|
||||
"GoMembership/internal/middlewares"
|
||||
"GoMembership/internal/models"
|
||||
"GoMembership/internal/services"
|
||||
"GoMembership/internal/utils"
|
||||
"GoMembership/internal/validation"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"GoMembership/pkg/errors"
|
||||
"GoMembership/pkg/logger"
|
||||
)
|
||||
|
||||
type UserController struct {
|
||||
Service services.UserServiceInterface
|
||||
EmailService *services.EmailService
|
||||
ConsentService services.ConsentServiceInterface
|
||||
BankAccountService services.BankAccountServiceInterface
|
||||
MembershipService services.MembershipServiceInterface
|
||||
LicenceService services.LicenceInterface
|
||||
}
|
||||
|
||||
type RegistrationData struct {
|
||||
User models.User `json:"user"`
|
||||
}
|
||||
|
||||
func (uc *UserController) CurrentUserHandler(c *gin.Context) {
|
||||
requestUser, err := uc.ExtractUserFromContext(c)
|
||||
if err != nil {
|
||||
utils.RespondWithError(c, err, "Error extracting user from context in CurrentUserHandler", http.StatusBadRequest, "general", "server.error.internal_server_error")
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"user": requestUser.Safe(),
|
||||
})
|
||||
}
|
||||
|
||||
func (uc *UserController) GetAllUsers(c *gin.Context) {
|
||||
|
||||
requestUser, err := uc.ExtractUserFromContext(c)
|
||||
if err != nil {
|
||||
utils.RespondWithError(c, err, "Error extracting user from context in UpdateHandler", http.StatusBadRequest, "general", "server.validation.no_auth_tokenw")
|
||||
return
|
||||
}
|
||||
if requestUser.RoleID == constants.Roles.Member {
|
||||
utils.RespondWithError(c, errors.ErrNotAuthorized, "Not allowed to update user", http.StatusForbidden, "user.user", "server.error.unauthorized")
|
||||
return
|
||||
}
|
||||
|
||||
users, err := uc.Service.GetUsers(nil)
|
||||
if err != nil {
|
||||
utils.RespondWithError(c, err, "Error getting users in GetAllUsers", http.StatusInternalServerError, "user.user", "server.error.internal_server_error")
|
||||
return
|
||||
}
|
||||
|
||||
// Create a slice to hold the safe user representations
|
||||
safeUsers := make([]map[string]interface{}, len(*users))
|
||||
|
||||
// Convert each user to its safe representation
|
||||
for i, user := range *users {
|
||||
safeUsers[i] = user.Safe()
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"users": users,
|
||||
})
|
||||
}
|
||||
|
||||
func (uc *UserController) UpdateHandler(c *gin.Context) {
|
||||
// 1. Extract and validate the user ID from the route
|
||||
requestUser, err := uc.ExtractUserFromContext(c)
|
||||
if err != nil {
|
||||
utils.RespondWithError(c, err, "Error extracting user from context in UpdateHandler", http.StatusBadRequest, "general", "server.validation.no_auth_tokenw")
|
||||
return
|
||||
}
|
||||
|
||||
var user models.User
|
||||
var updateData RegistrationData
|
||||
if err := c.ShouldBindJSON(&updateData); err != nil {
|
||||
utils.HandleValidationError(c, err)
|
||||
return
|
||||
}
|
||||
user = updateData.User
|
||||
|
||||
if !utils.HasPrivilige(requestUser, constants.Priviliges.Update) && user.ID != requestUser.ID {
|
||||
utils.RespondWithError(c, errors.ErrNotAuthorized, "Not allowed to update user", http.StatusUnauthorized, "user.user", "server.error.unauthorized")
|
||||
return
|
||||
}
|
||||
existingUser, err := uc.Service.GetUserByID(user.ID)
|
||||
if err != nil {
|
||||
utils.RespondWithError(c, err, "Error finding an existing user", http.StatusNotFound, "user.user", "server.error.not_found")
|
||||
return
|
||||
}
|
||||
// user.Membership.ID = existingUser.Membership.ID
|
||||
|
||||
// user.MembershipID = existingUser.MembershipID
|
||||
// if existingUser.Licence != nil {
|
||||
// user.Licence.ID = existingUser.Licence.ID
|
||||
// }
|
||||
// user.LicenceID = existingUser.LicenceID
|
||||
// user.BankAccount.ID = existingUser.BankAccount.ID
|
||||
// user.BankAccountID = existingUser.BankAccountID
|
||||
|
||||
if requestUser.RoleID <= constants.Priviliges.View {
|
||||
existingUser.Password = ""
|
||||
if err := utils.FilterAllowedStructFields(&user, existingUser, constants.MemberUpdateFields, ""); err != nil {
|
||||
if err.Error() == "Not authorized" {
|
||||
utils.RespondWithError(c, errors.ErrNotAuthorized, "Trying to update unauthorized fields", http.StatusUnauthorized, "user.user", "server.error.unauthorized")
|
||||
return
|
||||
}
|
||||
utils.RespondWithError(c, err, "Error filtering users input fields", http.StatusInternalServerError, "user.user", "server.error.internal_server_error")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
updatedUser, err := uc.Service.UpdateUser(&user)
|
||||
if err != nil {
|
||||
utils.HandleUserUpdateError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info.Printf("User %d updated successfully by user %d", updatedUser.ID, requestUser.ID)
|
||||
|
||||
c.JSON(http.StatusAccepted, gin.H{"message": "User updated successfully", "user": updatedUser.Safe()})
|
||||
}
|
||||
|
||||
func (uc *UserController) DeleteUser(c *gin.Context) {
|
||||
|
||||
requestUser, err := uc.ExtractUserFromContext(c)
|
||||
if err != nil {
|
||||
utils.RespondWithError(c, err, "Error extracting user from context in DeleteUser", http.StatusBadRequest, "general", "server.validation.no_auth_tokenw")
|
||||
return
|
||||
}
|
||||
|
||||
type deleteData struct {
|
||||
User struct {
|
||||
ID uint `json:"id"`
|
||||
LastName string `json:"last_name"`
|
||||
} `json:"user"`
|
||||
}
|
||||
|
||||
var data deleteData
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
utils.HandleValidationError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
if !utils.HasPrivilige(requestUser, constants.Priviliges.Delete) && data.User.ID != requestUser.ID {
|
||||
utils.RespondWithError(c, errors.ErrNotAuthorized, "Not allowed to delete user", http.StatusForbidden, "user.user", "server.error.unauthorized")
|
||||
return
|
||||
}
|
||||
|
||||
logger.Error.Printf("Deleting user: %v", data.User)
|
||||
if err := uc.Service.DeleteUser(data.User.LastName, data.User.ID); err != nil {
|
||||
utils.RespondWithError(c, err, "Error during user deletion", http.StatusInternalServerError, "user.user", "server.error.internal_server_error")
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "User deleted successfully"})
|
||||
}
|
||||
|
||||
func (uc *UserController) ExtractUserFromContext(c *gin.Context) (*models.User, error) {
|
||||
|
||||
tokenString, err := c.Cookie("jwt")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, claims, err := middlewares.ExtractContentFrom(tokenString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
jwtUserID := uint((*claims)["user_id"].(float64))
|
||||
user, err := uc.Service.GetUserByID(jwtUserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (uc *UserController) LogoutHandler(c *gin.Context) {
|
||||
tokenString, err := c.Cookie("jwt")
|
||||
if err != nil {
|
||||
logger.Error.Printf("unable to get token from cookie: %#v", err)
|
||||
}
|
||||
|
||||
middlewares.InvalidateSession(tokenString)
|
||||
|
||||
c.SetCookie("jwt", "", -1, "/", "", true, true)
|
||||
c.JSON(http.StatusOK, gin.H{"message": "Logged out successfully"})
|
||||
}
|
||||
|
||||
func (uc *UserController) LoginHandler(c *gin.Context) {
|
||||
var input struct {
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
utils.RespondWithError(c, err, "Invalid JSON or malformed request", http.StatusBadRequest, errors.Responses.Fields.General, errors.Responses.Keys.Invalid)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := uc.Service.GetUserByEmail(input.Email)
|
||||
if err != nil {
|
||||
utils.RespondWithError(c, err, "Login Error; user not found", http.StatusNotFound,
|
||||
errors.Responses.Fields.Login,
|
||||
errors.Responses.Keys.UserNotFoundWrongPassword)
|
||||
return
|
||||
}
|
||||
|
||||
if user.Status <= constants.DisabledStatus {
|
||||
utils.RespondWithError(c, fmt.Errorf("User banned from login %v %v", user.FirstName, user.LastName),
|
||||
"Login Error; user is disabled",
|
||||
http.StatusNotAcceptable,
|
||||
errors.Responses.Fields.Login,
|
||||
errors.Responses.Keys.UserDisabled)
|
||||
return
|
||||
}
|
||||
|
||||
ok, err := user.PasswordMatches(input.Password)
|
||||
if err != nil {
|
||||
utils.RespondWithError(c, err, "Login Error; password comparisson failed", http.StatusInternalServerError, errors.Responses.Fields.Login, errors.Responses.Keys.InternalServerError)
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
utils.RespondWithError(c, fmt.Errorf("%v %v(%v)", user.FirstName, user.LastName, user.Email),
|
||||
"Login Error; wrong password",
|
||||
http.StatusNotAcceptable,
|
||||
errors.Responses.Fields.Login,
|
||||
errors.Responses.Keys.UserNotFoundWrongPassword)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Error.Printf("jwtsecret: %v", config.Auth.JWTSecret)
|
||||
token, err := middlewares.GenerateToken(config.Auth.JWTSecret, user, "")
|
||||
if err != nil {
|
||||
utils.RespondWithError(c, err, "Error generating token in LoginHandler", http.StatusInternalServerError, errors.Responses.Fields.Login, errors.Responses.Keys.JwtGenerationFailed)
|
||||
return
|
||||
}
|
||||
|
||||
utils.SetCookie(c, token)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "Login successful",
|
||||
})
|
||||
}
|
||||
|
||||
func (uc *UserController) RegisterUser(c *gin.Context) {
|
||||
|
||||
var regData RegistrationData
|
||||
logger.Error.Printf("registering user...")
|
||||
if err := c.ShouldBindJSON(®Data); err != nil {
|
||||
utils.HandleValidationError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info.Printf("Registering user %v", regData.User.Email)
|
||||
selectedModel, err := uc.MembershipService.GetSubscriptionByName(®Data.User.Membership.SubscriptionModel.Name)
|
||||
if err != nil {
|
||||
utils.RespondWithError(c, err, "Error in Registeruser, couldn't get selected model", http.StatusNotFound, "subscription_model", "server.validation.subscription_model_not_found")
|
||||
return
|
||||
}
|
||||
regData.User.Membership.SubscriptionModel = *selectedModel
|
||||
if selectedModel.RequiredMembershipField != "" {
|
||||
if err := validation.CheckParentMembershipID(regData.User.Membership); err != nil {
|
||||
utils.RespondWithError(c, err, "Error in RegisterUser, couldn't check parent membership id", http.StatusBadRequest, "parent_membership_id", "server.validation.parent_membership_id_not_found")
|
||||
return
|
||||
}
|
||||
}
|
||||
regData.User.RoleID = constants.Roles.Member
|
||||
|
||||
// Register User
|
||||
id, token, err := uc.Service.RegisterUser(®Data.User)
|
||||
if err != nil {
|
||||
logger.Error.Printf("Couldn't register User(%v): %v", regData.User.Email, err)
|
||||
if strings.Contains(err.Error(), "UNIQUE constraint failed: users.email") {
|
||||
utils.RespondWithError(c, err, "Error in RegisterUser, couldn't register user", http.StatusConflict, "email", "server.validation.email_already_exists")
|
||||
} else {
|
||||
utils.RespondWithError(c, err, "Error in RegisterUser, couldn't register user", http.StatusConflict, "general", "server.error.internal_server_error")
|
||||
}
|
||||
return
|
||||
}
|
||||
regData.User.ID = id
|
||||
|
||||
// Register Consents
|
||||
var consents = [2]models.Consent{
|
||||
{
|
||||
FirstName: regData.User.FirstName,
|
||||
LastName: regData.User.LastName,
|
||||
Email: regData.User.Email,
|
||||
ConsentType: "TermsOfService",
|
||||
},
|
||||
{
|
||||
FirstName: regData.User.FirstName,
|
||||
LastName: regData.User.LastName,
|
||||
Email: regData.User.Email,
|
||||
ConsentType: "Privacy",
|
||||
},
|
||||
}
|
||||
for _, consent := range consents {
|
||||
_, err = uc.ConsentService.RegisterConsent(&consent)
|
||||
if err != nil {
|
||||
utils.RespondWithError(c, err, "Error in RegisterUser, couldn't register consent", http.StatusInternalServerError, "general", "server.error.internal_server_error")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Send notifications
|
||||
if err := uc.EmailService.SendVerificationEmail(®Data.User, &token); err != nil {
|
||||
logger.Error.Printf("Failed to send email verification email to user(%v): %v", regData.User.Email, err)
|
||||
// Proceed without returning error since user registration is successful
|
||||
// TODO Notify Admin
|
||||
}
|
||||
|
||||
// Notify admin of new user registration
|
||||
if err := uc.EmailService.SendRegistrationNotification(®Data.User); err != nil {
|
||||
logger.Error.Printf("Failed to notify admin of new user(%v) registration: %v", regData.User.Email, err)
|
||||
// Proceed without returning error since user registration is successful
|
||||
// TODO Notify Admin
|
||||
}
|
||||
c.JSON(http.StatusCreated, gin.H{
|
||||
"message": "Registration successuful",
|
||||
"id": regData.User.ID,
|
||||
})
|
||||
}
|
||||
|
||||
func (uc *UserController) VerifyMailHandler(c *gin.Context) {
|
||||
token := c.Query("token")
|
||||
if token == "" {
|
||||
logger.Error.Println("Missing token to verify mail")
|
||||
c.HTML(http.StatusBadRequest, "verification_error.html", gin.H{"ErrorMessage": "Missing token"})
|
||||
return
|
||||
}
|
||||
|
||||
verification, err := uc.Service.VerifyUser(&token, &constants.VerificationTypes.Email)
|
||||
if err != nil {
|
||||
logger.Error.Printf("Cannot verify user: %v", err)
|
||||
c.HTML(http.StatusUnauthorized, "verification_error.html", gin.H{"ErrorMessage": "Emailadresse wurde schon bestätigt. Sollte dies nicht der Fall sein, wende Dich bitte an info@carsharing-hasloh.de."})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := uc.Service.GetUserByID(verification.UserID)
|
||||
if err != nil {
|
||||
utils.RespondWithError(c, err, "Couldn't find user", http.StatusNotFound, errors.Responses.Fields.User, errors.Responses.Keys.UserNotFoundWrongPassword)
|
||||
return
|
||||
}
|
||||
|
||||
user.Status = constants.VerifiedStatus
|
||||
user.Verification = *verification
|
||||
user.ID = verification.UserID
|
||||
user.Password = ""
|
||||
|
||||
uc.Service.UpdateUser(user)
|
||||
logger.Info.Printf("Verified User: %#v", user.Email)
|
||||
|
||||
uc.EmailService.SendWelcomeEmail(user)
|
||||
c.HTML(http.StatusOK, "verification_success.html", gin.H{"FirstName": user.FirstName})
|
||||
}
|
||||
1228
go-backend/internal/controllers/user_controller_test.go
Normal file
1228
go-backend/internal/controllers/user_controller_test.go
Normal file
File diff suppressed because it is too large
Load Diff
177
go-backend/internal/database/db.go
Normal file
177
go-backend/internal/database/db.go
Normal file
@@ -0,0 +1,177 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"GoMembership/internal/constants"
|
||||
"GoMembership/internal/models"
|
||||
"GoMembership/pkg/logger"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"time"
|
||||
|
||||
"github.com/alexedwards/argon2id"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var DB *gorm.DB
|
||||
|
||||
func Open(dbPath string, adminMail string) error {
|
||||
|
||||
db, err := gorm.Open(sqlite.Open(dbPath), &gorm.Config{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := db.AutoMigrate(
|
||||
&models.User{},
|
||||
&models.SubscriptionModel{},
|
||||
&models.Membership{},
|
||||
&models.Consent{},
|
||||
&models.Verification{},
|
||||
&models.Licence{},
|
||||
&models.Category{},
|
||||
&models.BankAccount{}); err != nil {
|
||||
logger.Error.Fatalf("Couldn't create database: %v", err)
|
||||
return err
|
||||
}
|
||||
DB = db
|
||||
|
||||
logger.Info.Print("Opened DB")
|
||||
|
||||
var categoriesCount int64
|
||||
db.Model(&models.Category{}).Count(&categoriesCount)
|
||||
if categoriesCount == 0 {
|
||||
categories := createLicenceCategories()
|
||||
for _, model := range categories {
|
||||
result := db.Create(&model)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var subscriptionsCount int64
|
||||
db.Model(&models.SubscriptionModel{}).Count(&subscriptionsCount)
|
||||
subscriptionModels := createSubscriptionModels()
|
||||
for _, model := range subscriptionModels {
|
||||
var exists int64
|
||||
db.
|
||||
Model(&models.SubscriptionModel{}).
|
||||
Where("name = ?", model.Name).
|
||||
Count(&exists)
|
||||
logger.Error.Printf("looked for model.name %v and found %v", model.Name, exists)
|
||||
if exists == 0 {
|
||||
result := db.Create(&model)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var userCount int64
|
||||
db.Model(&models.User{}).Count(&userCount)
|
||||
if userCount == 0 {
|
||||
var createdModel models.SubscriptionModel
|
||||
if err := db.First(&createdModel).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
admin, err := createAdmin(adminMail, createdModel.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
result := db.Session(&gorm.Session{FullSaveAssociations: true}).Create(&admin)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func createSubscriptionModels() []models.SubscriptionModel {
|
||||
return []models.SubscriptionModel{
|
||||
{
|
||||
Name: "Keins",
|
||||
Details: "Dieses Modell ist für Vereinsmitglieder, die keinen Wunsch haben, an dem Carhsharing teilzunehmen.",
|
||||
HourlyRate: 999,
|
||||
MonthlyFee: 0,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func createLicenceCategories() []models.Category {
|
||||
return []models.Category{
|
||||
{Name: "AM"},
|
||||
{Name: "A1"},
|
||||
{Name: "A2"},
|
||||
{Name: "A"},
|
||||
{Name: "B"},
|
||||
{Name: "C1"},
|
||||
{Name: "C"},
|
||||
{Name: "D1"},
|
||||
{Name: "D"},
|
||||
{Name: "BE"},
|
||||
{Name: "C1E"},
|
||||
{Name: "CE"},
|
||||
{Name: "D1E"},
|
||||
{Name: "DE"},
|
||||
{Name: "T"},
|
||||
{Name: "L"},
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Landing page to create an admin
|
||||
|
||||
func createAdmin(userMail string, subscriptionModelID uint) (*models.User, error) {
|
||||
passwordBytes := make([]byte, 12)
|
||||
_, err := rand.Read(passwordBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Encode into a URL-safe base64 string
|
||||
password := base64.URLEncoding.EncodeToString(passwordBytes)[:12]
|
||||
|
||||
hash, err := argon2id.CreateHash(password, argon2id.DefaultParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logger.Error.Print("==============================================================")
|
||||
logger.Error.Printf("Admin Email: %v", userMail)
|
||||
logger.Error.Printf("Admin Password: %v", password)
|
||||
logger.Error.Print("==============================================================")
|
||||
|
||||
return &models.User{
|
||||
FirstName: "ad",
|
||||
LastName: "min",
|
||||
DateOfBirth: time.Now().AddDate(-20, 0, 0),
|
||||
Password: hash,
|
||||
Address: "Downhill 4",
|
||||
ZipCode: "99999",
|
||||
City: "TechTown",
|
||||
Phone: "0123455678",
|
||||
Email: userMail,
|
||||
Status: constants.ActiveStatus,
|
||||
RoleID: constants.Roles.Admin,
|
||||
Membership: models.Membership{
|
||||
Status: constants.DisabledStatus,
|
||||
StartDate: time.Now(),
|
||||
SubscriptionModelID: subscriptionModelID,
|
||||
},
|
||||
BankAccount: models.BankAccount{},
|
||||
Licence: &models.Licence{
|
||||
Status: constants.UnverifiedStatus,
|
||||
},
|
||||
}, nil
|
||||
//"DE49700500000008447644", //fake
|
||||
}
|
||||
|
||||
func Close() error {
|
||||
logger.Info.Print("Closing DB")
|
||||
db, err := DB.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return db.Close()
|
||||
}
|
||||
31
go-backend/internal/middlewares/api.go
Normal file
31
go-backend/internal/middlewares/api.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package middlewares
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"GoMembership/internal/config"
|
||||
)
|
||||
|
||||
func APIKeyMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
clientAPIKey := c.GetHeader("X-API-Key")
|
||||
|
||||
if clientAPIKey == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "API key is missing"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Using subtle.ConstantTimeCompare to mitigate timing attacks
|
||||
if subtle.ConstantTimeCompare([]byte(clientAPIKey), []byte(config.Auth.APIKEY)) != 1 {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid API key"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
61
go-backend/internal/middlewares/api_test.go
Normal file
61
go-backend/internal/middlewares/api_test.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package middlewares
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"GoMembership/internal/config"
|
||||
)
|
||||
|
||||
func TestAPIKeyMiddleware(t *testing.T) {
|
||||
// Set up a test API key
|
||||
testAPIKey := "test-api-key-12345"
|
||||
config.Auth.APIKEY = testAPIKey
|
||||
|
||||
// Set Gin to Test Mode
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// Tests table
|
||||
tests := []struct {
|
||||
name string
|
||||
apiKey string
|
||||
wantStatus int
|
||||
}{
|
||||
{"Valid API Key", testAPIKey, http.StatusOK},
|
||||
{"Missing API Key", "", http.StatusUnauthorized},
|
||||
{"Invalid API Key", "wrong-key", http.StatusUnauthorized},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Set up a new test router and handler
|
||||
router := gin.New()
|
||||
router.Use(APIKeyMiddleware())
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
// Create a test request
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/test", nil)
|
||||
if tt.apiKey != "" {
|
||||
req.Header.Set("X-API-Key", tt.apiKey)
|
||||
}
|
||||
|
||||
// Serve the request
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// Assert the response
|
||||
assert.Equal(t, tt.wantStatus, w.Code)
|
||||
|
||||
// Additional assertions for specific cases
|
||||
if tt.wantStatus == http.StatusUnauthorized {
|
||||
assert.Contains(t, w.Body.String(), "API key")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
179
go-backend/internal/middlewares/auth.go
Normal file
179
go-backend/internal/middlewares/auth.go
Normal file
@@ -0,0 +1,179 @@
|
||||
package middlewares
|
||||
|
||||
import (
|
||||
"GoMembership/internal/config"
|
||||
"GoMembership/internal/models"
|
||||
"GoMembership/internal/utils"
|
||||
customerrors "GoMembership/pkg/errors"
|
||||
"GoMembership/pkg/logger"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type Session struct {
|
||||
UserID uint
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
var (
|
||||
sessionDuration = 5 * 24 * time.Hour
|
||||
jwtSigningMethod = jwt.SigningMethodHS256
|
||||
jwtParser = jwt.NewParser(jwt.WithValidMethods([]string{jwtSigningMethod.Alg()}))
|
||||
sessions = make(map[string]*Session)
|
||||
)
|
||||
|
||||
func verifyAndRenewToken(tokenString string) (string, uint, error) {
|
||||
if tokenString == "" {
|
||||
logger.Error.Printf("empty tokenstring")
|
||||
return "", 0, fmt.Errorf("Authorization token is required")
|
||||
}
|
||||
token, claims, err := ExtractContentFrom(tokenString)
|
||||
if err != nil {
|
||||
logger.Error.Printf("Couldn't parse JWT token String: %v", err)
|
||||
return "", 0, err
|
||||
}
|
||||
sessionID := (*claims)["session_id"].(string)
|
||||
userID := uint((*claims)["user_id"].(float64))
|
||||
roleID := int8((*claims)["role_id"].(float64))
|
||||
|
||||
session, ok := sessions[sessionID]
|
||||
if !ok {
|
||||
logger.Error.Printf("session not found")
|
||||
return "", 0, fmt.Errorf("session not found")
|
||||
}
|
||||
if userID != session.UserID {
|
||||
return "", 0, fmt.Errorf("Cookie has been altered, aborting..")
|
||||
}
|
||||
if token.Valid {
|
||||
// token is valid, so we can return the old tokenString
|
||||
return tokenString, session.UserID, customerrors.ErrValidToken
|
||||
}
|
||||
|
||||
if time.Now().After(sessions[sessionID].ExpiresAt) {
|
||||
delete(sessions, sessionID)
|
||||
logger.Error.Printf("session expired")
|
||||
return "", 0, fmt.Errorf("session expired")
|
||||
}
|
||||
session.ExpiresAt = time.Now().Add(sessionDuration)
|
||||
|
||||
logger.Error.Printf("Session still valid generating new token")
|
||||
// Session is still valid, generate a new token
|
||||
user := models.User{ID: userID, RoleID: roleID}
|
||||
newTokenString, err := GenerateToken(config.Auth.JWTSecret, &user, sessionID)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
|
||||
return newTokenString, session.UserID, nil
|
||||
}
|
||||
|
||||
func AuthMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
tokenString, err := c.Cookie("jwt")
|
||||
if err != nil {
|
||||
logger.Error.Printf("No Auth token: %v\n", err)
|
||||
c.JSON(http.StatusUnauthorized,
|
||||
gin.H{"errors": []gin.H{{
|
||||
"field": "general",
|
||||
"key": "server.error.no_auth_token",
|
||||
}}})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
newToken, userID, err := verifyAndRenewToken(tokenString)
|
||||
if err != nil {
|
||||
if err == customerrors.ErrValidToken {
|
||||
c.Set("user_id", uint(userID))
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
logger.Error.Printf("Token(%v) is invalid: %v\n", tokenString, err)
|
||||
c.JSON(http.StatusUnauthorized,
|
||||
gin.H{"errors": []gin.H{{
|
||||
"field": "general",
|
||||
"key": "server.error.no_auth_token",
|
||||
}}})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
utils.SetCookie(c, newToken)
|
||||
c.Set("user_id", uint(userID))
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func GenerateToken(jwtKey string, user *models.User, sessionID string) (string, error) {
|
||||
if sessionID == "" {
|
||||
sessionID = uuid.New().String()
|
||||
}
|
||||
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{
|
||||
"user_id": user.ID,
|
||||
"role_id": user.RoleID,
|
||||
"session_id": sessionID,
|
||||
"exp": time.Now().Add(time.Minute * 1).Unix(), // Token expires in 10 Minutes
|
||||
})
|
||||
UpdateSession(sessionID, user.ID)
|
||||
return token.SignedString([]byte(jwtKey))
|
||||
}
|
||||
|
||||
func ExtractContentFrom(tokenString string) (*jwt.Token, *jwt.MapClaims, error) {
|
||||
|
||||
token, err := jwtParser.Parse(tokenString, func(_ *jwt.Token) (interface{}, error) {
|
||||
return []byte(config.Auth.JWTSecret), nil
|
||||
})
|
||||
|
||||
if !errors.Is(err, jwt.ErrTokenExpired) && err != nil {
|
||||
logger.Error.Printf("Error during token(%v) parsing: %#v", tokenString, err)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Token is expired, check if session is still valid
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
logger.Error.Printf("Invalid Token Claims")
|
||||
return nil, nil, fmt.Errorf("invalid token claims")
|
||||
}
|
||||
|
||||
if !ok {
|
||||
logger.Error.Printf("invalid session_id in token")
|
||||
return nil, nil, fmt.Errorf("invalid session_id in token")
|
||||
}
|
||||
return token, &claims, nil
|
||||
}
|
||||
|
||||
func UpdateSession(sessionID string, userID uint) {
|
||||
sessions[sessionID] = &Session{
|
||||
UserID: userID,
|
||||
ExpiresAt: time.Now().Add(sessionDuration),
|
||||
}
|
||||
}
|
||||
|
||||
func InvalidateSession(token string) (bool, error) {
|
||||
claims := jwt.MapClaims{}
|
||||
_, err := jwt.ParseWithClaims(
|
||||
token,
|
||||
claims,
|
||||
func(token *jwt.Token) (interface{}, error) {
|
||||
return config.Auth.JWTSecret, nil
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("Couldn't get JWT claims: %#v", err)
|
||||
}
|
||||
|
||||
sessionID, ok := claims["session_id"].(string)
|
||||
if !ok {
|
||||
return false, fmt.Errorf("No SessionID found")
|
||||
}
|
||||
|
||||
delete(sessions, sessionID)
|
||||
return true, nil
|
||||
}
|
||||
191
go-backend/internal/middlewares/auth_test.go
Normal file
191
go-backend/internal/middlewares/auth_test.go
Normal file
@@ -0,0 +1,191 @@
|
||||
package middlewares
|
||||
|
||||
import (
|
||||
"GoMembership/internal/config"
|
||||
"GoMembership/internal/constants"
|
||||
"GoMembership/internal/models"
|
||||
"GoMembership/pkg/logger"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func setupTestEnvironment() {
|
||||
cwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to get current working directory: %v", err)
|
||||
}
|
||||
configFilePath := filepath.Join(cwd, "..", "..", "configs", "config.json")
|
||||
templateHTMLPath := filepath.Join(cwd, "..", "..", "templates", "html")
|
||||
templateMailPath := filepath.Join(cwd, "..", "..", "templates", "email")
|
||||
|
||||
if err := os.Setenv("TEMPLATE_MAIL_PATH", templateMailPath); err != nil {
|
||||
log.Fatalf("Error setting environment variable: %v", err)
|
||||
}
|
||||
if err := os.Setenv("TEMPLATE_HTML_PATH", templateHTMLPath); err != nil {
|
||||
log.Fatalf("Error setting environment variable: %v", err)
|
||||
}
|
||||
if err := os.Setenv("CONFIG_FILE_PATH", configFilePath); err != nil {
|
||||
log.Fatalf("Error setting environment variable: %v", err)
|
||||
}
|
||||
config.LoadConfig()
|
||||
logger.Info.Printf("Config: %#v", config.CFG)
|
||||
}
|
||||
|
||||
func TestAuthMiddleware(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
setupTestEnvironment()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupAuth func(r *http.Request)
|
||||
expectedStatus int
|
||||
expectNewCookie bool
|
||||
expectedUserID uint
|
||||
}{
|
||||
{
|
||||
name: "Valid Token",
|
||||
setupAuth: func(r *http.Request) {
|
||||
user := models.User{ID: 123, RoleID: constants.Roles.Member}
|
||||
token, _ := GenerateToken(config.Auth.JWTSecret, &user, "")
|
||||
r.AddCookie(&http.Cookie{Name: "jwt", Value: token})
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedUserID: 123,
|
||||
},
|
||||
{
|
||||
name: "Missing Cookie",
|
||||
setupAuth: func(r *http.Request) {},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedUserID: 0,
|
||||
},
|
||||
{
|
||||
name: "Invalid Token",
|
||||
setupAuth: func(r *http.Request) {
|
||||
r.AddCookie(&http.Cookie{Name: "jwt", Value: "InvalidToken"})
|
||||
},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedUserID: 0,
|
||||
},
|
||||
{
|
||||
name: "Expired Token with Valid Session",
|
||||
setupAuth: func(r *http.Request) {
|
||||
sessionID := "test-session"
|
||||
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{
|
||||
"user_id": 123,
|
||||
"role_id": constants.Roles.Member,
|
||||
"session_id": sessionID,
|
||||
"exp": time.Now().Add(-time.Hour).Unix(), // Expired 1 hour ago
|
||||
})
|
||||
tokenString, _ := token.SignedString([]byte(config.Auth.JWTSecret))
|
||||
r.AddCookie(&http.Cookie{Name: "jwt", Value: tokenString})
|
||||
UpdateSession(sessionID, 123) // Add a valid session
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectNewCookie: true,
|
||||
expectedUserID: 123,
|
||||
},
|
||||
{
|
||||
name: "Expired Token with Expired Session",
|
||||
setupAuth: func(r *http.Request) {
|
||||
sessionID := "expired-session"
|
||||
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{
|
||||
"user_id": 123,
|
||||
"role_id": constants.Roles.Member,
|
||||
"session_id": sessionID,
|
||||
"exp": time.Now().Add(-time.Hour).Unix(), // Expired 1 hour ago
|
||||
})
|
||||
tokenString, _ := token.SignedString([]byte(config.Auth.JWTSecret))
|
||||
r.AddCookie(&http.Cookie{Name: "jwt", Value: tokenString})
|
||||
// Don't add a session, simulating an expired session
|
||||
},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedUserID: 0,
|
||||
},
|
||||
{
|
||||
name: "Invalid Signature",
|
||||
setupAuth: func(r *http.Request) {
|
||||
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{
|
||||
"user_id": 123,
|
||||
"session_id": "some-session",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
})
|
||||
tokenString, _ := token.SignedString([]byte("wrong_secret"))
|
||||
r.AddCookie(&http.Cookie{Name: "jwt", Value: tokenString})
|
||||
},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedUserID: 0,
|
||||
},
|
||||
{
|
||||
name: "Invalid Signing Method",
|
||||
setupAuth: func(r *http.Request) {
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodES256, jwt.MapClaims{
|
||||
"user_id": 123,
|
||||
"session_id": "some-session",
|
||||
"role_id": constants.Roles.Member,
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
})
|
||||
tokenString, _ := token.SignedString([]byte(config.Auth.JWTSecret))
|
||||
r.AddCookie(&http.Cookie{Name: "jwt", Value: tokenString})
|
||||
},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedUserID: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
logger.Error.Print("==============================================================")
|
||||
logger.Error.Printf("Testing : %v", tt.name)
|
||||
logger.Error.Print("==============================================================")
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Setup
|
||||
r := gin.New()
|
||||
r.Use(AuthMiddleware())
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
userID, exists := c.Get("user_id")
|
||||
if exists {
|
||||
c.JSON(http.StatusOK, gin.H{"user_id": userID})
|
||||
} else {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"user_id": 0})
|
||||
}
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, "/test", nil)
|
||||
tt.setupAuth(req)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, tt.expectedStatus, w.Code)
|
||||
|
||||
if tt.expectedStatus == http.StatusOK {
|
||||
var response map[string]uint
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expectedUserID, response["user_id"])
|
||||
|
||||
// Check if a new cookie was set
|
||||
cookies := w.Result().Cookies()
|
||||
if tt.expectNewCookie {
|
||||
assert.GreaterOrEqual(t, len(cookies), 1)
|
||||
assert.Equal(t, "jwt", cookies[0].Name)
|
||||
assert.NotEmpty(t, cookies[0].Value)
|
||||
} else {
|
||||
assert.Equal(t, 0, len(cookies), "Unexpected cookie set")
|
||||
}
|
||||
} else {
|
||||
assert.Equal(t, 0, len(w.Result().Cookies()))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
22
go-backend/internal/middlewares/cors.go
Normal file
22
go-backend/internal/middlewares/cors.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package middlewares
|
||||
|
||||
import (
|
||||
"GoMembership/internal/config"
|
||||
"GoMembership/pkg/logger"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-contrib/cors"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func CORSMiddleware() gin.HandlerFunc {
|
||||
logger.Info.Print("Applying CORS")
|
||||
return cors.New(cors.Config{
|
||||
AllowOrigins: strings.Split(config.Site.AllowOrigins, ","),
|
||||
AllowMethods: []string{"GET", "POST", "PATCH", "PUT", "OPTIONS"},
|
||||
AllowHeaders: []string{"Origin", "Content-Type", "Accept", "Authorization", "X-Requested-With", "X-CSRF-Token"},
|
||||
ExposeHeaders: []string{"Content-Length"},
|
||||
AllowCredentials: true,
|
||||
MaxAge: 12 * 60 * 60, // 12 hours
|
||||
})
|
||||
}
|
||||
104
go-backend/internal/middlewares/cors_test.go
Normal file
104
go-backend/internal/middlewares/cors_test.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package middlewares
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"GoMembership/internal/config"
|
||||
"GoMembership/pkg/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
Host = "127.0.0.1"
|
||||
Port int = 2525
|
||||
)
|
||||
|
||||
func TestCORSMiddleware(t *testing.T) {
|
||||
cwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to get current working directory: %v", err)
|
||||
}
|
||||
|
||||
configFilePath := filepath.Join(cwd, "..", "..", "configs", "config.json")
|
||||
templateHTMLPath := filepath.Join(cwd, "..", "..", "templates", "html")
|
||||
templateMailPath := filepath.Join(cwd, "..", "..", "templates", "email")
|
||||
|
||||
if err := os.Setenv("TEMPLATE_MAIL_PATH", templateMailPath); err != nil {
|
||||
log.Fatalf("Error setting environment variable: %v", err)
|
||||
}
|
||||
if err := os.Setenv("TEMPLATE_HTML_PATH", templateHTMLPath); err != nil {
|
||||
log.Fatalf("Error setting environment variable: %v", err)
|
||||
}
|
||||
if err := os.Setenv("CONFIG_FILE_PATH", configFilePath); err != nil {
|
||||
log.Fatalf("Error setting environment variable: %v", err)
|
||||
}
|
||||
if err := os.Setenv("SMTP_HOST", Host); err != nil {
|
||||
log.Fatalf("Error setting environment variable: %v", err)
|
||||
}
|
||||
if err := os.Setenv("SMTP_PORT", strconv.Itoa(Port)); err != nil {
|
||||
log.Fatalf("Error setting environment variable: %v", err)
|
||||
}
|
||||
if err := os.Setenv("BASE_URL", "http://"+Host+":2525"); err != nil {
|
||||
log.Fatalf("Error setting environment variable: %v", err)
|
||||
}
|
||||
// Load your configuration
|
||||
config.LoadConfig()
|
||||
|
||||
// Create a gin router with the CORS middleware
|
||||
router := gin.New()
|
||||
router.Use(CORSMiddleware())
|
||||
|
||||
// Add a simple handler
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.String(200, "test")
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
origin string
|
||||
expectedStatus int
|
||||
expectedHeaders map[string]string
|
||||
}{
|
||||
{
|
||||
name: "Allowed origin",
|
||||
origin: config.Site.AllowOrigins,
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedHeaders: map[string]string{
|
||||
"Access-Control-Allow-Origin": config.Site.AllowOrigins,
|
||||
"Content-Type": "text/plain; charset=utf-8",
|
||||
"Access-Control-Allow-Credentials": "true",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Disallowed origin",
|
||||
origin: "http://evil.com",
|
||||
expectedStatus: http.StatusForbidden,
|
||||
expectedHeaders: map[string]string{
|
||||
"Access-Control-Allow-Origin": "",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("Origin", tt.origin)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, tt.expectedStatus, w.Code)
|
||||
logger.Info.Printf("Recieved Headers: %#v", w.Header())
|
||||
for key, value := range tt.expectedHeaders {
|
||||
assert.Equal(t, value, w.Header().Get(key))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
44
go-backend/internal/middlewares/csp.go
Normal file
44
go-backend/internal/middlewares/csp.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package middlewares
|
||||
|
||||
import (
|
||||
"GoMembership/internal/config"
|
||||
"GoMembership/pkg/logger"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func CSPMiddleware() gin.HandlerFunc {
|
||||
logger.Error.Printf("applying CSP")
|
||||
return func(c *gin.Context) {
|
||||
policy := "default-src 'self'; " +
|
||||
"script-src 'self' 'unsafe-inline'" +
|
||||
"style-src 'self' 'unsafe-inline'" +
|
||||
"img-src 'self'" +
|
||||
"font-src 'self'" +
|
||||
"connect-src 'self'; " +
|
||||
"frame-ancestors 'none'; " +
|
||||
"form-action 'self'; " +
|
||||
"base-uri 'self'; " +
|
||||
"upgrade-insecure-requests;"
|
||||
|
||||
if config.Env == "development" {
|
||||
policy += " report-uri /csp-report;"
|
||||
c.Header("Content-Security-Policy-Report-Only", policy)
|
||||
} else {
|
||||
c.Header("Content-Security-Policy", policy)
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func CSPReportHandling(c *gin.Context) {
|
||||
var report map[string]interface{}
|
||||
if err := c.BindJSON(&report); err != nil {
|
||||
|
||||
logger.Error.Printf("Couldn't Bind JSON: %#v", err)
|
||||
return
|
||||
}
|
||||
logger.Info.Printf("CSP Violation: %+v", report)
|
||||
c.Status(http.StatusNoContent)
|
||||
}
|
||||
81
go-backend/internal/middlewares/csp_test.go
Normal file
81
go-backend/internal/middlewares/csp_test.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package middlewares
|
||||
|
||||
import (
|
||||
"GoMembership/internal/config"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCSPMiddleware(t *testing.T) {
|
||||
// Save the current environment and restore it after the test
|
||||
originalEnv := config.Env
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
environment string
|
||||
expectedHeader string
|
||||
expectedPolicy string
|
||||
}{
|
||||
{
|
||||
name: "Development Environment",
|
||||
environment: "development",
|
||||
expectedHeader: "Content-Security-Policy-Report-Only",
|
||||
expectedPolicy: "default-src 'self'; " +
|
||||
"script-src 'self' 'unsafe-inline'" +
|
||||
"style-src 'self' 'unsafe-inline'" +
|
||||
"img-src 'self'" +
|
||||
"font-src 'self'" +
|
||||
"connect-src 'self'; " +
|
||||
"frame-ancestors 'none'; " +
|
||||
"form-action 'self'; " +
|
||||
"base-uri 'self'; " +
|
||||
"upgrade-insecure-requests; report-uri /csp-report;",
|
||||
},
|
||||
{
|
||||
name: "Production Environment",
|
||||
environment: "production",
|
||||
expectedHeader: "Content-Security-Policy",
|
||||
expectedPolicy: "default-src 'self'; " +
|
||||
"script-src 'self' 'unsafe-inline'" +
|
||||
"style-src 'self' 'unsafe-inline'" +
|
||||
"img-src 'self'" +
|
||||
"font-src 'self'" +
|
||||
"connect-src 'self'; " +
|
||||
"frame-ancestors 'none'; " +
|
||||
"form-action 'self'; " +
|
||||
"base-uri 'self'; " +
|
||||
"upgrade-insecure-requests;",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Set up the test environment
|
||||
config.Env = tt.environment
|
||||
|
||||
// Create a new Gin router with the middleware
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(CSPMiddleware())
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "test")
|
||||
})
|
||||
|
||||
// Create a test request
|
||||
req, _ := http.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Serve the request
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// Check the response
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, tt.expectedPolicy, w.Header().Get(tt.expectedHeader))
|
||||
})
|
||||
}
|
||||
config.Env = originalEnv
|
||||
}
|
||||
21
go-backend/internal/middlewares/headers.go
Normal file
21
go-backend/internal/middlewares/headers.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package middlewares
|
||||
|
||||
import (
|
||||
"GoMembership/pkg/logger"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func SecurityHeadersMiddleware() gin.HandlerFunc {
|
||||
logger.Error.Printf("applying headers")
|
||||
return func(c *gin.Context) {
|
||||
c.Header("X-Frame-Options", "DENY")
|
||||
c.Header("X-Content-Type-Options", "nosniff")
|
||||
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
|
||||
c.Header("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
|
||||
c.Header("X-XSS-Protection", "1; mode=block")
|
||||
c.Header("Feature-Policy", "geolocation 'none'; midi 'none'; sync-xhr 'none'; microphone 'none'; camera 'none'; magnetometer 'none'; gyroscope 'none'; speaker 'none'; fullscreen 'self'; payment 'none'")
|
||||
c.Header("Permissions-Policy", "geolocation=(), midi=(), sync-xhr=(), microphone=(), camera=(), magnetometer=(), gyroscope=(), fullscreen=(self), payment=()")
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
83
go-backend/internal/middlewares/rate_limit.go
Normal file
83
go-backend/internal/middlewares/rate_limit.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package middlewares
|
||||
|
||||
import (
|
||||
"GoMembership/pkg/logger"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
type IPRateLimiter struct {
|
||||
ips map[string]*rate.Limiter
|
||||
mu *sync.RWMutex
|
||||
r rate.Limit
|
||||
b int
|
||||
}
|
||||
|
||||
func NewIPRateLimiter(r int, b int) *IPRateLimiter {
|
||||
return &IPRateLimiter{
|
||||
ips: make(map[string]*rate.Limiter),
|
||||
mu: &sync.RWMutex{},
|
||||
r: rate.Limit(r),
|
||||
b: b,
|
||||
}
|
||||
}
|
||||
|
||||
func (i *IPRateLimiter) GetLimiter(ip string) *rate.Limiter {
|
||||
i.mu.Lock()
|
||||
defer i.mu.Unlock()
|
||||
|
||||
limiter, exists := i.ips[ip]
|
||||
if !exists {
|
||||
limiter = rate.NewLimiter(i.r, i.b)
|
||||
i.ips[ip] = limiter
|
||||
}
|
||||
|
||||
return limiter
|
||||
}
|
||||
|
||||
// func RateLimitMiddleware() gin.HandlerFunc {
|
||||
// if iPLimiter == nil {
|
||||
|
||||
// iPLimiter := NewIPRateLimiter(
|
||||
// rate.Limit(config.Security.Ratelimits.Limit),
|
||||
// config.Security.Ratelimits.Burst)
|
||||
// }
|
||||
|
||||
// return func(c *gin.Context) {
|
||||
// ip := c.ClientIP()
|
||||
// l := iPLimiter.GetLimiter(ip)
|
||||
// if !l.Allow() {
|
||||
// c.JSON(http.StatusTooManyRequests, gin.H{
|
||||
// "error": "Too many requests",
|
||||
// })
|
||||
// c.Abort()
|
||||
// return
|
||||
// }
|
||||
// c.Next()
|
||||
// }
|
||||
// }
|
||||
|
||||
func RateLimitMiddleware(limiter *IPRateLimiter) gin.HandlerFunc {
|
||||
logger.Info.Printf("Limiter with Limit: %v, Burst: %v", limiter.r, limiter.b)
|
||||
return func(c *gin.Context) {
|
||||
if limiter == nil {
|
||||
logger.Error.Println("Limiter missing")
|
||||
c.AbortWithStatus(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
ip := c.ClientIP()
|
||||
l := limiter.GetLimiter(ip)
|
||||
if !l.Allow() {
|
||||
c.JSON(http.StatusTooManyRequests, gin.H{
|
||||
"error": "Too many requests",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
143
go-backend/internal/middlewares/rate_limit_test.go
Normal file
143
go-backend/internal/middlewares/rate_limit_test.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package middlewares
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
var (
|
||||
limit = 1
|
||||
burst = 60
|
||||
)
|
||||
|
||||
func TestRateLimitMiddleware(t *testing.T) {
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// Create a new rate limiter that allows 2 requests per second with a burst of 4
|
||||
|
||||
// Create a new Gin router with the rate limit middleware
|
||||
router := setupRouter()
|
||||
|
||||
// Test cases
|
||||
tests := []struct {
|
||||
name string
|
||||
requests int
|
||||
expectedStatus int
|
||||
sleep time.Duration
|
||||
}{
|
||||
{"Allow first request", 1, http.StatusOK, 0},
|
||||
{"Allow up to burst limit", burst - 1, http.StatusOK, 0},
|
||||
{"Block after burst limit", burst + 20, http.StatusTooManyRequests, 0},
|
||||
{"Allow after rate limit replenishes", 1, http.StatusOK, time.Second},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
time.Sleep(tt.sleep)
|
||||
var status int
|
||||
for i := 0; i < tt.requests; i++ {
|
||||
status = makeRequest(router, "192.168.0.2")
|
||||
|
||||
}
|
||||
if status != tt.expectedStatus {
|
||||
t.Errorf("Expected status %d, got %d", tt.expectedStatus, status)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPRateLimiter(t *testing.T) {
|
||||
limiter := NewIPRateLimiter(limit, burst)
|
||||
|
||||
limiter1 := limiter.GetLimiter("127.0.0.1")
|
||||
limiter2 := limiter.GetLimiter("192.168.0.1")
|
||||
|
||||
if limiter1 == limiter2 {
|
||||
t.Error("Expected different limiters for different IPs")
|
||||
}
|
||||
|
||||
limiter3 := limiter.GetLimiter("127.0.0.1")
|
||||
if limiter1 != limiter3 {
|
||||
t.Error("Expected the same limiter for the same IP")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDifferentRateLimits(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
duration time.Duration
|
||||
requests int
|
||||
expectedRequests int
|
||||
ip string
|
||||
}{
|
||||
{"Low rate", 5 * time.Second, burst + 5*limit, burst + 5*limit, "192.168.23.3"},
|
||||
{"Low rate with limiting", 4 * time.Second, burst + 4*limit + 4, burst + 4*limit, "192.168.23.4"},
|
||||
{"High rate", 1 * time.Second, burst + 5*limit, burst + limit, "192.168.23.5"},
|
||||
{"Fractional rate", 10 * time.Second, (burst + limit) / 2, (burst + limit) / 2, "192.168.23.6"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
router := setupRouter()
|
||||
testRateLimit(t, router, tc.duration, tc.requests, tc.expectedRequests, tc.ip)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testRateLimit(t *testing.T, router *gin.Engine, duration time.Duration, requests int, expectedRequests int, ip string) {
|
||||
start := time.Now()
|
||||
successCount := 0
|
||||
totalRequests := 0
|
||||
|
||||
t.Logf("Sleeping for: %v", time.Duration(duration.Nanoseconds()/int64(requests)))
|
||||
for time.Since(start) < duration {
|
||||
status := makeRequest(router, ip)
|
||||
totalRequests++
|
||||
|
||||
if status == http.StatusOK {
|
||||
successCount++
|
||||
}
|
||||
|
||||
time.Sleep(time.Duration(duration.Nanoseconds() / int64(requests)))
|
||||
}
|
||||
|
||||
actualDuration := time.Since(start)
|
||||
|
||||
t.Logf("limit: %v, burst: %v", limit, burst)
|
||||
t.Logf("Test duration: %v", actualDuration)
|
||||
t.Logf("Successful requests: %d", successCount)
|
||||
t.Logf("Expected successful requests: %d", expectedRequests)
|
||||
t.Logf("Total requests: %d", totalRequests)
|
||||
|
||||
if successCount < int(expectedRequests)-4 || successCount > int(expectedRequests)+4 {
|
||||
t.Errorf("Expected around %d successful requests, got %d", expectedRequests, successCount)
|
||||
}
|
||||
|
||||
if requests-expectedRequests != 0 && totalRequests <= successCount {
|
||||
t.Errorf("Expected some requests to be rate limited")
|
||||
}
|
||||
}
|
||||
|
||||
func setupRouter() *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
limiter := NewIPRateLimiter(limit, burst)
|
||||
router.Use(RateLimitMiddleware(limiter))
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "success")
|
||||
})
|
||||
return router
|
||||
}
|
||||
|
||||
func makeRequest(router *gin.Engine, ip string) int {
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("X-Forwarded-For", ip) // Set a consistent IP
|
||||
router.ServeHTTP(w, req)
|
||||
return w.Code
|
||||
}
|
||||
15
go-backend/internal/models/bank_account.go
Normal file
15
go-backend/internal/models/bank_account.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package models
|
||||
|
||||
import "time"
|
||||
|
||||
type BankAccount struct {
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
MandateDateSigned time.Time `gorm:"not null" json:"mandate_date_signed"`
|
||||
Bank string `json:"bank_name" binding:"safe_content"`
|
||||
AccountHolderName string `json:"account_holder_name" binding:"safe_content"`
|
||||
IBAN string `json:"iban"`
|
||||
BIC string `json:"bic"`
|
||||
MandateReference string `gorm:"not null" json:"mandate_reference"`
|
||||
ID uint `gorm:"primaryKey"`
|
||||
}
|
||||
17
go-backend/internal/models/consents.go
Normal file
17
go-backend/internal/models/consents.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type Consent struct {
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
FirstName string `gorm:"not null" json:"first_name" binding:"safe_content"`
|
||||
LastName string `gorm:"not null" json:"last_name" binding:"safe_content"`
|
||||
Email string `json:"email" binding:"email,safe_content"`
|
||||
ConsentType string `gorm:"not null" json:"consent_type" binding:"safe_content"`
|
||||
ID uint `gorm:"primaryKey"`
|
||||
User User
|
||||
UserID uint
|
||||
}
|
||||
22
go-backend/internal/models/drivers_licence.go
Normal file
22
go-backend/internal/models/drivers_licence.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type Licence struct {
|
||||
ID uint `json:"id"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
Status int8 `json:"status" binding:"omitempty,number"`
|
||||
Number string `json:"number" binding:"omitempty,safe_content"`
|
||||
IssuedDate time.Time `json:"issued_date" binding:"omitempty"`
|
||||
ExpirationDate time.Time `json:"expiration_date" binding:"omitempty"`
|
||||
IssuingCountry string `json:"country" binding:"safe_content"`
|
||||
Categories []Category `json:"categories" gorm:"many2many:licence_2_categories"`
|
||||
}
|
||||
|
||||
type Category struct {
|
||||
ID uint `json:"id" gorm:"primaryKey"`
|
||||
Name string `json:"category" binding:"safe_content"`
|
||||
}
|
||||
15
go-backend/internal/models/membership.go
Normal file
15
go-backend/internal/models/membership.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package models
|
||||
|
||||
import "time"
|
||||
|
||||
type Membership struct {
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
StartDate time.Time `json:"start_date"`
|
||||
EndDate time.Time `json:"end_date"`
|
||||
Status int8 `json:"status" binding:"number,safe_content"`
|
||||
SubscriptionModel SubscriptionModel `gorm:"foreignKey:SubscriptionModelID" json:"subscription_model"`
|
||||
SubscriptionModelID uint `json:"subsription_model_id"`
|
||||
ParentMembershipID uint `json:"parent_member_id" binding:"omitempty,omitnil,number"`
|
||||
ID uint `json:"id"`
|
||||
}
|
||||
19
go-backend/internal/models/subscription_model.go
Normal file
19
go-backend/internal/models/subscription_model.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type SubscriptionModel struct {
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
Name string `gorm:"unique" json:"name" binding:"required"`
|
||||
Details string `json:"details"`
|
||||
Conditions string `json:"conditions"`
|
||||
RequiredMembershipField string `json:"required_membership_field"`
|
||||
ID uint `json:"id" gorm:"primaryKey"`
|
||||
MonthlyFee float32 `json:"monthly_fee"`
|
||||
HourlyRate float32 `json:"hourly_rate"`
|
||||
IncludedPerYear int16 `json:"included_hours_per_year"`
|
||||
IncludedPerMonth int16 `json:"included_hours_per_month"`
|
||||
}
|
||||
132
go-backend/internal/models/user.go
Normal file
132
go-backend/internal/models/user.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"GoMembership/pkg/logger"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/alexedwards/argon2id"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
ID uint `gorm:"primarykey" json:"id"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
DeletedAt *time.Time `gorm:"index"`
|
||||
DateOfBirth time.Time `gorm:"not null" json:"dateofbirth" binding:"required,safe_content"`
|
||||
Company string `json:"company" binding:"omitempty,omitnil,safe_content"`
|
||||
Phone string `json:"phone" binding:"omitempty,omitnil,safe_content"`
|
||||
Notes string `json:"notes" binding:"safe_content"`
|
||||
FirstName string `gorm:"not null" json:"first_name" binding:"required,safe_content"`
|
||||
Password string `json:"password" binding:"safe_content"`
|
||||
Email string `gorm:"unique;not null" json:"email" binding:"required,email,safe_content"`
|
||||
LastName string `gorm:"not null" json:"last_name" binding:"required,safe_content"`
|
||||
ProfilePicture string `json:"profile_picture" binding:"omitempty,omitnil,image,safe_content"`
|
||||
Address string `gorm:"not null" json:"address" binding:"required,safe_content"`
|
||||
ZipCode string `gorm:"not null" json:"zip_code" binding:"required,alphanum,safe_content"`
|
||||
City string `form:"not null" json:"city" binding:"required,alphaunicode,safe_content"`
|
||||
Consents []Consent `gorm:"constraint:OnUpdate:CASCADE"`
|
||||
BankAccount BankAccount `gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE;" json:"bank_account"`
|
||||
BankAccountID uint
|
||||
Verification Verification `gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE;"`
|
||||
VerificationID uint
|
||||
Membership Membership `gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE;" json:"membership"`
|
||||
MembershipID uint
|
||||
Licence *Licence `gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE;" json:"licence"`
|
||||
LicenceID uint
|
||||
PaymentStatus int8 `json:"payment_status"`
|
||||
Status int8 `json:"status"`
|
||||
RoleID int8 `json:"role_id"`
|
||||
}
|
||||
|
||||
func (u *User) AfterCreate(tx *gorm.DB) (err error) {
|
||||
if u.BankAccount.ID != 0 && u.BankAccount.MandateReference == "" {
|
||||
mandateReference := u.GenerateMandateReference()
|
||||
|
||||
return tx.Model(&u.BankAccount).Update("MandateReference", mandateReference).Error
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *User) GenerateMandateReference() string {
|
||||
return fmt.Sprintf("%s%d%s", time.Now().Format("20060102"), u.ID, u.BankAccount.IBAN)
|
||||
}
|
||||
|
||||
func (u *User) SetPassword(plaintextPassword string) error {
|
||||
if plaintextPassword == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
hash, err := argon2id.CreateHash(plaintextPassword, argon2id.DefaultParams)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u.Password = hash
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *User) PasswordMatches(plaintextPassword string) (bool, error) {
|
||||
logger.Error.Printf("plaintext: %v user password: %v", plaintextPassword, u.Password)
|
||||
return argon2id.ComparePasswordAndHash(plaintextPassword, u.Password)
|
||||
}
|
||||
|
||||
func (u *User) Safe() map[string]interface{} {
|
||||
result := map[string]interface{}{
|
||||
"email": u.Email,
|
||||
"first_name": u.FirstName,
|
||||
"last_name": u.LastName,
|
||||
"phone": u.Phone,
|
||||
"notes": u.Notes,
|
||||
"address": u.Address,
|
||||
"zip_code": u.ZipCode,
|
||||
"city": u.City,
|
||||
"status": u.Status,
|
||||
"id": u.ID,
|
||||
"role_id": u.RoleID,
|
||||
"company": u.Company,
|
||||
"dateofbirth": u.DateOfBirth,
|
||||
"membership": map[string]interface{}{
|
||||
"id": u.Membership.ID,
|
||||
"start_date": u.Membership.StartDate,
|
||||
"end_date": u.Membership.EndDate,
|
||||
"status": u.Membership.Status,
|
||||
"subscription_model": map[string]interface{}{
|
||||
"id": u.Membership.SubscriptionModel.ID,
|
||||
"name": u.Membership.SubscriptionModel.Name,
|
||||
"details": u.Membership.SubscriptionModel.Details,
|
||||
"conditions": u.Membership.SubscriptionModel.Conditions,
|
||||
"monthly_fee": u.Membership.SubscriptionModel.MonthlyFee,
|
||||
"hourly_rate": u.Membership.SubscriptionModel.HourlyRate,
|
||||
"included_per_year": u.Membership.SubscriptionModel.IncludedPerYear,
|
||||
"included_per_month": u.Membership.SubscriptionModel.IncludedPerMonth,
|
||||
},
|
||||
},
|
||||
"licence": map[string]interface{}{
|
||||
"id": 0,
|
||||
},
|
||||
"bank_account": map[string]interface{}{
|
||||
"id": u.BankAccount.ID,
|
||||
"mandate_date_signed": u.BankAccount.MandateDateSigned,
|
||||
"bank": u.BankAccount.Bank,
|
||||
"account_holder_name": u.BankAccount.AccountHolderName,
|
||||
"iban": u.BankAccount.IBAN,
|
||||
"bic": u.BankAccount.BIC,
|
||||
"mandate_reference": u.BankAccount.MandateReference,
|
||||
},
|
||||
}
|
||||
|
||||
if u.Licence != nil {
|
||||
result["licence"] = map[string]interface{}{
|
||||
"id": u.Licence.ID,
|
||||
"number": u.Licence.Number,
|
||||
"status": u.Licence.Status,
|
||||
"issued_date": u.Licence.IssuedDate,
|
||||
"expiration_date": u.Licence.ExpirationDate,
|
||||
"country": u.Licence.IssuingCountry,
|
||||
"categories": u.Licence.Categories,
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
13
go-backend/internal/models/verification.go
Normal file
13
go-backend/internal/models/verification.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package models
|
||||
|
||||
import "time"
|
||||
|
||||
type Verification struct {
|
||||
UpdatedAt time.Time
|
||||
CreatedAt time.Time
|
||||
VerifiedAt *time.Time `gorm:"Default:NULL" json:"verified_at"`
|
||||
VerificationToken string `json:"token"`
|
||||
ID uint `gorm:"primaryKey"`
|
||||
UserID uint `gorm:"unique;" json:"user_id"`
|
||||
Type string
|
||||
}
|
||||
20
go-backend/internal/repositories/banking_repository.go
Normal file
20
go-backend/internal/repositories/banking_repository.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"GoMembership/internal/database"
|
||||
"GoMembership/internal/models"
|
||||
)
|
||||
|
||||
type BankAccountRepositoryInterface interface {
|
||||
CreateBankAccount(account *models.BankAccount) (uint, error)
|
||||
}
|
||||
|
||||
type BankAccountRepository struct{}
|
||||
|
||||
func (repo *BankAccountRepository) CreateBankAccount(account *models.BankAccount) (uint, error) {
|
||||
result := database.DB.Create(account)
|
||||
if result.Error != nil {
|
||||
return 0, result.Error
|
||||
}
|
||||
return account.ID, nil
|
||||
}
|
||||
21
go-backend/internal/repositories/consents_repository.go
Normal file
21
go-backend/internal/repositories/consents_repository.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"GoMembership/internal/database"
|
||||
"GoMembership/internal/models"
|
||||
)
|
||||
|
||||
type ConsentRepositoryInterface interface {
|
||||
CreateConsent(consent *models.Consent) (uint, error)
|
||||
}
|
||||
|
||||
type ConsentRepository struct{}
|
||||
|
||||
func (repo *ConsentRepository) CreateConsent(consent *models.Consent) (uint, error) {
|
||||
result := database.DB.Create(consent)
|
||||
|
||||
if result.Error != nil {
|
||||
return 0, result.Error
|
||||
}
|
||||
return consent.ID, nil
|
||||
}
|
||||
31
go-backend/internal/repositories/licence_repository.go
Normal file
31
go-backend/internal/repositories/licence_repository.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"GoMembership/internal/database"
|
||||
"GoMembership/internal/models"
|
||||
)
|
||||
|
||||
type LicenceInterface interface {
|
||||
FindCategoryByName(categoryName string) (models.Category, error)
|
||||
FindCategoriesByIDs(ids []uint) ([]models.Category, error)
|
||||
GetAllCategories() ([]models.Category, error)
|
||||
}
|
||||
|
||||
type LicenceRepository struct{}
|
||||
|
||||
func (r *LicenceRepository) GetAllCategories() ([]models.Category, error) {
|
||||
var categories []models.Category
|
||||
err := database.DB.Find(&categories).Error
|
||||
return categories, err
|
||||
}
|
||||
func (r *LicenceRepository) FindCategoriesByIDs(ids []uint) ([]models.Category, error) {
|
||||
var categories []models.Category
|
||||
err := database.DB.Where("id IN ?", ids).Find(&categories).Error
|
||||
return categories, err
|
||||
}
|
||||
|
||||
func (r *LicenceRepository) FindCategoryByName(categoryName string) (models.Category, error) {
|
||||
var category models.Category
|
||||
err := database.DB.Where("name = ?", categoryName).First(&category).Error
|
||||
return category, err
|
||||
}
|
||||
37
go-backend/internal/repositories/membership_repository.go
Normal file
37
go-backend/internal/repositories/membership_repository.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"GoMembership/internal/database"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"GoMembership/internal/models"
|
||||
)
|
||||
|
||||
type MembershipRepositoryInterface interface {
|
||||
CreateMembership(membership *models.Membership) (uint, error)
|
||||
FindMembershipByUserID(userID uint) (*models.Membership, error)
|
||||
}
|
||||
|
||||
type MembershipRepository struct{}
|
||||
|
||||
func (repo *MembershipRepository) CreateMembership(membership *models.Membership) (uint, error) {
|
||||
result := database.DB.Create(membership)
|
||||
if result.Error != nil {
|
||||
return 0, result.Error
|
||||
}
|
||||
return membership.ID, nil
|
||||
}
|
||||
|
||||
func (repo *MembershipRepository) FindMembershipByUserID(userID uint) (*models.Membership, error) {
|
||||
|
||||
var membership models.Membership
|
||||
result := database.DB.First(&membership, userID)
|
||||
if result.Error != nil {
|
||||
if result.Error == gorm.ErrRecordNotFound {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil, result.Error
|
||||
}
|
||||
return &membership, nil
|
||||
}
|
||||
@@ -0,0 +1,97 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"GoMembership/internal/database"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"GoMembership/internal/models"
|
||||
)
|
||||
|
||||
type SubscriptionModelsRepositoryInterface interface {
|
||||
CreateSubscriptionModel(subscriptionModel *models.SubscriptionModel) (uint, error)
|
||||
UpdateSubscription(subscription *models.SubscriptionModel) (*models.SubscriptionModel, error)
|
||||
GetSubscriptionModelNames() ([]string, error)
|
||||
GetSubscriptions(where map[string]interface{}) (*[]models.SubscriptionModel, error)
|
||||
// GetUsersBySubscription(id uint) (*[]models.SubscriptionModel, error)
|
||||
DeleteSubscription(id *uint) error
|
||||
}
|
||||
|
||||
type SubscriptionModelsRepository struct{}
|
||||
|
||||
func (sr *SubscriptionModelsRepository) CreateSubscriptionModel(subscriptionModel *models.SubscriptionModel) (uint, error) {
|
||||
|
||||
result := database.DB.Create(subscriptionModel)
|
||||
if result.Error != nil {
|
||||
return 0, result.Error
|
||||
}
|
||||
return subscriptionModel.ID, nil
|
||||
}
|
||||
|
||||
func (sr *SubscriptionModelsRepository) UpdateSubscription(subscription *models.SubscriptionModel) (*models.SubscriptionModel, error) {
|
||||
|
||||
result := database.DB.Model(&models.SubscriptionModel{ID: subscription.ID}).Updates(subscription)
|
||||
if result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
return subscription, nil
|
||||
}
|
||||
|
||||
func (sr *SubscriptionModelsRepository) DeleteSubscription(id *uint) error {
|
||||
|
||||
result := database.DB.Delete(&models.SubscriptionModel{}, id)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetSubscriptionByName(modelname *string) (*models.SubscriptionModel, error) {
|
||||
var model models.SubscriptionModel
|
||||
result := database.DB.Where("name = ?", modelname).First(&model)
|
||||
if result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
return &model, nil
|
||||
}
|
||||
|
||||
func (sr *SubscriptionModelsRepository) GetSubscriptionModelNames() ([]string, error) {
|
||||
var names []string
|
||||
if err := database.DB.Model(&models.SubscriptionModel{}).Pluck("name", &names).Error; err != nil {
|
||||
return []string{}, err
|
||||
}
|
||||
return names, nil
|
||||
}
|
||||
|
||||
func (sr *SubscriptionModelsRepository) GetSubscriptions(where map[string]interface{}) (*[]models.SubscriptionModel, error) {
|
||||
var subscriptions []models.SubscriptionModel
|
||||
result := database.DB.Where(where).Find(&subscriptions)
|
||||
if result.Error != nil {
|
||||
if result.Error == gorm.ErrRecordNotFound {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil, result.Error
|
||||
}
|
||||
return &subscriptions, nil
|
||||
}
|
||||
|
||||
func GetUsersBySubscription(subscriptionID uint) (*[]models.User, error) {
|
||||
var users []models.User
|
||||
|
||||
err := database.DB.Preload("Membership").
|
||||
Preload("Membership.SubscriptionModel").
|
||||
Preload("BankAccount").
|
||||
Preload("Licence").
|
||||
Preload("Licence.Categories").
|
||||
Joins("JOIN memberships ON users.membership_id = memberships.id").
|
||||
Joins("JOIN subscription_models ON memberships.subscription_model_id = subscription_models.id").
|
||||
Where("subscription_models.id = ?", subscriptionID).
|
||||
Find(&users).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &users, nil
|
||||
|
||||
}
|
||||
10
go-backend/internal/repositories/user_permissions.go
Normal file
10
go-backend/internal/repositories/user_permissions.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"GoMembership/internal/database"
|
||||
"GoMembership/internal/models"
|
||||
)
|
||||
|
||||
func (r *UserRepository) SetUserStatus(id uint, status uint) error {
|
||||
return database.DB.Model(&models.User{}).Where("id = ?", id).Update("status", status).Error
|
||||
}
|
||||
159
go-backend/internal/repositories/user_repository.go
Normal file
159
go-backend/internal/repositories/user_repository.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"gorm.io/gorm"
|
||||
|
||||
"GoMembership/internal/database"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
|
||||
"GoMembership/internal/models"
|
||||
"GoMembership/pkg/errors"
|
||||
"GoMembership/pkg/logger"
|
||||
)
|
||||
|
||||
type UserRepositoryInterface interface {
|
||||
CreateUser(user *models.User) (uint, error)
|
||||
UpdateUser(user *models.User) (*models.User, error)
|
||||
GetUsers(where map[string]interface{}) (*[]models.User, error)
|
||||
GetUserByEmail(email string) (*models.User, error)
|
||||
IsVerified(userID *uint) (bool, error)
|
||||
GetVerificationOfToken(token *string, verificationType *string) (*models.Verification, error)
|
||||
SetVerificationToken(verification *models.Verification) (token string, err error)
|
||||
DeleteVerification(id uint, verificationType string) error
|
||||
DeleteUser(id uint) error
|
||||
SetUserStatus(id uint, status uint) error
|
||||
}
|
||||
|
||||
type UserRepository struct{}
|
||||
|
||||
func (ur *UserRepository) DeleteUser(id uint) error {
|
||||
return database.DB.Delete(&models.User{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
func PasswordExists(userID *uint) (bool, error) {
|
||||
var user models.User
|
||||
result := database.DB.Select("password").First(&user, userID)
|
||||
if result.Error != nil {
|
||||
return false, result.Error
|
||||
}
|
||||
return user.Password != "", nil
|
||||
}
|
||||
|
||||
func (ur *UserRepository) CreateUser(user *models.User) (uint, error) {
|
||||
result := database.DB.Create(user)
|
||||
if result.Error != nil {
|
||||
logger.Error.Printf("Create User error: %#v", result.Error)
|
||||
return 0, result.Error
|
||||
}
|
||||
return user.ID, nil
|
||||
}
|
||||
|
||||
func (ur *UserRepository) UpdateUser(user *models.User) (*models.User, error) {
|
||||
if user == nil {
|
||||
return nil, errors.ErrNoData
|
||||
}
|
||||
|
||||
err := database.DB.Transaction(func(tx *gorm.DB) error {
|
||||
// Check if the user exists in the database
|
||||
var existingUser models.User
|
||||
|
||||
if err := tx.Preload(clause.Associations).
|
||||
Preload("Membership").
|
||||
Preload("Membership.SubscriptionModel").
|
||||
Preload("Licence").
|
||||
Preload("Licence.Categories").
|
||||
First(&existingUser, user.ID).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
// Update the user's main fields
|
||||
result := tx.Session(&gorm.Session{FullSaveAssociations: true}).Omit("Password").Updates(user)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return errors.ErrNoRowsAffected
|
||||
}
|
||||
|
||||
if user.Password != "" {
|
||||
if err := tx.Model(&models.User{}).
|
||||
Where("id = ?", user.ID).
|
||||
Update("Password", user.Password).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Update the Membership if provided
|
||||
if user.Membership.ID != 0 {
|
||||
if err := tx.Model(&existingUser.Membership).Updates(user.Membership).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Replace categories if Licence and Categories are provided
|
||||
if user.Licence != nil {
|
||||
if err := tx.Model(&user.Licence).Association("Categories").Replace(user.Licence.Categories); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var updatedUser models.User
|
||||
if err := database.DB.Preload("Licence.Categories").
|
||||
Preload("Membership").
|
||||
First(&updatedUser, user.ID).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &updatedUser, nil
|
||||
}
|
||||
|
||||
func (ur *UserRepository) GetUsers(where map[string]interface{}) (*[]models.User, error) {
|
||||
var users []models.User
|
||||
result := database.DB.
|
||||
Preload(clause.Associations).
|
||||
Preload("Membership.SubscriptionModel").
|
||||
Preload("Licence.Categories").
|
||||
Where(where).Find(&users)
|
||||
if result.Error != nil {
|
||||
if result.Error == gorm.ErrRecordNotFound {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil, result.Error
|
||||
}
|
||||
return &users, nil
|
||||
}
|
||||
|
||||
func GetUserByID(userID *uint) (*models.User, error) {
|
||||
var user models.User
|
||||
result := database.DB.
|
||||
Preload(clause.Associations).
|
||||
Preload("Membership").
|
||||
Preload("Membership.SubscriptionModel").
|
||||
Preload("Licence.Categories").
|
||||
First(&user, userID)
|
||||
if result.Error != nil {
|
||||
if result.Error == gorm.ErrRecordNotFound {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil, result.Error
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (ur *UserRepository) GetUserByEmail(email string) (*models.User, error) {
|
||||
var user models.User
|
||||
result := database.DB.Where("email = ?", email).First(&user)
|
||||
if result.Error != nil {
|
||||
if result.Error == gorm.ErrRecordNotFound {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil, result.Error
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
57
go-backend/internal/repositories/user_verification.go
Normal file
57
go-backend/internal/repositories/user_verification.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"GoMembership/internal/constants"
|
||||
"GoMembership/internal/database"
|
||||
"GoMembership/internal/models"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
func (ur *UserRepository) IsVerified(userID *uint) (bool, error) {
|
||||
var user models.User
|
||||
result := database.DB.Select("status").First(&user, userID)
|
||||
if result.Error != nil {
|
||||
if result.Error == gorm.ErrRecordNotFound {
|
||||
return false, gorm.ErrRecordNotFound
|
||||
}
|
||||
return false, result.Error
|
||||
}
|
||||
return user.Status > constants.DisabledStatus, nil
|
||||
}
|
||||
|
||||
func (ur *UserRepository) GetVerificationOfToken(token *string, verificationType *string) (*models.Verification, error) {
|
||||
|
||||
var emailVerification models.Verification
|
||||
result := database.DB.Where("verification_token = ? AND type = ?", token, verificationType).First(&emailVerification)
|
||||
if result.Error != nil {
|
||||
if result.Error == gorm.ErrRecordNotFound {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil, result.Error
|
||||
}
|
||||
return &emailVerification, nil
|
||||
}
|
||||
|
||||
func (ur *UserRepository) SetVerificationToken(verification *models.Verification) (token string, err error) {
|
||||
|
||||
result := database.DB.Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "user_id"}},
|
||||
DoUpdates: clause.AssignmentColumns([]string{"verification_token", "created_at", "type"}),
|
||||
}).Create(&verification)
|
||||
|
||||
if result.Error != nil {
|
||||
return "", result.Error
|
||||
}
|
||||
|
||||
return verification.VerificationToken, nil
|
||||
}
|
||||
|
||||
func (ur *UserRepository) DeleteVerification(id uint, verificationType string) error {
|
||||
result := database.DB.Where("user_id = ? AND type = ?", id, verificationType).Delete(&models.Verification{})
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
return nil
|
||||
}
|
||||
50
go-backend/internal/routes/routes.go
Normal file
50
go-backend/internal/routes/routes.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"GoMembership/internal/controllers"
|
||||
"GoMembership/internal/middlewares"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func RegisterRoutes(router *gin.Engine, userController *controllers.UserController, membershipcontroller *controllers.MembershipController, contactController *controllers.ContactController, licenceController *controllers.LicenceController) {
|
||||
router.GET("/users/verify", userController.VerifyMailHandler)
|
||||
router.POST("/users/register", userController.RegisterUser)
|
||||
router.POST("/users/contact", contactController.RelayContactRequest)
|
||||
router.POST("/users/password/request-change", userController.RequestPasswordChangeHandler)
|
||||
router.PATCH("/users/password/change/:id", userController.ChangePassword)
|
||||
router.POST("/users/login", userController.LoginHandler)
|
||||
router.POST("/csp-report", middlewares.CSPReportHandling)
|
||||
|
||||
// apiRouter := router.Group("/api")
|
||||
// apiRouter.Use(middlewares.APIKeyMiddleware())
|
||||
// {
|
||||
// apiRouter.POST("/v1/subscription", membershipcontroller.RegisterSubscription)
|
||||
// }
|
||||
|
||||
userRouter := router.Group("/backend")
|
||||
userRouter.Use(middlewares.AuthMiddleware())
|
||||
{
|
||||
userRouter.GET("/users/current", userController.CurrentUserHandler)
|
||||
userRouter.POST("/logout", userController.LogoutHandler)
|
||||
userRouter.PUT("/users", userController.UpdateHandler)
|
||||
userRouter.POST("/users", userController.RegisterUser)
|
||||
userRouter.GET("/users/all", userController.GetAllUsers)
|
||||
userRouter.DELETE("/users", userController.DeleteUser)
|
||||
}
|
||||
|
||||
membershipRouter := router.Group("/backend/membership")
|
||||
membershipRouter.Use(middlewares.AuthMiddleware())
|
||||
{
|
||||
membershipRouter.GET("/subscriptions", membershipcontroller.GetSubscriptions)
|
||||
membershipRouter.PUT("/subscriptions", membershipcontroller.UpdateHandler)
|
||||
membershipRouter.POST("/subscriptions", membershipcontroller.RegisterSubscription)
|
||||
membershipRouter.DELETE("/subscriptions", membershipcontroller.DeleteSubscription)
|
||||
}
|
||||
|
||||
licenceRouter := router.Group("/backend/licence")
|
||||
licenceRouter.Use(middlewares.AuthMiddleware())
|
||||
{
|
||||
licenceRouter.GET("/categories", licenceController.GetAllCategories)
|
||||
}
|
||||
}
|
||||
95
go-backend/internal/server/server.go
Normal file
95
go-backend/internal/server/server.go
Normal file
@@ -0,0 +1,95 @@
|
||||
// Package server initializes and runs the application server.
|
||||
// It sets up configurations, initializes the database, services, and controllers,
|
||||
// loads HTML templates, and starts the HTTP server.
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"GoMembership/internal/config"
|
||||
"GoMembership/internal/controllers"
|
||||
"GoMembership/internal/middlewares"
|
||||
"GoMembership/internal/repositories"
|
||||
"GoMembership/internal/validation"
|
||||
|
||||
"GoMembership/internal/routes"
|
||||
"GoMembership/internal/services"
|
||||
"GoMembership/pkg/logger"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
var shutdownChannel = make(chan struct{})
|
||||
var srv *http.Server
|
||||
|
||||
// Run initializes the server configuration, sets up services and controllers, and starts the HTTP server.
|
||||
func Run() {
|
||||
|
||||
emailService := services.NewEmailService(config.SMTP.Host, config.SMTP.Port, config.SMTP.User, config.SMTP.Password)
|
||||
var consentRepo repositories.ConsentRepositoryInterface = &repositories.ConsentRepository{}
|
||||
consentService := &services.ConsentService{Repo: consentRepo}
|
||||
|
||||
var bankAccountRepo repositories.BankAccountRepositoryInterface = &repositories.BankAccountRepository{}
|
||||
bankAccountService := &services.BankAccountService{Repo: bankAccountRepo}
|
||||
|
||||
var membershipRepo repositories.MembershipRepositoryInterface = &repositories.MembershipRepository{}
|
||||
var subscriptionRepo repositories.SubscriptionModelsRepositoryInterface = &repositories.SubscriptionModelsRepository{}
|
||||
membershipService := &services.MembershipService{Repo: membershipRepo, SubscriptionRepo: subscriptionRepo}
|
||||
|
||||
var licenceRepo repositories.LicenceInterface = &repositories.LicenceRepository{}
|
||||
licenceService := &services.LicenceService{Repo: licenceRepo}
|
||||
|
||||
var userRepo repositories.UserRepositoryInterface = &repositories.UserRepository{}
|
||||
userService := &services.UserService{Repo: userRepo, Licences: licenceRepo}
|
||||
|
||||
userController := &controllers.UserController{Service: userService, EmailService: emailService, ConsentService: consentService, LicenceService: licenceService, BankAccountService: bankAccountService, MembershipService: membershipService}
|
||||
membershipController := &controllers.MembershipController{Service: *membershipService, UserController: userController}
|
||||
licenceController := &controllers.LicenceController{Service: *licenceService}
|
||||
contactController := &controllers.ContactController{EmailService: emailService}
|
||||
|
||||
router := gin.Default()
|
||||
// gin.SetMode(gin.ReleaseMode)
|
||||
router.Static(config.Templates.StaticPath, "./style")
|
||||
// Load HTML templates
|
||||
router.LoadHTMLGlob(filepath.Join(config.Templates.HTMLPath, "*"))
|
||||
|
||||
router.Use(gin.Logger())
|
||||
router.Use(middlewares.CORSMiddleware())
|
||||
router.Use(middlewares.CSPMiddleware())
|
||||
router.Use(middlewares.SecurityHeadersMiddleware())
|
||||
|
||||
limiter := middlewares.NewIPRateLimiter(config.Security.Ratelimits.Limit, config.Security.Ratelimits.Burst)
|
||||
router.Use(middlewares.RateLimitMiddleware(limiter))
|
||||
|
||||
routes.RegisterRoutes(router, userController, membershipController, contactController, licenceController)
|
||||
validation.SetupValidators()
|
||||
|
||||
logger.Info.Println("Starting server on :8080")
|
||||
srv = &http.Server{
|
||||
Addr: ":8080",
|
||||
Handler: router,
|
||||
}
|
||||
go func() {
|
||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
logger.Error.Fatalf("could not start server: %v", err)
|
||||
}
|
||||
}()
|
||||
// Wait for the shutdown signal
|
||||
<-shutdownChannel
|
||||
}
|
||||
|
||||
func Shutdown(ctx context.Context) error {
|
||||
if srv == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Graceful shutdown with a timeout
|
||||
shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Attempt to shutdown the server
|
||||
return srv.Shutdown(shutdownCtx)
|
||||
}
|
||||
12
go-backend/internal/services/bank_account_service.go
Normal file
12
go-backend/internal/services/bank_account_service.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"GoMembership/internal/repositories"
|
||||
)
|
||||
|
||||
type BankAccountServiceInterface interface {
|
||||
}
|
||||
|
||||
type BankAccountService struct {
|
||||
Repo repositories.BankAccountRepositoryInterface
|
||||
}
|
||||
22
go-backend/internal/services/consent_service.go
Normal file
22
go-backend/internal/services/consent_service.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"GoMembership/internal/models"
|
||||
"GoMembership/internal/repositories"
|
||||
)
|
||||
|
||||
type ConsentServiceInterface interface {
|
||||
RegisterConsent(consent *models.Consent) (uint, error)
|
||||
}
|
||||
|
||||
type ConsentService struct {
|
||||
Repo repositories.ConsentRepositoryInterface
|
||||
}
|
||||
|
||||
func (service *ConsentService) RegisterConsent(consent *models.Consent) (uint, error) {
|
||||
consent.CreatedAt = time.Now()
|
||||
consent.UpdatedAt = time.Now()
|
||||
return service.Repo.CreateConsent(consent)
|
||||
}
|
||||
240
go-backend/internal/services/email_service.go
Normal file
240
go-backend/internal/services/email_service.go
Normal file
@@ -0,0 +1,240 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"html/template"
|
||||
|
||||
"gopkg.in/gomail.v2"
|
||||
|
||||
"GoMembership/internal/config"
|
||||
"GoMembership/internal/constants"
|
||||
"GoMembership/internal/models"
|
||||
"GoMembership/pkg/logger"
|
||||
)
|
||||
|
||||
type EmailService struct {
|
||||
dialer *gomail.Dialer
|
||||
}
|
||||
|
||||
func NewEmailService(host string, port int, username string, password string) *EmailService {
|
||||
dialer := gomail.NewDialer(host, port, username, password)
|
||||
return &EmailService{dialer: dialer}
|
||||
}
|
||||
|
||||
func (s *EmailService) SendEmail(to string, subject string, body string, bodyTXT string, replyTo string) error {
|
||||
msg := gomail.NewMessage()
|
||||
msg.SetHeader("From", s.dialer.Username)
|
||||
msg.SetHeader("To", to)
|
||||
msg.SetHeader("Subject", subject)
|
||||
if replyTo != "" {
|
||||
msg.SetHeader("REPLY_TO", replyTo)
|
||||
}
|
||||
if bodyTXT != "" {
|
||||
msg.SetBody("text/plain", bodyTXT)
|
||||
}
|
||||
|
||||
msg.AddAlternative("text/html", body)
|
||||
// msg.WriteTo(os.Stdout)
|
||||
|
||||
if err := s.dialer.DialAndSend(msg); err != nil {
|
||||
logger.Error.Printf("Could not send email to %s: %v", to, err)
|
||||
return err
|
||||
}
|
||||
logger.Info.Printf("Email sent to %s", to)
|
||||
return nil
|
||||
}
|
||||
|
||||
func ParseTemplate(filename string, data interface{}) (string, error) {
|
||||
// Read the email template file
|
||||
|
||||
templateDir := config.Templates.MailPath
|
||||
tpl, err := template.ParseFiles(templateDir + "/" + filename)
|
||||
if err != nil {
|
||||
logger.Error.Printf("Failed to parse email template: %v", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Buffer to hold the rendered template
|
||||
var tplBuffer bytes.Buffer
|
||||
if err := tpl.Execute(&tplBuffer, data); err != nil {
|
||||
logger.Error.Printf("Failed to execute email template: %v", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
return tplBuffer.String(), nil
|
||||
}
|
||||
|
||||
func (s *EmailService) SendVerificationEmail(user *models.User, token *string) error {
|
||||
// Prepare data to be injected into the template
|
||||
data := struct {
|
||||
FirstName string
|
||||
LastName string
|
||||
Token string
|
||||
BASEURL string
|
||||
}{
|
||||
FirstName: user.FirstName,
|
||||
LastName: user.LastName,
|
||||
Token: *token,
|
||||
BASEURL: config.Site.BaseURL,
|
||||
}
|
||||
|
||||
subject := constants.MailVerificationSubject
|
||||
body, err := ParseTemplate("mail_verification.tmpl", data)
|
||||
if err != nil {
|
||||
logger.Error.Print("Couldn't send verification mail")
|
||||
return err
|
||||
}
|
||||
return s.SendEmail(user.Email, subject, body, "", "")
|
||||
|
||||
}
|
||||
|
||||
func (s *EmailService) SendChangePasswordEmail(user *models.User, token *string) error {
|
||||
// Prepare data to be injected into the template
|
||||
data := struct {
|
||||
FirstName string
|
||||
LastName string
|
||||
Token string
|
||||
BASEURL string
|
||||
UserID uint
|
||||
}{
|
||||
FirstName: user.FirstName,
|
||||
LastName: user.LastName,
|
||||
Token: *token,
|
||||
BASEURL: config.Site.BaseURL,
|
||||
UserID: user.ID,
|
||||
}
|
||||
|
||||
subject := constants.MailChangePasswordSubject
|
||||
htmlBody, err := ParseTemplate("mail_change_password.tmpl", data)
|
||||
if err != nil {
|
||||
logger.Error.Print("Couldn't parse password mail")
|
||||
return err
|
||||
}
|
||||
plainBody, err := ParseTemplate("mail_change_password.txt.tmpl", data)
|
||||
if err != nil {
|
||||
logger.Error.Print("Couldn't parse password mail")
|
||||
return err
|
||||
}
|
||||
return s.SendEmail(user.Email, subject, htmlBody, plainBody, "")
|
||||
|
||||
}
|
||||
|
||||
func (s *EmailService) SendWelcomeEmail(user *models.User) error {
|
||||
// Prepare data to be injected into the template
|
||||
data := struct {
|
||||
Company string
|
||||
FirstName string
|
||||
MembershipModel string
|
||||
BASEURL string
|
||||
MembershipID uint
|
||||
MembershipFee float32
|
||||
Logo string
|
||||
WebsiteTitle string
|
||||
RentalFee float32
|
||||
}{
|
||||
Company: user.Company,
|
||||
FirstName: user.FirstName,
|
||||
MembershipModel: user.Membership.SubscriptionModel.Name,
|
||||
MembershipID: user.Membership.ID,
|
||||
MembershipFee: float32(user.Membership.SubscriptionModel.MonthlyFee),
|
||||
RentalFee: float32(user.Membership.SubscriptionModel.HourlyRate),
|
||||
BASEURL: config.Site.BaseURL,
|
||||
WebsiteTitle: config.Site.WebsiteTitle,
|
||||
Logo: config.Templates.LogoURI,
|
||||
}
|
||||
|
||||
subject := constants.MailWelcomeSubject
|
||||
htmlBody, err := ParseTemplate("mail_welcome.tmpl", data)
|
||||
if err != nil {
|
||||
logger.Error.Print("Couldn't send welcome mail")
|
||||
return err
|
||||
}
|
||||
plainBody, err := ParseTemplate("mail_welcome.txt.tmpl", data)
|
||||
if err != nil {
|
||||
logger.Error.Print("Couldn't parse password mail")
|
||||
return err
|
||||
}
|
||||
return s.SendEmail(user.Email, subject, htmlBody, plainBody, "")
|
||||
}
|
||||
|
||||
func (s *EmailService) SendRegistrationNotification(user *models.User) error {
|
||||
// Prepare data to be injected into the template
|
||||
data := struct {
|
||||
FirstName string
|
||||
DateOfBirth string
|
||||
LastName string
|
||||
MembershipModel string
|
||||
Address string
|
||||
IBAN string
|
||||
Email string
|
||||
Phone string
|
||||
City string
|
||||
Company string
|
||||
ZipCode string
|
||||
BASEURL string
|
||||
MembershipID uint
|
||||
RentalFee float32
|
||||
MembershipFee float32
|
||||
Logo string
|
||||
WebsiteTitle string
|
||||
}{
|
||||
Company: user.Company,
|
||||
FirstName: user.FirstName,
|
||||
LastName: user.LastName,
|
||||
MembershipModel: user.Membership.SubscriptionModel.Name,
|
||||
MembershipID: user.Membership.ID,
|
||||
MembershipFee: float32(user.Membership.SubscriptionModel.MonthlyFee),
|
||||
RentalFee: float32(user.Membership.SubscriptionModel.HourlyRate),
|
||||
Address: user.Address,
|
||||
ZipCode: user.ZipCode,
|
||||
City: user.City,
|
||||
DateOfBirth: user.DateOfBirth.Format("20060102"),
|
||||
Email: user.Email,
|
||||
Phone: user.Phone,
|
||||
IBAN: user.BankAccount.IBAN,
|
||||
BASEURL: config.Site.BaseURL,
|
||||
Logo: config.Templates.LogoURI,
|
||||
WebsiteTitle: config.Site.WebsiteTitle,
|
||||
}
|
||||
|
||||
subject := constants.MailRegistrationSubject
|
||||
htmlBody, err := ParseTemplate("mail_registration.tmpl", data)
|
||||
if err != nil {
|
||||
logger.Error.Print("Couldn't send admin notification mail")
|
||||
return err
|
||||
}
|
||||
plainBody, err := ParseTemplate("mail_registration.txt.tmpl", data)
|
||||
if err != nil {
|
||||
logger.Error.Print("Couldn't parse password mail")
|
||||
return err
|
||||
}
|
||||
return s.SendEmail(config.Recipients.UserRegistration, subject, htmlBody, plainBody, "")
|
||||
}
|
||||
|
||||
func (s *EmailService) RelayContactFormMessage(sender string, name string, message string) error {
|
||||
data := struct {
|
||||
Message string
|
||||
Name string
|
||||
BASEURL string
|
||||
Logo string
|
||||
WebsiteTitle string
|
||||
}{
|
||||
Message: message,
|
||||
Name: name,
|
||||
BASEURL: config.Site.BaseURL,
|
||||
Logo: config.Templates.LogoURI,
|
||||
WebsiteTitle: config.Site.WebsiteTitle,
|
||||
}
|
||||
subject := constants.MailContactSubject
|
||||
htmlBody, err := ParseTemplate("mail_contact_form.tmpl", data)
|
||||
if err != nil {
|
||||
logger.Error.Print("Couldn't send contact form message mail")
|
||||
return err
|
||||
}
|
||||
plainBody, err := ParseTemplate("mail_contact_form.txt.tmpl", data)
|
||||
if err != nil {
|
||||
logger.Error.Print("Couldn't parse password mail")
|
||||
return err
|
||||
}
|
||||
return s.SendEmail(config.Recipients.ContactForm, subject, htmlBody, plainBody, sender)
|
||||
}
|
||||
18
go-backend/internal/services/licence_service.go
Normal file
18
go-backend/internal/services/licence_service.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"GoMembership/internal/models"
|
||||
"GoMembership/internal/repositories"
|
||||
)
|
||||
|
||||
type LicenceInterface interface {
|
||||
GetAllCategories() ([]models.Category, error)
|
||||
}
|
||||
|
||||
type LicenceService struct {
|
||||
Repo repositories.LicenceInterface
|
||||
}
|
||||
|
||||
func (s *LicenceService) GetAllCategories() ([]models.Category, error) {
|
||||
return s.Repo.GetAllCategories()
|
||||
}
|
||||
100
go-backend/internal/services/membership_service.go
Normal file
100
go-backend/internal/services/membership_service.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"GoMembership/internal/models"
|
||||
"GoMembership/internal/repositories"
|
||||
"GoMembership/pkg/errors"
|
||||
)
|
||||
|
||||
type MembershipServiceInterface interface {
|
||||
RegisterMembership(membership *models.Membership) (uint, error)
|
||||
FindMembershipByUserID(userID uint) (*models.Membership, error)
|
||||
RegisterSubscription(subscription *models.SubscriptionModel) (uint, error)
|
||||
UpdateSubscription(subscription *models.SubscriptionModel) (*models.SubscriptionModel, error)
|
||||
DeleteSubscription(id *uint, name *string) error
|
||||
GetSubscriptionModelNames() ([]string, error)
|
||||
GetSubscriptionByName(modelname *string) (*models.SubscriptionModel, error)
|
||||
GetSubscriptions(where map[string]interface{}) (*[]models.SubscriptionModel, error)
|
||||
}
|
||||
|
||||
type MembershipService struct {
|
||||
Repo repositories.MembershipRepositoryInterface
|
||||
SubscriptionRepo repositories.SubscriptionModelsRepositoryInterface
|
||||
}
|
||||
|
||||
func (service *MembershipService) RegisterMembership(membership *models.Membership) (uint, error) {
|
||||
membership.StartDate = time.Now()
|
||||
return service.Repo.CreateMembership(membership)
|
||||
}
|
||||
|
||||
func (service *MembershipService) UpdateSubscription(subscription *models.SubscriptionModel) (*models.SubscriptionModel, error) {
|
||||
|
||||
existingSubscription, err := repositories.GetSubscriptionByName(&subscription.Name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if existingSubscription == nil {
|
||||
return nil, errors.ErrSubscriptionNotFound
|
||||
}
|
||||
if existingSubscription.MonthlyFee != subscription.MonthlyFee ||
|
||||
existingSubscription.HourlyRate != subscription.HourlyRate ||
|
||||
existingSubscription.Conditions != subscription.Conditions ||
|
||||
existingSubscription.IncludedPerYear != subscription.IncludedPerYear ||
|
||||
existingSubscription.IncludedPerMonth != subscription.IncludedPerMonth {
|
||||
return nil, errors.ErrInvalidSubscriptionData
|
||||
}
|
||||
subscription.ID = existingSubscription.ID
|
||||
return service.SubscriptionRepo.UpdateSubscription(subscription)
|
||||
|
||||
}
|
||||
|
||||
func (service *MembershipService) DeleteSubscription(id *uint, name *string) error {
|
||||
if *name == "" {
|
||||
return errors.ErrNoData
|
||||
}
|
||||
exists, err := repositories.GetSubscriptionByName(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if exists == nil {
|
||||
return errors.ErrNotFound
|
||||
}
|
||||
if *id != exists.ID {
|
||||
return errors.ErrInvalidSubscriptionData
|
||||
}
|
||||
usersInSubscription, err := repositories.GetUsersBySubscription(*id)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(*usersInSubscription) > 0 {
|
||||
return errors.ErrSubscriptionInUse
|
||||
}
|
||||
return service.SubscriptionRepo.DeleteSubscription(id)
|
||||
}
|
||||
|
||||
func (service *MembershipService) FindMembershipByUserID(userID uint) (*models.Membership, error) {
|
||||
return service.Repo.FindMembershipByUserID(userID)
|
||||
}
|
||||
|
||||
// Membership_Subscriptions
|
||||
func (service *MembershipService) RegisterSubscription(subscription *models.SubscriptionModel) (uint, error) {
|
||||
return service.SubscriptionRepo.CreateSubscriptionModel(subscription)
|
||||
}
|
||||
|
||||
func (service *MembershipService) GetSubscriptionModelNames() ([]string, error) {
|
||||
return service.SubscriptionRepo.GetSubscriptionModelNames()
|
||||
}
|
||||
|
||||
func (service *MembershipService) GetSubscriptionByName(modelname *string) (*models.SubscriptionModel, error) {
|
||||
return repositories.GetSubscriptionByName(modelname)
|
||||
}
|
||||
|
||||
func (service *MembershipService) GetSubscriptions(where map[string]interface{}) (*[]models.SubscriptionModel, error) {
|
||||
if where == nil {
|
||||
where = map[string]interface{}{}
|
||||
}
|
||||
return service.SubscriptionRepo.GetSubscriptions(where)
|
||||
}
|
||||
21
go-backend/internal/services/user_password.go
Normal file
21
go-backend/internal/services/user_password.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"GoMembership/internal/constants"
|
||||
"GoMembership/internal/models"
|
||||
)
|
||||
|
||||
func (s *UserService) HandlePasswordChangeRequest(user *models.User) (token string, err error) {
|
||||
// Deactivate user and reset Verification
|
||||
if err := s.SetUserStatus(user.ID, constants.DisabledStatus); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := s.RevokeVerification(&user.ID, constants.VerificationTypes.Password); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Generate a token
|
||||
return s.SetVerificationToken(&user.ID, &constants.VerificationTypes.Password)
|
||||
|
||||
}
|
||||
5
go-backend/internal/services/user_permissions.go
Normal file
5
go-backend/internal/services/user_permissions.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package services
|
||||
|
||||
func (s *UserService) SetUserStatus(id uint, status uint) error {
|
||||
return s.Repo.SetUserStatus(id, status)
|
||||
}
|
||||
116
go-backend/internal/services/user_service.go
Normal file
116
go-backend/internal/services/user_service.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"GoMembership/internal/constants"
|
||||
"GoMembership/internal/models"
|
||||
"GoMembership/internal/repositories"
|
||||
"GoMembership/pkg/errors"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"time"
|
||||
)
|
||||
|
||||
type UserServiceInterface interface {
|
||||
RegisterUser(user *models.User) (id uint, token string, err error)
|
||||
GetUserByEmail(email string) (*models.User, error)
|
||||
GetUserByID(id uint) (*models.User, error)
|
||||
GetUsers(where map[string]interface{}) (*[]models.User, error)
|
||||
UpdateUser(user *models.User) (*models.User, error)
|
||||
DeleteUser(lastname string, id uint) error
|
||||
SetUserStatus(id uint, status uint) error
|
||||
VerifyUser(token *string, verificationType *string) (*models.Verification, error)
|
||||
SetVerificationToken(id *uint, verificationType *string) (string, error)
|
||||
RevokeVerification(id *uint, verificationType string) error
|
||||
HandlePasswordChangeRequest(user *models.User) (token string, err error)
|
||||
}
|
||||
|
||||
type UserService struct {
|
||||
Repo repositories.UserRepositoryInterface
|
||||
Licences repositories.LicenceInterface
|
||||
}
|
||||
|
||||
func (service *UserService) DeleteUser(lastname string, id uint) error {
|
||||
if id == 0 || lastname == "" {
|
||||
return errors.ErrNoData
|
||||
}
|
||||
|
||||
user, err := service.GetUserByID(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if user == nil {
|
||||
return errors.ErrUserNotFound
|
||||
}
|
||||
|
||||
return service.Repo.DeleteUser(id)
|
||||
}
|
||||
|
||||
func (service *UserService) UpdateUser(user *models.User) (*models.User, error) {
|
||||
|
||||
if user.ID == 0 {
|
||||
return nil, errors.ErrUserNotFound
|
||||
}
|
||||
|
||||
user.SetPassword(user.Password)
|
||||
|
||||
// Validate subscription model
|
||||
selectedModel, err := repositories.GetSubscriptionByName(&user.Membership.SubscriptionModel.Name)
|
||||
if err != nil {
|
||||
return nil, errors.ErrSubscriptionNotFound
|
||||
}
|
||||
user.Membership.SubscriptionModel = *selectedModel
|
||||
user.Membership.SubscriptionModelID = selectedModel.ID
|
||||
|
||||
updatedUser, err := service.Repo.UpdateUser(user)
|
||||
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, errors.ErrUserNotFound
|
||||
}
|
||||
if strings.Contains(err.Error(), "UNIQUE constraint failed") {
|
||||
return nil, errors.ErrDuplicateEntry
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return updatedUser, nil
|
||||
}
|
||||
|
||||
func (service *UserService) RegisterUser(user *models.User) (id uint, token string, err error) {
|
||||
|
||||
user.SetPassword(user.Password)
|
||||
|
||||
user.Status = constants.UnverifiedStatus
|
||||
user.CreatedAt = time.Now()
|
||||
user.UpdatedAt = time.Now()
|
||||
user.PaymentStatus = constants.AwaitingPaymentStatus
|
||||
user.BankAccount.MandateDateSigned = time.Now()
|
||||
id, err = service.Repo.CreateUser(user)
|
||||
if err != nil {
|
||||
return 0, "", err
|
||||
}
|
||||
|
||||
token, err = service.SetVerificationToken(&id, &constants.VerificationTypes.Email)
|
||||
if err != nil {
|
||||
return 0, "", err
|
||||
}
|
||||
return id, token, nil
|
||||
}
|
||||
|
||||
func (service *UserService) GetUserByID(id uint) (*models.User, error) {
|
||||
return repositories.GetUserByID(&id)
|
||||
}
|
||||
|
||||
func (service *UserService) GetUserByEmail(email string) (*models.User, error) {
|
||||
return service.Repo.GetUserByEmail(email)
|
||||
}
|
||||
|
||||
func (service *UserService) GetUsers(where map[string]interface{}) (*[]models.User, error) {
|
||||
if where == nil {
|
||||
where = map[string]interface{}{}
|
||||
}
|
||||
return service.Repo.GetUsers(where)
|
||||
}
|
||||
59
go-backend/internal/services/user_verification.go
Normal file
59
go-backend/internal/services/user_verification.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"GoMembership/internal/models"
|
||||
"GoMembership/internal/utils"
|
||||
"GoMembership/pkg/errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (s *UserService) SetVerificationToken(id *uint, verificationType *string) (string, error) {
|
||||
|
||||
token, err := utils.GenerateVerificationToken()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Check if user is already verified
|
||||
verified, err := s.Repo.IsVerified(id)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if verified {
|
||||
return "", errors.ErrAlreadyVerified
|
||||
}
|
||||
|
||||
// Prepare the Verification record
|
||||
verification := models.Verification{
|
||||
UserID: *id,
|
||||
VerificationToken: token,
|
||||
Type: *verificationType,
|
||||
}
|
||||
|
||||
return s.Repo.SetVerificationToken(&verification)
|
||||
}
|
||||
|
||||
func (s *UserService) RevokeVerification(id *uint, verificationType string) error {
|
||||
return s.Repo.DeleteVerification(*id, verificationType)
|
||||
|
||||
}
|
||||
|
||||
func (service *UserService) VerifyUser(token *string, verificationType *string) (*models.Verification, error) {
|
||||
verification, err := service.Repo.GetVerificationOfToken(token, verificationType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check if the user is already verified
|
||||
verified, err := service.Repo.IsVerified(&verification.UserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if verified {
|
||||
return nil, errors.ErrAlreadyVerified
|
||||
}
|
||||
t := time.Now()
|
||||
verification.VerifiedAt = &t
|
||||
|
||||
return verification, nil
|
||||
}
|
||||
20
go-backend/internal/utils/cookies.go
Normal file
20
go-backend/internal/utils/cookies.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func SetCookie(c *gin.Context, token string) {
|
||||
c.SetSameSite(http.SameSiteLaxMode)
|
||||
c.SetCookie(
|
||||
"jwt",
|
||||
token,
|
||||
5*24*60*60, // 5 days
|
||||
"/",
|
||||
"",
|
||||
true,
|
||||
true,
|
||||
)
|
||||
}
|
||||
101
go-backend/internal/utils/crypto.go
Normal file
101
go-backend/internal/utils/crypto.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"mime"
|
||||
"mime/quotedprintable"
|
||||
"net/mail"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Email struct {
|
||||
MimeVersion string
|
||||
Date string
|
||||
From string
|
||||
To string
|
||||
Subject string
|
||||
ContentType string
|
||||
Body string
|
||||
}
|
||||
|
||||
func GenerateRandomString(length int) (string, error) {
|
||||
bytes := make([]byte, length)
|
||||
_, err := rand.Read(bytes)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.URLEncoding.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
func GenerateVerificationToken() (string, error) {
|
||||
return GenerateRandomString(32)
|
||||
}
|
||||
|
||||
func DecodeMail(message string) (*Email, error) {
|
||||
msg, err := mail.ReadMessage(strings.NewReader(message))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
decodedBody, err := io.ReadAll(msg.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
decodedBodyString, err := DecodeQuotedPrintable(string(decodedBody))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
decodedSubject, err := DecodeRFC2047(msg.Header.Get("Subject"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
email := &Email{}
|
||||
|
||||
// Populate the headers
|
||||
email.MimeVersion = msg.Header.Get("Mime-Version")
|
||||
email.Date = msg.Header.Get("Date")
|
||||
email.From = msg.Header.Get("From")
|
||||
email.To = msg.Header.Get("To")
|
||||
email.Subject = decodedSubject
|
||||
email.Body = decodedBodyString
|
||||
email.ContentType = msg.Header.Get("Content-Type")
|
||||
|
||||
return email, nil
|
||||
}
|
||||
|
||||
func DecodeRFC2047(encoded string) (string, error) {
|
||||
decoder := new(mime.WordDecoder)
|
||||
decoded, err := decoder.DecodeHeader(encoded)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return decoded, nil
|
||||
}
|
||||
|
||||
func DecodeQuotedPrintable(encodedString string) (string, error) {
|
||||
// Decode quoted-printable encoding
|
||||
reader := quotedprintable.NewReader(strings.NewReader(encodedString))
|
||||
decodedBytes := new(bytes.Buffer)
|
||||
_, err := decodedBytes.ReadFrom(reader)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return decodedBytes.String(), nil
|
||||
}
|
||||
|
||||
func EncodeQuotedPrintable(s string) string {
|
||||
var buf bytes.Buffer
|
||||
|
||||
// Use Quoted-Printable encoder
|
||||
qp := quotedprintable.NewWriter(&buf)
|
||||
|
||||
// Write the UTF-8 encoded string to the Quoted-Printable encoder
|
||||
qp.Write([]byte(s))
|
||||
qp.Close()
|
||||
|
||||
// Encode the result into a MIME header
|
||||
return mime.QEncoding.Encode("UTF-8", buf.String())
|
||||
}
|
||||
33
go-backend/internal/utils/mock_smtp.go
Normal file
33
go-backend/internal/utils/mock_smtp.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
smtpmock "github.com/mocktools/go-smtp-mock/v2"
|
||||
)
|
||||
|
||||
var Server smtpmock.Server
|
||||
|
||||
// StartMockSMTPServer starts a mock SMTP server for testing
|
||||
func SMTPStart(host string, port int) error {
|
||||
Server = *smtpmock.New(smtpmock.ConfigurationAttr{
|
||||
HostAddress: host,
|
||||
PortNumber: port,
|
||||
LogToStdout: false,
|
||||
LogServerActivity: false,
|
||||
})
|
||||
if err := Server.Start(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func SMTPGetMessages() []smtpmock.Message {
|
||||
return Server.MessagesAndPurge()
|
||||
}
|
||||
|
||||
func SMTPStop() error {
|
||||
|
||||
if err := Server.Stop(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
131
go-backend/internal/utils/priviliges.go
Normal file
131
go-backend/internal/utils/priviliges.go
Normal file
@@ -0,0 +1,131 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"GoMembership/internal/constants"
|
||||
"GoMembership/internal/models"
|
||||
"GoMembership/pkg/logger"
|
||||
"errors"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
func HasPrivilige(user *models.User, privilige int8) bool {
|
||||
switch privilige {
|
||||
case constants.Priviliges.View:
|
||||
return user.RoleID >= constants.Roles.Viewer
|
||||
case constants.Priviliges.Update:
|
||||
return user.RoleID >= constants.Roles.Editor
|
||||
case constants.Priviliges.Create:
|
||||
return user.RoleID >= constants.Roles.Editor
|
||||
case constants.Priviliges.Delete:
|
||||
return user.RoleID >= constants.Roles.Editor
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// FilterAllowedStructFields filters allowed fields recursively in a struct and modifies structToModify in place.
|
||||
func FilterAllowedStructFields(input interface{}, existing interface{}, allowedFields map[string]bool, prefix string) error {
|
||||
v := reflect.ValueOf(input)
|
||||
origin := reflect.ValueOf(existing)
|
||||
|
||||
// Ensure both input and target are pointers to structs
|
||||
if v.Kind() != reflect.Ptr || origin.Kind() != reflect.Ptr {
|
||||
return errors.New("both input and existing must be pointers to structs")
|
||||
}
|
||||
|
||||
v = v.Elem()
|
||||
origin = origin.Elem()
|
||||
|
||||
if v.Kind() != reflect.Struct || origin.Kind() != reflect.Struct {
|
||||
return errors.New("both input and existing must be structs")
|
||||
}
|
||||
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
field := v.Type().Field(i)
|
||||
key := field.Name
|
||||
|
||||
// Skip unexported fields
|
||||
if !field.IsExported() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Build the full field path
|
||||
fullKey := key
|
||||
if prefix != "" {
|
||||
fullKey = prefix + "." + key
|
||||
}
|
||||
fieldValue := v.Field(i)
|
||||
originField := origin.Field(i)
|
||||
|
||||
// Handle nil pointers
|
||||
if fieldValue.Kind() == reflect.Ptr {
|
||||
if fieldValue.IsNil() {
|
||||
// If the field is nil, skip it or initialize it
|
||||
if !allowedFields[fullKey] {
|
||||
// If the field is not allowed, set it to the corresponding field from existing
|
||||
fieldValue.Set(originField)
|
||||
}
|
||||
continue
|
||||
}
|
||||
// Dereference the pointer for further processing
|
||||
fieldValue = fieldValue.Elem()
|
||||
originField = originField.Elem()
|
||||
}
|
||||
|
||||
// Handle slices
|
||||
if fieldValue.Kind() == reflect.Slice {
|
||||
if !allowedFields[fullKey] {
|
||||
// If the slice is not allowed, set it to the corresponding slice from existing
|
||||
fieldValue.Set(originField)
|
||||
continue
|
||||
} else {
|
||||
originField.Set(fieldValue)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle nested structs (including pointers to structs)
|
||||
if fieldValue.Kind() == reflect.Struct || (fieldValue.Kind() == reflect.Ptr && fieldValue.Type().Elem().Kind() == reflect.Struct) {
|
||||
if fieldValue.Kind() == reflect.Ptr {
|
||||
if fieldValue.IsNil() {
|
||||
continue
|
||||
}
|
||||
fieldValue = fieldValue.Elem()
|
||||
originField = originField.Elem() // May result in an invalid originField
|
||||
}
|
||||
|
||||
var originCopy reflect.Value
|
||||
|
||||
// Check if originField is valid (non-zero)
|
||||
if originField.IsValid() {
|
||||
originCopy = reflect.New(originField.Type()).Elem()
|
||||
originCopy.Set(originField)
|
||||
} else {
|
||||
// If originField is invalid (e.g., existing had a nil pointer),
|
||||
// create a new instance of the type from fieldValue
|
||||
originCopy = reflect.New(fieldValue.Type()).Elem()
|
||||
}
|
||||
|
||||
err := FilterAllowedStructFields(
|
||||
fieldValue.Addr().Interface(),
|
||||
originCopy.Addr().Interface(),
|
||||
allowedFields,
|
||||
fullKey,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Only allow whitelisted fields
|
||||
if !allowedFields[fullKey] {
|
||||
logger.Error.Printf("denying update of field: %#v", fullKey)
|
||||
fieldValue.Set(originField)
|
||||
} else {
|
||||
logger.Error.Printf("updating whitelisted field: %#v", fullKey)
|
||||
}
|
||||
|
||||
}
|
||||
return nil
|
||||
}
|
||||
176
go-backend/internal/utils/priviliges_test.go
Normal file
176
go-backend/internal/utils/priviliges_test.go
Normal file
@@ -0,0 +1,176 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
Name string
|
||||
Age int
|
||||
Address *Address
|
||||
Tags []string
|
||||
License License
|
||||
}
|
||||
|
||||
type Address struct {
|
||||
City string
|
||||
Country string
|
||||
}
|
||||
|
||||
type License struct {
|
||||
ID string
|
||||
Categories []string
|
||||
}
|
||||
|
||||
func TestFilterAllowedStructFields(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
existing interface{}
|
||||
allowedFields map[string]bool
|
||||
expectedResult interface{}
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Filter top-level fields",
|
||||
input: &User{
|
||||
Name: "Alice",
|
||||
Age: 30,
|
||||
},
|
||||
existing: &User{
|
||||
Name: "Bob",
|
||||
Age: 25,
|
||||
},
|
||||
allowedFields: map[string]bool{
|
||||
"Name": true,
|
||||
},
|
||||
expectedResult: &User{
|
||||
Name: "Alice", // Allowed field
|
||||
Age: 25, // Kept from existing
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Filter nested struct fields",
|
||||
input: &User{
|
||||
Name: "Alice",
|
||||
Address: &Address{
|
||||
City: "New York",
|
||||
Country: "USA",
|
||||
},
|
||||
},
|
||||
existing: &User{
|
||||
Name: "Bob",
|
||||
Address: &Address{
|
||||
City: "London",
|
||||
Country: "UK",
|
||||
},
|
||||
},
|
||||
allowedFields: map[string]bool{
|
||||
"Address.City": true,
|
||||
},
|
||||
expectedResult: &User{
|
||||
Name: "Bob", // Kept from existing
|
||||
Address: &Address{
|
||||
City: "New York", // Allowed field
|
||||
Country: "UK", // Kept from existing
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Filter slice fields",
|
||||
input: &User{
|
||||
Tags: []string{"admin", "user"},
|
||||
},
|
||||
existing: &User{
|
||||
Tags: []string{"guest"},
|
||||
},
|
||||
allowedFields: map[string]bool{
|
||||
"Tags": true,
|
||||
},
|
||||
expectedResult: &User{
|
||||
Tags: []string{"admin", "user"}, // Allowed slice
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Filter slice of structs",
|
||||
input: &User{
|
||||
License: License{
|
||||
ID: "123",
|
||||
Categories: []string{"A", "B"},
|
||||
},
|
||||
},
|
||||
existing: &User{
|
||||
License: License{
|
||||
ID: "456",
|
||||
Categories: []string{"C"},
|
||||
},
|
||||
},
|
||||
allowedFields: map[string]bool{
|
||||
"License.ID": true,
|
||||
},
|
||||
expectedResult: &User{
|
||||
License: License{
|
||||
ID: "123", // Allowed field
|
||||
Categories: []string{"C"}, // Kept from existing
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Filter pointer fields",
|
||||
input: &User{
|
||||
Address: &Address{
|
||||
City: "Paris",
|
||||
},
|
||||
},
|
||||
existing: &User{
|
||||
Address: &Address{
|
||||
City: "Berlin",
|
||||
Country: "Germany",
|
||||
},
|
||||
},
|
||||
allowedFields: map[string]bool{
|
||||
"Address.City": true,
|
||||
},
|
||||
expectedResult: &User{
|
||||
Address: &Address{
|
||||
City: "Paris", // Allowed field
|
||||
Country: "Germany", // Kept from existing
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid input (non-pointer)",
|
||||
input: User{
|
||||
Name: "Alice",
|
||||
},
|
||||
existing: &User{
|
||||
Name: "Bob",
|
||||
},
|
||||
allowedFields: map[string]bool{
|
||||
"Name": true,
|
||||
},
|
||||
expectedResult: nil,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := FilterAllowedStructFields(tt.input, tt.existing, tt.allowedFields, "")
|
||||
if (err != nil) != tt.expectError {
|
||||
t.Errorf("FilterAllowedStructFields() error = %v, expectError %v", err, tt.expectError)
|
||||
return
|
||||
}
|
||||
|
||||
if !tt.expectError && !reflect.DeepEqual(tt.input, tt.expectedResult) {
|
||||
t.Errorf("FilterAllowedStructFields() = %+v, expected %+v", tt.input, tt.expectedResult)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
50
go-backend/internal/utils/response_handler.go
Normal file
50
go-backend/internal/utils/response_handler.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"GoMembership/pkg/errors"
|
||||
"GoMembership/pkg/logger"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-playground/validator/v10"
|
||||
)
|
||||
|
||||
func RespondWithError(c *gin.Context, err error, context string, code int, field string, key string) {
|
||||
logger.Error.Printf("Sending %v Error Response(Field: %v Key: %v) %v: %v", code, field, key, context, err.Error())
|
||||
c.JSON(code, gin.H{"errors": []gin.H{{
|
||||
"field": field,
|
||||
"key": key,
|
||||
}}})
|
||||
}
|
||||
|
||||
func HandleValidationError(c *gin.Context, err error) {
|
||||
var validationErrors []gin.H
|
||||
logger.Error.Printf("Sending validation error response Error %v", err.Error())
|
||||
if ve, ok := err.(validator.ValidationErrors); ok {
|
||||
for _, e := range ve {
|
||||
validationErrors = append(validationErrors, gin.H{
|
||||
"field": e.Field(),
|
||||
"key": "server.validation." + e.Tag(),
|
||||
})
|
||||
}
|
||||
} else {
|
||||
validationErrors = append(validationErrors, gin.H{
|
||||
"field": "general",
|
||||
"key": "server.error.invalid_json",
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusBadRequest, gin.H{"errors": validationErrors})
|
||||
}
|
||||
|
||||
func HandleUserUpdateError(c *gin.Context, err error) {
|
||||
switch err {
|
||||
case errors.ErrUserNotFound:
|
||||
RespondWithError(c, err, "Error while updating user", http.StatusNotFound, "user.user", "server.validation.user_not_found")
|
||||
case errors.ErrInvalidUserData:
|
||||
RespondWithError(c, err, "Error while updating user", http.StatusBadRequest, "user.user", "server.validation.invalid_user_data")
|
||||
case errors.ErrSubscriptionNotFound:
|
||||
RespondWithError(c, err, "Error while updating user", http.StatusBadRequest, "subscription", "server.validation.subscription_data")
|
||||
default:
|
||||
RespondWithError(c, err, "Error while updating user", http.StatusInternalServerError, "user.user", "server.error.internal_server_error")
|
||||
}
|
||||
}
|
||||
53
go-backend/internal/validation/DriversLicence_validation.go
Normal file
53
go-backend/internal/validation/DriversLicence_validation.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package validation
|
||||
|
||||
import (
|
||||
"GoMembership/internal/models"
|
||||
"time"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
)
|
||||
|
||||
func validateDriverslicence(sl validator.StructLevel) {
|
||||
dl := sl.Current().Interface().(models.User).Licence
|
||||
// if !vValidateLicence(dl.Number) {
|
||||
if dl.Number == "" {
|
||||
sl.ReportError(dl.Number, "licence_number", "", "invalid", "")
|
||||
}
|
||||
if dl.IssuedDate.After(time.Now()) {
|
||||
sl.ReportError(dl.IssuedDate, "issued_date", "", "invalid", "")
|
||||
}
|
||||
if dl.ExpirationDate.Before(time.Now().AddDate(0, 0, 3)) {
|
||||
sl.ReportError(dl.ExpirationDate, "expiration_date", "", "too_soon", "")
|
||||
}
|
||||
}
|
||||
|
||||
// seems like not every country has to have an licence id and it seems that germany changed their id generation type..
|
||||
// func validateLicence(fieldValue string) bool {
|
||||
// if len(fieldValue) != 11 {
|
||||
// return false
|
||||
// }
|
||||
|
||||
// id, tenthChar := string(fieldValue[:9]), string(fieldValue[9])
|
||||
|
||||
// if tenthChar == "X" {
|
||||
// tenthChar = "10"
|
||||
// }
|
||||
// tenthValue, _ := strconv.ParseInt(tenthChar, 10, 8)
|
||||
|
||||
// // for readability
|
||||
// weights := []int{9, 8, 7, 6, 5, 4, 3, 2, 1}
|
||||
// sum := 0
|
||||
|
||||
// for i := 0; i < 9; i++ {
|
||||
// char := string(id[i])
|
||||
// value, _ := strconv.ParseInt(char, 36, 64)
|
||||
// sum += int(value) * weights[i]
|
||||
// }
|
||||
|
||||
// calcCheckDigit := sum % 11
|
||||
// if calcCheckDigit != int(tenthValue) {
|
||||
// return false
|
||||
// }
|
||||
|
||||
// return true
|
||||
// }
|
||||
27
go-backend/internal/validation/bankAccount_validation.go
Normal file
27
go-backend/internal/validation/bankAccount_validation.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package validation
|
||||
|
||||
import (
|
||||
"GoMembership/internal/models"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
"github.com/jbub/banking/iban"
|
||||
"github.com/jbub/banking/swift"
|
||||
)
|
||||
|
||||
func validateBankAccount(sl validator.StructLevel) {
|
||||
ba := sl.Current().Interface().(models.User).BankAccount
|
||||
if !ibanValidator(ba.IBAN) {
|
||||
sl.ReportError(ba.IBAN, "IBAN", "BankAccount.IBAN", "required", "")
|
||||
}
|
||||
if ba.BIC != "" && !bicValidator(ba.BIC) {
|
||||
sl.ReportError(ba.IBAN, "IBAN", "BankAccount.IBAN", "required", "")
|
||||
}
|
||||
}
|
||||
|
||||
func ibanValidator(fieldValue string) bool {
|
||||
return iban.Validate(fieldValue) == nil
|
||||
}
|
||||
|
||||
func bicValidator(fieldValue string) bool {
|
||||
return swift.Validate(fieldValue) == nil
|
||||
}
|
||||
34
go-backend/internal/validation/general_validation.go
Normal file
34
go-backend/internal/validation/general_validation.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package validation
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
)
|
||||
|
||||
var xssPatterns = []*regexp.Regexp{
|
||||
regexp.MustCompile(`(?i)<script`),
|
||||
regexp.MustCompile(`(?i)javascript:`),
|
||||
regexp.MustCompile(`(?i)on\w+\s*=`),
|
||||
regexp.MustCompile(`(?i)(vbscript|data):`),
|
||||
regexp.MustCompile(`(?i)<(iframe|object|embed|applet)`),
|
||||
regexp.MustCompile(`(?i)expression\s*\(`),
|
||||
regexp.MustCompile(`(?i)url\s*\(`),
|
||||
regexp.MustCompile(`(?i)<\?`),
|
||||
regexp.MustCompile(`(?i)<%`),
|
||||
regexp.MustCompile(`(?i)<!\[CDATA\[`),
|
||||
regexp.MustCompile(`(?i)<(svg|animate)`),
|
||||
regexp.MustCompile(`(?i)<(audio|video|source)`),
|
||||
regexp.MustCompile(`(?i)base64`),
|
||||
}
|
||||
|
||||
func ValidateSafeContent(fl validator.FieldLevel) bool {
|
||||
input := strings.ToLower(fl.Field().String())
|
||||
for _, pattern := range xssPatterns {
|
||||
if pattern.MatchString(input) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
42
go-backend/internal/validation/membership_validation.go
Normal file
42
go-backend/internal/validation/membership_validation.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package validation
|
||||
|
||||
import (
|
||||
"GoMembership/internal/models"
|
||||
"GoMembership/internal/repositories"
|
||||
"GoMembership/pkg/errors"
|
||||
"GoMembership/pkg/logger"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
)
|
||||
|
||||
func validateMembership(sl validator.StructLevel) {
|
||||
membership := sl.Current().Interface().(models.User).Membership
|
||||
if membership.SubscriptionModel.RequiredMembershipField != "" {
|
||||
switch membership.SubscriptionModel.RequiredMembershipField {
|
||||
case "ParentMembershipID":
|
||||
if err := CheckParentMembershipID(membership); err != nil {
|
||||
logger.Error.Printf("Error ParentMembershipValidation: %v", err.Error())
|
||||
sl.ReportError(membership.ParentMembershipID, membership.SubscriptionModel.RequiredMembershipField,
|
||||
"RequiredMembershipField", "invalid", "")
|
||||
}
|
||||
default:
|
||||
logger.Error.Printf("Error no matching RequiredMembershipField: %v", errors.ErrInvalidValue.Error())
|
||||
sl.ReportError(membership.ParentMembershipID, membership.SubscriptionModel.RequiredMembershipField,
|
||||
"RequiredMembershipField", "not_implemented", "")
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func CheckParentMembershipID(membership models.Membership) error {
|
||||
|
||||
if membership.ParentMembershipID == 0 {
|
||||
return errors.ValErrParentIDNotSet
|
||||
} else {
|
||||
_, err := repositories.GetUserByID(&membership.ParentMembershipID)
|
||||
if err != nil {
|
||||
return errors.ValErrParentIDNotFound
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
20
go-backend/internal/validation/setup.go
Normal file
20
go-backend/internal/validation/setup.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package validation
|
||||
|
||||
import (
|
||||
"GoMembership/internal/models"
|
||||
|
||||
"github.com/gin-gonic/gin/binding"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
)
|
||||
|
||||
func SetupValidators() {
|
||||
if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
|
||||
// Register custom validators
|
||||
v.RegisterValidation("safe_content", ValidateSafeContent)
|
||||
|
||||
// Register struct-level validations
|
||||
v.RegisterStructValidation(validateUser, models.User{})
|
||||
v.RegisterStructValidation(ValidateSubscription, models.SubscriptionModel{})
|
||||
}
|
||||
}
|
||||
46
go-backend/internal/validation/subscription_validation.go
Normal file
46
go-backend/internal/validation/subscription_validation.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package validation
|
||||
|
||||
import (
|
||||
"GoMembership/internal/models"
|
||||
"GoMembership/internal/repositories"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
)
|
||||
|
||||
// ValidateNewSubscription validates a new subscription model being created
|
||||
func ValidateSubscription(sl validator.StructLevel) {
|
||||
subscription := sl.Current().Interface().(models.SubscriptionModel)
|
||||
|
||||
if subscription.Name == "" {
|
||||
sl.ReportError(subscription.Name, "Name", "name", "required", "")
|
||||
}
|
||||
|
||||
if sl.Parent().Type().Name() == "MembershipData" {
|
||||
// This is modifying a subscription directly
|
||||
if subscription.Details == "" {
|
||||
sl.ReportError(subscription.Details, "Details", "details", "required", "")
|
||||
}
|
||||
|
||||
if subscription.MonthlyFee < 0 {
|
||||
sl.ReportError(subscription.MonthlyFee, "MonthlyFee", "monthly_fee", "gte", "0")
|
||||
}
|
||||
|
||||
if subscription.HourlyRate < 0 {
|
||||
sl.ReportError(subscription.HourlyRate, "HourlyRate", "hourly_rate", "gte", "0")
|
||||
}
|
||||
|
||||
if subscription.IncludedPerYear < 0 {
|
||||
sl.ReportError(subscription.IncludedPerYear, "IncludedPerYear", "included_hours_per_year", "gte", "0")
|
||||
}
|
||||
|
||||
if subscription.IncludedPerMonth < 0 {
|
||||
sl.ReportError(subscription.IncludedPerMonth, "IncludedPerMonth", "included_hours_per_month", "gte", "0")
|
||||
}
|
||||
} else {
|
||||
// This is a nested probably user struct. We are only checking if the model exists
|
||||
existingSubscription, err := repositories.GetSubscriptionByName(&subscription.Name)
|
||||
if err != nil || existingSubscription == nil {
|
||||
sl.ReportError(subscription.Name, "Subscription_Name", "name", "exists", "")
|
||||
}
|
||||
}
|
||||
}
|
||||
49
go-backend/internal/validation/user_validation.go
Normal file
49
go-backend/internal/validation/user_validation.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package validation
|
||||
|
||||
import (
|
||||
"GoMembership/internal/constants"
|
||||
"GoMembership/internal/models"
|
||||
"GoMembership/internal/repositories"
|
||||
"GoMembership/pkg/logger"
|
||||
"time"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
)
|
||||
|
||||
func validateUser(sl validator.StructLevel) {
|
||||
user := sl.Current().Interface().(models.User)
|
||||
|
||||
isSuper := user.RoleID >= constants.Roles.Admin
|
||||
|
||||
if user.RoleID > constants.Roles.Member && user.Password == "" {
|
||||
passwordExists, err := repositories.PasswordExists(&user.ID)
|
||||
if err != nil || !passwordExists {
|
||||
logger.Error.Printf("Error checking password exists for user %v: %v", user.Email, err)
|
||||
sl.ReportError(user.Password, "Password", "password", "required", "")
|
||||
}
|
||||
}
|
||||
// Validate User > 18 years old
|
||||
if user.DateOfBirth.After(time.Now().AddDate(-18, 0, 0)) {
|
||||
sl.ReportError(user.DateOfBirth, "DateOfBirth", "dateofbirth", "age", "")
|
||||
}
|
||||
// validate subscriptionModel
|
||||
if user.Membership.SubscriptionModel.Name == "" {
|
||||
sl.ReportError(user.Membership.SubscriptionModel.Name, "SubscriptionModel.Name", "name", "required", "")
|
||||
} else {
|
||||
selectedModel, err := repositories.GetSubscriptionByName(&user.Membership.SubscriptionModel.Name)
|
||||
if err != nil {
|
||||
logger.Error.Printf("Error finding subscription model for user %v: %v", user.Email, err)
|
||||
sl.ReportError(user.Membership.SubscriptionModel.Name, "SubscriptionModel.Name", "name", "invalid", "")
|
||||
} else {
|
||||
user.Membership.SubscriptionModel = *selectedModel
|
||||
}
|
||||
}
|
||||
|
||||
validateMembership(sl)
|
||||
if !isSuper {
|
||||
validateBankAccount(sl)
|
||||
if user.Licence != nil {
|
||||
validateDriverslicence(sl)
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user