Compare commits
9 Commits
b34a85e9d6
...
00facf8758
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
00facf8758 | ||
|
|
62624cd0f8 | ||
|
|
361fa1316a | ||
|
|
851e62dbac | ||
|
|
1e68e7d390 | ||
|
|
31c47270ab | ||
|
|
81e9068eba | ||
|
|
74ef7efdec | ||
|
|
46afa417b7 |
71
internal/controllers/SQLInjection_test.go
Normal file
71
internal/controllers/SQLInjection_test.go
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
package controllers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SQLInjectionTest struct {
|
||||||
|
name string
|
||||||
|
email string
|
||||||
|
password string
|
||||||
|
expectedStatus int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sit *SQLInjectionTest) SetupContext() (*gin.Context, *httptest.ResponseRecorder, *gin.Engine) {
|
||||||
|
loginData := loginInput{
|
||||||
|
Email: sit.email,
|
||||||
|
Password: sit.password,
|
||||||
|
}
|
||||||
|
jsonData, _ := json.Marshal(loginData)
|
||||||
|
return GetMockedJSONContext(jsonData, "/login")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sit *SQLInjectionTest) RunHandler(c *gin.Context, router *gin.Engine) {
|
||||||
|
router.POST("/login", Uc.LoginHandler)
|
||||||
|
router.ServeHTTP(c.Writer, c.Request)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sit *SQLInjectionTest) ValidateResponse(w *httptest.ResponseRecorder) error {
|
||||||
|
if sit.expectedStatus != w.Code {
|
||||||
|
responseBody, _ := io.ReadAll(w.Body)
|
||||||
|
return fmt.Errorf("SQL Injection Attempt: Didn't get the expected response code: got: %v; expected: %v. Context: %#v", w.Code, sit.expectedStatus, string(responseBody))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sit *SQLInjectionTest) ValidateResult() error {
|
||||||
|
// Add any additional validation if needed
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func testSQLInjectionAttempt(t *testing.T) {
|
||||||
|
tests := []SQLInjectionTest{
|
||||||
|
{
|
||||||
|
name: "SQL Injection Attempt in Email",
|
||||||
|
email: "' OR '1'='1",
|
||||||
|
password: "password123",
|
||||||
|
expectedStatus: http.StatusNotFound,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SQL Injection Attempt in Password",
|
||||||
|
email: "user@example.com",
|
||||||
|
password: "' OR '1'='1",
|
||||||
|
expectedStatus: http.StatusNotFound,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if err := runSingleTest(&tt); err != nil {
|
||||||
|
t.Errorf("Test failed: %v", err.Error())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
31
internal/controllers/XSS_test.go
Normal file
31
internal/controllers/XSS_test.go
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
package controllers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func testXSSAttempt(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
router := gin.New()
|
||||||
|
router.POST("/register", Uc.RegisterUser)
|
||||||
|
|
||||||
|
xssPayload := "<script>alert('XSS')</script>"
|
||||||
|
user := getBaseUser()
|
||||||
|
user.FirstName = xssPayload
|
||||||
|
user.Email = "user@xss.hack"
|
||||||
|
jsonData, _ := json.Marshal(RegistrationData{User: user})
|
||||||
|
req, _ := http.NewRequest("POST", "/register", bytes.NewBuffer(jsonData))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusNotAcceptable, w.Code)
|
||||||
|
assert.NotContains(t, w.Body.String(), xssPayload)
|
||||||
|
}
|
||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strconv"
|
"strconv"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"log"
|
"log"
|
||||||
|
|
||||||
@@ -36,6 +37,11 @@ const (
|
|||||||
Port int = 2525
|
Port int = 2525
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type loginInput struct {
|
||||||
|
Email string `json:"email"`
|
||||||
|
Password string `json:"password"`
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
Uc *UserController
|
Uc *UserController
|
||||||
Mc *MembershipController
|
Mc *MembershipController
|
||||||
@@ -73,6 +79,9 @@ func TestSuite(t *testing.T) {
|
|||||||
if err := os.Setenv("BASE_URL", "http://"+Host+":2525"); err != nil {
|
if err := os.Setenv("BASE_URL", "http://"+Host+":2525"); err != nil {
|
||||||
log.Fatalf("Error setting environment variable: %v", err)
|
log.Fatalf("Error setting environment variable: %v", err)
|
||||||
}
|
}
|
||||||
|
if err := os.Setenv("DB_PATH", "test.db"); err != nil {
|
||||||
|
log.Fatalf("Error setting environment variable: %v", err)
|
||||||
|
}
|
||||||
config.LoadConfig()
|
config.LoadConfig()
|
||||||
if err := database.Open("test.db", config.Recipients.AdminEmail); err != nil {
|
if err := database.Open("test.db", config.Recipients.AdminEmail); err != nil {
|
||||||
log.Fatalf("Failed to create DB: %#v", err)
|
log.Fatalf("Failed to create DB: %#v", err)
|
||||||
@@ -100,13 +109,14 @@ func TestSuite(t *testing.T) {
|
|||||||
log.Fatalf("Failed to init Subscription plans: %#v", err)
|
log.Fatalf("Failed to init Subscription plans: %#v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run all tests
|
|
||||||
// code := m.Run()
|
|
||||||
|
|
||||||
t.Run("userController", func(t *testing.T) {
|
t.Run("userController", func(t *testing.T) {
|
||||||
testUserController(t)
|
testUserController(t)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("SQL_Injection", func(t *testing.T) {
|
||||||
|
testSQLInjectionAttempt(t)
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("contactController", func(t *testing.T) {
|
t.Run("contactController", func(t *testing.T) {
|
||||||
testContactController(t)
|
testContactController(t)
|
||||||
})
|
})
|
||||||
@@ -115,6 +125,10 @@ func TestSuite(t *testing.T) {
|
|||||||
testMembershipController(t)
|
testMembershipController(t)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("XSSAttempt", func(t *testing.T) {
|
||||||
|
testXSSAttempt(t)
|
||||||
|
})
|
||||||
|
|
||||||
if err := utils.SMTPStop(); err != nil {
|
if err := utils.SMTPStop(); err != nil {
|
||||||
log.Fatalf("Failed to stop SMTP Mockup Server: %#v", err)
|
log.Fatalf("Failed to stop SMTP Mockup Server: %#v", err)
|
||||||
}
|
}
|
||||||
@@ -195,6 +209,24 @@ func GetMockedFormContext(formData url.Values, url string) (*gin.Context, *httpt
|
|||||||
return c, w, router
|
return c, w, router
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getBaseUser() models.User {
|
||||||
|
return models.User{
|
||||||
|
DateOfBirth: time.Date(2000, time.January, 1, 0, 0, 0, 0, time.UTC),
|
||||||
|
FirstName: "John",
|
||||||
|
LastName: "Doe",
|
||||||
|
Email: "john.doe@example.com",
|
||||||
|
Address: "Pablo Escobar Str. 4",
|
||||||
|
ZipCode: "25474",
|
||||||
|
City: "Hasloh",
|
||||||
|
Phone: "01738484993",
|
||||||
|
BankAccount: models.BankAccount{IBAN: "DE89370400440532013000"},
|
||||||
|
Membership: models.Membership{SubscriptionModel: models.SubscriptionModel{Name: "Basic"}},
|
||||||
|
ProfilePicture: "",
|
||||||
|
Password: "password123",
|
||||||
|
Company: "",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func deleteTestDB(dbPath string) error {
|
func deleteTestDB(dbPath string) error {
|
||||||
err := os.Remove(dbPath)
|
err := os.Remove(dbPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package controllers
|
|||||||
import (
|
import (
|
||||||
"GoMembership/internal/models"
|
"GoMembership/internal/models"
|
||||||
"GoMembership/internal/services"
|
"GoMembership/internal/services"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"net/http"
|
"net/http"
|
||||||
// "strconv"
|
// "strconv"
|
||||||
@@ -31,7 +32,11 @@ func (mc *MembershipController) RegisterSubscription(c *gin.Context) {
|
|||||||
id, err := mc.Service.RegisterSubscription(®Data.Model)
|
id, err := mc.Service.RegisterSubscription(®Data.Model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error.Printf("Couldn't register Membershipmodel: %v", err)
|
logger.Error.Printf("Couldn't register Membershipmodel: %v", err)
|
||||||
c.JSON(int(id), "Couldn't register Membershipmodel")
|
if strings.Contains(err.Error(), "UNIQUE constraint failed") {
|
||||||
|
c.JSON(http.StatusConflict, "Duplicate subscription name")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusNotAcceptable, "Couldn't register Membershipmodel")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
logger.Info.Printf("registering subscription: %+v", regData)
|
logger.Info.Printf("registering subscription: %+v", regData)
|
||||||
|
|||||||
@@ -6,7 +6,9 @@ import (
|
|||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"GoMembership/internal/config"
|
||||||
"GoMembership/internal/models"
|
"GoMembership/internal/models"
|
||||||
|
"GoMembership/pkg/logger"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -23,6 +25,9 @@ func testMembershipController(t *testing.T) {
|
|||||||
|
|
||||||
tests := getSubscriptionData()
|
tests := getSubscriptionData()
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
logger.Error.Print("==============================================================")
|
||||||
|
logger.Error.Printf("MembershipController : %v", tt.Name)
|
||||||
|
logger.Error.Print("==============================================================")
|
||||||
t.Run(tt.Name, func(t *testing.T) {
|
t.Run(tt.Name, func(t *testing.T) {
|
||||||
if err := runSingleTest(&tt); err != nil {
|
if err := runSingleTest(&tt); err != nil {
|
||||||
t.Errorf("Test failed: %v", err.Error())
|
t.Errorf("Test failed: %v", err.Error())
|
||||||
@@ -63,8 +68,9 @@ func validateSubscription(assert bool, wantDBData map[string]interface{}) error
|
|||||||
|
|
||||||
func getBaseSubscription() MembershipData {
|
func getBaseSubscription() MembershipData {
|
||||||
return MembershipData{
|
return MembershipData{
|
||||||
|
APIKey: config.Auth.APIKEY,
|
||||||
Model: models.SubscriptionModel{
|
Model: models.SubscriptionModel{
|
||||||
Name: "Just a Subscription",
|
Name: "Premium",
|
||||||
Details: "A subscription detail",
|
Details: "A subscription detail",
|
||||||
MonthlyFee: 12.0,
|
MonthlyFee: 12.0,
|
||||||
HourlyRate: 14.0,
|
HourlyRate: 14.0,
|
||||||
@@ -79,7 +85,7 @@ func customizeSubscription(customize func(MembershipData) MembershipData) Member
|
|||||||
func getSubscriptionData() []RegisterSubscriptionTest {
|
func getSubscriptionData() []RegisterSubscriptionTest {
|
||||||
return []RegisterSubscriptionTest{
|
return []RegisterSubscriptionTest{
|
||||||
{
|
{
|
||||||
Name: "No Details should fail",
|
Name: "Missing details should fail",
|
||||||
WantResponse: http.StatusNotAcceptable,
|
WantResponse: http.StatusNotAcceptable,
|
||||||
WantDBData: map[string]interface{}{"name": "Just a Subscription"},
|
WantDBData: map[string]interface{}{"name": "Just a Subscription"},
|
||||||
Assert: false,
|
Assert: false,
|
||||||
@@ -90,7 +96,7 @@ func getSubscriptionData() []RegisterSubscriptionTest {
|
|||||||
})),
|
})),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "No Model Name should fail",
|
Name: "Missing model name should fail",
|
||||||
WantResponse: http.StatusNotAcceptable,
|
WantResponse: http.StatusNotAcceptable,
|
||||||
WantDBData: map[string]interface{}{"name": ""},
|
WantDBData: map[string]interface{}{"name": ""},
|
||||||
Assert: false,
|
Assert: false,
|
||||||
@@ -100,10 +106,30 @@ func getSubscriptionData() []RegisterSubscriptionTest {
|
|||||||
return subscription
|
return subscription
|
||||||
})),
|
})),
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Name: "Negative monthly fee should fail",
|
||||||
|
WantResponse: http.StatusNotAcceptable,
|
||||||
|
WantDBData: map[string]interface{}{"name": "Premium"},
|
||||||
|
Assert: false,
|
||||||
|
Input: GenerateInputJSON(customizeSubscription(func(sub MembershipData) MembershipData {
|
||||||
|
sub.Model.MonthlyFee = -10.0
|
||||||
|
return sub
|
||||||
|
})),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "Negative hourly rate should fail",
|
||||||
|
WantResponse: http.StatusNotAcceptable,
|
||||||
|
WantDBData: map[string]interface{}{"name": "Premium"},
|
||||||
|
Assert: false,
|
||||||
|
Input: GenerateInputJSON(customizeSubscription(func(sub MembershipData) MembershipData {
|
||||||
|
sub.Model.HourlyRate = -1.0
|
||||||
|
return sub
|
||||||
|
})),
|
||||||
|
},
|
||||||
{
|
{
|
||||||
Name: "correct entry should pass",
|
Name: "correct entry should pass",
|
||||||
WantResponse: http.StatusCreated,
|
WantResponse: http.StatusCreated,
|
||||||
WantDBData: map[string]interface{}{"name": "Just a Subscription"},
|
WantDBData: map[string]interface{}{"name": "Premium"},
|
||||||
Assert: true,
|
Assert: true,
|
||||||
Input: GenerateInputJSON(
|
Input: GenerateInputJSON(
|
||||||
customizeSubscription(func(subscription MembershipData) MembershipData {
|
customizeSubscription(func(subscription MembershipData) MembershipData {
|
||||||
@@ -113,5 +139,12 @@ func getSubscriptionData() []RegisterSubscriptionTest {
|
|||||||
return subscription
|
return subscription
|
||||||
})),
|
})),
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Name: "Duplicate subscription name should fail",
|
||||||
|
WantResponse: http.StatusConflict,
|
||||||
|
WantDBData: map[string]interface{}{"name": "Premium"},
|
||||||
|
Assert: true, // The original subscription should still exist
|
||||||
|
Input: GenerateInputJSON(getBaseSubscription()),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,17 +1,18 @@
|
|||||||
package controllers
|
package controllers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"GoMembership/internal/config"
|
||||||
|
|
||||||
"GoMembership/internal/constants"
|
"GoMembership/internal/constants"
|
||||||
"GoMembership/internal/middlewares"
|
"GoMembership/internal/middlewares"
|
||||||
"GoMembership/internal/models"
|
"GoMembership/internal/models"
|
||||||
"GoMembership/internal/services"
|
"GoMembership/internal/services"
|
||||||
|
"GoMembership/internal/utils"
|
||||||
|
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
"GoMembership/pkg/errors"
|
||||||
"GoMembership/pkg/logger"
|
"GoMembership/pkg/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -27,12 +28,80 @@ type RegistrationData struct {
|
|||||||
User models.User `json:"user"`
|
User models.User `json:"user"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (uc *UserController) CurrentUserHandler(c *gin.Context) {
|
func (uc *UserController) UpdateHandler(c *gin.Context) {
|
||||||
userIDString, ok := c.Get("user_id")
|
var user models.User
|
||||||
if !ok || userIDString == nil {
|
if err := c.ShouldBindJSON(&user); err != nil {
|
||||||
logger.Error.Printf("Error getting user_id from header")
|
logger.Error.Printf("Couldn't decode input: %v", err)
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "Couldn't decode request data"})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
userID := userIDString.(float64)
|
tokenString, err := c.Cookie("jwt")
|
||||||
|
if err != nil {
|
||||||
|
logger.Error.Printf("No Auth token: %v\n", err)
|
||||||
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "No Auth token"})
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, claims, err := middlewares.ExtractContentFrom(tokenString)
|
||||||
|
if err != nil {
|
||||||
|
|
||||||
|
logger.Error.Printf("Error retrieving token and claims from JWT")
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "JWT parsing error"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
jwtUserID := int64((*claims)["user_id"].(float64))
|
||||||
|
userRole := int8((*claims)["role_id"].(float64))
|
||||||
|
if user.ID == 0 {
|
||||||
|
logger.Error.Printf("No User.ID in request from user with id: %v, aborting", jwtUserID)
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "No user id provided"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if user.ID != jwtUserID && userRole < constants.Roles.Editor {
|
||||||
|
c.JSON(http.StatusForbidden, gin.H{"error": "You are not authorized to update this user"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// TODO: If it's not an admin, prevent changes to critical fields
|
||||||
|
// if userRole != constants.Roles.Admin {
|
||||||
|
// existingUser, err := uc.Service.GetUserByID(jwtUserID)
|
||||||
|
// if err != nil {
|
||||||
|
// c.JSON(http.StatusInternalServerError, gin.H{"error": "Error retrieving user data"})
|
||||||
|
// return
|
||||||
|
// }
|
||||||
|
// user.Email = existingUser.Email
|
||||||
|
// user.RoleID = existingUser.RoleID
|
||||||
|
// }
|
||||||
|
|
||||||
|
updatedUser, err := uc.Service.UpdateUser(&user)
|
||||||
|
if err != nil {
|
||||||
|
switch err {
|
||||||
|
case errors.ErrUserNotFound:
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "User not found"})
|
||||||
|
case errors.ErrInvalidUserData:
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user data"})
|
||||||
|
default:
|
||||||
|
logger.Error.Printf("Failed to update user: %v", err)
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Internal Server error"})
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusAccepted, gin.H{"message": "User updated successfully", "user": updatedUser})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (uc *UserController) CurrentUserHandler(c *gin.Context) {
|
||||||
|
userIDInterface, ok := c.Get("user_id")
|
||||||
|
if !ok || userIDInterface == nil {
|
||||||
|
logger.Error.Printf("Error getting user_id from header")
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Missing or invalid user ID type"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
userID, ok := userIDInterface.(int64)
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
logger.Error.Printf("Error: user_id is not of type int64")
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user ID type"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
user, err := uc.Service.GetUserByID(int64(userID))
|
user, err := uc.Service.GetUserByID(int64(userID))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error.Printf("Error retrieving valid user: %v", err)
|
logger.Error.Printf("Error retrieving valid user: %v", err)
|
||||||
@@ -44,7 +113,13 @@ func (uc *UserController) CurrentUserHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (uc *UserController) LogoutHandler(c *gin.Context) {
|
func (uc *UserController) LogoutHandler(c *gin.Context) {
|
||||||
// just clear the JWT cookie
|
tokenString, err := c.Cookie("jwt")
|
||||||
|
if err != nil {
|
||||||
|
logger.Error.Printf("unable to get token from cookie: %#v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
middlewares.InvalidateSession(tokenString)
|
||||||
|
|
||||||
c.SetCookie("jwt", "", -1, "/", "", true, true)
|
c.SetCookie("jwt", "", -1, "/", "", true, true)
|
||||||
c.JSON(http.StatusOK, gin.H{"message": "Logged out successfully"})
|
c.JSON(http.StatusOK, gin.H{"message": "Logged out successfully"})
|
||||||
}
|
}
|
||||||
@@ -82,25 +157,17 @@ func (uc *UserController) LoginHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
token, err := middlewares.GenerateToken(user.ID)
|
logger.Error.Printf("jwtsevret: %v", config.Auth.JWTSecret)
|
||||||
|
token, err := middlewares.GenerateToken(config.Auth.JWTSecret, user, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate JWT token"})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate JWT token"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.SetCookie(
|
utils.SetCookie(c, token)
|
||||||
"jwt",
|
|
||||||
token,
|
|
||||||
10*60, // 10 minutes
|
|
||||||
"/",
|
|
||||||
"",
|
|
||||||
true,
|
|
||||||
true,
|
|
||||||
)
|
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"message": "Login successful",
|
"message": "Login successful",
|
||||||
"set-token": token,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -118,7 +185,7 @@ func (uc *UserController) RegisterUser(c *gin.Context) {
|
|||||||
c.JSON(http.StatusNotAcceptable, gin.H{"error": "No subscription model provided"})
|
c.JSON(http.StatusNotAcceptable, gin.H{"error": "No subscription model provided"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
logger.Error.Printf("user.membership: %#v", regData.User.Membership)
|
||||||
selectedModel, err := uc.MembershipService.GetModelByName(®Data.User.Membership.SubscriptionModel.Name)
|
selectedModel, err := uc.MembershipService.GetModelByName(®Data.User.Membership.SubscriptionModel.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error.Printf("%v:No subscription model found: %#v", regData.User.Email, err)
|
logger.Error.Printf("%v:No subscription model found: %#v", regData.User.Email, err)
|
||||||
@@ -134,7 +201,7 @@ func (uc *UserController) RegisterUser(c *gin.Context) {
|
|||||||
id, token, err := uc.Service.RegisterUser(®Data.User)
|
id, token, err := uc.Service.RegisterUser(®Data.User)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error.Printf("Couldn't register User(%v): %v", regData.User.Email, err)
|
logger.Error.Printf("Couldn't register User(%v): %v", regData.User.Email, err)
|
||||||
c.JSON(int(id), gin.H{"error": fmt.Sprintf("Couldn't register User: %v", err)})
|
c.JSON(int(id), gin.H{"error": "Couldn't register User"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
regData.User.ID = id
|
regData.User.ID = id
|
||||||
@@ -194,7 +261,7 @@ func (uc *UserController) VerifyMailHandler(c *gin.Context) {
|
|||||||
c.HTML(http.StatusUnauthorized, "verification_error.html", gin.H{"ErrorMessage": "Emailadresse wurde schon bestätigt. Sollte dies nicht der Fall sein, wende Dich bitte an info@carsharing-hasloh.de."})
|
c.HTML(http.StatusUnauthorized, "verification_error.html", gin.H{"ErrorMessage": "Emailadresse wurde schon bestätigt. Sollte dies nicht der Fall sein, wende Dich bitte an info@carsharing-hasloh.de."})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
logger.Info.Printf("User: %#v", user)
|
logger.Info.Printf("VerificationMailHandler User: %#v", user.Email)
|
||||||
|
|
||||||
uc.EmailService.SendWelcomeEmail(user)
|
uc.EmailService.SendWelcomeEmail(user)
|
||||||
c.HTML(http.StatusOK, "verification_success.html", gin.H{"FirstName": user.FirstName})
|
c.HTML(http.StatusOK, "verification_success.html", gin.H{"FirstName": user.FirstName})
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package controllers
|
package controllers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@@ -9,6 +10,7 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -22,12 +24,9 @@ import (
|
|||||||
"GoMembership/internal/models"
|
"GoMembership/internal/models"
|
||||||
"GoMembership/internal/utils"
|
"GoMembership/internal/utils"
|
||||||
"GoMembership/pkg/logger"
|
"GoMembership/pkg/logger"
|
||||||
)
|
|
||||||
|
|
||||||
type loginInput struct {
|
"github.com/golang-jwt/jwt/v5"
|
||||||
Email string `json:"email"`
|
)
|
||||||
Password string `json:"password"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type RegisterUserTest struct {
|
type RegisterUserTest struct {
|
||||||
WantDBData map[string]interface{}
|
WantDBData map[string]interface{}
|
||||||
@@ -37,6 +36,8 @@ type RegisterUserTest struct {
|
|||||||
Assert bool
|
Assert bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var jwtSigningMethod = jwt.SigningMethodHS256
|
||||||
|
|
||||||
func (rt *RegisterUserTest) SetupContext() (*gin.Context, *httptest.ResponseRecorder, *gin.Engine) {
|
func (rt *RegisterUserTest) SetupContext() (*gin.Context, *httptest.ResponseRecorder, *gin.Engine) {
|
||||||
return GetMockedJSONContext([]byte(rt.Input), "register")
|
return GetMockedJSONContext([]byte(rt.Input), "register")
|
||||||
}
|
}
|
||||||
@@ -61,17 +62,23 @@ func testUserController(t *testing.T) {
|
|||||||
|
|
||||||
tests := getTestUsers()
|
tests := getTestUsers()
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
logger.Error.Print("==============================================================")
|
||||||
|
logger.Error.Printf("Register User Testing : %v", tt.Name)
|
||||||
|
logger.Error.Print("==============================================================")
|
||||||
t.Run(tt.Name, func(t *testing.T) {
|
t.Run(tt.Name, func(t *testing.T) {
|
||||||
if err := runSingleTest(&tt); err != nil {
|
if err := runSingleTest(&tt); err != nil {
|
||||||
t.Fatalf("Test failed: %v", err.Error())
|
t.Fatalf("Test failed: %v", err.Error())
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
testCurrentUserHandler(t)
|
|
||||||
|
loginEmail, loginCookie := testLoginHandler(t)
|
||||||
|
logoutCookie := testCurrentUserHandler(t, loginEmail, loginCookie)
|
||||||
|
testUpdateUser(t, loginEmail, loginCookie)
|
||||||
|
testLogoutHandler(t, logoutCookie)
|
||||||
}
|
}
|
||||||
|
|
||||||
func testLogoutHandler(t *testing.T) {
|
func testLogoutHandler(t *testing.T, loginCookie http.Cookie) {
|
||||||
loginCookie := testCurrentUserHandler(t)
|
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -93,6 +100,9 @@ func testLogoutHandler(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
logger.Error.Print("==============================================================")
|
||||||
|
logger.Error.Printf("Logout User Testing : %v", tt.name)
|
||||||
|
logger.Error.Print("==============================================================")
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
router := gin.New()
|
router := gin.New()
|
||||||
@@ -125,11 +135,11 @@ func testLogoutHandler(t *testing.T) {
|
|||||||
|
|
||||||
// Verify that the user can no longer access protected routes
|
// Verify that the user can no longer access protected routes
|
||||||
w = httptest.NewRecorder()
|
w = httptest.NewRecorder()
|
||||||
req, _ = http.NewRequest("GET", "/current-user", nil)
|
req, _ = http.NewRequest("GET", "/current", nil)
|
||||||
if logoutCookie != nil {
|
if logoutCookie != nil {
|
||||||
req.AddCookie(logoutCookie)
|
req.AddCookie(logoutCookie)
|
||||||
}
|
}
|
||||||
router.GET("/current-user", middlewares.AuthMiddleware(), Uc.CurrentUserHandler)
|
router.GET("/current", middlewares.AuthMiddleware(), Uc.CurrentUserHandler)
|
||||||
router.ServeHTTP(w, req)
|
router.ServeHTTP(w, req)
|
||||||
assert.Equal(t, http.StatusUnauthorized, w.Code, "User should not be able to access protected routes after logout")
|
assert.Equal(t, http.StatusUnauthorized, w.Code, "User should not be able to access protected routes after logout")
|
||||||
})
|
})
|
||||||
@@ -196,9 +206,8 @@ func testLoginHandler(t *testing.T) (string, http.Cookie) {
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
if tt.wantToken {
|
if tt.wantToken {
|
||||||
logger.Info.Printf("Response: %#v", response)
|
assert.Contains(t, response, "message")
|
||||||
assert.Contains(t, response, "set-token")
|
assert.Equal(t, "Login successful", response["message"])
|
||||||
assert.NotEmpty(t, response["set-token"])
|
|
||||||
for _, cookie := range w.Result().Cookies() {
|
for _, cookie := range w.Result().Cookies() {
|
||||||
if cookie.Name == "jwt" {
|
if cookie.Name == "jwt" {
|
||||||
loginCookie = *cookie
|
loginCookie = *cookie
|
||||||
@@ -211,7 +220,8 @@ func testLoginHandler(t *testing.T) (string, http.Cookie) {
|
|||||||
}
|
}
|
||||||
assert.NotEmpty(t, loginCookie)
|
assert.NotEmpty(t, loginCookie)
|
||||||
} else {
|
} else {
|
||||||
assert.NotContains(t, response, "set-token")
|
assert.Contains(t, response, "error")
|
||||||
|
assert.NotEmpty(t, response["error"])
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -220,8 +230,7 @@ func testLoginHandler(t *testing.T) (string, http.Cookie) {
|
|||||||
return loginInput.Email, loginCookie
|
return loginInput.Email, loginCookie
|
||||||
}
|
}
|
||||||
|
|
||||||
func testCurrentUserHandler(t *testing.T) http.Cookie {
|
func testCurrentUserHandler(t *testing.T, loginEmail string, loginCookie http.Cookie) http.Cookie {
|
||||||
loginEmail, loginCookie := testLoginHandler(t)
|
|
||||||
// This test should run after the user login test
|
// This test should run after the user login test
|
||||||
invalidCookie := http.Cookie{
|
invalidCookie := http.Cookie{
|
||||||
Name: "jwt",
|
Name: "jwt",
|
||||||
@@ -232,6 +241,7 @@ func testCurrentUserHandler(t *testing.T) http.Cookie {
|
|||||||
setupCookie func(*http.Request)
|
setupCookie func(*http.Request)
|
||||||
expectedUserMail string
|
expectedUserMail string
|
||||||
expectedStatus int
|
expectedStatus int
|
||||||
|
expectNewCookie bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "With valid cookie",
|
name: "With valid cookie",
|
||||||
@@ -241,6 +251,24 @@ func testCurrentUserHandler(t *testing.T) http.Cookie {
|
|||||||
expectedUserMail: loginEmail,
|
expectedUserMail: loginEmail,
|
||||||
expectedStatus: http.StatusOK,
|
expectedStatus: http.StatusOK,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "With valid expired cookie",
|
||||||
|
setupCookie: func(req *http.Request) {
|
||||||
|
sessionID := "test-session"
|
||||||
|
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{
|
||||||
|
"user_id": 1,
|
||||||
|
"role_id": 0,
|
||||||
|
"session_id": sessionID,
|
||||||
|
"exp": time.Now().Add(-time.Hour).Unix(), // Expired 1 hour ago
|
||||||
|
})
|
||||||
|
tokenString, _ := token.SignedString([]byte(config.Auth.JWTSecret))
|
||||||
|
req.AddCookie(&http.Cookie{Name: "jwt", Value: tokenString})
|
||||||
|
middlewares.UpdateSession(sessionID, 1) // Add a valid session
|
||||||
|
},
|
||||||
|
expectedUserMail: config.Recipients.AdminEmail,
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectNewCookie: true,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "Without cookie",
|
name: "Without cookie",
|
||||||
setupCookie: func(req *http.Request) {},
|
setupCookie: func(req *http.Request) {},
|
||||||
@@ -259,18 +287,15 @@ func testCurrentUserHandler(t *testing.T) http.Cookie {
|
|||||||
logger.Error.Print("==============================================================")
|
logger.Error.Print("==============================================================")
|
||||||
logger.Error.Printf("Testing : %v", tt.name)
|
logger.Error.Printf("Testing : %v", tt.name)
|
||||||
logger.Error.Print("==============================================================")
|
logger.Error.Print("==============================================================")
|
||||||
if tt.expectedStatus == http.StatusOK {
|
|
||||||
time.Sleep(time.Second) // Small delay to ensure different timestamps to get a different JWT token
|
|
||||||
}
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
router := gin.New()
|
router := gin.New()
|
||||||
router.Use(middlewares.AuthMiddleware())
|
router.Use(middlewares.AuthMiddleware())
|
||||||
router.GET("/current-user", Uc.CurrentUserHandler)
|
router.GET("/current", Uc.CurrentUserHandler)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
req, _ := http.NewRequest("GET", "/current-user", nil)
|
req, _ := http.NewRequest("GET", "/current", nil)
|
||||||
tt.setupCookie(req)
|
tt.setupCookie(req)
|
||||||
|
|
||||||
router.ServeHTTP(w, req)
|
router.ServeHTTP(w, req)
|
||||||
@@ -290,9 +315,13 @@ func testCurrentUserHandler(t *testing.T) http.Cookie {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
assert.NotNil(t, newCookie, "Cookie should be renewed")
|
if tt.expectNewCookie {
|
||||||
assert.NotEqual(t, loginCookie.Value, newCookie.Value, "Cookie value should be different")
|
assert.NotNil(t, newCookie, "New cookie should be set for expired token")
|
||||||
assert.True(t, newCookie.MaxAge > 0, "New cookie should not be expired")
|
assert.NotEqual(t, loginCookie.Value, newCookie.Value, "Cookie value should be different")
|
||||||
|
assert.True(t, newCookie.MaxAge > 0, "New cookie should not be expired")
|
||||||
|
} else {
|
||||||
|
assert.Nil(t, newCookie, "No new cookie should be set for non-expired token")
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// For unauthorized requests, check for an error message
|
// For unauthorized requests, check for an error message
|
||||||
var errorResponse map[string]string
|
var errorResponse map[string]string
|
||||||
@@ -316,29 +345,32 @@ func validateUser(assert bool, wantDBData map[string]interface{}) error {
|
|||||||
if assert != (len(*users) != 0) {
|
if assert != (len(*users) != 0) {
|
||||||
return fmt.Errorf("User entry query didn't met expectation: %v != %#v", assert, *users)
|
return fmt.Errorf("User entry query didn't met expectation: %v != %#v", assert, *users)
|
||||||
}
|
}
|
||||||
|
|
||||||
if assert {
|
if assert {
|
||||||
//check for email delivery
|
//check for email delivery
|
||||||
messages := utils.SMTPGetMessages()
|
messages := utils.SMTPGetMessages()
|
||||||
for _, message := range messages {
|
for _, message := range messages {
|
||||||
mail, err := utils.DecodeMail(message.MsgRequest())
|
mail, err := utils.DecodeMail(message.MsgRequest())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
logger.Error.Printf("Error in validateUser: %#v", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if strings.Contains(mail.Subject, constants.MailRegistrationSubject) {
|
if strings.Contains(mail.Subject, constants.MailRegistrationSubject) {
|
||||||
if err := checkRegistrationMail(mail, &(*users)[0]); err != nil {
|
if err := checkRegistrationMail(mail, &(*users)[0]); err != nil {
|
||||||
|
logger.Error.Printf("Error in checkRegistrationMail: %#v", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
} else if strings.Contains(mail.Subject, constants.MailVerificationSubject) {
|
} else if strings.Contains(mail.Subject, constants.MailVerificationSubject) {
|
||||||
if err := checkVerificationMail(mail, &(*users)[0]); err != nil {
|
if err := checkVerificationMail(mail, &(*users)[0]); err != nil {
|
||||||
|
logger.Error.Printf("Error in checkVerificationMail: %#v", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
verifiedUsers, err := Uc.Service.GetUsers(wantDBData)
|
verifiedUsers, err := Uc.Service.GetUsers(wantDBData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
logger.Error.Printf("Error in GetUsers: %#v", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if (*verifiedUsers)[0].Status != constants.VerifiedStatus {
|
if (*verifiedUsers)[0].Status != constants.VerifiedStatus {
|
||||||
return fmt.Errorf("Users status isn't verified after email verification. Status is: %#v", (*verifiedUsers)[0].Status)
|
return fmt.Errorf("Users(%v) status isn't verified after email verification. Status is: %v", (*verifiedUsers)[0].Email, (*verifiedUsers)[0].Status)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return fmt.Errorf("Subject not expected: %v", mail.Subject)
|
return fmt.Errorf("Subject not expected: %v", mail.Subject)
|
||||||
@@ -348,6 +380,168 @@ func validateUser(assert bool, wantDBData map[string]interface{}) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func testUpdateUser(t *testing.T, loginEmail string, loginCookie http.Cookie) {
|
||||||
|
|
||||||
|
invalidCookie := http.Cookie{
|
||||||
|
Name: "jwt",
|
||||||
|
Value: "invalid.token.here",
|
||||||
|
}
|
||||||
|
// Get the user we just created
|
||||||
|
users, err := Uc.Service.GetUsers(map[string]interface{}{"email": "john.doe@example.com"})
|
||||||
|
if err != nil || len(*users) == 0 {
|
||||||
|
t.Fatalf("Failed to get test user: %v", err)
|
||||||
|
}
|
||||||
|
user := (*users)[0]
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
setupCookie func(*http.Request)
|
||||||
|
updateFunc func(*models.User)
|
||||||
|
expectedStatus int
|
||||||
|
expectedError string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Valid Update",
|
||||||
|
setupCookie: func(req *http.Request) {
|
||||||
|
req.AddCookie(&loginCookie)
|
||||||
|
},
|
||||||
|
updateFunc: func(u *models.User) {
|
||||||
|
u.Password = ""
|
||||||
|
u.FirstName = "John Updated"
|
||||||
|
u.LastName = "Doe Updated"
|
||||||
|
u.Phone = "01738484994"
|
||||||
|
},
|
||||||
|
expectedStatus: http.StatusAccepted,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Password Update",
|
||||||
|
setupCookie: func(req *http.Request) {
|
||||||
|
req.AddCookie(&loginCookie)
|
||||||
|
},
|
||||||
|
updateFunc: func(u *models.User) {
|
||||||
|
u.Password = "NewPassword"
|
||||||
|
},
|
||||||
|
expectedStatus: http.StatusAccepted,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Valid Update, invalid cookie",
|
||||||
|
setupCookie: func(req *http.Request) {
|
||||||
|
req.AddCookie(&invalidCookie)
|
||||||
|
},
|
||||||
|
updateFunc: func(u *models.User) {
|
||||||
|
u.Password = ""
|
||||||
|
u.FirstName = "John Updated"
|
||||||
|
u.LastName = "Doe Updated"
|
||||||
|
u.Phone = "01738484994"
|
||||||
|
},
|
||||||
|
expectedStatus: http.StatusUnauthorized,
|
||||||
|
expectedError: "Auth token invalid",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid Email Update",
|
||||||
|
setupCookie: func(req *http.Request) {
|
||||||
|
req.AddCookie(&loginCookie)
|
||||||
|
},
|
||||||
|
updateFunc: func(u *models.User) {
|
||||||
|
u.Password = ""
|
||||||
|
u.Email = "invalid-email"
|
||||||
|
},
|
||||||
|
expectedStatus: http.StatusBadRequest,
|
||||||
|
expectedError: "Invalid user data",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "User ID mismatch while not admin",
|
||||||
|
setupCookie: func(req *http.Request) {
|
||||||
|
req.AddCookie(&loginCookie)
|
||||||
|
},
|
||||||
|
updateFunc: func(u *models.User) {
|
||||||
|
u.Password = ""
|
||||||
|
u.ID = 1
|
||||||
|
u.FirstName = "John Missing ID"
|
||||||
|
},
|
||||||
|
expectedStatus: http.StatusForbidden,
|
||||||
|
expectedError: "You are not authorized to update this user",
|
||||||
|
},
|
||||||
|
// {
|
||||||
|
// name: "Non-existent User",
|
||||||
|
// setupCookie: func(req *http.Request) {
|
||||||
|
// req.AddCookie(&loginCookie)
|
||||||
|
// },
|
||||||
|
// updateFunc: func(u *models.User) {
|
||||||
|
// u.Password = ""
|
||||||
|
// u.ID = 99999
|
||||||
|
// u.FirstName = "Non-existent"
|
||||||
|
// },
|
||||||
|
// expectedStatus: http.StatusNotFound,
|
||||||
|
// expectedError: "User not found",
|
||||||
|
// },
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
logger.Error.Print("==============================================================")
|
||||||
|
logger.Error.Printf("Update Testing : %v", tt.name)
|
||||||
|
logger.Error.Print("==============================================================")
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Create a copy of the user and apply the updates
|
||||||
|
updatedUser := user
|
||||||
|
tt.updateFunc(&updatedUser)
|
||||||
|
|
||||||
|
// Convert user to JSON
|
||||||
|
jsonData, err := json.Marshal(updatedUser)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to marshal user data: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create request
|
||||||
|
req, _ := http.NewRequest("PUT", "/users/"+strconv.FormatInt(user.ID, 10), bytes.NewBuffer(jsonData))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
tt.setupCookie(req)
|
||||||
|
// Create response recorder
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// Set up router and add middleware
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(middlewares.AuthMiddleware())
|
||||||
|
router.PUT("/users/:id", Uc.UpdateHandler)
|
||||||
|
|
||||||
|
// Perform request
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
// Check status code
|
||||||
|
assert.Equal(t, tt.expectedStatus, w.Code)
|
||||||
|
|
||||||
|
// Parse response
|
||||||
|
var response map[string]interface{}
|
||||||
|
err = json.Unmarshal(w.Body.Bytes(), &response)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
if tt.expectedError != "" {
|
||||||
|
assert.Equal(t, tt.expectedError, response["error"])
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, "User updated successfully", response["message"])
|
||||||
|
|
||||||
|
// Verify the update in the database
|
||||||
|
updatedUserFromDB, err := Uc.Service.GetUserByID(user.ID)
|
||||||
|
updatedUserFromDB.UpdatedAt = updatedUser.UpdatedAt
|
||||||
|
updatedUserFromDB.Membership.UpdatedAt = updatedUser.Membership.UpdatedAt
|
||||||
|
updatedUserFromDB.BankAccount.UpdatedAt = updatedUser.BankAccount.UpdatedAt
|
||||||
|
updatedUserFromDB.Verification.UpdatedAt = updatedUser.Verification.UpdatedAt
|
||||||
|
updatedUserFromDB.Membership.SubscriptionModel.UpdatedAt = updatedUser.Membership.SubscriptionModel.UpdatedAt
|
||||||
|
if updatedUser.Password == "" {
|
||||||
|
assert.Equal(t, user.Password, (*updatedUserFromDB).Password)
|
||||||
|
} else {
|
||||||
|
assert.NotEqual(t, user.Password, (*updatedUserFromDB).Password)
|
||||||
|
updatedUser.Password = ""
|
||||||
|
}
|
||||||
|
updatedUserFromDB.Password = ""
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, updatedUser, *updatedUserFromDB, "Updated user in DB does not match expected user")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func checkWelcomeMail(message *utils.Email, user *models.User) error {
|
func checkWelcomeMail(message *utils.Email, user *models.User) error {
|
||||||
|
|
||||||
if !strings.Contains(message.To, user.Email) {
|
if !strings.Contains(message.To, user.Email) {
|
||||||
@@ -506,23 +700,6 @@ func getVerificationURL(mailBody string) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TEST DATA:
|
// TEST DATA:
|
||||||
func getBaseUser() models.User {
|
|
||||||
return models.User{
|
|
||||||
DateOfBirth: time.Date(2000, time.January, 1, 0, 0, 0, 0, time.UTC),
|
|
||||||
FirstName: "John",
|
|
||||||
LastName: "Doe",
|
|
||||||
Email: "john.doe@example.com",
|
|
||||||
Address: "Pablo Escobar Str. 4",
|
|
||||||
ZipCode: "25474",
|
|
||||||
City: "Hasloh",
|
|
||||||
Phone: "01738484993",
|
|
||||||
BankAccount: models.BankAccount{IBAN: "DE89370400440532013000"},
|
|
||||||
Membership: models.Membership{SubscriptionModel: models.SubscriptionModel{Name: "Basic"}},
|
|
||||||
ProfilePicture: "",
|
|
||||||
Password: "password123",
|
|
||||||
Company: "",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func customizeInput(customize func(models.User) models.User) *RegistrationData {
|
func customizeInput(customize func(models.User) models.User) *RegistrationData {
|
||||||
user := getBaseUser()
|
user := getBaseUser()
|
||||||
|
|||||||
@@ -2,68 +2,75 @@ package middlewares
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"GoMembership/internal/config"
|
"GoMembership/internal/config"
|
||||||
|
"GoMembership/internal/models"
|
||||||
|
"GoMembership/internal/utils"
|
||||||
|
customerrors "GoMembership/pkg/errors"
|
||||||
"GoMembership/pkg/logger"
|
"GoMembership/pkg/logger"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
type Session struct {
|
||||||
jwtKey = []byte(config.Auth.JWTSecret)
|
UserID int64
|
||||||
jwtSigningMethod = jwt.SigningMethodHS256
|
ExpiresAt time.Time
|
||||||
jwtParser = jwt.NewParser(jwt.WithValidMethods([]string{jwtSigningMethod.Alg()}))
|
|
||||||
)
|
|
||||||
|
|
||||||
func GenerateToken(userID int64) (string, error) {
|
|
||||||
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{
|
|
||||||
"user_id": userID,
|
|
||||||
"exp": time.Now().Add(time.Minute * 10).Unix(), // Token expires in 10 Minutes
|
|
||||||
})
|
|
||||||
|
|
||||||
logger.Error.Printf("token generated: %#v", token)
|
|
||||||
return token.SignedString(jwtKey)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func verifyToken(tokenString string) (*jwt.Token, error) {
|
var (
|
||||||
|
sessionDuration = 5 * 24 * time.Hour
|
||||||
|
jwtSigningMethod = jwt.SigningMethodHS256
|
||||||
|
jwtParser = jwt.NewParser(jwt.WithValidMethods([]string{jwtSigningMethod.Alg()}))
|
||||||
|
sessions = make(map[string]*Session)
|
||||||
|
)
|
||||||
|
|
||||||
|
func verifyAndRenewToken(tokenString string) (string, int64, error) {
|
||||||
if tokenString == "" {
|
if tokenString == "" {
|
||||||
return nil, fmt.Errorf("Authorization token is required")
|
logger.Error.Printf("empty tokenstring")
|
||||||
|
return "", -1, fmt.Errorf("Authorization token is required")
|
||||||
}
|
}
|
||||||
token, err := jwtParser.Parse(tokenString, func(_ *jwt.Token) (interface{}, error) {
|
token, claims, err := ExtractContentFrom(tokenString)
|
||||||
return jwtKey, nil
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
logger.Error.Printf("Couldn't parse JWT token String: %v", err)
|
||||||
|
return "", -1, err
|
||||||
}
|
}
|
||||||
|
sessionID := (*claims)["session_id"].(string)
|
||||||
|
userID := int64((*claims)["user_id"].(float64))
|
||||||
|
roleID := int8((*claims)["role_id"].(float64))
|
||||||
|
|
||||||
if !token.Valid {
|
session, ok := sessions[sessionID]
|
||||||
return nil, fmt.Errorf("invalid token")
|
|
||||||
}
|
|
||||||
|
|
||||||
claims, ok := token.Claims.(jwt.MapClaims)
|
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("invalid token claims")
|
logger.Error.Printf("session not found")
|
||||||
|
return "", -1, fmt.Errorf("session not found")
|
||||||
|
}
|
||||||
|
if userID != session.UserID {
|
||||||
|
return "", -1, fmt.Errorf("Cookie has been altered, aborting..")
|
||||||
|
}
|
||||||
|
if token.Valid {
|
||||||
|
// token is valid, so we can return the old tokenString
|
||||||
|
return tokenString, session.UserID, customerrors.ErrValidToken
|
||||||
}
|
}
|
||||||
|
|
||||||
exp, ok := claims["exp"].(float64)
|
if time.Now().After(sessions[sessionID].ExpiresAt) {
|
||||||
if !ok {
|
delete(sessions, sessionID)
|
||||||
return nil, fmt.Errorf("invalid expiration claim")
|
logger.Error.Printf("session expired")
|
||||||
|
return "", -1, fmt.Errorf("session expired")
|
||||||
|
}
|
||||||
|
session.ExpiresAt = time.Now().Add(sessionDuration)
|
||||||
|
|
||||||
|
logger.Error.Printf("Session still valid generating new token")
|
||||||
|
// Session is still valid, generate a new token
|
||||||
|
user := models.User{ID: userID, RoleID: roleID}
|
||||||
|
newTokenString, err := GenerateToken(config.Auth.JWTSecret, &user, sessionID)
|
||||||
|
if err != nil {
|
||||||
|
return "", -1, err
|
||||||
}
|
}
|
||||||
|
|
||||||
userID, ok := claims["user_id"].(float64)
|
return newTokenString, session.UserID, nil
|
||||||
if !ok {
|
|
||||||
logger.Error.Printf("Invalid user ID: %v", userID)
|
|
||||||
return nil, fmt.Errorf("Invalid user ID")
|
|
||||||
}
|
|
||||||
|
|
||||||
if time.Now().Unix() > int64(exp) {
|
|
||||||
return nil, fmt.Errorf("token expired")
|
|
||||||
}
|
|
||||||
|
|
||||||
return token, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func AuthMiddleware() gin.HandlerFunc {
|
func AuthMiddleware() gin.HandlerFunc {
|
||||||
@@ -76,34 +83,89 @@ func AuthMiddleware() gin.HandlerFunc {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
token, err := verifyToken(tokenString)
|
newToken, userID, err := verifyAndRenewToken(tokenString)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error.Printf("Token is invalid: %v\n", err)
|
if err == customerrors.ErrValidToken {
|
||||||
|
c.Set("user_id", int64(userID))
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logger.Error.Printf("Token(%v) is invalid: %v\n", tokenString, err)
|
||||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Auth token invalid"})
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "Auth token invalid"})
|
||||||
c.Abort()
|
c.Abort()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
claims, _ := token.Claims.(jwt.MapClaims)
|
|
||||||
userID, _ := claims["user_id"].(float64)
|
|
||||||
|
|
||||||
// Generate a new token
|
utils.SetCookie(c, newToken)
|
||||||
newToken, err := GenerateToken(int64(userID))
|
c.Set("user_id", int64(userID))
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to refresh token"})
|
|
||||||
c.Abort()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.SetCookie(
|
|
||||||
"jwt",
|
|
||||||
newToken,
|
|
||||||
10*60, // 10 minutes
|
|
||||||
"/",
|
|
||||||
"",
|
|
||||||
true,
|
|
||||||
true,
|
|
||||||
)
|
|
||||||
c.Set("user_id", userID)
|
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GenerateToken(jwtKey string, user *models.User, sessionID string) (string, error) {
|
||||||
|
if sessionID == "" {
|
||||||
|
sessionID = uuid.New().String()
|
||||||
|
}
|
||||||
|
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{
|
||||||
|
"user_id": user.ID,
|
||||||
|
"role_id": user.RoleID,
|
||||||
|
"session_id": sessionID,
|
||||||
|
"exp": time.Now().Add(time.Minute * 1).Unix(), // Token expires in 10 Minutes
|
||||||
|
})
|
||||||
|
UpdateSession(sessionID, user.ID)
|
||||||
|
return token.SignedString([]byte(jwtKey))
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExtractContentFrom(tokenString string) (*jwt.Token, *jwt.MapClaims, error) {
|
||||||
|
|
||||||
|
token, err := jwtParser.Parse(tokenString, func(_ *jwt.Token) (interface{}, error) {
|
||||||
|
return []byte(config.Auth.JWTSecret), nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if !errors.Is(err, jwt.ErrTokenExpired) && err != nil {
|
||||||
|
logger.Error.Printf("Error during token(%v) parsing: %#v", tokenString, err)
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Token is expired, check if session is still valid
|
||||||
|
claims, ok := token.Claims.(jwt.MapClaims)
|
||||||
|
if !ok {
|
||||||
|
logger.Error.Printf("Invalid Token Claims")
|
||||||
|
return nil, nil, fmt.Errorf("invalid token claims")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
logger.Error.Printf("invalid session_id in token")
|
||||||
|
return nil, nil, fmt.Errorf("invalid session_id in token")
|
||||||
|
}
|
||||||
|
return token, &claims, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func UpdateSession(sessionID string, userID int64) {
|
||||||
|
sessions[sessionID] = &Session{
|
||||||
|
UserID: userID,
|
||||||
|
ExpiresAt: time.Now().Add(sessionDuration),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func InvalidateSession(token string) (bool, error) {
|
||||||
|
claims := jwt.MapClaims{}
|
||||||
|
_, err := jwt.ParseWithClaims(
|
||||||
|
token,
|
||||||
|
claims,
|
||||||
|
func(token *jwt.Token) (interface{}, error) {
|
||||||
|
return config.Auth.JWTSecret, nil
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("Couldn't get JWT claims: %#v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionID, ok := claims["session_id"].(string)
|
||||||
|
if !ok {
|
||||||
|
return false, fmt.Errorf("No SessionID found")
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(sessions, sessionID)
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ package middlewares
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"GoMembership/internal/config"
|
"GoMembership/internal/config"
|
||||||
|
"GoMembership/internal/constants"
|
||||||
|
"GoMembership/internal/models"
|
||||||
"GoMembership/pkg/logger"
|
"GoMembership/pkg/logger"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"log"
|
"log"
|
||||||
@@ -17,9 +19,7 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestAuthMiddleware(t *testing.T) {
|
func setupTestEnvironment() {
|
||||||
gin.SetMode(gin.TestMode)
|
|
||||||
|
|
||||||
cwd, err := os.Getwd()
|
cwd, err := os.Getwd()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Failed to get current working directory: %v", err)
|
log.Fatalf("Failed to get current working directory: %v", err)
|
||||||
@@ -39,17 +39,25 @@ func TestAuthMiddleware(t *testing.T) {
|
|||||||
}
|
}
|
||||||
config.LoadConfig()
|
config.LoadConfig()
|
||||||
logger.Info.Printf("Config: %#v", config.CFG)
|
logger.Info.Printf("Config: %#v", config.CFG)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthMiddleware(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
setupTestEnvironment()
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
setupAuth func(r *http.Request)
|
setupAuth func(r *http.Request)
|
||||||
expectedStatus int
|
expectedStatus int
|
||||||
expectedUserID int64
|
expectNewCookie bool
|
||||||
|
expectedUserID int64
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Valid Token",
|
name: "Valid Token",
|
||||||
setupAuth: func(r *http.Request) {
|
setupAuth: func(r *http.Request) {
|
||||||
token, _ := GenerateToken(123)
|
user := models.User{ID: 123, RoleID: constants.Roles.Member}
|
||||||
|
token, _ := GenerateToken(config.Auth.JWTSecret, &user, "")
|
||||||
r.AddCookie(&http.Cookie{Name: "jwt", Value: token})
|
r.AddCookie(&http.Cookie{Name: "jwt", Value: token})
|
||||||
},
|
},
|
||||||
expectedStatus: http.StatusOK,
|
expectedStatus: http.StatusOK,
|
||||||
@@ -70,14 +78,36 @@ func TestAuthMiddleware(t *testing.T) {
|
|||||||
expectedUserID: 0,
|
expectedUserID: 0,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Expired Token",
|
name: "Expired Token with Valid Session",
|
||||||
setupAuth: func(r *http.Request) {
|
setupAuth: func(r *http.Request) {
|
||||||
|
sessionID := "test-session"
|
||||||
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{
|
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{
|
||||||
"user_id": 123,
|
"user_id": 123,
|
||||||
"exp": time.Now().Add(-time.Hour).Unix(), // Expired 1 hour ago
|
"role_id": constants.Roles.Member,
|
||||||
|
"session_id": sessionID,
|
||||||
|
"exp": time.Now().Add(-time.Hour).Unix(), // Expired 1 hour ago
|
||||||
})
|
})
|
||||||
tokenString, _ := token.SignedString(jwtKey)
|
tokenString, _ := token.SignedString([]byte(config.Auth.JWTSecret))
|
||||||
r.AddCookie(&http.Cookie{Name: "jwt", Value: tokenString})
|
r.AddCookie(&http.Cookie{Name: "jwt", Value: tokenString})
|
||||||
|
UpdateSession(sessionID, 123) // Add a valid session
|
||||||
|
},
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectNewCookie: true,
|
||||||
|
expectedUserID: 123,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Expired Token with Expired Session",
|
||||||
|
setupAuth: func(r *http.Request) {
|
||||||
|
sessionID := "expired-session"
|
||||||
|
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{
|
||||||
|
"user_id": 123,
|
||||||
|
"role_id": constants.Roles.Member,
|
||||||
|
"session_id": sessionID,
|
||||||
|
"exp": time.Now().Add(-time.Hour).Unix(), // Expired 1 hour ago
|
||||||
|
})
|
||||||
|
tokenString, _ := token.SignedString([]byte(config.Auth.JWTSecret))
|
||||||
|
r.AddCookie(&http.Cookie{Name: "jwt", Value: tokenString})
|
||||||
|
// Don't add a session, simulating an expired session
|
||||||
},
|
},
|
||||||
expectedStatus: http.StatusUnauthorized,
|
expectedStatus: http.StatusUnauthorized,
|
||||||
expectedUserID: 0,
|
expectedUserID: 0,
|
||||||
@@ -86,8 +116,9 @@ func TestAuthMiddleware(t *testing.T) {
|
|||||||
name: "Invalid Signature",
|
name: "Invalid Signature",
|
||||||
setupAuth: func(r *http.Request) {
|
setupAuth: func(r *http.Request) {
|
||||||
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{
|
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{
|
||||||
"user_id": 123,
|
"user_id": 123,
|
||||||
"exp": time.Now().Add(time.Hour).Unix(),
|
"session_id": "some-session",
|
||||||
|
"exp": time.Now().Add(time.Hour).Unix(),
|
||||||
})
|
})
|
||||||
tokenString, _ := token.SignedString([]byte("wrong_secret"))
|
tokenString, _ := token.SignedString([]byte("wrong_secret"))
|
||||||
r.AddCookie(&http.Cookie{Name: "jwt", Value: tokenString})
|
r.AddCookie(&http.Cookie{Name: "jwt", Value: tokenString})
|
||||||
@@ -99,8 +130,10 @@ func TestAuthMiddleware(t *testing.T) {
|
|||||||
name: "Invalid Signing Method",
|
name: "Invalid Signing Method",
|
||||||
setupAuth: func(r *http.Request) {
|
setupAuth: func(r *http.Request) {
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodES256, jwt.MapClaims{
|
token := jwt.NewWithClaims(jwt.SigningMethodES256, jwt.MapClaims{
|
||||||
"user_id": 123,
|
"user_id": 123,
|
||||||
"exp": time.Now().Add(time.Hour).Unix(),
|
"session_id": "some-session",
|
||||||
|
"role_id": constants.Roles.Member,
|
||||||
|
"exp": time.Now().Add(time.Hour).Unix(),
|
||||||
})
|
})
|
||||||
tokenString, _ := token.SignedString([]byte(config.Auth.JWTSecret))
|
tokenString, _ := token.SignedString([]byte(config.Auth.JWTSecret))
|
||||||
r.AddCookie(&http.Cookie{Name: "jwt", Value: tokenString})
|
r.AddCookie(&http.Cookie{Name: "jwt", Value: tokenString})
|
||||||
@@ -143,9 +176,13 @@ func TestAuthMiddleware(t *testing.T) {
|
|||||||
|
|
||||||
// Check if a new cookie was set
|
// Check if a new cookie was set
|
||||||
cookies := w.Result().Cookies()
|
cookies := w.Result().Cookies()
|
||||||
assert.GreaterOrEqual(t, len(cookies), 1)
|
if tt.expectNewCookie {
|
||||||
assert.Equal(t, "jwt", cookies[0].Name)
|
assert.GreaterOrEqual(t, len(cookies), 1)
|
||||||
assert.NotEmpty(t, cookies[0].Value)
|
assert.Equal(t, "jwt", cookies[0].Name)
|
||||||
|
assert.NotEmpty(t, cookies[0].Value)
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, 0, len(cookies), "Unexpected cookie set")
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
assert.Equal(t, 0, len(w.Result().Cookies()))
|
assert.Equal(t, 0, len(w.Result().Cookies()))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ type BankAccount struct {
|
|||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
UpdatedAt time.Time
|
UpdatedAt time.Time
|
||||||
MandateDateSigned time.Time `gorm:"not null"` // json:"mandate_date_signed"`
|
MandateDateSigned time.Time `gorm:"not null"` // json:"mandate_date_signed"`
|
||||||
Bank string //`json:"bank_name" validate:"omitempty,alphanumunicode"`
|
Bank string //`json:"bank_name" validate:"omitempty,alphanumunicode,safe_content"`
|
||||||
AccountHolderName string //`json:"account_holder_name" validate:"omitempty,alphaunicode"`
|
AccountHolderName string //`json:"account_holder_name" validate:"omitempty,alphaunicode,safe_content"`
|
||||||
IBAN string `gorm:"not null" json:"iban" validate:"required,iban"`
|
IBAN string `gorm:"not null" json:"iban" validate:"required,iban"`
|
||||||
BIC string //`json:"bic" validate:"omitempty,bic"`
|
BIC string //`json:"bic" validate:"omitempty,bic"`
|
||||||
MandateReference string `gorm:"not null"` //json:"mandate_reference"`
|
MandateReference string `gorm:"not null"` //json:"mandate_reference"`
|
||||||
|
|||||||
@@ -5,10 +5,10 @@ import "time"
|
|||||||
type Consent struct {
|
type Consent struct {
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
UpdatedAt time.Time
|
UpdatedAt time.Time
|
||||||
FirstName string `gorm:"not null" json:"first_name"`
|
FirstName string `gorm:"not null" json:"first_name" validate:"safe_content"`
|
||||||
LastName string `gorm:"not null" json:"last_name"`
|
LastName string `gorm:"not null" json:"last_name" validate:"safe_content"`
|
||||||
Email string `json:"email"`
|
Email string `json:"email" validate:"email,safe_content"`
|
||||||
ConsentType string `gorm:"not null" json:"consent_type"`
|
ConsentType string `gorm:"not null" json:"consent_type" validate:"safe_content"`
|
||||||
ID int64 `gorm:"primaryKey"`
|
ID int64 `gorm:"primaryKey"`
|
||||||
UserID int64 `gorm:"not null" json:"user_id"`
|
UserID int64 `gorm:"not null" json:"user_id"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ type Membership struct {
|
|||||||
UpdatedAt time.Time
|
UpdatedAt time.Time
|
||||||
StartDate time.Time `json:"start_date"`
|
StartDate time.Time `json:"start_date"`
|
||||||
EndDate time.Time `json:"end_date"`
|
EndDate time.Time `json:"end_date"`
|
||||||
Status string `json:"status"`
|
Status string `json:"status" validate:"safe_content"`
|
||||||
SubscriptionModel SubscriptionModel `gorm:"foreignKey:SubscriptionModelID" json:"subscription_model"`
|
SubscriptionModel SubscriptionModel `gorm:"foreignKey:SubscriptionModelID" json:"subscription_model"`
|
||||||
ParentMembershipID int64 `json:"parent_member_id" validate:"omitempty,omitnil,number"`
|
ParentMembershipID int64 `json:"parent_member_id" validate:"omitempty,omitnil,number"`
|
||||||
SubscriptionModelID int64 `json:"subsription_model_id"`
|
SubscriptionModelID int64 `json:"subsription_model_id"`
|
||||||
|
|||||||
@@ -7,13 +7,13 @@ import (
|
|||||||
type SubscriptionModel struct {
|
type SubscriptionModel struct {
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
UpdatedAt time.Time
|
UpdatedAt time.Time
|
||||||
Name string `json:"name" validate:"required,subscriptionModel"`
|
Name string `gorm:"unique" json:"name" validate:"required,subscriptionModel,safe_content"`
|
||||||
Details string `json:"details" validate:"required"`
|
Details string `json:"details" validate:"required"`
|
||||||
Conditions string `json:"conditions"`
|
Conditions string `json:"conditions"`
|
||||||
RequiredMembershipField string `json:"required_membership_field" validate:"membershipField"`
|
RequiredMembershipField string `json:"required_membership_field" validate:"membershipField"`
|
||||||
ID int64 `gorm:"primaryKey"`
|
ID int64 `gorm:"primaryKey"`
|
||||||
MonthlyFee float32 `json:"monthly_fee" validate:"required,number"`
|
MonthlyFee float32 `json:"monthly_fee" validate:"required,number,gte=0"`
|
||||||
HourlyRate float32 `json:"hourly_rate" validate:"required,number"`
|
HourlyRate float32 `json:"hourly_rate" validate:"required,number,gte=0"`
|
||||||
IncludedPerYear int16 `json:"included_hours_per_year" validate:"omitempty,number"`
|
IncludedPerYear int16 `json:"included_hours_per_year" validate:"omitempty,number,gte=0"`
|
||||||
IncludedPerMonth int16 `json:"included_hours_per_month" validate:"omitempty,number"`
|
IncludedPerMonth int16 `json:"included_hours_per_month" validate:"omitempty,number,gte=0"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,17 +12,17 @@ type User struct {
|
|||||||
UpdatedAt time.Time
|
UpdatedAt time.Time
|
||||||
DateOfBirth time.Time `gorm:"not null" json:"date_of_birth" validate:"required,age"`
|
DateOfBirth time.Time `gorm:"not null" json:"date_of_birth" validate:"required,age"`
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
Company string `json:"company" validate:"omitempty,omitnil"`
|
Company string `json:"company" validate:"omitempty,omitnil,safe_content"`
|
||||||
Phone string `json:"phone" validate:"omitempty,omitnil"`
|
Phone string `json:"phone" validate:"omitempty,omitnil,safe_content"`
|
||||||
Notes *string `json:"notes"`
|
Notes *string `json:"notes,safe_content"`
|
||||||
FirstName string `gorm:"not null" json:"first_name" validate:"required"`
|
FirstName string `gorm:"not null" json:"first_name" validate:"required,safe_content"`
|
||||||
Password string `json:"password" validate:"required_unless=RoleID 0"`
|
Password string `json:"password" validate:"required_unless=RoleID 0,safe_content"`
|
||||||
Email string `gorm:"unique;not null" json:"email" validate:"required,email"`
|
Email string `gorm:"unique;not null" json:"email" validate:"required,email,safe_content"`
|
||||||
LastName string `gorm:"not null" json:"last_name" validate:"required"`
|
LastName string `gorm:"not null" json:"last_name" validate:"required,safe_content"`
|
||||||
ProfilePicture string `json:"profile_picture" validate:"omitempty,omitnil,image"`
|
ProfilePicture string `json:"profile_picture" validate:"omitempty,omitnil,image,safe_content"`
|
||||||
Address string `gorm:"not null" json:"address" validate:"required"`
|
Address string `gorm:"not null" json:"address" validate:"required,safe_content"`
|
||||||
ZipCode string `gorm:"not null" json:"zip_code" validate:"required,alphanum"`
|
ZipCode string `gorm:"not null" json:"zip_code" validate:"required,alphanum,safe_content"`
|
||||||
City string `form:"not null" json:"city" validate:"required,alphaunicode"`
|
City string `form:"not null" json:"city" validate:"required,alphaunicode,safe_content"`
|
||||||
Consents []Consent `gorm:"constraint:OnUpdate:CASCADE"`
|
Consents []Consent `gorm:"constraint:OnUpdate:CASCADE"`
|
||||||
BankAccount BankAccount `gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE;" json:"bank_account"`
|
BankAccount BankAccount `gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE;" json:"bank_account"`
|
||||||
Verification Verification `gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE;"`
|
Verification Verification `gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE;"`
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
package repositories
|
package repositories
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"time"
|
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
|
||||||
"GoMembership/internal/constants"
|
"GoMembership/internal/constants"
|
||||||
@@ -16,13 +14,13 @@ import (
|
|||||||
|
|
||||||
type UserRepositoryInterface interface {
|
type UserRepositoryInterface interface {
|
||||||
CreateUser(user *models.User) (int64, error)
|
CreateUser(user *models.User) (int64, error)
|
||||||
UpdateUser(userID int64, user *models.User) error
|
UpdateUser(user *models.User) (*models.User, error)
|
||||||
GetUsers(where map[string]interface{}) (*[]models.User, error)
|
GetUsers(where map[string]interface{}) (*[]models.User, error)
|
||||||
GetUserByID(id int64) (*models.User, error)
|
GetUserByID(userID *int64) (*models.User, error)
|
||||||
GetUserByEmail(email string) (*models.User, error)
|
GetUserByEmail(email string) (*models.User, error)
|
||||||
SetVerificationToken(user *models.User, token *string) (int64, error)
|
SetVerificationToken(verification *models.Verification) (int64, error)
|
||||||
IsVerified(userID *int64) (bool, error)
|
IsVerified(userID *int64) (bool, error)
|
||||||
VerifyUserOfToken(token *string) (*models.User, error)
|
GetVerificationOfToken(token *string) (*models.Verification, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type UserRepository struct{}
|
type UserRepository struct{}
|
||||||
@@ -35,21 +33,36 @@ func (ur *UserRepository) CreateUser(user *models.User) (int64, error) {
|
|||||||
return user.ID, nil
|
return user.ID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ur *UserRepository) UpdateUser(userID int64, user *models.User) error {
|
func (ur *UserRepository) UpdateUser(user *models.User) (*models.User, error) {
|
||||||
// logger.Info.Printf("Updating User: %#v\n", user)
|
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return errors.ErrNoData
|
return nil, errors.ErrNoData
|
||||||
}
|
|
||||||
result := database.DB.Session(&gorm.Session{FullSaveAssociations: true}).Updates(&user)
|
|
||||||
if result.Error != nil {
|
|
||||||
return result.Error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if result.RowsAffected == 0 {
|
err := database.DB.Transaction(func(tx *gorm.DB) error {
|
||||||
return errors.ErrNoRowsAffected
|
if err := tx.First(&models.User{}, user.ID).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
result := tx.Session(&gorm.Session{FullSaveAssociations: true}).Updates(user)
|
||||||
|
if result.Error != nil {
|
||||||
|
return result.Error
|
||||||
|
}
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
return errors.ErrNoRowsAffected
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
var updatedUser models.User
|
||||||
|
if err := database.DB.First(&updatedUser, user.ID).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &updatedUser, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ur *UserRepository) GetUsers(where map[string]interface{}) (*[]models.User, error) {
|
func (ur *UserRepository) GetUsers(where map[string]interface{}) (*[]models.User, error) {
|
||||||
@@ -70,7 +83,7 @@ func (ur *UserRepository) GetUsers(where map[string]interface{}) (*[]models.User
|
|||||||
return &users, nil
|
return &users, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ur *UserRepository) GetUserByID(id int64) (*models.User, error) {
|
func (ur *UserRepository) GetUserByID(userID *int64) (*models.User, error) {
|
||||||
var user models.User
|
var user models.User
|
||||||
result := database.DB.
|
result := database.DB.
|
||||||
Preload("Consents").
|
Preload("Consents").
|
||||||
@@ -78,7 +91,7 @@ func (ur *UserRepository) GetUserByID(id int64) (*models.User, error) {
|
|||||||
Preload("Verification").
|
Preload("Verification").
|
||||||
Preload("Membership", func(db *gorm.DB) *gorm.DB {
|
Preload("Membership", func(db *gorm.DB) *gorm.DB {
|
||||||
return db.Preload("SubscriptionModel")
|
return db.Preload("SubscriptionModel")
|
||||||
}).First(&user, id)
|
}).First(&user, userID)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if result.Error == gorm.ErrRecordNotFound {
|
if result.Error == gorm.ErrRecordNotFound {
|
||||||
return nil, gorm.ErrRecordNotFound
|
return nil, gorm.ErrRecordNotFound
|
||||||
@@ -112,7 +125,8 @@ func (ur *UserRepository) IsVerified(userID *int64) (bool, error) {
|
|||||||
return user.Status != constants.UnverifiedStatus, nil
|
return user.Status != constants.UnverifiedStatus, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ur *UserRepository) VerifyUserOfToken(token *string) (*models.User, error) {
|
func (ur *UserRepository) GetVerificationOfToken(token *string) (*models.Verification, error) {
|
||||||
|
|
||||||
var emailVerification models.Verification
|
var emailVerification models.Verification
|
||||||
result := database.DB.Where("verification_token = ?", token).First(&emailVerification)
|
result := database.DB.Where("verification_token = ?", token).First(&emailVerification)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
@@ -121,49 +135,10 @@ func (ur *UserRepository) VerifyUserOfToken(token *string) (*models.User, error)
|
|||||||
}
|
}
|
||||||
return nil, result.Error
|
return nil, result.Error
|
||||||
}
|
}
|
||||||
|
return &emailVerification, nil
|
||||||
// Check if the user is already verified
|
|
||||||
verified, err := ur.IsVerified(&emailVerification.UserID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
user, err := ur.GetUserByID(emailVerification.UserID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if verified {
|
|
||||||
return user, errors.ErrAlreadyVerified
|
|
||||||
}
|
|
||||||
// Update user status to active
|
|
||||||
t := time.Now()
|
|
||||||
emailVerification.EmailVerifiedAt = &t
|
|
||||||
user.Status = constants.VerifiedStatus
|
|
||||||
user.Verification = emailVerification
|
|
||||||
|
|
||||||
err = ur.UpdateUser(emailVerification.UserID, user)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return user, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ur *UserRepository) SetVerificationToken(user *models.User, token *string) (int64, error) {
|
func (ur *UserRepository) SetVerificationToken(verification *models.Verification) (int64, error) {
|
||||||
// Check if user is already verified
|
|
||||||
verified, err := ur.IsVerified(&user.ID)
|
|
||||||
if err != nil {
|
|
||||||
return -1, err
|
|
||||||
}
|
|
||||||
if verified {
|
|
||||||
return -1, errors.ErrAlreadyVerified
|
|
||||||
}
|
|
||||||
|
|
||||||
// Prepare the Verification record
|
|
||||||
verification := models.Verification{
|
|
||||||
UserID: user.ID,
|
|
||||||
VerificationToken: *token,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use GORM to insert or update the Verification record
|
// Use GORM to insert or update the Verification record
|
||||||
result := database.DB.Clauses(clause.OnConflict{
|
result := database.DB.Clauses(clause.OnConflict{
|
||||||
Columns: []clause.Column{{Name: "user_id"}},
|
Columns: []clause.Column{{Name: "user_id"}},
|
||||||
|
|||||||
@@ -15,21 +15,18 @@ func RegisterRoutes(router *gin.Engine, userController *controllers.UserControll
|
|||||||
router.POST("/users/login", userController.LoginHandler)
|
router.POST("/users/login", userController.LoginHandler)
|
||||||
router.POST("/csp-report", middlewares.CSPReportHandling)
|
router.POST("/csp-report", middlewares.CSPReportHandling)
|
||||||
|
|
||||||
// create subrouter for teh authenticated area /account
|
|
||||||
// also pthprefix matches everything below /account
|
|
||||||
// accountRouter := router.PathPrefix("/account").Subrouter()
|
|
||||||
// accountRouter.Use(middlewares.AuthMiddleware)
|
|
||||||
//create api key required router
|
|
||||||
apiRouter := router.Group("/api")
|
apiRouter := router.Group("/api")
|
||||||
|
apiRouter.Use(middlewares.APIKeyMiddleware())
|
||||||
{
|
{
|
||||||
router.POST("/v1/subscription", membershipcontroller.RegisterSubscription)
|
router.POST("/v1/subscription", membershipcontroller.RegisterSubscription)
|
||||||
}
|
}
|
||||||
apiRouter.Use(middlewares.APIKeyMiddleware())
|
|
||||||
|
|
||||||
authRouter := router.Group("/users/backend")
|
authRouter := router.Group("/backend/users")
|
||||||
authRouter.Use(middlewares.AuthMiddleware())
|
authRouter.Use(middlewares.AuthMiddleware())
|
||||||
{
|
{
|
||||||
authRouter.GET("/current-user", userController.CurrentUserHandler)
|
authRouter.GET("/current", userController.CurrentUserHandler)
|
||||||
authRouter.POST("/logout", userController.LogoutHandler)
|
authRouter.POST("/logout", userController.LogoutHandler)
|
||||||
|
authRouter.PATCH("/update", userController.UpdateHandler)
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package services
|
package services
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
|
||||||
"slices"
|
"slices"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -9,6 +8,7 @@ import (
|
|||||||
|
|
||||||
"GoMembership/internal/models"
|
"GoMembership/internal/models"
|
||||||
"GoMembership/internal/repositories"
|
"GoMembership/internal/repositories"
|
||||||
|
"GoMembership/internal/utils"
|
||||||
"GoMembership/pkg/errors"
|
"GoMembership/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -38,7 +38,7 @@ func (service *MembershipService) FindMembershipByUserID(userID int64) (*models.
|
|||||||
// Membership_Subscriptions
|
// Membership_Subscriptions
|
||||||
func (service *MembershipService) RegisterSubscription(subscription *models.SubscriptionModel) (int64, error) {
|
func (service *MembershipService) RegisterSubscription(subscription *models.SubscriptionModel) (int64, error) {
|
||||||
if err := validateSubscriptionData(subscription); err != nil {
|
if err := validateSubscriptionData(subscription); err != nil {
|
||||||
return http.StatusNotAcceptable, err
|
return -1, err
|
||||||
}
|
}
|
||||||
return service.SubscriptionRepo.CreateSubscriptionModel(subscription)
|
return service.SubscriptionRepo.CreateSubscriptionModel(subscription)
|
||||||
}
|
}
|
||||||
@@ -65,8 +65,9 @@ func (service *MembershipService) GetSubscriptions(where map[string]interface{})
|
|||||||
|
|
||||||
func validateSubscriptionData(subscription *models.SubscriptionModel) error {
|
func validateSubscriptionData(subscription *models.SubscriptionModel) error {
|
||||||
validate := validator.New()
|
validate := validator.New()
|
||||||
|
// subscriptionModel and membershipField don't have to be evaluated if adding a new subscription
|
||||||
validate.RegisterValidation("subscriptionModel", func(fl validator.FieldLevel) bool { return true })
|
validate.RegisterValidation("subscriptionModel", func(fl validator.FieldLevel) bool { return true })
|
||||||
validate.RegisterValidation("membershipField", func(fl validator.FieldLevel) bool { return true })
|
validate.RegisterValidation("membershipField", func(fl validator.FieldLevel) bool { return true })
|
||||||
|
validate.RegisterValidation("safe_content", utils.ValidateSafeContent)
|
||||||
return validate.Struct(subscription)
|
return validate.Struct(subscription)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,10 +8,12 @@ import (
|
|||||||
"GoMembership/internal/models"
|
"GoMembership/internal/models"
|
||||||
"GoMembership/internal/repositories"
|
"GoMembership/internal/repositories"
|
||||||
"GoMembership/internal/utils"
|
"GoMembership/internal/utils"
|
||||||
|
"GoMembership/pkg/errors"
|
||||||
"GoMembership/pkg/logger"
|
"GoMembership/pkg/logger"
|
||||||
|
|
||||||
"github.com/alexedwards/argon2id"
|
"github.com/alexedwards/argon2id"
|
||||||
"github.com/go-playground/validator/v10"
|
"github.com/go-playground/validator/v10"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@@ -22,50 +24,42 @@ type UserServiceInterface interface {
|
|||||||
GetUserByID(id int64) (*models.User, error)
|
GetUserByID(id int64) (*models.User, error)
|
||||||
GetUsers(where map[string]interface{}) (*[]models.User, error)
|
GetUsers(where map[string]interface{}) (*[]models.User, error)
|
||||||
VerifyUser(token *string) (*models.User, error)
|
VerifyUser(token *string) (*models.User, error)
|
||||||
|
UpdateUser(user *models.User) (*models.User, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type UserService struct {
|
type UserService struct {
|
||||||
Repo repositories.UserRepositoryInterface
|
Repo repositories.UserRepositoryInterface
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service *UserService) RegisterUser(user *models.User) (int64, string, error) {
|
func (service *UserService) UpdateUser(user *models.User) (*models.User, error) {
|
||||||
if err := validateRegistrationData(user); err != nil {
|
|
||||||
return http.StatusNotAcceptable, "", err
|
if err := validateUserData(user); err != nil {
|
||||||
|
return nil, errors.ErrInvalidUserData
|
||||||
}
|
}
|
||||||
|
|
||||||
setPassword(user.Password, user)
|
if user.Password != "" {
|
||||||
|
setPassword(user.Password, user)
|
||||||
|
}
|
||||||
|
|
||||||
user.Status = constants.UnverifiedStatus
|
|
||||||
user.CreatedAt = time.Now()
|
|
||||||
user.UpdatedAt = time.Now()
|
user.UpdatedAt = time.Now()
|
||||||
|
|
||||||
id, err := service.Repo.CreateUser(user)
|
updatedUser, err := service.Repo.UpdateUser(user)
|
||||||
|
|
||||||
if err != nil && strings.Contains(err.Error(), "UNIQUE constraint failed") {
|
|
||||||
return http.StatusConflict, "", err
|
|
||||||
} else if err != nil {
|
|
||||||
return http.StatusInternalServerError, "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
user.ID = id
|
|
||||||
|
|
||||||
token, err := utils.GenerateVerificationToken()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return http.StatusInternalServerError, "", err
|
if err == gorm.ErrRecordNotFound {
|
||||||
|
return nil, errors.ErrUserNotFound
|
||||||
|
}
|
||||||
|
if strings.Contains(err.Error(), "UNIQUE constraint failed") {
|
||||||
|
return nil, errors.ErrDuplicateEntry
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Info.Printf("TOKEN: %v", token)
|
return updatedUser, nil
|
||||||
|
|
||||||
_, err = service.Repo.SetVerificationToken(user, &token)
|
|
||||||
if err != nil {
|
|
||||||
return http.StatusInternalServerError, "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
return id, token, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service *UserService) Update(user *models.User) (int64, string, error) {
|
func (service *UserService) RegisterUser(user *models.User) (int64, string, error) {
|
||||||
if err := validateRegistrationData(user); err != nil {
|
if err := validateUserData(user); err != nil {
|
||||||
return http.StatusNotAcceptable, "", err
|
return http.StatusNotAcceptable, "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -92,17 +86,31 @@ func (service *UserService) Update(user *models.User) (int64, string, error) {
|
|||||||
|
|
||||||
logger.Info.Printf("TOKEN: %v", token)
|
logger.Info.Printf("TOKEN: %v", token)
|
||||||
|
|
||||||
_, err = service.Repo.SetVerificationToken(user, &token)
|
// Check if user is already verified
|
||||||
|
verified, err := service.Repo.IsVerified(&user.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return http.StatusInternalServerError, "", err
|
return http.StatusInternalServerError, "", err
|
||||||
}
|
}
|
||||||
|
if verified {
|
||||||
|
return http.StatusAlreadyReported, "", errors.ErrAlreadyVerified
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare the Verification record
|
||||||
|
verification := models.Verification{
|
||||||
|
UserID: user.ID,
|
||||||
|
VerificationToken: token,
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err = service.Repo.SetVerificationToken(&verification); err != nil {
|
||||||
|
return http.StatusInternalServerError, "", err
|
||||||
|
}
|
||||||
|
|
||||||
return id, token, nil
|
return id, token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service *UserService) GetUserByID(id int64) (*models.User, error) {
|
func (service *UserService) GetUserByID(id int64) (*models.User, error) {
|
||||||
|
|
||||||
return service.Repo.GetUserByID(id)
|
return service.Repo.GetUserByID(&id)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service *UserService) GetUserByEmail(email string) (*models.User, error) {
|
func (service *UserService) GetUserByEmail(email string) (*models.User, error) {
|
||||||
@@ -114,19 +122,41 @@ func (service *UserService) GetUsers(where map[string]interface{}) (*[]models.Us
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (service *UserService) VerifyUser(token *string) (*models.User, error) {
|
func (service *UserService) VerifyUser(token *string) (*models.User, error) {
|
||||||
user, err := service.Repo.VerifyUserOfToken(token)
|
verification, err := service.Repo.GetVerificationOfToken(token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
// Check if the user is already verified
|
||||||
|
verified, err := service.Repo.IsVerified(&verification.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
user, err := service.Repo.GetUserByID(&verification.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if verified {
|
||||||
|
return user, errors.ErrAlreadyVerified
|
||||||
|
}
|
||||||
|
// Update user status to active
|
||||||
|
t := time.Now()
|
||||||
|
verification.EmailVerifiedAt = &t
|
||||||
|
|
||||||
|
user.Status = constants.VerifiedStatus
|
||||||
|
user.Verification = *verification
|
||||||
|
user.ID = verification.UserID
|
||||||
|
service.Repo.UpdateUser(user)
|
||||||
|
|
||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateRegistrationData(user *models.User) error {
|
func validateUserData(user *models.User) error {
|
||||||
validate := validator.New()
|
validate := validator.New()
|
||||||
validate.RegisterValidation("age", utils.AgeValidator)
|
validate.RegisterValidation("age", utils.AgeValidator)
|
||||||
validate.RegisterValidation("bic", utils.BICValidator)
|
validate.RegisterValidation("bic", utils.BICValidator)
|
||||||
validate.RegisterValidation("iban", utils.IBANValidator)
|
validate.RegisterValidation("iban", utils.IBANValidator)
|
||||||
validate.RegisterValidation("subscriptionModel", utils.SubscriptionModelValidator)
|
validate.RegisterValidation("subscriptionModel", utils.SubscriptionModelValidator)
|
||||||
|
validate.RegisterValidation("safe_content", utils.ValidateSafeContent)
|
||||||
validate.RegisterValidation("membershipField", utils.ValidateRequiredMembershipField)
|
validate.RegisterValidation("membershipField", utils.ValidateRequiredMembershipField)
|
||||||
|
|
||||||
return validate.Struct(user)
|
return validate.Struct(user)
|
||||||
|
|||||||
20
internal/utils/cookies.go
Normal file
20
internal/utils/cookies.go
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func SetCookie(c *gin.Context, token string) {
|
||||||
|
c.SetSameSite(http.SameSiteLaxMode)
|
||||||
|
c.SetCookie(
|
||||||
|
"jwt",
|
||||||
|
token,
|
||||||
|
5*24*60*60, // 5 days
|
||||||
|
"/",
|
||||||
|
"",
|
||||||
|
true,
|
||||||
|
true,
|
||||||
|
)
|
||||||
|
}
|
||||||
@@ -7,7 +7,9 @@ import (
|
|||||||
"GoMembership/internal/models"
|
"GoMembership/internal/models"
|
||||||
"GoMembership/pkg/logger"
|
"GoMembership/pkg/logger"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"regexp"
|
||||||
"slices"
|
"slices"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-playground/validator/v10"
|
"github.com/go-playground/validator/v10"
|
||||||
@@ -15,19 +17,24 @@ import (
|
|||||||
"github.com/jbub/banking/swift"
|
"github.com/jbub/banking/swift"
|
||||||
)
|
)
|
||||||
|
|
||||||
//
|
var xssPatterns = []*regexp.Regexp{
|
||||||
// func IsEmailValid(email string) bool {
|
regexp.MustCompile(`(?i)<script`),
|
||||||
// regex := `^[a-z0-9._%+\-]+@[a-z0-9.\-]+\.[a-z]{2,}$`
|
regexp.MustCompile(`(?i)javascript:`),
|
||||||
// re := regexp.MustCompile(regex)
|
regexp.MustCompile(`(?i)on\w+\s*=`),
|
||||||
// return re.MatchString(email)
|
regexp.MustCompile(`(?i)(vbscript|data):`),
|
||||||
// }
|
regexp.MustCompile(`(?i)<(iframe|object|embed|applet)`),
|
||||||
|
regexp.MustCompile(`(?i)expression\s*\(`),
|
||||||
|
regexp.MustCompile(`(?i)url\s*\(`),
|
||||||
|
regexp.MustCompile(`(?i)<\?`),
|
||||||
|
regexp.MustCompile(`(?i)<%`),
|
||||||
|
regexp.MustCompile(`(?i)<!\[CDATA\[`),
|
||||||
|
regexp.MustCompile(`(?i)<(svg|animate)`),
|
||||||
|
regexp.MustCompile(`(?i)<(audio|video|source)`),
|
||||||
|
regexp.MustCompile(`(?i)base64`),
|
||||||
|
}
|
||||||
|
|
||||||
func AgeValidator(fl validator.FieldLevel) bool {
|
func AgeValidator(fl validator.FieldLevel) bool {
|
||||||
fieldValue := fl.Field()
|
fieldValue := fl.Field()
|
||||||
// Ensure the field is of type time.Time
|
|
||||||
// if fieldValue.Kind() != reflect.Struct || !fieldValue.Type().ConvertibleTo(reflect.TypeOf(time.Time{})) {
|
|
||||||
// return false
|
|
||||||
// }
|
|
||||||
dateOfBirth := fieldValue.Interface().(time.Time)
|
dateOfBirth := fieldValue.Interface().(time.Time)
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
age := now.Year() - dateOfBirth.Year()
|
age := now.Year() - dateOfBirth.Year()
|
||||||
@@ -113,3 +120,13 @@ func BICValidator(fl validator.FieldLevel) bool {
|
|||||||
|
|
||||||
return swift.Validate(fieldValue) == nil
|
return swift.Validate(fieldValue) == nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ValidateSafeContent(fl validator.FieldLevel) bool {
|
||||||
|
input := strings.ToLower(fl.Field().String())
|
||||||
|
for _, pattern := range xssPatterns {
|
||||||
|
if pattern.MatchString(input) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|||||||
@@ -15,4 +15,7 @@ var (
|
|||||||
ErrValueTooLong = errors.New("cookie value too long")
|
ErrValueTooLong = errors.New("cookie value too long")
|
||||||
ErrInvalidValue = errors.New("invalid cookie value")
|
ErrInvalidValue = errors.New("invalid cookie value")
|
||||||
ErrInvalidSigningAlgorithm = errors.New("invalid signing algorithm")
|
ErrInvalidSigningAlgorithm = errors.New("invalid signing algorithm")
|
||||||
|
ErrValidToken = errors.New("valid token")
|
||||||
|
ErrInvalidUserData = errors.New("invalid user data")
|
||||||
|
ErrDuplicateEntry = errors.New("duplicate entry; unique constraint failed")
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user