Compare commits

...

9 Commits

Author SHA1 Message Date
$(pass /github/name)
00facf8758 add: update handling 2024-09-20 08:29:00 +02:00
$(pass /github/name)
62624cd0f8 membership input validation improved & tests 2024-09-20 08:28:23 +02:00
$(pass /github/name)
361fa1316a add sql injection test 2024-09-20 08:27:34 +02:00
$(pass /github/name)
851e62dbac add xss validation 2024-09-20 08:26:07 +02:00
$(pass /github/name)
1e68e7d390 Add: session handling 2024-09-20 08:25:26 +02:00
$(pass /github/name)
31c47270ab chg Routing again 2024-09-20 08:24:42 +02:00
$(pass /github/name)
81e9068eba add: Cookie generation 2024-09-20 08:00:24 +02:00
$(pass /github/name)
74ef7efdec add: custom errors 2024-09-20 07:58:17 +02:00
$(pass /github/name)
46afa417b7 xss mitigation & test 2024-09-20 07:57:54 +02:00
21 changed files with 849 additions and 291 deletions

View File

@@ -0,0 +1,71 @@
package controllers
import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
)
type SQLInjectionTest struct {
name string
email string
password string
expectedStatus int
}
func (sit *SQLInjectionTest) SetupContext() (*gin.Context, *httptest.ResponseRecorder, *gin.Engine) {
loginData := loginInput{
Email: sit.email,
Password: sit.password,
}
jsonData, _ := json.Marshal(loginData)
return GetMockedJSONContext(jsonData, "/login")
}
func (sit *SQLInjectionTest) RunHandler(c *gin.Context, router *gin.Engine) {
router.POST("/login", Uc.LoginHandler)
router.ServeHTTP(c.Writer, c.Request)
}
func (sit *SQLInjectionTest) ValidateResponse(w *httptest.ResponseRecorder) error {
if sit.expectedStatus != w.Code {
responseBody, _ := io.ReadAll(w.Body)
return fmt.Errorf("SQL Injection Attempt: Didn't get the expected response code: got: %v; expected: %v. Context: %#v", w.Code, sit.expectedStatus, string(responseBody))
}
return nil
}
func (sit *SQLInjectionTest) ValidateResult() error {
// Add any additional validation if needed
return nil
}
func testSQLInjectionAttempt(t *testing.T) {
tests := []SQLInjectionTest{
{
name: "SQL Injection Attempt in Email",
email: "' OR '1'='1",
password: "password123",
expectedStatus: http.StatusNotFound,
},
{
name: "SQL Injection Attempt in Password",
email: "user@example.com",
password: "' OR '1'='1",
expectedStatus: http.StatusNotFound,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := runSingleTest(&tt); err != nil {
t.Errorf("Test failed: %v", err.Error())
}
})
}
}

View File

@@ -0,0 +1,31 @@
package controllers
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
)
func testXSSAttempt(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.POST("/register", Uc.RegisterUser)
xssPayload := "<script>alert('XSS')</script>"
user := getBaseUser()
user.FirstName = xssPayload
user.Email = "user@xss.hack"
jsonData, _ := json.Marshal(RegistrationData{User: user})
req, _ := http.NewRequest("POST", "/register", bytes.NewBuffer(jsonData))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
assert.Equal(t, http.StatusNotAcceptable, w.Code)
assert.NotContains(t, w.Body.String(), xssPayload)
}

View File

@@ -10,6 +10,7 @@ import (
"path/filepath"
"strconv"
"testing"
"time"
"log"
@@ -36,6 +37,11 @@ const (
Port int = 2525
)
type loginInput struct {
Email string `json:"email"`
Password string `json:"password"`
}
var (
Uc *UserController
Mc *MembershipController
@@ -73,6 +79,9 @@ func TestSuite(t *testing.T) {
if err := os.Setenv("BASE_URL", "http://"+Host+":2525"); err != nil {
log.Fatalf("Error setting environment variable: %v", err)
}
if err := os.Setenv("DB_PATH", "test.db"); err != nil {
log.Fatalf("Error setting environment variable: %v", err)
}
config.LoadConfig()
if err := database.Open("test.db", config.Recipients.AdminEmail); err != nil {
log.Fatalf("Failed to create DB: %#v", err)
@@ -100,13 +109,14 @@ func TestSuite(t *testing.T) {
log.Fatalf("Failed to init Subscription plans: %#v", err)
}
// Run all tests
// code := m.Run()
t.Run("userController", func(t *testing.T) {
testUserController(t)
})
t.Run("SQL_Injection", func(t *testing.T) {
testSQLInjectionAttempt(t)
})
t.Run("contactController", func(t *testing.T) {
testContactController(t)
})
@@ -115,6 +125,10 @@ func TestSuite(t *testing.T) {
testMembershipController(t)
})
t.Run("XSSAttempt", func(t *testing.T) {
testXSSAttempt(t)
})
if err := utils.SMTPStop(); err != nil {
log.Fatalf("Failed to stop SMTP Mockup Server: %#v", err)
}
@@ -195,6 +209,24 @@ func GetMockedFormContext(formData url.Values, url string) (*gin.Context, *httpt
return c, w, router
}
func getBaseUser() models.User {
return models.User{
DateOfBirth: time.Date(2000, time.January, 1, 0, 0, 0, 0, time.UTC),
FirstName: "John",
LastName: "Doe",
Email: "john.doe@example.com",
Address: "Pablo Escobar Str. 4",
ZipCode: "25474",
City: "Hasloh",
Phone: "01738484993",
BankAccount: models.BankAccount{IBAN: "DE89370400440532013000"},
Membership: models.Membership{SubscriptionModel: models.SubscriptionModel{Name: "Basic"}},
ProfilePicture: "",
Password: "password123",
Company: "",
}
}
func deleteTestDB(dbPath string) error {
err := os.Remove(dbPath)
if err != nil {

View File

@@ -3,6 +3,7 @@ package controllers
import (
"GoMembership/internal/models"
"GoMembership/internal/services"
"strings"
"net/http"
// "strconv"
@@ -31,7 +32,11 @@ func (mc *MembershipController) RegisterSubscription(c *gin.Context) {
id, err := mc.Service.RegisterSubscription(&regData.Model)
if err != nil {
logger.Error.Printf("Couldn't register Membershipmodel: %v", err)
c.JSON(int(id), "Couldn't register Membershipmodel")
if strings.Contains(err.Error(), "UNIQUE constraint failed") {
c.JSON(http.StatusConflict, "Duplicate subscription name")
return
}
c.JSON(http.StatusNotAcceptable, "Couldn't register Membershipmodel")
return
}
logger.Info.Printf("registering subscription: %+v", regData)

View File

@@ -6,7 +6,9 @@ import (
"net/http/httptest"
"testing"
"GoMembership/internal/config"
"GoMembership/internal/models"
"GoMembership/pkg/logger"
"github.com/gin-gonic/gin"
)
@@ -23,6 +25,9 @@ func testMembershipController(t *testing.T) {
tests := getSubscriptionData()
for _, tt := range tests {
logger.Error.Print("==============================================================")
logger.Error.Printf("MembershipController : %v", tt.Name)
logger.Error.Print("==============================================================")
t.Run(tt.Name, func(t *testing.T) {
if err := runSingleTest(&tt); err != nil {
t.Errorf("Test failed: %v", err.Error())
@@ -63,8 +68,9 @@ func validateSubscription(assert bool, wantDBData map[string]interface{}) error
func getBaseSubscription() MembershipData {
return MembershipData{
APIKey: config.Auth.APIKEY,
Model: models.SubscriptionModel{
Name: "Just a Subscription",
Name: "Premium",
Details: "A subscription detail",
MonthlyFee: 12.0,
HourlyRate: 14.0,
@@ -79,7 +85,7 @@ func customizeSubscription(customize func(MembershipData) MembershipData) Member
func getSubscriptionData() []RegisterSubscriptionTest {
return []RegisterSubscriptionTest{
{
Name: "No Details should fail",
Name: "Missing details should fail",
WantResponse: http.StatusNotAcceptable,
WantDBData: map[string]interface{}{"name": "Just a Subscription"},
Assert: false,
@@ -90,7 +96,7 @@ func getSubscriptionData() []RegisterSubscriptionTest {
})),
},
{
Name: "No Model Name should fail",
Name: "Missing model name should fail",
WantResponse: http.StatusNotAcceptable,
WantDBData: map[string]interface{}{"name": ""},
Assert: false,
@@ -100,10 +106,30 @@ func getSubscriptionData() []RegisterSubscriptionTest {
return subscription
})),
},
{
Name: "Negative monthly fee should fail",
WantResponse: http.StatusNotAcceptable,
WantDBData: map[string]interface{}{"name": "Premium"},
Assert: false,
Input: GenerateInputJSON(customizeSubscription(func(sub MembershipData) MembershipData {
sub.Model.MonthlyFee = -10.0
return sub
})),
},
{
Name: "Negative hourly rate should fail",
WantResponse: http.StatusNotAcceptable,
WantDBData: map[string]interface{}{"name": "Premium"},
Assert: false,
Input: GenerateInputJSON(customizeSubscription(func(sub MembershipData) MembershipData {
sub.Model.HourlyRate = -1.0
return sub
})),
},
{
Name: "correct entry should pass",
WantResponse: http.StatusCreated,
WantDBData: map[string]interface{}{"name": "Just a Subscription"},
WantDBData: map[string]interface{}{"name": "Premium"},
Assert: true,
Input: GenerateInputJSON(
customizeSubscription(func(subscription MembershipData) MembershipData {
@@ -113,5 +139,12 @@ func getSubscriptionData() []RegisterSubscriptionTest {
return subscription
})),
},
{
Name: "Duplicate subscription name should fail",
WantResponse: http.StatusConflict,
WantDBData: map[string]interface{}{"name": "Premium"},
Assert: true, // The original subscription should still exist
Input: GenerateInputJSON(getBaseSubscription()),
},
}
}

View File

@@ -1,17 +1,18 @@
package controllers
import (
"fmt"
"GoMembership/internal/config"
"GoMembership/internal/constants"
"GoMembership/internal/middlewares"
"GoMembership/internal/models"
"GoMembership/internal/services"
"GoMembership/internal/utils"
"net/http"
"github.com/gin-gonic/gin"
"GoMembership/pkg/errors"
"GoMembership/pkg/logger"
)
@@ -27,12 +28,80 @@ type RegistrationData struct {
User models.User `json:"user"`
}
func (uc *UserController) CurrentUserHandler(c *gin.Context) {
userIDString, ok := c.Get("user_id")
if !ok || userIDString == nil {
logger.Error.Printf("Error getting user_id from header")
func (uc *UserController) UpdateHandler(c *gin.Context) {
var user models.User
if err := c.ShouldBindJSON(&user); err != nil {
logger.Error.Printf("Couldn't decode input: %v", err)
c.JSON(http.StatusBadRequest, gin.H{"error": "Couldn't decode request data"})
return
}
userID := userIDString.(float64)
tokenString, err := c.Cookie("jwt")
if err != nil {
logger.Error.Printf("No Auth token: %v\n", err)
c.JSON(http.StatusUnauthorized, gin.H{"error": "No Auth token"})
c.Abort()
return
}
_, claims, err := middlewares.ExtractContentFrom(tokenString)
if err != nil {
logger.Error.Printf("Error retrieving token and claims from JWT")
c.JSON(http.StatusInternalServerError, gin.H{"error": "JWT parsing error"})
return
}
jwtUserID := int64((*claims)["user_id"].(float64))
userRole := int8((*claims)["role_id"].(float64))
if user.ID == 0 {
logger.Error.Printf("No User.ID in request from user with id: %v, aborting", jwtUserID)
c.JSON(http.StatusBadRequest, gin.H{"error": "No user id provided"})
return
}
if user.ID != jwtUserID && userRole < constants.Roles.Editor {
c.JSON(http.StatusForbidden, gin.H{"error": "You are not authorized to update this user"})
return
}
// TODO: If it's not an admin, prevent changes to critical fields
// if userRole != constants.Roles.Admin {
// existingUser, err := uc.Service.GetUserByID(jwtUserID)
// if err != nil {
// c.JSON(http.StatusInternalServerError, gin.H{"error": "Error retrieving user data"})
// return
// }
// user.Email = existingUser.Email
// user.RoleID = existingUser.RoleID
// }
updatedUser, err := uc.Service.UpdateUser(&user)
if err != nil {
switch err {
case errors.ErrUserNotFound:
c.JSON(http.StatusNotFound, gin.H{"error": "User not found"})
case errors.ErrInvalidUserData:
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user data"})
default:
logger.Error.Printf("Failed to update user: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Internal Server error"})
}
return
}
c.JSON(http.StatusAccepted, gin.H{"message": "User updated successfully", "user": updatedUser})
}
func (uc *UserController) CurrentUserHandler(c *gin.Context) {
userIDInterface, ok := c.Get("user_id")
if !ok || userIDInterface == nil {
logger.Error.Printf("Error getting user_id from header")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Missing or invalid user ID type"})
return
}
userID, ok := userIDInterface.(int64)
if !ok {
logger.Error.Printf("Error: user_id is not of type int64")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user ID type"})
return
}
user, err := uc.Service.GetUserByID(int64(userID))
if err != nil {
logger.Error.Printf("Error retrieving valid user: %v", err)
@@ -44,7 +113,13 @@ func (uc *UserController) CurrentUserHandler(c *gin.Context) {
}
func (uc *UserController) LogoutHandler(c *gin.Context) {
// just clear the JWT cookie
tokenString, err := c.Cookie("jwt")
if err != nil {
logger.Error.Printf("unable to get token from cookie: %#v", err)
}
middlewares.InvalidateSession(tokenString)
c.SetCookie("jwt", "", -1, "/", "", true, true)
c.JSON(http.StatusOK, gin.H{"message": "Logged out successfully"})
}
@@ -82,25 +157,17 @@ func (uc *UserController) LoginHandler(c *gin.Context) {
return
}
token, err := middlewares.GenerateToken(user.ID)
logger.Error.Printf("jwtsevret: %v", config.Auth.JWTSecret)
token, err := middlewares.GenerateToken(config.Auth.JWTSecret, user, "")
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate JWT token"})
return
}
c.SetCookie(
"jwt",
token,
10*60, // 10 minutes
"/",
"",
true,
true,
)
utils.SetCookie(c, token)
c.JSON(http.StatusOK, gin.H{
"message": "Login successful",
"set-token": token,
"message": "Login successful",
})
}
@@ -118,7 +185,7 @@ func (uc *UserController) RegisterUser(c *gin.Context) {
c.JSON(http.StatusNotAcceptable, gin.H{"error": "No subscription model provided"})
return
}
logger.Error.Printf("user.membership: %#v", regData.User.Membership)
selectedModel, err := uc.MembershipService.GetModelByName(&regData.User.Membership.SubscriptionModel.Name)
if err != nil {
logger.Error.Printf("%v:No subscription model found: %#v", regData.User.Email, err)
@@ -134,7 +201,7 @@ func (uc *UserController) RegisterUser(c *gin.Context) {
id, token, err := uc.Service.RegisterUser(&regData.User)
if err != nil {
logger.Error.Printf("Couldn't register User(%v): %v", regData.User.Email, err)
c.JSON(int(id), gin.H{"error": fmt.Sprintf("Couldn't register User: %v", err)})
c.JSON(int(id), gin.H{"error": "Couldn't register User"})
return
}
regData.User.ID = id
@@ -194,7 +261,7 @@ func (uc *UserController) VerifyMailHandler(c *gin.Context) {
c.HTML(http.StatusUnauthorized, "verification_error.html", gin.H{"ErrorMessage": "Emailadresse wurde schon bestätigt. Sollte dies nicht der Fall sein, wende Dich bitte an info@carsharing-hasloh.de."})
return
}
logger.Info.Printf("User: %#v", user)
logger.Info.Printf("VerificationMailHandler User: %#v", user.Email)
uc.EmailService.SendWelcomeEmail(user)
c.HTML(http.StatusOK, "verification_success.html", gin.H{"FirstName": user.FirstName})

View File

@@ -1,6 +1,7 @@
package controllers
import (
"bytes"
"encoding/json"
"fmt"
"io"
@@ -9,6 +10,7 @@ import (
"net/url"
"path/filepath"
"regexp"
"strconv"
"strings"
"testing"
"time"
@@ -22,12 +24,9 @@ import (
"GoMembership/internal/models"
"GoMembership/internal/utils"
"GoMembership/pkg/logger"
)
type loginInput struct {
Email string `json:"email"`
Password string `json:"password"`
}
"github.com/golang-jwt/jwt/v5"
)
type RegisterUserTest struct {
WantDBData map[string]interface{}
@@ -37,6 +36,8 @@ type RegisterUserTest struct {
Assert bool
}
var jwtSigningMethod = jwt.SigningMethodHS256
func (rt *RegisterUserTest) SetupContext() (*gin.Context, *httptest.ResponseRecorder, *gin.Engine) {
return GetMockedJSONContext([]byte(rt.Input), "register")
}
@@ -61,17 +62,23 @@ func testUserController(t *testing.T) {
tests := getTestUsers()
for _, tt := range tests {
logger.Error.Print("==============================================================")
logger.Error.Printf("Register User Testing : %v", tt.Name)
logger.Error.Print("==============================================================")
t.Run(tt.Name, func(t *testing.T) {
if err := runSingleTest(&tt); err != nil {
t.Fatalf("Test failed: %v", err.Error())
}
})
}
testCurrentUserHandler(t)
loginEmail, loginCookie := testLoginHandler(t)
logoutCookie := testCurrentUserHandler(t, loginEmail, loginCookie)
testUpdateUser(t, loginEmail, loginCookie)
testLogoutHandler(t, logoutCookie)
}
func testLogoutHandler(t *testing.T) {
loginCookie := testCurrentUserHandler(t)
func testLogoutHandler(t *testing.T, loginCookie http.Cookie) {
tests := []struct {
name string
@@ -93,6 +100,9 @@ func testLogoutHandler(t *testing.T) {
}
for _, tt := range tests {
logger.Error.Print("==============================================================")
logger.Error.Printf("Logout User Testing : %v", tt.name)
logger.Error.Print("==============================================================")
t.Run(tt.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
@@ -125,11 +135,11 @@ func testLogoutHandler(t *testing.T) {
// Verify that the user can no longer access protected routes
w = httptest.NewRecorder()
req, _ = http.NewRequest("GET", "/current-user", nil)
req, _ = http.NewRequest("GET", "/current", nil)
if logoutCookie != nil {
req.AddCookie(logoutCookie)
}
router.GET("/current-user", middlewares.AuthMiddleware(), Uc.CurrentUserHandler)
router.GET("/current", middlewares.AuthMiddleware(), Uc.CurrentUserHandler)
router.ServeHTTP(w, req)
assert.Equal(t, http.StatusUnauthorized, w.Code, "User should not be able to access protected routes after logout")
})
@@ -196,9 +206,8 @@ func testLoginHandler(t *testing.T) (string, http.Cookie) {
assert.NoError(t, err)
if tt.wantToken {
logger.Info.Printf("Response: %#v", response)
assert.Contains(t, response, "set-token")
assert.NotEmpty(t, response["set-token"])
assert.Contains(t, response, "message")
assert.Equal(t, "Login successful", response["message"])
for _, cookie := range w.Result().Cookies() {
if cookie.Name == "jwt" {
loginCookie = *cookie
@@ -211,7 +220,8 @@ func testLoginHandler(t *testing.T) (string, http.Cookie) {
}
assert.NotEmpty(t, loginCookie)
} else {
assert.NotContains(t, response, "set-token")
assert.Contains(t, response, "error")
assert.NotEmpty(t, response["error"])
}
})
@@ -220,8 +230,7 @@ func testLoginHandler(t *testing.T) (string, http.Cookie) {
return loginInput.Email, loginCookie
}
func testCurrentUserHandler(t *testing.T) http.Cookie {
loginEmail, loginCookie := testLoginHandler(t)
func testCurrentUserHandler(t *testing.T, loginEmail string, loginCookie http.Cookie) http.Cookie {
// This test should run after the user login test
invalidCookie := http.Cookie{
Name: "jwt",
@@ -232,6 +241,7 @@ func testCurrentUserHandler(t *testing.T) http.Cookie {
setupCookie func(*http.Request)
expectedUserMail string
expectedStatus int
expectNewCookie bool
}{
{
name: "With valid cookie",
@@ -241,6 +251,24 @@ func testCurrentUserHandler(t *testing.T) http.Cookie {
expectedUserMail: loginEmail,
expectedStatus: http.StatusOK,
},
{
name: "With valid expired cookie",
setupCookie: func(req *http.Request) {
sessionID := "test-session"
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{
"user_id": 1,
"role_id": 0,
"session_id": sessionID,
"exp": time.Now().Add(-time.Hour).Unix(), // Expired 1 hour ago
})
tokenString, _ := token.SignedString([]byte(config.Auth.JWTSecret))
req.AddCookie(&http.Cookie{Name: "jwt", Value: tokenString})
middlewares.UpdateSession(sessionID, 1) // Add a valid session
},
expectedUserMail: config.Recipients.AdminEmail,
expectedStatus: http.StatusOK,
expectNewCookie: true,
},
{
name: "Without cookie",
setupCookie: func(req *http.Request) {},
@@ -259,18 +287,15 @@ func testCurrentUserHandler(t *testing.T) http.Cookie {
logger.Error.Print("==============================================================")
logger.Error.Printf("Testing : %v", tt.name)
logger.Error.Print("==============================================================")
if tt.expectedStatus == http.StatusOK {
time.Sleep(time.Second) // Small delay to ensure different timestamps to get a different JWT token
}
t.Run(tt.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(middlewares.AuthMiddleware())
router.GET("/current-user", Uc.CurrentUserHandler)
router.GET("/current", Uc.CurrentUserHandler)
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/current-user", nil)
req, _ := http.NewRequest("GET", "/current", nil)
tt.setupCookie(req)
router.ServeHTTP(w, req)
@@ -290,9 +315,13 @@ func testCurrentUserHandler(t *testing.T) http.Cookie {
break
}
}
assert.NotNil(t, newCookie, "Cookie should be renewed")
assert.NotEqual(t, loginCookie.Value, newCookie.Value, "Cookie value should be different")
assert.True(t, newCookie.MaxAge > 0, "New cookie should not be expired")
if tt.expectNewCookie {
assert.NotNil(t, newCookie, "New cookie should be set for expired token")
assert.NotEqual(t, loginCookie.Value, newCookie.Value, "Cookie value should be different")
assert.True(t, newCookie.MaxAge > 0, "New cookie should not be expired")
} else {
assert.Nil(t, newCookie, "No new cookie should be set for non-expired token")
}
} else {
// For unauthorized requests, check for an error message
var errorResponse map[string]string
@@ -316,29 +345,32 @@ func validateUser(assert bool, wantDBData map[string]interface{}) error {
if assert != (len(*users) != 0) {
return fmt.Errorf("User entry query didn't met expectation: %v != %#v", assert, *users)
}
if assert {
//check for email delivery
messages := utils.SMTPGetMessages()
for _, message := range messages {
mail, err := utils.DecodeMail(message.MsgRequest())
if err != nil {
logger.Error.Printf("Error in validateUser: %#v", err)
return err
}
if strings.Contains(mail.Subject, constants.MailRegistrationSubject) {
if err := checkRegistrationMail(mail, &(*users)[0]); err != nil {
logger.Error.Printf("Error in checkRegistrationMail: %#v", err)
return err
}
} else if strings.Contains(mail.Subject, constants.MailVerificationSubject) {
if err := checkVerificationMail(mail, &(*users)[0]); err != nil {
logger.Error.Printf("Error in checkVerificationMail: %#v", err)
return err
}
verifiedUsers, err := Uc.Service.GetUsers(wantDBData)
if err != nil {
logger.Error.Printf("Error in GetUsers: %#v", err)
return err
}
if (*verifiedUsers)[0].Status != constants.VerifiedStatus {
return fmt.Errorf("Users status isn't verified after email verification. Status is: %#v", (*verifiedUsers)[0].Status)
return fmt.Errorf("Users(%v) status isn't verified after email verification. Status is: %v", (*verifiedUsers)[0].Email, (*verifiedUsers)[0].Status)
}
} else {
return fmt.Errorf("Subject not expected: %v", mail.Subject)
@@ -348,6 +380,168 @@ func validateUser(assert bool, wantDBData map[string]interface{}) error {
return nil
}
func testUpdateUser(t *testing.T, loginEmail string, loginCookie http.Cookie) {
invalidCookie := http.Cookie{
Name: "jwt",
Value: "invalid.token.here",
}
// Get the user we just created
users, err := Uc.Service.GetUsers(map[string]interface{}{"email": "john.doe@example.com"})
if err != nil || len(*users) == 0 {
t.Fatalf("Failed to get test user: %v", err)
}
user := (*users)[0]
tests := []struct {
name string
setupCookie func(*http.Request)
updateFunc func(*models.User)
expectedStatus int
expectedError string
}{
{
name: "Valid Update",
setupCookie: func(req *http.Request) {
req.AddCookie(&loginCookie)
},
updateFunc: func(u *models.User) {
u.Password = ""
u.FirstName = "John Updated"
u.LastName = "Doe Updated"
u.Phone = "01738484994"
},
expectedStatus: http.StatusAccepted,
},
{
name: "Password Update",
setupCookie: func(req *http.Request) {
req.AddCookie(&loginCookie)
},
updateFunc: func(u *models.User) {
u.Password = "NewPassword"
},
expectedStatus: http.StatusAccepted,
},
{
name: "Valid Update, invalid cookie",
setupCookie: func(req *http.Request) {
req.AddCookie(&invalidCookie)
},
updateFunc: func(u *models.User) {
u.Password = ""
u.FirstName = "John Updated"
u.LastName = "Doe Updated"
u.Phone = "01738484994"
},
expectedStatus: http.StatusUnauthorized,
expectedError: "Auth token invalid",
},
{
name: "Invalid Email Update",
setupCookie: func(req *http.Request) {
req.AddCookie(&loginCookie)
},
updateFunc: func(u *models.User) {
u.Password = ""
u.Email = "invalid-email"
},
expectedStatus: http.StatusBadRequest,
expectedError: "Invalid user data",
},
{
name: "User ID mismatch while not admin",
setupCookie: func(req *http.Request) {
req.AddCookie(&loginCookie)
},
updateFunc: func(u *models.User) {
u.Password = ""
u.ID = 1
u.FirstName = "John Missing ID"
},
expectedStatus: http.StatusForbidden,
expectedError: "You are not authorized to update this user",
},
// {
// name: "Non-existent User",
// setupCookie: func(req *http.Request) {
// req.AddCookie(&loginCookie)
// },
// updateFunc: func(u *models.User) {
// u.Password = ""
// u.ID = 99999
// u.FirstName = "Non-existent"
// },
// expectedStatus: http.StatusNotFound,
// expectedError: "User not found",
// },
}
for _, tt := range tests {
logger.Error.Print("==============================================================")
logger.Error.Printf("Update Testing : %v", tt.name)
logger.Error.Print("==============================================================")
t.Run(tt.name, func(t *testing.T) {
// Create a copy of the user and apply the updates
updatedUser := user
tt.updateFunc(&updatedUser)
// Convert user to JSON
jsonData, err := json.Marshal(updatedUser)
if err != nil {
t.Fatalf("Failed to marshal user data: %v", err)
}
// Create request
req, _ := http.NewRequest("PUT", "/users/"+strconv.FormatInt(user.ID, 10), bytes.NewBuffer(jsonData))
req.Header.Set("Content-Type", "application/json")
tt.setupCookie(req)
// Create response recorder
w := httptest.NewRecorder()
// Set up router and add middleware
router := gin.New()
router.Use(middlewares.AuthMiddleware())
router.PUT("/users/:id", Uc.UpdateHandler)
// Perform request
router.ServeHTTP(w, req)
// Check status code
assert.Equal(t, tt.expectedStatus, w.Code)
// Parse response
var response map[string]interface{}
err = json.Unmarshal(w.Body.Bytes(), &response)
assert.NoError(t, err)
if tt.expectedError != "" {
assert.Equal(t, tt.expectedError, response["error"])
} else {
assert.Equal(t, "User updated successfully", response["message"])
// Verify the update in the database
updatedUserFromDB, err := Uc.Service.GetUserByID(user.ID)
updatedUserFromDB.UpdatedAt = updatedUser.UpdatedAt
updatedUserFromDB.Membership.UpdatedAt = updatedUser.Membership.UpdatedAt
updatedUserFromDB.BankAccount.UpdatedAt = updatedUser.BankAccount.UpdatedAt
updatedUserFromDB.Verification.UpdatedAt = updatedUser.Verification.UpdatedAt
updatedUserFromDB.Membership.SubscriptionModel.UpdatedAt = updatedUser.Membership.SubscriptionModel.UpdatedAt
if updatedUser.Password == "" {
assert.Equal(t, user.Password, (*updatedUserFromDB).Password)
} else {
assert.NotEqual(t, user.Password, (*updatedUserFromDB).Password)
updatedUser.Password = ""
}
updatedUserFromDB.Password = ""
assert.NoError(t, err)
assert.Equal(t, updatedUser, *updatedUserFromDB, "Updated user in DB does not match expected user")
}
})
}
}
func checkWelcomeMail(message *utils.Email, user *models.User) error {
if !strings.Contains(message.To, user.Email) {
@@ -506,23 +700,6 @@ func getVerificationURL(mailBody string) (string, error) {
}
// TEST DATA:
func getBaseUser() models.User {
return models.User{
DateOfBirth: time.Date(2000, time.January, 1, 0, 0, 0, 0, time.UTC),
FirstName: "John",
LastName: "Doe",
Email: "john.doe@example.com",
Address: "Pablo Escobar Str. 4",
ZipCode: "25474",
City: "Hasloh",
Phone: "01738484993",
BankAccount: models.BankAccount{IBAN: "DE89370400440532013000"},
Membership: models.Membership{SubscriptionModel: models.SubscriptionModel{Name: "Basic"}},
ProfilePicture: "",
Password: "password123",
Company: "",
}
}
func customizeInput(customize func(models.User) models.User) *RegistrationData {
user := getBaseUser()

View File

@@ -2,68 +2,75 @@ package middlewares
import (
"GoMembership/internal/config"
"GoMembership/internal/models"
"GoMembership/internal/utils"
customerrors "GoMembership/pkg/errors"
"GoMembership/pkg/logger"
"errors"
"fmt"
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
)
var (
jwtKey = []byte(config.Auth.JWTSecret)
jwtSigningMethod = jwt.SigningMethodHS256
jwtParser = jwt.NewParser(jwt.WithValidMethods([]string{jwtSigningMethod.Alg()}))
)
func GenerateToken(userID int64) (string, error) {
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{
"user_id": userID,
"exp": time.Now().Add(time.Minute * 10).Unix(), // Token expires in 10 Minutes
})
logger.Error.Printf("token generated: %#v", token)
return token.SignedString(jwtKey)
type Session struct {
UserID int64
ExpiresAt time.Time
}
func verifyToken(tokenString string) (*jwt.Token, error) {
var (
sessionDuration = 5 * 24 * time.Hour
jwtSigningMethod = jwt.SigningMethodHS256
jwtParser = jwt.NewParser(jwt.WithValidMethods([]string{jwtSigningMethod.Alg()}))
sessions = make(map[string]*Session)
)
func verifyAndRenewToken(tokenString string) (string, int64, error) {
if tokenString == "" {
return nil, fmt.Errorf("Authorization token is required")
logger.Error.Printf("empty tokenstring")
return "", -1, fmt.Errorf("Authorization token is required")
}
token, err := jwtParser.Parse(tokenString, func(_ *jwt.Token) (interface{}, error) {
return jwtKey, nil
})
token, claims, err := ExtractContentFrom(tokenString)
if err != nil {
return nil, err
logger.Error.Printf("Couldn't parse JWT token String: %v", err)
return "", -1, err
}
sessionID := (*claims)["session_id"].(string)
userID := int64((*claims)["user_id"].(float64))
roleID := int8((*claims)["role_id"].(float64))
if !token.Valid {
return nil, fmt.Errorf("invalid token")
}
claims, ok := token.Claims.(jwt.MapClaims)
session, ok := sessions[sessionID]
if !ok {
return nil, fmt.Errorf("invalid token claims")
logger.Error.Printf("session not found")
return "", -1, fmt.Errorf("session not found")
}
if userID != session.UserID {
return "", -1, fmt.Errorf("Cookie has been altered, aborting..")
}
if token.Valid {
// token is valid, so we can return the old tokenString
return tokenString, session.UserID, customerrors.ErrValidToken
}
exp, ok := claims["exp"].(float64)
if !ok {
return nil, fmt.Errorf("invalid expiration claim")
if time.Now().After(sessions[sessionID].ExpiresAt) {
delete(sessions, sessionID)
logger.Error.Printf("session expired")
return "", -1, fmt.Errorf("session expired")
}
session.ExpiresAt = time.Now().Add(sessionDuration)
logger.Error.Printf("Session still valid generating new token")
// Session is still valid, generate a new token
user := models.User{ID: userID, RoleID: roleID}
newTokenString, err := GenerateToken(config.Auth.JWTSecret, &user, sessionID)
if err != nil {
return "", -1, err
}
userID, ok := claims["user_id"].(float64)
if !ok {
logger.Error.Printf("Invalid user ID: %v", userID)
return nil, fmt.Errorf("Invalid user ID")
}
if time.Now().Unix() > int64(exp) {
return nil, fmt.Errorf("token expired")
}
return token, nil
return newTokenString, session.UserID, nil
}
func AuthMiddleware() gin.HandlerFunc {
@@ -76,34 +83,89 @@ func AuthMiddleware() gin.HandlerFunc {
return
}
token, err := verifyToken(tokenString)
newToken, userID, err := verifyAndRenewToken(tokenString)
if err != nil {
logger.Error.Printf("Token is invalid: %v\n", err)
if err == customerrors.ErrValidToken {
c.Set("user_id", int64(userID))
c.Next()
return
}
logger.Error.Printf("Token(%v) is invalid: %v\n", tokenString, err)
c.JSON(http.StatusUnauthorized, gin.H{"error": "Auth token invalid"})
c.Abort()
return
}
claims, _ := token.Claims.(jwt.MapClaims)
userID, _ := claims["user_id"].(float64)
// Generate a new token
newToken, err := GenerateToken(int64(userID))
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to refresh token"})
c.Abort()
return
}
c.SetCookie(
"jwt",
newToken,
10*60, // 10 minutes
"/",
"",
true,
true,
)
c.Set("user_id", userID)
utils.SetCookie(c, newToken)
c.Set("user_id", int64(userID))
c.Next()
}
}
func GenerateToken(jwtKey string, user *models.User, sessionID string) (string, error) {
if sessionID == "" {
sessionID = uuid.New().String()
}
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{
"user_id": user.ID,
"role_id": user.RoleID,
"session_id": sessionID,
"exp": time.Now().Add(time.Minute * 1).Unix(), // Token expires in 10 Minutes
})
UpdateSession(sessionID, user.ID)
return token.SignedString([]byte(jwtKey))
}
func ExtractContentFrom(tokenString string) (*jwt.Token, *jwt.MapClaims, error) {
token, err := jwtParser.Parse(tokenString, func(_ *jwt.Token) (interface{}, error) {
return []byte(config.Auth.JWTSecret), nil
})
if !errors.Is(err, jwt.ErrTokenExpired) && err != nil {
logger.Error.Printf("Error during token(%v) parsing: %#v", tokenString, err)
return nil, nil, err
}
// Token is expired, check if session is still valid
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
logger.Error.Printf("Invalid Token Claims")
return nil, nil, fmt.Errorf("invalid token claims")
}
if !ok {
logger.Error.Printf("invalid session_id in token")
return nil, nil, fmt.Errorf("invalid session_id in token")
}
return token, &claims, nil
}
func UpdateSession(sessionID string, userID int64) {
sessions[sessionID] = &Session{
UserID: userID,
ExpiresAt: time.Now().Add(sessionDuration),
}
}
func InvalidateSession(token string) (bool, error) {
claims := jwt.MapClaims{}
_, err := jwt.ParseWithClaims(
token,
claims,
func(token *jwt.Token) (interface{}, error) {
return config.Auth.JWTSecret, nil
},
)
if err != nil {
return false, fmt.Errorf("Couldn't get JWT claims: %#v", err)
}
sessionID, ok := claims["session_id"].(string)
if !ok {
return false, fmt.Errorf("No SessionID found")
}
delete(sessions, sessionID)
return true, nil
}

View File

@@ -2,6 +2,8 @@ package middlewares
import (
"GoMembership/internal/config"
"GoMembership/internal/constants"
"GoMembership/internal/models"
"GoMembership/pkg/logger"
"encoding/json"
"log"
@@ -17,9 +19,7 @@ import (
"github.com/stretchr/testify/assert"
)
func TestAuthMiddleware(t *testing.T) {
gin.SetMode(gin.TestMode)
func setupTestEnvironment() {
cwd, err := os.Getwd()
if err != nil {
log.Fatalf("Failed to get current working directory: %v", err)
@@ -39,17 +39,25 @@ func TestAuthMiddleware(t *testing.T) {
}
config.LoadConfig()
logger.Info.Printf("Config: %#v", config.CFG)
}
func TestAuthMiddleware(t *testing.T) {
gin.SetMode(gin.TestMode)
setupTestEnvironment()
tests := []struct {
name string
setupAuth func(r *http.Request)
expectedStatus int
expectedUserID int64
name string
setupAuth func(r *http.Request)
expectedStatus int
expectNewCookie bool
expectedUserID int64
}{
{
name: "Valid Token",
setupAuth: func(r *http.Request) {
token, _ := GenerateToken(123)
user := models.User{ID: 123, RoleID: constants.Roles.Member}
token, _ := GenerateToken(config.Auth.JWTSecret, &user, "")
r.AddCookie(&http.Cookie{Name: "jwt", Value: token})
},
expectedStatus: http.StatusOK,
@@ -70,14 +78,36 @@ func TestAuthMiddleware(t *testing.T) {
expectedUserID: 0,
},
{
name: "Expired Token",
name: "Expired Token with Valid Session",
setupAuth: func(r *http.Request) {
sessionID := "test-session"
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{
"user_id": 123,
"exp": time.Now().Add(-time.Hour).Unix(), // Expired 1 hour ago
"user_id": 123,
"role_id": constants.Roles.Member,
"session_id": sessionID,
"exp": time.Now().Add(-time.Hour).Unix(), // Expired 1 hour ago
})
tokenString, _ := token.SignedString(jwtKey)
tokenString, _ := token.SignedString([]byte(config.Auth.JWTSecret))
r.AddCookie(&http.Cookie{Name: "jwt", Value: tokenString})
UpdateSession(sessionID, 123) // Add a valid session
},
expectedStatus: http.StatusOK,
expectNewCookie: true,
expectedUserID: 123,
},
{
name: "Expired Token with Expired Session",
setupAuth: func(r *http.Request) {
sessionID := "expired-session"
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{
"user_id": 123,
"role_id": constants.Roles.Member,
"session_id": sessionID,
"exp": time.Now().Add(-time.Hour).Unix(), // Expired 1 hour ago
})
tokenString, _ := token.SignedString([]byte(config.Auth.JWTSecret))
r.AddCookie(&http.Cookie{Name: "jwt", Value: tokenString})
// Don't add a session, simulating an expired session
},
expectedStatus: http.StatusUnauthorized,
expectedUserID: 0,
@@ -86,8 +116,9 @@ func TestAuthMiddleware(t *testing.T) {
name: "Invalid Signature",
setupAuth: func(r *http.Request) {
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{
"user_id": 123,
"exp": time.Now().Add(time.Hour).Unix(),
"user_id": 123,
"session_id": "some-session",
"exp": time.Now().Add(time.Hour).Unix(),
})
tokenString, _ := token.SignedString([]byte("wrong_secret"))
r.AddCookie(&http.Cookie{Name: "jwt", Value: tokenString})
@@ -99,8 +130,10 @@ func TestAuthMiddleware(t *testing.T) {
name: "Invalid Signing Method",
setupAuth: func(r *http.Request) {
token := jwt.NewWithClaims(jwt.SigningMethodES256, jwt.MapClaims{
"user_id": 123,
"exp": time.Now().Add(time.Hour).Unix(),
"user_id": 123,
"session_id": "some-session",
"role_id": constants.Roles.Member,
"exp": time.Now().Add(time.Hour).Unix(),
})
tokenString, _ := token.SignedString([]byte(config.Auth.JWTSecret))
r.AddCookie(&http.Cookie{Name: "jwt", Value: tokenString})
@@ -143,9 +176,13 @@ func TestAuthMiddleware(t *testing.T) {
// Check if a new cookie was set
cookies := w.Result().Cookies()
assert.GreaterOrEqual(t, len(cookies), 1)
assert.Equal(t, "jwt", cookies[0].Name)
assert.NotEmpty(t, cookies[0].Value)
if tt.expectNewCookie {
assert.GreaterOrEqual(t, len(cookies), 1)
assert.Equal(t, "jwt", cookies[0].Name)
assert.NotEmpty(t, cookies[0].Value)
} else {
assert.Equal(t, 0, len(cookies), "Unexpected cookie set")
}
} else {
assert.Equal(t, 0, len(w.Result().Cookies()))
}

View File

@@ -6,8 +6,8 @@ type BankAccount struct {
CreatedAt time.Time
UpdatedAt time.Time
MandateDateSigned time.Time `gorm:"not null"` // json:"mandate_date_signed"`
Bank string //`json:"bank_name" validate:"omitempty,alphanumunicode"`
AccountHolderName string //`json:"account_holder_name" validate:"omitempty,alphaunicode"`
Bank string //`json:"bank_name" validate:"omitempty,alphanumunicode,safe_content"`
AccountHolderName string //`json:"account_holder_name" validate:"omitempty,alphaunicode,safe_content"`
IBAN string `gorm:"not null" json:"iban" validate:"required,iban"`
BIC string //`json:"bic" validate:"omitempty,bic"`
MandateReference string `gorm:"not null"` //json:"mandate_reference"`

View File

@@ -5,10 +5,10 @@ import "time"
type Consent struct {
CreatedAt time.Time
UpdatedAt time.Time
FirstName string `gorm:"not null" json:"first_name"`
LastName string `gorm:"not null" json:"last_name"`
Email string `json:"email"`
ConsentType string `gorm:"not null" json:"consent_type"`
FirstName string `gorm:"not null" json:"first_name" validate:"safe_content"`
LastName string `gorm:"not null" json:"last_name" validate:"safe_content"`
Email string `json:"email" validate:"email,safe_content"`
ConsentType string `gorm:"not null" json:"consent_type" validate:"safe_content"`
ID int64 `gorm:"primaryKey"`
UserID int64 `gorm:"not null" json:"user_id"`
}

View File

@@ -7,7 +7,7 @@ type Membership struct {
UpdatedAt time.Time
StartDate time.Time `json:"start_date"`
EndDate time.Time `json:"end_date"`
Status string `json:"status"`
Status string `json:"status" validate:"safe_content"`
SubscriptionModel SubscriptionModel `gorm:"foreignKey:SubscriptionModelID" json:"subscription_model"`
ParentMembershipID int64 `json:"parent_member_id" validate:"omitempty,omitnil,number"`
SubscriptionModelID int64 `json:"subsription_model_id"`

View File

@@ -7,13 +7,13 @@ import (
type SubscriptionModel struct {
CreatedAt time.Time
UpdatedAt time.Time
Name string `json:"name" validate:"required,subscriptionModel"`
Name string `gorm:"unique" json:"name" validate:"required,subscriptionModel,safe_content"`
Details string `json:"details" validate:"required"`
Conditions string `json:"conditions"`
RequiredMembershipField string `json:"required_membership_field" validate:"membershipField"`
ID int64 `gorm:"primaryKey"`
MonthlyFee float32 `json:"monthly_fee" validate:"required,number"`
HourlyRate float32 `json:"hourly_rate" validate:"required,number"`
IncludedPerYear int16 `json:"included_hours_per_year" validate:"omitempty,number"`
IncludedPerMonth int16 `json:"included_hours_per_month" validate:"omitempty,number"`
MonthlyFee float32 `json:"monthly_fee" validate:"required,number,gte=0"`
HourlyRate float32 `json:"hourly_rate" validate:"required,number,gte=0"`
IncludedPerYear int16 `json:"included_hours_per_year" validate:"omitempty,number,gte=0"`
IncludedPerMonth int16 `json:"included_hours_per_month" validate:"omitempty,number,gte=0"`
}

View File

@@ -12,17 +12,17 @@ type User struct {
UpdatedAt time.Time
DateOfBirth time.Time `gorm:"not null" json:"date_of_birth" validate:"required,age"`
CreatedAt time.Time
Company string `json:"company" validate:"omitempty,omitnil"`
Phone string `json:"phone" validate:"omitempty,omitnil"`
Notes *string `json:"notes"`
FirstName string `gorm:"not null" json:"first_name" validate:"required"`
Password string `json:"password" validate:"required_unless=RoleID 0"`
Email string `gorm:"unique;not null" json:"email" validate:"required,email"`
LastName string `gorm:"not null" json:"last_name" validate:"required"`
ProfilePicture string `json:"profile_picture" validate:"omitempty,omitnil,image"`
Address string `gorm:"not null" json:"address" validate:"required"`
ZipCode string `gorm:"not null" json:"zip_code" validate:"required,alphanum"`
City string `form:"not null" json:"city" validate:"required,alphaunicode"`
Company string `json:"company" validate:"omitempty,omitnil,safe_content"`
Phone string `json:"phone" validate:"omitempty,omitnil,safe_content"`
Notes *string `json:"notes,safe_content"`
FirstName string `gorm:"not null" json:"first_name" validate:"required,safe_content"`
Password string `json:"password" validate:"required_unless=RoleID 0,safe_content"`
Email string `gorm:"unique;not null" json:"email" validate:"required,email,safe_content"`
LastName string `gorm:"not null" json:"last_name" validate:"required,safe_content"`
ProfilePicture string `json:"profile_picture" validate:"omitempty,omitnil,image,safe_content"`
Address string `gorm:"not null" json:"address" validate:"required,safe_content"`
ZipCode string `gorm:"not null" json:"zip_code" validate:"required,alphanum,safe_content"`
City string `form:"not null" json:"city" validate:"required,alphaunicode,safe_content"`
Consents []Consent `gorm:"constraint:OnUpdate:CASCADE"`
BankAccount BankAccount `gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE;" json:"bank_account"`
Verification Verification `gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE;"`

View File

@@ -1,8 +1,6 @@
package repositories
import (
"time"
"gorm.io/gorm"
"GoMembership/internal/constants"
@@ -16,13 +14,13 @@ import (
type UserRepositoryInterface interface {
CreateUser(user *models.User) (int64, error)
UpdateUser(userID int64, user *models.User) error
UpdateUser(user *models.User) (*models.User, error)
GetUsers(where map[string]interface{}) (*[]models.User, error)
GetUserByID(id int64) (*models.User, error)
GetUserByID(userID *int64) (*models.User, error)
GetUserByEmail(email string) (*models.User, error)
SetVerificationToken(user *models.User, token *string) (int64, error)
SetVerificationToken(verification *models.Verification) (int64, error)
IsVerified(userID *int64) (bool, error)
VerifyUserOfToken(token *string) (*models.User, error)
GetVerificationOfToken(token *string) (*models.Verification, error)
}
type UserRepository struct{}
@@ -35,21 +33,36 @@ func (ur *UserRepository) CreateUser(user *models.User) (int64, error) {
return user.ID, nil
}
func (ur *UserRepository) UpdateUser(userID int64, user *models.User) error {
// logger.Info.Printf("Updating User: %#v\n", user)
func (ur *UserRepository) UpdateUser(user *models.User) (*models.User, error) {
if user == nil {
return errors.ErrNoData
}
result := database.DB.Session(&gorm.Session{FullSaveAssociations: true}).Updates(&user)
if result.Error != nil {
return result.Error
return nil, errors.ErrNoData
}
if result.RowsAffected == 0 {
return errors.ErrNoRowsAffected
err := database.DB.Transaction(func(tx *gorm.DB) error {
if err := tx.First(&models.User{}, user.ID).Error; err != nil {
return err
}
result := tx.Session(&gorm.Session{FullSaveAssociations: true}).Updates(user)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return errors.ErrNoRowsAffected
}
return nil
})
if err != nil {
return nil, err
}
return nil
var updatedUser models.User
if err := database.DB.First(&updatedUser, user.ID).Error; err != nil {
return nil, err
}
return &updatedUser, nil
}
func (ur *UserRepository) GetUsers(where map[string]interface{}) (*[]models.User, error) {
@@ -70,7 +83,7 @@ func (ur *UserRepository) GetUsers(where map[string]interface{}) (*[]models.User
return &users, nil
}
func (ur *UserRepository) GetUserByID(id int64) (*models.User, error) {
func (ur *UserRepository) GetUserByID(userID *int64) (*models.User, error) {
var user models.User
result := database.DB.
Preload("Consents").
@@ -78,7 +91,7 @@ func (ur *UserRepository) GetUserByID(id int64) (*models.User, error) {
Preload("Verification").
Preload("Membership", func(db *gorm.DB) *gorm.DB {
return db.Preload("SubscriptionModel")
}).First(&user, id)
}).First(&user, userID)
if result.Error != nil {
if result.Error == gorm.ErrRecordNotFound {
return nil, gorm.ErrRecordNotFound
@@ -112,7 +125,8 @@ func (ur *UserRepository) IsVerified(userID *int64) (bool, error) {
return user.Status != constants.UnverifiedStatus, nil
}
func (ur *UserRepository) VerifyUserOfToken(token *string) (*models.User, error) {
func (ur *UserRepository) GetVerificationOfToken(token *string) (*models.Verification, error) {
var emailVerification models.Verification
result := database.DB.Where("verification_token = ?", token).First(&emailVerification)
if result.Error != nil {
@@ -121,49 +135,10 @@ func (ur *UserRepository) VerifyUserOfToken(token *string) (*models.User, error)
}
return nil, result.Error
}
// Check if the user is already verified
verified, err := ur.IsVerified(&emailVerification.UserID)
if err != nil {
return nil, err
}
user, err := ur.GetUserByID(emailVerification.UserID)
if err != nil {
return nil, err
}
if verified {
return user, errors.ErrAlreadyVerified
}
// Update user status to active
t := time.Now()
emailVerification.EmailVerifiedAt = &t
user.Status = constants.VerifiedStatus
user.Verification = emailVerification
err = ur.UpdateUser(emailVerification.UserID, user)
if err != nil {
return nil, err
}
return user, nil
return &emailVerification, nil
}
func (ur *UserRepository) SetVerificationToken(user *models.User, token *string) (int64, error) {
// Check if user is already verified
verified, err := ur.IsVerified(&user.ID)
if err != nil {
return -1, err
}
if verified {
return -1, errors.ErrAlreadyVerified
}
// Prepare the Verification record
verification := models.Verification{
UserID: user.ID,
VerificationToken: *token,
}
func (ur *UserRepository) SetVerificationToken(verification *models.Verification) (int64, error) {
// Use GORM to insert or update the Verification record
result := database.DB.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "user_id"}},

View File

@@ -15,21 +15,18 @@ func RegisterRoutes(router *gin.Engine, userController *controllers.UserControll
router.POST("/users/login", userController.LoginHandler)
router.POST("/csp-report", middlewares.CSPReportHandling)
// create subrouter for teh authenticated area /account
// also pthprefix matches everything below /account
// accountRouter := router.PathPrefix("/account").Subrouter()
// accountRouter.Use(middlewares.AuthMiddleware)
//create api key required router
apiRouter := router.Group("/api")
apiRouter.Use(middlewares.APIKeyMiddleware())
{
router.POST("/v1/subscription", membershipcontroller.RegisterSubscription)
}
apiRouter.Use(middlewares.APIKeyMiddleware())
authRouter := router.Group("/users/backend")
authRouter := router.Group("/backend/users")
authRouter.Use(middlewares.AuthMiddleware())
{
authRouter.GET("/current-user", userController.CurrentUserHandler)
authRouter.GET("/current", userController.CurrentUserHandler)
authRouter.POST("/logout", userController.LogoutHandler)
authRouter.PATCH("/update", userController.UpdateHandler)
}
}

View File

@@ -1,7 +1,6 @@
package services
import (
"net/http"
"slices"
"time"
@@ -9,6 +8,7 @@ import (
"GoMembership/internal/models"
"GoMembership/internal/repositories"
"GoMembership/internal/utils"
"GoMembership/pkg/errors"
)
@@ -38,7 +38,7 @@ func (service *MembershipService) FindMembershipByUserID(userID int64) (*models.
// Membership_Subscriptions
func (service *MembershipService) RegisterSubscription(subscription *models.SubscriptionModel) (int64, error) {
if err := validateSubscriptionData(subscription); err != nil {
return http.StatusNotAcceptable, err
return -1, err
}
return service.SubscriptionRepo.CreateSubscriptionModel(subscription)
}
@@ -65,8 +65,9 @@ func (service *MembershipService) GetSubscriptions(where map[string]interface{})
func validateSubscriptionData(subscription *models.SubscriptionModel) error {
validate := validator.New()
// subscriptionModel and membershipField don't have to be evaluated if adding a new subscription
validate.RegisterValidation("subscriptionModel", func(fl validator.FieldLevel) bool { return true })
validate.RegisterValidation("membershipField", func(fl validator.FieldLevel) bool { return true })
validate.RegisterValidation("safe_content", utils.ValidateSafeContent)
return validate.Struct(subscription)
}

View File

@@ -8,10 +8,12 @@ import (
"GoMembership/internal/models"
"GoMembership/internal/repositories"
"GoMembership/internal/utils"
"GoMembership/pkg/errors"
"GoMembership/pkg/logger"
"github.com/alexedwards/argon2id"
"github.com/go-playground/validator/v10"
"gorm.io/gorm"
"time"
)
@@ -22,50 +24,42 @@ type UserServiceInterface interface {
GetUserByID(id int64) (*models.User, error)
GetUsers(where map[string]interface{}) (*[]models.User, error)
VerifyUser(token *string) (*models.User, error)
UpdateUser(user *models.User) (*models.User, error)
}
type UserService struct {
Repo repositories.UserRepositoryInterface
}
func (service *UserService) RegisterUser(user *models.User) (int64, string, error) {
if err := validateRegistrationData(user); err != nil {
return http.StatusNotAcceptable, "", err
func (service *UserService) UpdateUser(user *models.User) (*models.User, error) {
if err := validateUserData(user); err != nil {
return nil, errors.ErrInvalidUserData
}
setPassword(user.Password, user)
if user.Password != "" {
setPassword(user.Password, user)
}
user.Status = constants.UnverifiedStatus
user.CreatedAt = time.Now()
user.UpdatedAt = time.Now()
id, err := service.Repo.CreateUser(user)
updatedUser, err := service.Repo.UpdateUser(user)
if err != nil && strings.Contains(err.Error(), "UNIQUE constraint failed") {
return http.StatusConflict, "", err
} else if err != nil {
return http.StatusInternalServerError, "", err
}
user.ID = id
token, err := utils.GenerateVerificationToken()
if err != nil {
return http.StatusInternalServerError, "", err
if err == gorm.ErrRecordNotFound {
return nil, errors.ErrUserNotFound
}
if strings.Contains(err.Error(), "UNIQUE constraint failed") {
return nil, errors.ErrDuplicateEntry
}
return nil, err
}
logger.Info.Printf("TOKEN: %v", token)
_, err = service.Repo.SetVerificationToken(user, &token)
if err != nil {
return http.StatusInternalServerError, "", err
}
return id, token, nil
return updatedUser, nil
}
func (service *UserService) Update(user *models.User) (int64, string, error) {
if err := validateRegistrationData(user); err != nil {
func (service *UserService) RegisterUser(user *models.User) (int64, string, error) {
if err := validateUserData(user); err != nil {
return http.StatusNotAcceptable, "", err
}
@@ -92,17 +86,31 @@ func (service *UserService) Update(user *models.User) (int64, string, error) {
logger.Info.Printf("TOKEN: %v", token)
_, err = service.Repo.SetVerificationToken(user, &token)
// Check if user is already verified
verified, err := service.Repo.IsVerified(&user.ID)
if err != nil {
return http.StatusInternalServerError, "", err
}
if verified {
return http.StatusAlreadyReported, "", errors.ErrAlreadyVerified
}
// Prepare the Verification record
verification := models.Verification{
UserID: user.ID,
VerificationToken: token,
}
if _, err = service.Repo.SetVerificationToken(&verification); err != nil {
return http.StatusInternalServerError, "", err
}
return id, token, nil
}
func (service *UserService) GetUserByID(id int64) (*models.User, error) {
return service.Repo.GetUserByID(id)
return service.Repo.GetUserByID(&id)
}
func (service *UserService) GetUserByEmail(email string) (*models.User, error) {
@@ -114,19 +122,41 @@ func (service *UserService) GetUsers(where map[string]interface{}) (*[]models.Us
}
func (service *UserService) VerifyUser(token *string) (*models.User, error) {
user, err := service.Repo.VerifyUserOfToken(token)
verification, err := service.Repo.GetVerificationOfToken(token)
if err != nil {
return nil, err
}
// Check if the user is already verified
verified, err := service.Repo.IsVerified(&verification.UserID)
if err != nil {
return nil, err
}
user, err := service.Repo.GetUserByID(&verification.UserID)
if err != nil {
return nil, err
}
if verified {
return user, errors.ErrAlreadyVerified
}
// Update user status to active
t := time.Now()
verification.EmailVerifiedAt = &t
user.Status = constants.VerifiedStatus
user.Verification = *verification
user.ID = verification.UserID
service.Repo.UpdateUser(user)
return user, nil
}
func validateRegistrationData(user *models.User) error {
func validateUserData(user *models.User) error {
validate := validator.New()
validate.RegisterValidation("age", utils.AgeValidator)
validate.RegisterValidation("bic", utils.BICValidator)
validate.RegisterValidation("iban", utils.IBANValidator)
validate.RegisterValidation("subscriptionModel", utils.SubscriptionModelValidator)
validate.RegisterValidation("safe_content", utils.ValidateSafeContent)
validate.RegisterValidation("membershipField", utils.ValidateRequiredMembershipField)
return validate.Struct(user)

20
internal/utils/cookies.go Normal file
View File

@@ -0,0 +1,20 @@
package utils
import (
"net/http"
"github.com/gin-gonic/gin"
)
func SetCookie(c *gin.Context, token string) {
c.SetSameSite(http.SameSiteLaxMode)
c.SetCookie(
"jwt",
token,
5*24*60*60, // 5 days
"/",
"",
true,
true,
)
}

View File

@@ -7,7 +7,9 @@ import (
"GoMembership/internal/models"
"GoMembership/pkg/logger"
"reflect"
"regexp"
"slices"
"strings"
"time"
"github.com/go-playground/validator/v10"
@@ -15,19 +17,24 @@ import (
"github.com/jbub/banking/swift"
)
//
// func IsEmailValid(email string) bool {
// regex := `^[a-z0-9._%+\-]+@[a-z0-9.\-]+\.[a-z]{2,}$`
// re := regexp.MustCompile(regex)
// return re.MatchString(email)
// }
var xssPatterns = []*regexp.Regexp{
regexp.MustCompile(`(?i)<script`),
regexp.MustCompile(`(?i)javascript:`),
regexp.MustCompile(`(?i)on\w+\s*=`),
regexp.MustCompile(`(?i)(vbscript|data):`),
regexp.MustCompile(`(?i)<(iframe|object|embed|applet)`),
regexp.MustCompile(`(?i)expression\s*\(`),
regexp.MustCompile(`(?i)url\s*\(`),
regexp.MustCompile(`(?i)<\?`),
regexp.MustCompile(`(?i)<%`),
regexp.MustCompile(`(?i)<!\[CDATA\[`),
regexp.MustCompile(`(?i)<(svg|animate)`),
regexp.MustCompile(`(?i)<(audio|video|source)`),
regexp.MustCompile(`(?i)base64`),
}
func AgeValidator(fl validator.FieldLevel) bool {
fieldValue := fl.Field()
// Ensure the field is of type time.Time
// if fieldValue.Kind() != reflect.Struct || !fieldValue.Type().ConvertibleTo(reflect.TypeOf(time.Time{})) {
// return false
// }
dateOfBirth := fieldValue.Interface().(time.Time)
now := time.Now()
age := now.Year() - dateOfBirth.Year()
@@ -113,3 +120,13 @@ func BICValidator(fl validator.FieldLevel) bool {
return swift.Validate(fieldValue) == nil
}
func ValidateSafeContent(fl validator.FieldLevel) bool {
input := strings.ToLower(fl.Field().String())
for _, pattern := range xssPatterns {
if pattern.MatchString(input) {
return false
}
}
return true
}

View File

@@ -15,4 +15,7 @@ var (
ErrValueTooLong = errors.New("cookie value too long")
ErrInvalidValue = errors.New("invalid cookie value")
ErrInvalidSigningAlgorithm = errors.New("invalid signing algorithm")
ErrValidToken = errors.New("valid token")
ErrInvalidUserData = errors.New("invalid user data")
ErrDuplicateEntry = errors.New("duplicate entry; unique constraint failed")
)