Compare commits

...

21 Commits

Author SHA1 Message Date
Alex
073d353764 tests 2025-03-11 20:52:54 +01:00
Alex
9d2b33f832 adapted new user model. 2025-03-11 20:52:39 +01:00
Alex
e60aaa1d69 refactored auth.go & tests 2025-03-11 20:52:11 +01:00
Alex
ca99e28433 removed logging in csp 2025-03-11 20:51:54 +01:00
Alex
9427492cb1 removed logging in headers 2025-03-11 20:51:44 +01:00
Alex
60d3f075bf naming 2025-03-11 20:51:28 +01:00
Alex
d473aef3a9 refactored user service for new model repo style 2025-03-11 20:51:05 +01:00
Alex
ef4d3c9576 add userid to drivers licence model 2025-03-11 20:50:34 +01:00
Alex
9a8b386931 cleaned verification model 2025-03-11 20:50:19 +01:00
Alex
9c429185dc add route for new mail verification 2025-03-11 20:47:57 +01:00
Alex
c7865d0582 add user id to mail verification url 2025-03-11 20:47:31 +01:00
Alex
feb8abcc42 refactored validation 2025-03-11 20:46:45 +01:00
Alex
c8d0904fd7 del obsolete repos and services 2025-03-11 20:46:24 +01:00
Alex
294ad76e4b first step to remove global database.db 2025-03-11 20:45:49 +01:00
Alex
ca441d51e7 moved repo to user model 2025-03-11 20:44:29 +01:00
Alex
39c060794a del obsolete handleVerifyUserError 2025-03-11 20:43:42 +01:00
Alex
c6ea179eca moved field validation to validation package 2025-03-11 20:42:45 +01:00
Alex
0d6013d566 add new errors 2025-03-11 20:42:05 +01:00
Alex
0c3204df15 fix: moved to licenceServiceInterface 2025-03-11 20:39:45 +01:00
Alex
cfc10ab087 moved generateRandomString to local config class 2025-03-11 20:39:12 +01:00
Alex
df6125b7cb locale 2025-03-11 20:30:58 +01:00
40 changed files with 855 additions and 693 deletions

View File

@@ -75,6 +75,7 @@ export default {
user_not_found_or_wrong_password: 'Existiert nicht oder falsches Passwort', user_not_found_or_wrong_password: 'Existiert nicht oder falsches Passwort',
email_already_registered: 'Ein Mitglied wurde schon mit dieser Emailadresse erstellt.', email_already_registered: 'Ein Mitglied wurde schon mit dieser Emailadresse erstellt.',
password_already_changed: 'Das Passwort wurde schon geändert.', password_already_changed: 'Das Passwort wurde schon geändert.',
user_already_verified: 'Ihre Email Adresse wurde schon bestätigt.',
insecure: 'Unsicheres Passwort, versuchen Sie {message}', insecure: 'Unsicheres Passwort, versuchen Sie {message}',
longer: 'oder verwenden Sie ein längeres Passwort', longer: 'oder verwenden Sie ein längeres Passwort',
special: 'mehr Sonderzeichen einzufügen', special: 'mehr Sonderzeichen einzufügen',
@@ -204,6 +205,7 @@ export default {
payments: 'Zahlungen', payments: 'Zahlungen',
add_new: 'Neu', add_new: 'Neu',
email_sent: 'Email wurde gesendet..', email_sent: 'Email wurde gesendet..',
verification: 'Verifikation',
// For payments section // For payments section
payment: { payment: {
id: 'Zahlungs-Nr', id: 'Zahlungs-Nr',

View File

@@ -18,18 +18,18 @@ func main() {
config.LoadConfig() config.LoadConfig()
err := database.Open(config.DB.Path, config.Recipients.AdminEmail) db, err := database.Open(config.DB.Path, config.Recipients.AdminEmail)
if err != nil { if err != nil {
logger.Error.Fatalf("Couldn't init database: %v", err) logger.Error.Fatalf("Couldn't init database: %v", err)
} }
defer func() { defer func() {
if err := database.Close(); err != nil { if err := database.Close(db); err != nil {
logger.Error.Fatalf("Failed to close database: %v", err) logger.Error.Fatalf("Failed to close database: %v", err)
} }
}() }()
go server.Run() go server.Run(db)
gracefulShutdown() gracefulShutdown()
} }

View File

@@ -8,6 +8,8 @@
package config package config
import ( import (
"crypto/rand"
"encoding/base64"
"encoding/json" "encoding/json"
"os" "os"
"path/filepath" "path/filepath"
@@ -15,7 +17,6 @@ import (
"github.com/kelseyhightower/envconfig" "github.com/kelseyhightower/envconfig"
"GoMembership/internal/utils"
"GoMembership/pkg/logger" "GoMembership/pkg/logger"
) )
@@ -99,12 +100,12 @@ func LoadConfig() {
readFile(&CFG) readFile(&CFG)
readEnv(&CFG) readEnv(&CFG)
logger.Info.Printf("Config file environment: %v", CFGPath) logger.Info.Printf("Config file environment: %v", CFGPath)
csrfSecret, err := utils.GenerateRandomString(32) csrfSecret, err := generateRandomString(32)
if err != nil { if err != nil {
logger.Error.Fatalf("could not generate CSRF secret: %v", err) logger.Error.Fatalf("could not generate CSRF secret: %v", err)
} }
jwtSecret, err := utils.GenerateRandomString(32) jwtSecret, err := generateRandomString(32)
if err != nil { if err != nil {
logger.Error.Fatalf("could not generate JWT secret: %v", err) logger.Error.Fatalf("could not generate JWT secret: %v", err)
} }
@@ -160,3 +161,12 @@ func readEnv(cfg *Config) {
logger.Error.Fatalf("could not decode env variables: %#v", err) logger.Error.Fatalf("could not decode env variables: %#v", err)
} }
} }
func generateRandomString(length int) (string, error) {
bytes := make([]byte, length)
_, err := rand.Read(bytes)
if err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(bytes), nil
}

View File

@@ -48,6 +48,8 @@ var (
Uc *UserController Uc *UserController
Mc *MembershipController Mc *MembershipController
Cc *ContactController Cc *ContactController
AdminCookie *http.Cookie
MemberCookie *http.Cookie
) )
func TestMain(t *testing.T) { func TestMain(t *testing.T) {
@@ -85,7 +87,8 @@ func TestMain(t *testing.T) {
log.Fatalf("Error setting environment variable: %v", err) log.Fatalf("Error setting environment variable: %v", err)
} }
config.LoadConfig() config.LoadConfig()
if err := database.Open("test.db", config.Recipients.AdminEmail); err != nil { db, err := database.Open("test.db", config.Recipients.AdminEmail)
if err != nil {
log.Fatalf("Failed to create DB: %#v", err) log.Fatalf("Failed to create DB: %#v", err)
} }
utils.SMTPStart(Host, Port) utils.SMTPStart(Host, Port)
@@ -101,13 +104,12 @@ func TestMain(t *testing.T) {
membershipService := &services.MembershipService{Repo: membershipRepo, SubscriptionRepo: subscriptionRepo} membershipService := &services.MembershipService{Repo: membershipRepo, SubscriptionRepo: subscriptionRepo}
var licenceRepo repositories.LicenceInterface = &repositories.LicenceRepository{} var licenceRepo repositories.LicenceInterface = &repositories.LicenceRepository{}
var userRepo repositories.UserRepositoryInterface = &repositories.UserRepository{} userService := &services.UserService{DB: db, Licences: licenceRepo}
userService := &services.UserService{Repo: userRepo, Licences: licenceRepo}
licenceService := &services.LicenceService{Repo: licenceRepo} licenceService := &services.LicenceService{Repo: licenceRepo}
Uc = &UserController{Service: userService, LicenceService: licenceService, EmailService: emailService, ConsentService: consentService, BankAccountService: bankAccountService, MembershipService: membershipService} Uc = &UserController{Service: userService, LicenceService: licenceService, EmailService: emailService, ConsentService: consentService, BankAccountService: bankAccountService, MembershipService: membershipService}
Mc = &MembershipController{UserController: &MockUserController{}, Service: *membershipService} Mc = &MembershipController{UserService: userService, Service: membershipService}
Cc = &ContactController{EmailService: emailService} Cc = &ContactController{EmailService: emailService}
if err := initSubscriptionPlans(); err != nil { if err := initSubscriptionPlans(); err != nil {
@@ -132,7 +134,7 @@ func TestMain(t *testing.T) {
} }
admin.SetPassword("securepassword") admin.SetPassword("securepassword")
database.DB.Create(&admin) database.DB.Create(&admin)
validation.SetupValidators() validation.SetupValidators(db)
t.Run("userController", func(t *testing.T) { t.Run("userController", func(t *testing.T) {
testUserController(t) testUserController(t)
}) })

View File

@@ -10,7 +10,7 @@ import (
) )
type LicenceController struct { type LicenceController struct {
Service services.LicenceService Service services.LicenceServiceInterface
} }
func (lc *LicenceController) GetAllCategories(c *gin.Context) { func (lc *LicenceController) GetAllCategories(c *gin.Context) {

View File

@@ -48,7 +48,7 @@ func TestGetAllCategories_Success(t *testing.T) {
service := &services.LicenceService{Repo: mockRepo} service := &services.LicenceService{Repo: mockRepo}
// Create controller with service // Create controller with service
lc := &LicenceController{Service: *service} lc := &LicenceController{Service: service}
// Setup router and request // Setup router and request
router := gin.Default() router := gin.Default()
@@ -76,7 +76,7 @@ func TestGetAllCategories_Error(t *testing.T) {
service := &services.LicenceService{Repo: mockRepo} service := &services.LicenceService{Repo: mockRepo}
// Create controller with service // Create controller with service
lc := &LicenceController{Service: *service} lc := &LicenceController{Service: service}
// Setup router and request // Setup router and request
router := gin.Default() router := gin.Default()

View File

@@ -16,10 +16,8 @@ import (
) )
type MembershipController struct { type MembershipController struct {
Service services.MembershipService Service services.MembershipServiceInterface
UserController interface { UserService services.UserServiceInterface
ExtractUserFromContext(*gin.Context) (*models.User, error)
}
} }
type MembershipData struct { type MembershipData struct {
@@ -30,14 +28,14 @@ type MembershipData struct {
func (mc *MembershipController) RegisterSubscription(c *gin.Context) { func (mc *MembershipController) RegisterSubscription(c *gin.Context) {
var regData MembershipData var regData MembershipData
requestUser, err := mc.UserController.ExtractUserFromContext(c) requestUser, err := mc.UserService.FromContext(c)
if err != nil { if err != nil {
utils.RespondWithError(c, err, "Error extracting user from context in subscription registrationHandler", http.StatusBadRequest, errors.Responses.Fields.User, errors.Responses.Keys.NoAuthToken) utils.RespondWithError(c, err, "Error extracting user from context in subscription registrationHandler", http.StatusBadRequest, errors.Responses.Fields.User, errors.Responses.Keys.NoAuthToken)
return return
} }
if !utils.HasPrivilige(requestUser, constants.Priviliges.Create) { if !requestUser.HasPrivilege(constants.Priviliges.Create) {
utils.RespondWithError(c, errors.ErrNotAuthorized, "Not allowed to register subscription", http.StatusForbidden, errors.Responses.Fields.User, errors.Responses.Keys.Unauthorized) utils.RespondWithError(c, errors.ErrNotAuthorized, "Not allowed to register subscription", http.StatusUnauthorized, errors.Responses.Fields.User, errors.Responses.Keys.Unauthorized)
return return
} }
@@ -66,14 +64,14 @@ func (mc *MembershipController) RegisterSubscription(c *gin.Context) {
func (mc *MembershipController) UpdateHandler(c *gin.Context) { func (mc *MembershipController) UpdateHandler(c *gin.Context) {
var regData MembershipData var regData MembershipData
requestUser, err := mc.UserController.ExtractUserFromContext(c) requestUser, err := mc.UserService.FromContext(c)
if err != nil { if err != nil {
utils.RespondWithError(c, err, "Error extracting user from context in subscription Updatehandler", http.StatusBadRequest, errors.Responses.Fields.User, errors.Responses.Keys.NoAuthToken) utils.RespondWithError(c, err, "Error extracting user from context in subscription Updatehandler", http.StatusBadRequest, errors.Responses.Fields.User, errors.Responses.Keys.NoAuthToken)
return return
} }
if !utils.HasPrivilige(requestUser, constants.Priviliges.Update) { if !requestUser.HasPrivilege(constants.Priviliges.Update) {
utils.RespondWithError(c, errors.ErrNotAuthorized, "Not allowed to update subscription", http.StatusForbidden, errors.Responses.Fields.User, errors.Responses.Keys.Unauthorized) utils.RespondWithError(c, errors.ErrNotAuthorized, "Not allowed to update subscription", http.StatusUnauthorized, errors.Responses.Fields.User, errors.Responses.Keys.Unauthorized)
return return
} }
@@ -105,14 +103,14 @@ func (mc *MembershipController) DeleteSubscription(c *gin.Context) {
} }
var data deleteData var data deleteData
requestUser, err := mc.UserController.ExtractUserFromContext(c) requestUser, err := mc.UserService.FromContext(c)
if err != nil { if err != nil {
utils.RespondWithError(c, err, "Error extracting user from context in subscription deleteSubscription", http.StatusBadRequest, errors.Responses.Fields.User, errors.Responses.Keys.NoAuthToken) utils.RespondWithError(c, err, "Error extracting user from context in subscription deleteSubscription", http.StatusBadRequest, errors.Responses.Fields.User, errors.Responses.Keys.NoAuthToken)
return return
} }
if !utils.HasPrivilige(requestUser, constants.Priviliges.Delete) { if !requestUser.HasPrivilege(constants.Priviliges.Delete) {
utils.RespondWithError(c, errors.ErrNotAuthorized, "Not allowed to update subscription", http.StatusForbidden, errors.Responses.Fields.User, errors.Responses.Keys.Unauthorized) utils.RespondWithError(c, errors.ErrNotAuthorized, "Not allowed to update subscription", http.StatusUnauthorized, errors.Responses.Fields.User, errors.Responses.Keys.Unauthorized)
return return
} }

View File

@@ -6,7 +6,6 @@ import (
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"GoMembership/internal/constants"
"GoMembership/internal/database" "GoMembership/internal/database"
"GoMembership/internal/models" "GoMembership/internal/models"
"GoMembership/pkg/logger" "GoMembership/pkg/logger"
@@ -15,6 +14,7 @@ import (
) )
type RegisterSubscriptionTest struct { type RegisterSubscriptionTest struct {
SetupCookie func(r *http.Request)
WantDBData map[string]interface{} WantDBData map[string]interface{}
Input string Input string
Name string Name string
@@ -23,6 +23,7 @@ type RegisterSubscriptionTest struct {
} }
type UpdateSubscriptionTest struct { type UpdateSubscriptionTest struct {
SetupCookie func(r *http.Request)
WantDBData map[string]interface{} WantDBData map[string]interface{}
Input string Input string
Name string Name string
@@ -31,6 +32,7 @@ type UpdateSubscriptionTest struct {
} }
type DeleteSubscriptionTest struct { type DeleteSubscriptionTest struct {
SetupCookie func(r *http.Request)
WantDBData map[string]interface{} WantDBData map[string]interface{}
Input string Input string
Name string Name string
@@ -38,29 +40,8 @@ type DeleteSubscriptionTest struct {
Assert bool Assert bool
} }
type MockUserController struct {
UserController // Embed the UserController
}
func (m *MockUserController) ExtractUserFromContext(c *gin.Context) (*models.User, error) {
return &models.User{
ID: 1,
FirstName: "Admin",
LastName: "User",
Email: "admin@test.com",
RoleID: constants.Roles.Admin,
}, nil
}
func setupMockAuth() {
// Create and assign the mock controller
mockController := &MockUserController{}
Mc.UserController = mockController
}
func testMembershipController(t *testing.T) { func testMembershipController(t *testing.T) {
setupMockAuth()
tests := getSubscriptionRegistrationData() tests := getSubscriptionRegistrationData()
for _, tt := range tests { for _, tt := range tests {
logger.Error.Print("==============================================================") logger.Error.Print("==============================================================")
@@ -101,6 +82,7 @@ func (rt *RegisterSubscriptionTest) SetupContext() (*gin.Context, *httptest.Resp
} }
func (rt *RegisterSubscriptionTest) RunHandler(c *gin.Context, router *gin.Engine) { func (rt *RegisterSubscriptionTest) RunHandler(c *gin.Context, router *gin.Engine) {
rt.SetupCookie(c.Request)
Mc.RegisterSubscription(c) Mc.RegisterSubscription(c)
} }
@@ -131,6 +113,7 @@ func (ut *UpdateSubscriptionTest) SetupContext() (*gin.Context, *httptest.Respon
} }
func (ut *UpdateSubscriptionTest) RunHandler(c *gin.Context, router *gin.Engine) { func (ut *UpdateSubscriptionTest) RunHandler(c *gin.Context, router *gin.Engine) {
ut.SetupCookie(c.Request)
Mc.UpdateHandler(c) Mc.UpdateHandler(c)
} }
@@ -150,6 +133,7 @@ func (dt *DeleteSubscriptionTest) SetupContext() (*gin.Context, *httptest.Respon
} }
func (dt *DeleteSubscriptionTest) RunHandler(c *gin.Context, router *gin.Engine) { func (dt *DeleteSubscriptionTest) RunHandler(c *gin.Context, router *gin.Engine) {
dt.SetupCookie(c.Request)
Mc.DeleteSubscription(c) Mc.DeleteSubscription(c)
} }
@@ -183,6 +167,9 @@ func customizeSubscription(customize func(MembershipData) MembershipData) Member
func getSubscriptionRegistrationData() []RegisterSubscriptionTest { func getSubscriptionRegistrationData() []RegisterSubscriptionTest {
return []RegisterSubscriptionTest{ return []RegisterSubscriptionTest{
{ {
SetupCookie: func(req *http.Request) {
req.AddCookie(AdminCookie)
},
Name: "Missing details should fail", Name: "Missing details should fail",
WantResponse: http.StatusBadRequest, WantResponse: http.StatusBadRequest,
WantDBData: map[string]interface{}{"name": "Just a Subscription"}, WantDBData: map[string]interface{}{"name": "Just a Subscription"},
@@ -194,6 +181,9 @@ func getSubscriptionRegistrationData() []RegisterSubscriptionTest {
})), })),
}, },
{ {
SetupCookie: func(req *http.Request) {
req.AddCookie(AdminCookie)
},
Name: "Missing model name should fail", Name: "Missing model name should fail",
WantResponse: http.StatusBadRequest, WantResponse: http.StatusBadRequest,
WantDBData: map[string]interface{}{"name": ""}, WantDBData: map[string]interface{}{"name": ""},
@@ -205,6 +195,9 @@ func getSubscriptionRegistrationData() []RegisterSubscriptionTest {
})), })),
}, },
{ {
SetupCookie: func(req *http.Request) {
req.AddCookie(AdminCookie)
},
Name: "Negative monthly fee should fail", Name: "Negative monthly fee should fail",
WantResponse: http.StatusBadRequest, WantResponse: http.StatusBadRequest,
WantDBData: map[string]interface{}{"name": "Premium"}, WantDBData: map[string]interface{}{"name": "Premium"},
@@ -215,6 +208,9 @@ func getSubscriptionRegistrationData() []RegisterSubscriptionTest {
})), })),
}, },
{ {
SetupCookie: func(req *http.Request) {
req.AddCookie(AdminCookie)
},
Name: "Negative hourly rate should fail", Name: "Negative hourly rate should fail",
WantResponse: http.StatusBadRequest, WantResponse: http.StatusBadRequest,
WantDBData: map[string]interface{}{"name": "Premium"}, WantDBData: map[string]interface{}{"name": "Premium"},
@@ -225,6 +221,25 @@ func getSubscriptionRegistrationData() []RegisterSubscriptionTest {
})), })),
}, },
{ {
SetupCookie: func(req *http.Request) {
req.AddCookie(MemberCookie)
},
Name: "correct entry but not authorized",
WantResponse: http.StatusUnauthorized,
WantDBData: map[string]interface{}{"name": "Premium"},
Assert: false,
Input: GenerateInputJSON(
customizeSubscription(func(subscription MembershipData) MembershipData {
subscription.Subscription.Conditions = "Some Condition"
subscription.Subscription.IncludedPerYear = 0
subscription.Subscription.IncludedPerMonth = 1
return subscription
})),
},
{
SetupCookie: func(req *http.Request) {
req.AddCookie(AdminCookie)
},
Name: "correct entry should pass", Name: "correct entry should pass",
WantResponse: http.StatusCreated, WantResponse: http.StatusCreated,
WantDBData: map[string]interface{}{"name": "Premium"}, WantDBData: map[string]interface{}{"name": "Premium"},
@@ -238,6 +253,9 @@ func getSubscriptionRegistrationData() []RegisterSubscriptionTest {
})), })),
}, },
{ {
SetupCookie: func(req *http.Request) {
req.AddCookie(AdminCookie)
},
Name: "Duplicate subscription name should fail", Name: "Duplicate subscription name should fail",
WantResponse: http.StatusConflict, WantResponse: http.StatusConflict,
WantDBData: map[string]interface{}{"name": "Premium"}, WantDBData: map[string]interface{}{"name": "Premium"},
@@ -250,6 +268,9 @@ func getSubscriptionRegistrationData() []RegisterSubscriptionTest {
func getSubscriptionUpdateData() []UpdateSubscriptionTest { func getSubscriptionUpdateData() []UpdateSubscriptionTest {
return []UpdateSubscriptionTest{ return []UpdateSubscriptionTest{
{ {
SetupCookie: func(req *http.Request) {
req.AddCookie(AdminCookie)
},
Name: "Modified Monthly Fee, should fail", Name: "Modified Monthly Fee, should fail",
WantResponse: http.StatusBadRequest, WantResponse: http.StatusBadRequest,
WantDBData: map[string]interface{}{"name": "Premium", "monthly_fee": "12"}, WantDBData: map[string]interface{}{"name": "Premium", "monthly_fee": "12"},
@@ -261,6 +282,9 @@ func getSubscriptionUpdateData() []UpdateSubscriptionTest {
})), })),
}, },
{ {
SetupCookie: func(req *http.Request) {
req.AddCookie(AdminCookie)
},
Name: "Missing ID, should fail", Name: "Missing ID, should fail",
WantResponse: http.StatusBadRequest, WantResponse: http.StatusBadRequest,
WantDBData: map[string]interface{}{"name": "Premium"}, WantDBData: map[string]interface{}{"name": "Premium"},
@@ -272,6 +296,9 @@ func getSubscriptionUpdateData() []UpdateSubscriptionTest {
})), })),
}, },
{ {
SetupCookie: func(req *http.Request) {
req.AddCookie(AdminCookie)
},
Name: "Modified Hourly Rate, should fail", Name: "Modified Hourly Rate, should fail",
WantResponse: http.StatusBadRequest, WantResponse: http.StatusBadRequest,
WantDBData: map[string]interface{}{"name": "Premium", "hourly_rate": "14"}, WantDBData: map[string]interface{}{"name": "Premium", "hourly_rate": "14"},
@@ -283,6 +310,9 @@ func getSubscriptionUpdateData() []UpdateSubscriptionTest {
})), })),
}, },
{ {
SetupCookie: func(req *http.Request) {
req.AddCookie(AdminCookie)
},
Name: "IncludedPerYear changed, should fail", Name: "IncludedPerYear changed, should fail",
WantResponse: http.StatusBadRequest, WantResponse: http.StatusBadRequest,
WantDBData: map[string]interface{}{"name": "Premium", "included_per_year": "0"}, WantDBData: map[string]interface{}{"name": "Premium", "included_per_year": "0"},
@@ -294,6 +324,9 @@ func getSubscriptionUpdateData() []UpdateSubscriptionTest {
})), })),
}, },
{ {
SetupCookie: func(req *http.Request) {
req.AddCookie(AdminCookie)
},
Name: "IncludedPerMonth changed, should fail", Name: "IncludedPerMonth changed, should fail",
WantResponse: http.StatusBadRequest, WantResponse: http.StatusBadRequest,
WantDBData: map[string]interface{}{"name": "Premium", "included_per_month": "1"}, WantDBData: map[string]interface{}{"name": "Premium", "included_per_month": "1"},
@@ -305,6 +338,9 @@ func getSubscriptionUpdateData() []UpdateSubscriptionTest {
})), })),
}, },
{ {
SetupCookie: func(req *http.Request) {
req.AddCookie(AdminCookie)
},
Name: "Update non-existent subscription should fail", Name: "Update non-existent subscription should fail",
WantResponse: http.StatusNotFound, WantResponse: http.StatusNotFound,
WantDBData: map[string]interface{}{"name": "NonExistentSubscription"}, WantDBData: map[string]interface{}{"name": "NonExistentSubscription"},
@@ -316,6 +352,26 @@ func getSubscriptionUpdateData() []UpdateSubscriptionTest {
})), })),
}, },
{ {
SetupCookie: func(req *http.Request) {
req.AddCookie(MemberCookie)
},
Name: "Correct Update but unauthorized",
WantResponse: http.StatusUnauthorized,
WantDBData: map[string]interface{}{"name": "Premium", "details": "Altered Details"},
Assert: false,
Input: GenerateInputJSON(
customizeSubscription(func(subscription MembershipData) MembershipData {
subscription.Subscription.Details = "Altered Details"
subscription.Subscription.Conditions = "Some Condition"
subscription.Subscription.IncludedPerYear = 0
subscription.Subscription.IncludedPerMonth = 1
return subscription
})),
},
{
SetupCookie: func(req *http.Request) {
req.AddCookie(AdminCookie)
},
Name: "Correct Update should pass", Name: "Correct Update should pass",
WantResponse: http.StatusAccepted, WantResponse: http.StatusAccepted,
WantDBData: map[string]interface{}{"name": "Premium", "details": "Altered Details"}, WantDBData: map[string]interface{}{"name": "Premium", "details": "Altered Details"},
@@ -338,10 +394,11 @@ func getSubscriptionDeleteData() []DeleteSubscriptionTest {
database.DB.Where("name = ?", "Premium").First(&premiumSub) database.DB.Where("name = ?", "Premium").First(&premiumSub)
database.DB.Where("name = ?", "Basic").First(&basicSub) database.DB.Where("name = ?", "Basic").First(&basicSub)
logger.Error.Printf("premiumSub.ID: %v", premiumSub.ID)
logger.Error.Printf("basicSub.ID: %v", basicSub.ID)
return []DeleteSubscriptionTest{ return []DeleteSubscriptionTest{
{ {
SetupCookie: func(req *http.Request) {
req.AddCookie(AdminCookie)
},
Name: "Delete non-existent subscription should fail", Name: "Delete non-existent subscription should fail",
WantResponse: http.StatusNotFound, WantResponse: http.StatusNotFound,
WantDBData: map[string]interface{}{"name": "NonExistentSubscription"}, WantDBData: map[string]interface{}{"name": "NonExistentSubscription"},
@@ -354,6 +411,9 @@ func getSubscriptionDeleteData() []DeleteSubscriptionTest {
})), })),
}, },
{ {
SetupCookie: func(req *http.Request) {
req.AddCookie(AdminCookie)
},
Name: "Delete subscription without name should fail", Name: "Delete subscription without name should fail",
WantResponse: http.StatusExpectationFailed, WantResponse: http.StatusExpectationFailed,
WantDBData: map[string]interface{}{"name": ""}, WantDBData: map[string]interface{}{"name": ""},
@@ -366,6 +426,9 @@ func getSubscriptionDeleteData() []DeleteSubscriptionTest {
})), })),
}, },
{ {
SetupCookie: func(req *http.Request) {
req.AddCookie(AdminCookie)
},
Name: "Delete subscription with users should fail", Name: "Delete subscription with users should fail",
WantResponse: http.StatusExpectationFailed, WantResponse: http.StatusExpectationFailed,
WantDBData: map[string]interface{}{"name": "Basic"}, WantDBData: map[string]interface{}{"name": "Basic"},
@@ -378,6 +441,24 @@ func getSubscriptionDeleteData() []DeleteSubscriptionTest {
})), })),
}, },
{ {
SetupCookie: func(req *http.Request) {
req.AddCookie(MemberCookie)
},
Name: "Delete valid subscription should succeed",
WantResponse: http.StatusUnauthorized,
WantDBData: map[string]interface{}{"name": "Premium"},
Assert: true,
Input: GenerateInputJSON(
customizeSubscription(func(subscription MembershipData) MembershipData {
subscription.Subscription.Name = "Premium"
subscription.Subscription.ID = premiumSub.ID
return subscription
})),
},
{
SetupCookie: func(req *http.Request) {
req.AddCookie(AdminCookie)
},
Name: "Delete valid subscription should succeed", Name: "Delete valid subscription should succeed",
WantResponse: http.StatusOK, WantResponse: http.StatusOK,
WantDBData: map[string]interface{}{"name": "Premium"}, WantDBData: map[string]interface{}{"name": "Premium"},

View File

@@ -4,7 +4,6 @@ import (
"GoMembership/internal/constants" "GoMembership/internal/constants"
"GoMembership/internal/utils" "GoMembership/internal/utils"
"GoMembership/pkg/errors" "GoMembership/pkg/errors"
"fmt"
"net/http" "net/http"
"strconv" "strconv"
@@ -15,16 +14,15 @@ import (
func (uc *UserController) CreatePasswordHandler(c *gin.Context) { func (uc *UserController) CreatePasswordHandler(c *gin.Context) {
requestUser, err := uc.ExtractUserFromContext(c) requestUser, err := uc.Service.FromContext(c)
if err != nil { if err != nil {
utils.RespondWithError(c, err, "Error extracting user from context in UpdateHandler", http.StatusBadRequest, errors.Responses.Fields.User, errors.Responses.Keys.NoAuthToken) utils.RespondWithError(c, err, "Couldn't get User from Request Context", http.StatusBadRequest, errors.Responses.Fields.General, errors.Responses.Keys.NoAuthToken)
return return
} }
if !utils.HasPrivilige(requestUser, constants.Priviliges.AccessControl) { if !requestUser.IsAdmin() {
utils.RespondWithError(c, errors.ErrNotAuthorized, fmt.Sprintf("Not allowed to handle other users. RoleID(%v)<Privilige(%v)", requestUser.RoleID, constants.Priviliges.View), http.StatusForbidden, errors.Responses.Fields.User, errors.Responses.Keys.Unauthorized) utils.RespondWithError(c, errors.ErrNotAuthorized, "Requesting user not authorized to grant user access", http.StatusUnauthorized, errors.Responses.Fields.User, errors.Responses.Keys.Unauthorized)
return return
} }
// Expected data from the user // Expected data from the user
var input struct { var input struct {
User struct { User struct {
@@ -38,21 +36,26 @@ func (uc *UserController) CreatePasswordHandler(c *gin.Context) {
} }
// find user // find user
db_user, err := uc.Service.GetUserByID(input.User.ID) user, err := uc.Service.FromID(&input.User.ID)
if err != nil { if err != nil {
utils.RespondWithError(c, err, "couldn't get user by id", http.StatusNotFound, errors.Responses.Fields.User, errors.Responses.Keys.NotFound) utils.RespondWithError(c, err, "couldn't get user by id", http.StatusNotFound, errors.Responses.Fields.User, errors.Responses.Keys.NotFound)
return return
} }
// create token // Deactivate user and reset Verification
token, err := uc.Service.HandlePasswordChangeRequest(db_user) user.Status = constants.DisabledStatus
v, err := user.SetVerification(constants.VerificationTypes.Password)
if err != nil { if err != nil {
utils.RespondWithError(c, err, "couldn't handle password change request", http.StatusInternalServerError, errors.Responses.Fields.General, errors.Responses.Keys.InternalServerError) utils.RespondWithError(c, err, "couldn't set verification", http.StatusInternalServerError, errors.Responses.Fields.User, errors.Responses.Keys.InternalServerError)
return return
} }
if _, err := uc.Service.Update(user); err != nil {
utils.RespondWithError(c, err, "Couldn't update user in createPasswordHandler", http.StatusInternalServerError, errors.Responses.Fields.General, errors.Responses.Keys.InternalServerError)
return
}
// send email // send email
if err := uc.EmailService.SendGrantBackendAccessEmail(db_user, &token); err != nil { if err := uc.EmailService.SendGrantBackendAccessEmail(user, &v.VerificationToken); err != nil {
utils.RespondWithError(c, err, "Couldn't send grant backend access email", http.StatusInternalServerError, errors.Responses.Fields.General, errors.Responses.Keys.InternalServerError) utils.RespondWithError(c, err, "Couldn't send grant backend access email", http.StatusInternalServerError, errors.Responses.Fields.General, errors.Responses.Keys.InternalServerError)
return return
} }
@@ -74,27 +77,30 @@ func (uc *UserController) RequestPasswordChangeHandler(c *gin.Context) {
return return
} }
// find user // find user
db_user, err := uc.Service.GetUserByEmail(input.Email) user, err := uc.Service.FromEmail(&input.Email)
if err != nil { if err != nil {
utils.RespondWithError(c, err, "couldn't get user by email", http.StatusNotFound, errors.Responses.Fields.User, errors.Responses.Keys.NotFound) utils.RespondWithError(c, err, "couldn't get user by email", http.StatusNotFound, errors.Responses.Fields.User, errors.Responses.Keys.NotFound)
return return
} }
// check if user may change the password // check if user may change the password
if db_user.Status <= constants.DisabledStatus { if !user.IsVerified() {
utils.RespondWithError(c, errors.ErrNotAuthorized, "User password change request denied, user is disabled", http.StatusForbidden, errors.Responses.Fields.Login, errors.Responses.Keys.UserDisabled) utils.RespondWithError(c, errors.ErrNotAuthorized, "User password change request denied, user is not verified or disabled", http.StatusForbidden, errors.Responses.Fields.Login, errors.Responses.Keys.UserDisabled)
return return
} }
// create token user.Status = constants.DisabledStatus
token, err := uc.Service.HandlePasswordChangeRequest(db_user) v, err := user.SetVerification(constants.VerificationTypes.Password)
if err != nil { if err != nil {
utils.RespondWithError(c, err, "couldn't handle password change request", http.StatusInternalServerError, errors.Responses.Fields.General, errors.Responses.Keys.InternalServerError) utils.RespondWithError(c, err, "couldn't set verification", http.StatusInternalServerError, errors.Responses.Fields.User, errors.Responses.Keys.InternalServerError)
return return
} }
if _, err := uc.Service.Update(user); err != nil {
utils.RespondWithError(c, err, "Couldn't update user in createPasswordHandler", http.StatusInternalServerError, errors.Responses.Fields.General, errors.Responses.Keys.InternalServerError)
return
}
// send email // send email
if err := uc.EmailService.SendChangePasswordEmail(db_user, &token); err != nil { if err := uc.EmailService.SendChangePasswordEmail(user, &v.VerificationToken); err != nil {
utils.RespondWithError(c, err, "Couldn't send change password email", http.StatusInternalServerError, errors.Responses.Fields.General, errors.Responses.Keys.InternalServerError) utils.RespondWithError(c, err, "Couldn't send change password email", http.StatusInternalServerError, errors.Responses.Fields.General, errors.Responses.Keys.InternalServerError)
return return
} }
@@ -115,27 +121,24 @@ func (uc *UserController) ChangePassword(c *gin.Context) {
utils.RespondWithError(c, err, "Invalid user ID", http.StatusBadRequest, errors.Responses.Fields.User, errors.Responses.Keys.InvalidUserID) utils.RespondWithError(c, err, "Invalid user ID", http.StatusBadRequest, errors.Responses.Fields.User, errors.Responses.Keys.InvalidUserID)
return return
} }
userID := uint(userIDint)
user, err := uc.Service.FromID(&userID)
if err != nil {
utils.RespondWithError(c, err, "Couldn't find user", http.StatusNotFound, errors.Responses.Fields.User, errors.Responses.Keys.UserNotFoundWrongPassword)
return
}
if err := c.ShouldBindJSON(&input); err != nil { if err := c.ShouldBindJSON(&input); err != nil {
utils.HandleValidationError(c, err) utils.HandleValidationError(c, err)
return return
} }
verification, err := uc.Service.VerifyUser(&input.Token, &constants.VerificationTypes.Password) if !user.Verify(input.Token, constants.VerificationTypes.Password) {
if err != nil || uint(userIDint) != verification.UserID { utils.RespondWithError(c, errors.ErrAlreadyVerified, "Couldn't verify user", http.StatusInternalServerError, errors.Responses.Fields.General, errors.Responses.Keys.InternalServerError)
utils.HandleVerifyUserError(c, err)
return
}
user, err := uc.Service.GetUserByID(verification.UserID)
if err != nil {
utils.RespondWithError(c, err, "Couldn't find user", http.StatusNotFound, errors.Responses.Fields.User, errors.Responses.Keys.UserNotFoundWrongPassword)
return return
} }
user.Status = constants.ActiveStatus user.Status = constants.ActiveStatus
user.Verification = *verification
user.ID = verification.UserID
user.Password = input.Password user.Password = input.Password
// Get Gin's binding validator engine with all registered validators // Get Gin's binding validator engine with all registered validators
@@ -146,7 +149,7 @@ func (uc *UserController) ChangePassword(c *gin.Context) {
utils.HandleValidationError(c, err) utils.HandleValidationError(c, err)
return return
} }
_, err = uc.Service.UpdateUser(user) _, err = uc.Service.Update(user)
if err != nil { if err != nil {
utils.HandleUserUpdateError(c, err) utils.HandleUserUpdateError(c, err)
return return

View File

@@ -28,8 +28,8 @@ type TestContext struct {
} }
func setupTestContext() (*TestContext, error) { func setupTestContext() (*TestContext, error) {
testEmail := "john.doe@example.com"
user, err := Uc.Service.GetUserByEmail("john.doe@example.com") user, err := Uc.Service.FromEmail(&testEmail)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -39,7 +39,7 @@ func setupTestContext() (*TestContext, error) {
user: user, user: user,
}, nil }, nil
} }
func testCreatePasswordHandler(t *testing.T, loginCookie http.Cookie, adminCookie http.Cookie) { func testCreatePasswordHandler(t *testing.T) {
invalidCookie := http.Cookie{ invalidCookie := http.Cookie{
Name: "jwt", Name: "jwt",
Value: "invalid.token.here", Value: "invalid.token.here",
@@ -58,7 +58,7 @@ func testCreatePasswordHandler(t *testing.T, loginCookie http.Cookie, adminCooki
body, _ := json.Marshal(requestBody) body, _ := json.Marshal(requestBody)
t.Run("successful password creation request from admin", func(t *testing.T) { t.Run("successful password creation request from admin", func(t *testing.T) {
req, _ := http.NewRequest("POST", "/password", bytes.NewBuffer(body)) req, _ := http.NewRequest("POST", "/password", bytes.NewBuffer(body))
req.AddCookie(&adminCookie) req.AddCookie(AdminCookie)
tc.router.ServeHTTP(tc.response, req) tc.router.ServeHTTP(tc.response, req)
logger.Error.Printf("Test results for %#v", t.Name()) logger.Error.Printf("Test results for %#v", t.Name())
assert.Equal(t, http.StatusAccepted, tc.response.Code) assert.Equal(t, http.StatusAccepted, tc.response.Code)
@@ -73,11 +73,11 @@ func testCreatePasswordHandler(t *testing.T, loginCookie http.Cookie, adminCooki
tc.response = httptest.NewRecorder() tc.response = httptest.NewRecorder()
t.Run("failed password creation request from member", func(t *testing.T) { t.Run("failed password creation request from member", func(t *testing.T) {
req, _ := http.NewRequest("POST", "/password", bytes.NewBuffer(body)) req, _ := http.NewRequest("POST", "/password", bytes.NewBuffer(body))
req.AddCookie(&loginCookie) req.AddCookie(MemberCookie)
tc.router.ServeHTTP(tc.response, req) tc.router.ServeHTTP(tc.response, req)
logger.Error.Printf("Test results for %#v", t.Name()) logger.Error.Printf("Test results for %#v", t.Name())
assert.Equal(t, http.StatusForbidden, tc.response.Code) assert.Equal(t, http.StatusUnauthorized, tc.response.Code)
assert.JSONEq(t, `{"errors":[{"field":"user.user","key":"server.error.unauthorized"}]}`, tc.response.Body.String()) assert.JSONEq(t, `{"errors":[{"field":"user.user","key":"server.error.unauthorized"}]}`, tc.response.Body.String())
err = checkEmailDelivery(tc.user, false) err = checkEmailDelivery(tc.user, false)
assert.NoError(t, err) assert.NoError(t, err)
@@ -203,6 +203,7 @@ func checkPasswordMail(message *utils.Email, user *models.User) error {
if !strings.Contains(message.Body, verification.VerificationToken) { if !strings.Contains(message.Body, verification.VerificationToken) {
return fmt.Errorf("Token(%v) has not been rendered in password mail.", verification.VerificationToken) return fmt.Errorf("Token(%v) has not been rendered in password mail.", verification.VerificationToken)
} }
if strings.Trim(tokenURL, " ") != fmt.Sprintf("%v%v/auth/password/change/%v?token=%v", config.Site.BaseURL, config.Site.FrontendPath, user.ID, verification.VerificationToken) { if strings.Trim(tokenURL, " ") != fmt.Sprintf("%v%v/auth/password/change/%v?token=%v", config.Site.BaseURL, config.Site.FrontendPath, user.ID, verification.VerificationToken) {
return fmt.Errorf("Token has not been rendered correctly in password mail: %v%v/auth/password/change/%v?token=%v", config.Site.BaseURL, config.Site.FrontendPath, user.ID, verification.VerificationToken) return fmt.Errorf("Token has not been rendered correctly in password mail: %v%v/auth/password/change/%v?token=%v", config.Site.BaseURL, config.Site.FrontendPath, user.ID, verification.VerificationToken)
} }

View File

@@ -9,11 +9,14 @@ import (
"GoMembership/internal/utils" "GoMembership/internal/utils"
"GoMembership/internal/validation" "GoMembership/internal/validation"
"fmt" "fmt"
"strconv"
"strings" "strings"
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding"
"github.com/go-playground/validator/v10"
"GoMembership/pkg/errors" "GoMembership/pkg/errors"
"GoMembership/pkg/logger" "GoMembership/pkg/logger"
@@ -25,7 +28,7 @@ type UserController struct {
ConsentService services.ConsentServiceInterface ConsentService services.ConsentServiceInterface
BankAccountService services.BankAccountServiceInterface BankAccountService services.BankAccountServiceInterface
MembershipService services.MembershipServiceInterface MembershipService services.MembershipServiceInterface
LicenceService services.LicenceInterface LicenceService services.LicenceServiceInterface
} }
type RegistrationData struct { type RegistrationData struct {
@@ -33,7 +36,7 @@ type RegistrationData struct {
} }
func (uc *UserController) CurrentUserHandler(c *gin.Context) { func (uc *UserController) CurrentUserHandler(c *gin.Context) {
requestUser, err := uc.ExtractUserFromContext(c) requestUser, err := uc.Service.FromContext(c)
if err != nil { if err != nil {
utils.RespondWithError(c, err, "Error extracting user from context in CurrentUserHandler", http.StatusBadRequest, errors.Responses.Fields.User, errors.Responses.Keys.NoAuthToken) utils.RespondWithError(c, err, "Error extracting user from context in CurrentUserHandler", http.StatusBadRequest, errors.Responses.Fields.User, errors.Responses.Keys.NoAuthToken)
return return
@@ -46,19 +49,20 @@ func (uc *UserController) CurrentUserHandler(c *gin.Context) {
func (uc *UserController) GetAllUsers(c *gin.Context) { func (uc *UserController) GetAllUsers(c *gin.Context) {
requestUser, err := uc.ExtractUserFromContext(c) requestUser, err := uc.Service.FromContext(c)
if err != nil { if err != nil {
utils.RespondWithError(c, err, "Error extracting user from context in UpdateHandler", http.StatusBadRequest, errors.Responses.Fields.User, errors.Responses.Keys.NoAuthToken) utils.RespondWithError(c, err, "Error extracting user from context in UpdateHandler", http.StatusBadRequest, errors.Responses.Fields.User, errors.Responses.Keys.NoAuthToken)
return return
} }
if !utils.HasPrivilige(requestUser, constants.Priviliges.View) {
if !requestUser.HasPrivilege(constants.Priviliges.View) {
utils.RespondWithError(c, errors.ErrNotAuthorized, fmt.Sprintf("Not allowed to handle all users. RoleID(%v)<Privilige(%v)", requestUser.RoleID, constants.Priviliges.View), http.StatusForbidden, errors.Responses.Fields.User, errors.Responses.Keys.Unauthorized) utils.RespondWithError(c, errors.ErrNotAuthorized, fmt.Sprintf("Not allowed to handle all users. RoleID(%v)<Privilige(%v)", requestUser.RoleID, constants.Priviliges.View), http.StatusForbidden, errors.Responses.Fields.User, errors.Responses.Keys.Unauthorized)
return return
} }
users, err := uc.Service.GetUsers(nil) users, err := uc.Service.GetUsers(nil)
if err != nil { if err != nil {
utils.RespondWithError(c, err, "Error getting users in GetAllUsers", http.StatusInternalServerError, errors.Responses.Fields.User, errors.Responses.Keys.InternalServerError) utils.RespondWithError(c, err, "Error getting all users", http.StatusInternalServerError, errors.Responses.Fields.User, errors.Responses.Keys.InternalServerError)
return return
} }
@@ -69,6 +73,7 @@ func (uc *UserController) GetAllUsers(c *gin.Context) {
for i, user := range *users { for i, user := range *users {
safeUsers[i] = user.Safe() safeUsers[i] = user.Safe()
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"users": users, "users": users,
}) })
@@ -76,7 +81,7 @@ func (uc *UserController) GetAllUsers(c *gin.Context) {
func (uc *UserController) UpdateHandler(c *gin.Context) { func (uc *UserController) UpdateHandler(c *gin.Context) {
// 1. Extract and validate the user ID from the route // 1. Extract and validate the user ID from the route
requestUser, err := uc.ExtractUserFromContext(c) requestUser, err := uc.Service.FromContext(c)
if err != nil { if err != nil {
utils.RespondWithError(c, err, "Error extracting user from context in UpdateHandler", http.StatusBadRequest, errors.Responses.Fields.User, errors.Responses.Keys.NoAuthToken) utils.RespondWithError(c, err, "Error extracting user from context in UpdateHandler", http.StatusBadRequest, errors.Responses.Fields.User, errors.Responses.Keys.NoAuthToken)
return return
@@ -89,28 +94,20 @@ func (uc *UserController) UpdateHandler(c *gin.Context) {
} }
user := updateData.User user := updateData.User
if !utils.HasPrivilige(requestUser, constants.Priviliges.Update) && user.ID != requestUser.ID { if !requestUser.HasPrivilege(constants.Priviliges.Update) && user.ID != requestUser.ID {
utils.RespondWithError(c, errors.ErrNotAuthorized, "Not allowed to update user", http.StatusForbidden, errors.Responses.Fields.User, errors.Responses.Keys.Unauthorized) utils.RespondWithError(c, errors.ErrNotAuthorized, "Not allowed to update user", http.StatusForbidden, errors.Responses.Fields.User, errors.Responses.Keys.Unauthorized)
return return
} }
existingUser, err := uc.Service.GetUserByID(user.ID)
if requestUser.IsMember() {
existingUser, err := uc.Service.FromID(&user.ID)
if err != nil { if err != nil {
utils.RespondWithError(c, err, "Error finding an existing user", http.StatusNotFound, errors.Responses.Fields.User, errors.Responses.Keys.NotFound) utils.RespondWithError(c, err, "Error finding an existing user", http.StatusNotFound, errors.Responses.Fields.User, errors.Responses.Keys.NotFound)
return return
} }
user.MembershipID = existingUser.MembershipID
user.Membership.ID = existingUser.Membership.ID
if existingUser.Licence != nil {
user.Licence.ID = existingUser.Licence.ID
}
user.LicenceID = existingUser.LicenceID
user.BankAccount.ID = existingUser.BankAccount.ID
user.BankAccountID = existingUser.BankAccountID
if requestUser.RoleID <= constants.Priviliges.View {
// deleting existing Users Password to prevent it from being recognized as changed in any case. (Incoming Password is empty if not changed) // deleting existing Users Password to prevent it from being recognized as changed in any case. (Incoming Password is empty if not changed)
existingUser.Password = "" existingUser.Password = ""
if err := utils.FilterAllowedStructFields(&user, existingUser, constants.MemberUpdateFields, ""); err != nil { if err := validation.FilterAllowedStructFields(&user, existingUser, constants.MemberUpdateFields, ""); err != nil {
if err.Error() == "Not authorized" { if err.Error() == "Not authorized" {
utils.RespondWithError(c, errors.ErrNotAuthorized, "Trying to update unauthorized fields", http.StatusUnauthorized, errors.Responses.Fields.User, errors.Responses.Keys.Unauthorized) utils.RespondWithError(c, errors.ErrNotAuthorized, "Trying to update unauthorized fields", http.StatusUnauthorized, errors.Responses.Fields.User, errors.Responses.Keys.Unauthorized)
} else { } else {
@@ -120,20 +117,20 @@ func (uc *UserController) UpdateHandler(c *gin.Context) {
} }
} }
updatedUser, err := uc.Service.UpdateUser(&user) updatedUser, err := uc.Service.Update(&user)
if err != nil { if err != nil {
utils.HandleUserUpdateError(c, err) utils.HandleUserUpdateError(c, err)
return return
} }
logger.Info.Printf("User %d updated successfully by user %d", updatedUser.ID, requestUser.ID) logger.Info.Printf("User %v updated successfully by user %v", updatedUser.Email, requestUser.Email)
c.JSON(http.StatusAccepted, gin.H{"message": "User updated successfully", "user": updatedUser.Safe()}) c.JSON(http.StatusAccepted, gin.H{"message": "User updated successfully", "user": updatedUser.Safe()})
} }
func (uc *UserController) DeleteUser(c *gin.Context) { func (uc *UserController) DeleteUser(c *gin.Context) {
requestUser, err := uc.ExtractUserFromContext(c) requestUser, err := uc.Service.FromContext(c)
if err != nil { if err != nil {
utils.RespondWithError(c, err, "Error extracting user from context in DeleteUser", http.StatusBadRequest, errors.Responses.Fields.User, errors.Responses.Keys.NoAuthToken) utils.RespondWithError(c, err, "Error extracting user from context in DeleteUser", http.StatusBadRequest, errors.Responses.Fields.User, errors.Responses.Keys.NoAuthToken)
return return
@@ -152,13 +149,13 @@ func (uc *UserController) DeleteUser(c *gin.Context) {
return return
} }
if !utils.HasPrivilige(requestUser, constants.Priviliges.Delete) && data.User.ID != requestUser.ID { if !requestUser.HasPrivilege(constants.Priviliges.Delete) && data.User.ID != requestUser.ID {
utils.RespondWithError(c, errors.ErrNotAuthorized, "Not allowed to delete user", http.StatusForbidden, errors.Responses.Fields.User, errors.Responses.Keys.Unauthorized) utils.RespondWithError(c, errors.ErrNotAuthorized, "Not allowed to delete user", http.StatusForbidden, errors.Responses.Fields.User, errors.Responses.Keys.Unauthorized)
return return
} }
logger.Error.Printf("Deleting user: %v", data.User) logger.Error.Printf("Deleting user: %v", data.User)
if err := uc.Service.DeleteUser(data.User.LastName, data.User.ID); err != nil { if err := uc.Service.Delete(&data.User.ID); err != nil {
utils.HandleDeleteUserError(c, err) utils.HandleDeleteUserError(c, err)
return return
} }
@@ -166,24 +163,6 @@ func (uc *UserController) DeleteUser(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "User deleted successfully"}) c.JSON(http.StatusOK, gin.H{"message": "User deleted successfully"})
} }
func (uc *UserController) ExtractUserFromContext(c *gin.Context) (*models.User, error) {
tokenString, err := c.Cookie("jwt")
if err != nil {
return nil, err
}
_, claims, err := middlewares.ExtractContentFrom(tokenString)
if err != nil {
return nil, err
}
jwtUserID := uint((*claims)["user_id"].(float64))
user, err := uc.Service.GetUserByID(jwtUserID)
if err != nil {
return nil, err
}
return user, nil
}
func (uc *UserController) LogoutHandler(c *gin.Context) { func (uc *UserController) LogoutHandler(c *gin.Context) {
tokenString, err := c.Cookie("jwt") tokenString, err := c.Cookie("jwt")
if err != nil { if err != nil {
@@ -207,7 +186,7 @@ func (uc *UserController) LoginHandler(c *gin.Context) {
return return
} }
user, err := uc.Service.GetUserByEmail(input.Email) user, err := uc.Service.FromEmail(&input.Email)
if err != nil { if err != nil {
utils.RespondWithError(c, err, "Login Error; user not found", http.StatusNotFound, utils.RespondWithError(c, err, "Login Error; user not found", http.StatusNotFound,
errors.Responses.Fields.Login, errors.Responses.Fields.Login,
@@ -215,9 +194,9 @@ func (uc *UserController) LoginHandler(c *gin.Context) {
return return
} }
if user.Status <= constants.DisabledStatus { if !user.IsVerified() {
utils.RespondWithError(c, fmt.Errorf("User banned from login %v %v", user.FirstName, user.LastName), utils.RespondWithError(c, fmt.Errorf("User banned from login or not verified %v %v", user.FirstName, user.LastName),
"Login Error; user is disabled", "Login Error; user is disabled or not verified",
http.StatusNotAcceptable, http.StatusNotAcceptable,
errors.Responses.Fields.Login, errors.Responses.Fields.Login,
errors.Responses.Keys.UserDisabled) errors.Responses.Keys.UserDisabled)
@@ -238,8 +217,10 @@ func (uc *UserController) LoginHandler(c *gin.Context) {
return return
} }
logger.Error.Printf("jwtsecret: %v", config.Auth.JWTSecret) // "user_id": user.ID,
token, err := middlewares.GenerateToken(config.Auth.JWTSecret, user, "") // "role_id": user.RoleID,
claims := map[string]interface{}{"user_id": user.ID, "role_id": user.RoleID}
token, err := middlewares.GenerateToken(&config.Auth.JWTSecret, claims, "")
if err != nil { if err != nil {
utils.RespondWithError(c, err, "Error generating token in LoginHandler", http.StatusInternalServerError, errors.Responses.Fields.Login, errors.Responses.Keys.JwtGenerationFailed) utils.RespondWithError(c, err, "Error generating token in LoginHandler", http.StatusInternalServerError, errors.Responses.Fields.Login, errors.Responses.Keys.JwtGenerationFailed)
return return
@@ -256,7 +237,6 @@ func (uc *UserController) LoginHandler(c *gin.Context) {
func (uc *UserController) RegisterUser(c *gin.Context) { func (uc *UserController) RegisterUser(c *gin.Context) {
var regData RegistrationData var regData RegistrationData
logger.Error.Printf("registering user...")
if err := c.ShouldBindJSON(&regData); err != nil { if err := c.ShouldBindJSON(&regData); err != nil {
utils.HandleValidationError(c, err) utils.HandleValidationError(c, err)
return return
@@ -269,12 +249,14 @@ func (uc *UserController) RegisterUser(c *gin.Context) {
return return
} }
regData.User.Membership.SubscriptionModel = *selectedModel regData.User.Membership.SubscriptionModel = *selectedModel
if selectedModel.RequiredMembershipField != "" { // Get Gin's binding validator engine with all registered validators
if err := validation.CheckParentMembershipID(regData.User.Membership); err != nil { validate := binding.Validator.Engine().(*validator.Validate)
utils.RespondWithError(c, err, "Error in RegisterUser, couldn't check parent membership id", http.StatusBadRequest, errors.Responses.Fields.ParentMemberShipID, errors.Responses.Keys.NotFound)
// Validate the populated user struct
if err := validate.Struct(regData.User); err != nil {
utils.HandleValidationError(c, err)
return return
} }
}
if regData.User.Membership.SubscriptionModel.Name == constants.SupporterSubscriptionModelName { if regData.User.Membership.SubscriptionModel.Name == constants.SupporterSubscriptionModelName {
regData.User.RoleID = constants.Roles.Supporter regData.User.RoleID = constants.Roles.Supporter
} else { } else {
@@ -282,9 +264,9 @@ func (uc *UserController) RegisterUser(c *gin.Context) {
} }
// Register User // Register User
id, token, err := uc.Service.RegisterUser(&regData.User) id, token, err := uc.Service.Register(&regData.User)
if err != nil { if err != nil {
if strings.Contains(err.Error(), "UNIQUE constraint failed: users.email") { if strings.Contains(err.Error(), "UNIQUE constraint failed:") {
utils.RespondWithError(c, err, "Error in RegisterUser, couldn't register user", http.StatusConflict, errors.Responses.Fields.Email, errors.Responses.Keys.Duplicate) utils.RespondWithError(c, err, "Error in RegisterUser, couldn't register user", http.StatusConflict, errors.Responses.Fields.Email, errors.Responses.Keys.Duplicate)
} else { } else {
utils.RespondWithError(c, err, "Error in RegisterUser, couldn't register user", http.StatusConflict, errors.Responses.Fields.General, errors.Responses.Keys.InternalServerError) utils.RespondWithError(c, err, "Error in RegisterUser, couldn't register user", http.StatusConflict, errors.Responses.Fields.General, errors.Responses.Keys.InternalServerError)
@@ -294,7 +276,7 @@ func (uc *UserController) RegisterUser(c *gin.Context) {
regData.User.ID = id regData.User.ID = id
// if this is a supporter don't send mails and he never did give any consent. So stop here // if this is a supporter don't send mails and he never did give any consent. So stop here
if regData.User.RoleID == constants.Roles.Supporter { if regData.User.IsSupporter() {
c.JSON(http.StatusCreated, gin.H{ c.JSON(http.StatusCreated, gin.H{
"message": "Supporter Registration successuful", "message": "Supporter Registration successuful",
@@ -318,6 +300,7 @@ func (uc *UserController) RegisterUser(c *gin.Context) {
ConsentType: "Privacy", ConsentType: "Privacy",
}, },
} }
for _, consent := range consents { for _, consent := range consents {
_, err = uc.ConsentService.RegisterConsent(&consent) _, err = uc.ConsentService.RegisterConsent(&consent)
if err != nil { if err != nil {
@@ -326,6 +309,7 @@ func (uc *UserController) RegisterUser(c *gin.Context) {
} }
} }
logger.Error.Printf("Sending Verification mail to user with id: %#v", id)
// Send notifications // Send notifications
if err := uc.EmailService.SendVerificationEmail(&regData.User, &token); err != nil { if err := uc.EmailService.SendVerificationEmail(&regData.User, &token); err != nil {
utils.RespondWithError(c, err, "Error in RegisterUser, couldn't send verification email", http.StatusInternalServerError, errors.Responses.Fields.Email, errors.Responses.Keys.UndeliveredVerificationMail) utils.RespondWithError(c, err, "Error in RegisterUser, couldn't send verification email", http.StatusInternalServerError, errors.Responses.Fields.Email, errors.Responses.Keys.UndeliveredVerificationMail)
@@ -351,26 +335,35 @@ func (uc *UserController) VerifyMailHandler(c *gin.Context) {
c.HTML(http.StatusBadRequest, "verification_error.html", gin.H{"ErrorMessage": "Missing token"}) c.HTML(http.StatusBadRequest, "verification_error.html", gin.H{"ErrorMessage": "Missing token"})
return return
} }
userIDint, err := strconv.Atoi(c.Param("id"))
verification, err := uc.Service.VerifyUser(&token, &constants.VerificationTypes.Email)
if err != nil { if err != nil {
c.HTML(http.StatusBadRequest, "verification_error.html", gin.H{"ErrorMessage": "Couldn't verify user"}) logger.Error.Println("Missing user ID to verify mail")
c.HTML(http.StatusBadRequest, "verification_error.html", gin.H{"ErrorMessage": "Missing user"})
return return
} }
userID := uint(userIDint)
user, err := uc.Service.GetUserByID(verification.UserID) user, err := uc.Service.FromID(&userID)
if err != nil { if err != nil {
c.HTML(http.StatusBadRequest, "verification_error.html", gin.H{"ErrorMessage": "Internal server error, couldn't verify user"}) logger.Error.Printf("Couldn't find user in verifyMailHandler: %#v", err)
c.HTML(http.StatusBadRequest, "verification_error.html", gin.H{"ErrorMessage": "Couldn't find user"})
return
}
if !user.Verify(token, constants.VerificationTypes.Email) {
logger.Error.Printf("Couldn't find user verification in verifyMailHandler: %v", err)
c.HTML(http.StatusBadRequest, "verification_error.html", gin.H{"ErrorMessage": "Couldn't find user verification request"})
return return
} }
user.Status = constants.VerifiedStatus user.Status = constants.VerifiedStatus
user.Verification = *verification
user.ID = verification.UserID
user.Password = "" user.Password = ""
uc.Service.UpdateUser(user) updatedUser, err := uc.Service.Update(user)
logger.Info.Printf("Verified User: %#v", user.Email) if err != nil {
logger.Error.Printf("Failed to update user(%v) after verification: %v", user.Email, err)
c.HTML(http.StatusInternalServerError, "verification_error.html", gin.H{"ErrorMessage": "Internal server error, couldn't verify user"})
return
}
logger.Info.Printf("Verified User: %#v", updatedUser.Email)
uc.EmailService.SendWelcomeEmail(user) uc.EmailService.SendWelcomeEmail(user)
c.HTML(http.StatusOK, "verification_success.html", gin.H{"FirstName": user.FirstName}) c.HTML(http.StatusOK, "verification_success.html", gin.H{"FirstName": user.FirstName})

View File

@@ -5,6 +5,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"log"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
@@ -75,8 +76,8 @@ func testUserController(t *testing.T) {
} }
// activate user for login // activate user for login
database.DB.Model(&models.User{}).Where("email = ?", "john.doe@example.com").Update("status", constants.ActiveStatus) database.DB.Model(&models.User{}).Where("email = ?", "john.doe@example.com").Update("status", constants.ActiveStatus)
loginEmail, loginCookie := testLoginHandler(t) loginEmail := testLoginHandler(t)
logoutCookie := testCurrentUserHandler(t, loginEmail, loginCookie) testCurrentUserHandler(t, loginEmail)
// creating a admin cookie // creating a admin cookie
c, w, _ := GetMockedJSONContext([]byte(`{ c, w, _ := GetMockedJSONContext([]byte(`{
@@ -91,28 +92,27 @@ func testUserController(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, "Login successful", response["message"]) assert.Equal(t, "Login successful", response["message"])
var adminCookie http.Cookie
for _, cookie := range w.Result().Cookies() { for _, cookie := range w.Result().Cookies() {
if cookie.Name == "jwt" { if cookie.Name == "jwt" {
adminCookie = *cookie AdminCookie = cookie
tokenString := adminCookie.Value tokenString := AdminCookie.Value
_, claims, err := middlewares.ExtractContentFrom(tokenString) _, claims, err := middlewares.ExtractContentFrom(tokenString)
assert.NoError(t, err, "FAiled getting cookie string") assert.NoError(t, err, "Failed getting cookie string")
jwtUserID := uint((*claims)["user_id"].(float64)) jwtUserID := uint((*claims)["user_id"].(float64))
user, err := Uc.Service.GetUserByID(jwtUserID) user, err := Uc.Service.FromID(&jwtUserID)
assert.NoError(t, err, "FAiled getting cookie string") assert.NoError(t, err, "Failed getting cookie string")
logger.Error.Printf("ADMIN USER: %#v", user) logger.Error.Printf("ADMIN USER: %#v", user)
break break
} }
} }
assert.NotEmpty(t, adminCookie) assert.NotEmpty(t, AdminCookie)
testUpdateUser(t, loginCookie, adminCookie) testUpdateUser(t)
testLogoutHandler(t, logoutCookie) testLogoutHandler(t)
testCreatePasswordHandler(t, loginCookie, adminCookie) testCreatePasswordHandler(t)
} }
func testLogoutHandler(t *testing.T, loginCookie http.Cookie) { func testLogoutHandler(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@@ -122,7 +122,7 @@ func testLogoutHandler(t *testing.T, loginCookie http.Cookie) {
{ {
name: "Logout with valid cookie", name: "Logout with valid cookie",
setupCookie: func(req *http.Request) { setupCookie: func(req *http.Request) {
req.AddCookie(&loginCookie) req.AddCookie(MemberCookie)
}, },
expectedStatus: http.StatusOK, expectedStatus: http.StatusOK,
}, },
@@ -180,9 +180,8 @@ func testLogoutHandler(t *testing.T, loginCookie http.Cookie) {
} }
} }
func testLoginHandler(t *testing.T) (string, http.Cookie) { func testLoginHandler(t *testing.T) string {
// This test should run after the user registration test // This test should run after the user registration test
var loginCookie http.Cookie
var loginInput loginInput var loginInput loginInput
t.Run("LoginHandler", func(t *testing.T) { t.Run("LoginHandler", func(t *testing.T) {
// Test cases // Test cases
@@ -244,7 +243,7 @@ func testLoginHandler(t *testing.T) (string, http.Cookie) {
assert.Equal(t, "Login successful", response["message"]) assert.Equal(t, "Login successful", response["message"])
for _, cookie := range w.Result().Cookies() { for _, cookie := range w.Result().Cookies() {
if cookie.Name == "jwt" { if cookie.Name == "jwt" {
loginCookie = *cookie MemberCookie = cookie
// tokenString := loginCookie.Value // tokenString := loginCookie.Value
// _, claims, err := middlewares.ExtractContentFrom(tokenString) // _, claims, err := middlewares.ExtractContentFrom(tokenString)
@@ -260,7 +259,7 @@ func testLoginHandler(t *testing.T) (string, http.Cookie) {
break break
} }
} }
assert.NotEmpty(t, loginCookie) assert.NotEmpty(t, MemberCookie)
} else { } else {
assert.Contains(t, response, "errors") assert.Contains(t, response, "errors")
assert.NotEmpty(t, response["errors"]) assert.NotEmpty(t, response["errors"])
@@ -269,10 +268,10 @@ func testLoginHandler(t *testing.T) (string, http.Cookie) {
} }
}) })
return loginInput.Email, loginCookie return loginInput.Email
} }
func testCurrentUserHandler(t *testing.T, loginEmail string, loginCookie http.Cookie) http.Cookie { func testCurrentUserHandler(t *testing.T, loginEmail string) http.Cookie {
// This test should run after the user login test // This test should run after the user login test
invalidCookie := http.Cookie{ invalidCookie := http.Cookie{
Name: "jwt", Name: "jwt",
@@ -289,7 +288,7 @@ func testCurrentUserHandler(t *testing.T, loginEmail string, loginCookie http.Co
{ {
name: "With valid cookie", name: "With valid cookie",
setupCookie: func(req *http.Request) { setupCookie: func(req *http.Request) {
req.AddCookie(&loginCookie) req.AddCookie(MemberCookie)
}, },
expectedUserMail: loginEmail, expectedUserMail: loginEmail,
expectedStatus: http.StatusOK, expectedStatus: http.StatusOK,
@@ -369,7 +368,7 @@ func testCurrentUserHandler(t *testing.T, loginEmail string, loginCookie http.Co
} }
if tt.expectNewCookie { if tt.expectNewCookie {
assert.NotNil(t, newCookie, "New cookie should be set for expired token") assert.NotNil(t, newCookie, "New cookie should be set for expired token")
assert.NotEqual(t, loginCookie.Value, newCookie.Value, "Cookie value should be different") assert.NotEqual(t, MemberCookie.Value, newCookie.Value, "Cookie value should be different")
assert.True(t, newCookie.MaxAge > 0, "New cookie should not be expired") assert.True(t, newCookie.MaxAge > 0, "New cookie should not be expired")
} else { } else {
assert.Nil(t, newCookie, "No new cookie should be set for non-expired token") assert.Nil(t, newCookie, "No new cookie should be set for non-expired token")
@@ -395,7 +394,7 @@ func testCurrentUserHandler(t *testing.T, loginEmail string, loginCookie http.Co
}) })
} }
return loginCookie return *MemberCookie
} }
func validateUser(assert bool, wantDBData map[string]interface{}) error { func validateUser(assert bool, wantDBData map[string]interface{}) error {
@@ -420,6 +419,11 @@ func validateUser(assert bool, wantDBData map[string]interface{}) error {
return fmt.Errorf("Mandate reference is invalid. Expected: %s, Got: %s", expected, user.BankAccount.MandateReference) return fmt.Errorf("Mandate reference is invalid. Expected: %s, Got: %s", expected, user.BankAccount.MandateReference)
} }
// Supoorter don't get mails
if user.IsSupporter() {
return nil
}
//check for email delivery //check for email delivery
messages := utils.SMTPGetMessages() messages := utils.SMTPGetMessages()
for _, message := range messages { for _, message := range messages {
@@ -454,18 +458,18 @@ func validateUser(assert bool, wantDBData map[string]interface{}) error {
return nil return nil
} }
func testUpdateUser(t *testing.T, loginCookie http.Cookie, adminCookie http.Cookie) { func testUpdateUser(t *testing.T) {
invalidCookie := http.Cookie{ invalidCookie := http.Cookie{
Name: "jwt", Name: "jwt",
Value: "invalid.token.here", Value: "invalid.token.here",
} }
// Get the user we just created // Get the user we just created
users, err := Uc.Service.GetUsers(map[string]interface{}{"email": "john.doe@example.com"}) johnsMail := "john.doe@example.com"
if err != nil || len(*users) == 0 { user, err := Uc.Service.FromEmail(&johnsMail)
if err != nil {
t.Fatalf("Failed to get test user: %v", err) t.Fatalf("Failed to get test user: %v", err)
} }
user := (*users)[0]
if user.Licence == nil { if user.Licence == nil {
user.Licence = &models.Licence{ user.Licence = &models.Licence{
Number: "Z021AB37X13", Number: "Z021AB37X13",
@@ -485,7 +489,7 @@ func testUpdateUser(t *testing.T, loginCookie http.Cookie, adminCookie http.Cook
{ {
name: "Valid Admin Update", name: "Valid Admin Update",
setupCookie: func(req *http.Request) { setupCookie: func(req *http.Request) {
req.AddCookie(&adminCookie) req.AddCookie(AdminCookie)
}, },
updateFunc: func(u *models.User) { updateFunc: func(u *models.User) {
u.Password = "" u.Password = ""
@@ -514,7 +518,7 @@ func testUpdateUser(t *testing.T, loginCookie http.Cookie, adminCookie http.Cook
{ {
name: "Invalid Email Update", name: "Invalid Email Update",
setupCookie: func(req *http.Request) { setupCookie: func(req *http.Request) {
req.AddCookie(&loginCookie) req.AddCookie(MemberCookie)
}, },
updateFunc: func(u *models.User) { updateFunc: func(u *models.User) {
u.Password = "" u.Password = ""
@@ -532,7 +536,7 @@ func testUpdateUser(t *testing.T, loginCookie http.Cookie, adminCookie http.Cook
{ {
name: "admin may change licence number", name: "admin may change licence number",
setupCookie: func(req *http.Request) { setupCookie: func(req *http.Request) {
req.AddCookie(&adminCookie) req.AddCookie(AdminCookie)
}, },
updateFunc: func(u *models.User) { updateFunc: func(u *models.User) {
u.Password = "" u.Password = ""
@@ -546,7 +550,7 @@ func testUpdateUser(t *testing.T, loginCookie http.Cookie, adminCookie http.Cook
{ {
name: "Change phone number", name: "Change phone number",
setupCookie: func(req *http.Request) { setupCookie: func(req *http.Request) {
req.AddCookie(&loginCookie) req.AddCookie(MemberCookie)
}, },
updateFunc: func(u *models.User) { updateFunc: func(u *models.User) {
u.Password = "" u.Password = ""
@@ -560,7 +564,7 @@ func testUpdateUser(t *testing.T, loginCookie http.Cookie, adminCookie http.Cook
{ {
name: "Add category", name: "Add category",
setupCookie: func(req *http.Request) { setupCookie: func(req *http.Request) {
req.AddCookie(&loginCookie) req.AddCookie(MemberCookie)
}, },
updateFunc: func(u *models.User) { updateFunc: func(u *models.User) {
u.Password = "" u.Password = ""
@@ -578,7 +582,7 @@ func testUpdateUser(t *testing.T, loginCookie http.Cookie, adminCookie http.Cook
{ {
name: "Delete 1 and add 1 category", name: "Delete 1 and add 1 category",
setupCookie: func(req *http.Request) { setupCookie: func(req *http.Request) {
req.AddCookie(&loginCookie) req.AddCookie(MemberCookie)
}, },
updateFunc: func(u *models.User) { updateFunc: func(u *models.User) {
u.Password = "" u.Password = ""
@@ -597,7 +601,7 @@ func testUpdateUser(t *testing.T, loginCookie http.Cookie, adminCookie http.Cook
{ {
name: "Delete 1 category", name: "Delete 1 category",
setupCookie: func(req *http.Request) { setupCookie: func(req *http.Request) {
req.AddCookie(&loginCookie) req.AddCookie(MemberCookie)
}, },
updateFunc: func(u *models.User) { updateFunc: func(u *models.User) {
u.Password = "" u.Password = ""
@@ -615,7 +619,7 @@ func testUpdateUser(t *testing.T, loginCookie http.Cookie, adminCookie http.Cook
{ {
name: "Delete all categories", name: "Delete all categories",
setupCookie: func(req *http.Request) { setupCookie: func(req *http.Request) {
req.AddCookie(&loginCookie) req.AddCookie(MemberCookie)
}, },
updateFunc: func(u *models.User) { updateFunc: func(u *models.User) {
u.Password = "" u.Password = ""
@@ -630,7 +634,7 @@ func testUpdateUser(t *testing.T, loginCookie http.Cookie, adminCookie http.Cook
{ {
name: "User ID mismatch while not admin", name: "User ID mismatch while not admin",
setupCookie: func(req *http.Request) { setupCookie: func(req *http.Request) {
req.AddCookie(&loginCookie) req.AddCookie(MemberCookie)
}, },
updateFunc: func(u *models.User) { updateFunc: func(u *models.User) {
u.Password = "" u.Password = ""
@@ -649,7 +653,7 @@ func testUpdateUser(t *testing.T, loginCookie http.Cookie, adminCookie http.Cook
{ {
name: "Password Update low entropy should fail", name: "Password Update low entropy should fail",
setupCookie: func(req *http.Request) { setupCookie: func(req *http.Request) {
req.AddCookie(&loginCookie) req.AddCookie(MemberCookie)
}, },
updateFunc: func(u *models.User) { updateFunc: func(u *models.User) {
u.FirstName = "John Updated" u.FirstName = "John Updated"
@@ -666,7 +670,7 @@ func testUpdateUser(t *testing.T, loginCookie http.Cookie, adminCookie http.Cook
{ {
name: "Password Update", name: "Password Update",
setupCookie: func(req *http.Request) { setupCookie: func(req *http.Request) {
req.AddCookie(&loginCookie) req.AddCookie(MemberCookie)
}, },
updateFunc: func(u *models.User) { updateFunc: func(u *models.User) {
u.FirstName = "John Updated" u.FirstName = "John Updated"
@@ -687,7 +691,7 @@ func testUpdateUser(t *testing.T, loginCookie http.Cookie, adminCookie http.Cook
{ {
name: "Admin Password Update", name: "Admin Password Update",
setupCookie: func(req *http.Request) { setupCookie: func(req *http.Request) {
req.AddCookie(&adminCookie) req.AddCookie(AdminCookie)
}, },
updateFunc: func(u *models.User) { updateFunc: func(u *models.User) {
u.LastName = "Doe Updated" u.LastName = "Doe Updated"
@@ -700,7 +704,7 @@ func testUpdateUser(t *testing.T, loginCookie http.Cookie, adminCookie http.Cook
{ {
name: "Non-existent User", name: "Non-existent User",
setupCookie: func(req *http.Request) { setupCookie: func(req *http.Request) {
req.AddCookie(&loginCookie) req.AddCookie(MemberCookie)
}, },
updateFunc: func(u *models.User) { updateFunc: func(u *models.User) {
u.Password = "" u.Password = ""
@@ -719,7 +723,7 @@ func testUpdateUser(t *testing.T, loginCookie http.Cookie, adminCookie http.Cook
logger.Error.Print("==============================================================") logger.Error.Print("==============================================================")
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
// Create a copy of the user and apply the updates // Create a copy of the user and apply the updates
updatedUser := user updatedUser := *user
// logger.Error.Printf("users licence to be updated: %+v", user.Licence) // logger.Error.Printf("users licence to be updated: %+v", user.Licence)
tt.updateFunc(&updatedUser) tt.updateFunc(&updatedUser)
@@ -784,7 +788,7 @@ func testUpdateUser(t *testing.T, loginCookie http.Cookie, adminCookie http.Cook
assert.Equal(t, "User updated successfully", message) assert.Equal(t, "User updated successfully", message)
// Verify the update in the database // Verify the update in the database
updatedUserFromDB, err := Uc.Service.GetUserByID(user.ID) updatedUserFromDB, err := Uc.Service.FromID(&user.ID)
assert.NoError(t, err) assert.NoError(t, err)
if updatedUser.Password == "" { if updatedUser.Password == "" {
@@ -823,11 +827,17 @@ func testUpdateUser(t *testing.T, loginCookie http.Cookie, adminCookie http.Cook
assert.Equal(t, updatedUser.Membership.SubscriptionModelID, updatedUserFromDB.Membership.SubscriptionModelID, "Membership.SubscriptionModelID mismatch") assert.Equal(t, updatedUser.Membership.SubscriptionModelID, updatedUserFromDB.Membership.SubscriptionModelID, "Membership.SubscriptionModelID mismatch")
assert.Equal(t, updatedUser.Membership.ParentMembershipID, updatedUserFromDB.Membership.ParentMembershipID, "Membership.ParentMembershipID mismatch") assert.Equal(t, updatedUser.Membership.ParentMembershipID, updatedUserFromDB.Membership.ParentMembershipID, "Membership.ParentMembershipID mismatch")
if updatedUser.Licence == nil {
assert.Nil(t, updatedUserFromDB.Licence, "database licence of user is not nil, but user.licence is nil")
} else {
logger.Error.Printf("updatedUser licence: %#v", updatedUser.Licence)
logger.Error.Printf("dbUser licence: %#v", updatedUserFromDB.Licence)
assert.Equal(t, updatedUser.Licence.Status, updatedUserFromDB.Licence.Status, "Licence.Status mismatch") assert.Equal(t, updatedUser.Licence.Status, updatedUserFromDB.Licence.Status, "Licence.Status mismatch")
assert.Equal(t, updatedUser.Licence.Number, updatedUserFromDB.Licence.Number, "Licence.Number mismatch") assert.Equal(t, updatedUser.Licence.Number, updatedUserFromDB.Licence.Number, "Licence.Number mismatch")
assert.Equal(t, updatedUser.Licence.IssuedDate, updatedUserFromDB.Licence.IssuedDate, "Licence.IssuedDate mismatch") assert.Equal(t, updatedUser.Licence.IssuedDate, updatedUserFromDB.Licence.IssuedDate, "Licence.IssuedDate mismatch")
assert.Equal(t, updatedUser.Licence.ExpirationDate, updatedUserFromDB.Licence.ExpirationDate, "Licence.ExpirationDate mismatch") assert.Equal(t, updatedUser.Licence.ExpirationDate, updatedUserFromDB.Licence.ExpirationDate, "Licence.ExpirationDate mismatch")
assert.Equal(t, updatedUser.Licence.IssuingCountry, updatedUserFromDB.Licence.IssuingCountry, "Licence.IssuingCountry mismatch") assert.Equal(t, updatedUser.Licence.IssuingCountry, updatedUserFromDB.Licence.IssuingCountry, "Licence.IssuingCountry mismatch")
}
// For slices or more complex nested structures, you might want to use deep equality checks // For slices or more complex nested structures, you might want to use deep equality checks
assert.ElementsMatch(t, updatedUser.Consents, updatedUserFromDB.Consents, "Consents mismatch") assert.ElementsMatch(t, updatedUser.Consents, updatedUserFromDB.Consents, "Consents mismatch")
@@ -936,15 +946,19 @@ func checkVerificationMail(message *utils.Email, user *models.User) error {
if err != nil { if err != nil {
return fmt.Errorf("Error parsing verification URL: %#v", err.Error()) return fmt.Errorf("Error parsing verification URL: %#v", err.Error())
} }
if !strings.Contains(verificationURL, user.Verification.VerificationToken) { v, err := user.GetVerification(constants.VerificationTypes.Email)
return fmt.Errorf("Users Verification link token(%v) has not been rendered in email verification mail. %v", user.Verification.VerificationToken, verificationURL) if err != nil {
return fmt.Errorf("Error getting verification token: %v", err.Error())
}
if !strings.Contains(verificationURL, v.VerificationToken) {
return fmt.Errorf("Users Verification link token(%v) has not been rendered in email verification mail. %v", v.VerificationToken, verificationURL)
} }
if !strings.Contains(message.Body, config.Site.BaseURL) { if !strings.Contains(message.Body, config.Site.BaseURL) {
return fmt.Errorf("Base Url (%v) has not been rendered in email verification mail.", config.Site.BaseURL) return fmt.Errorf("Base Url (%v) has not been rendered in email verification mail.", config.Site.BaseURL)
} }
// open the provided link: // open the provided link:
if err := verifyMail(verificationURL); err != nil { if err := verifyMail(verificationURL, user.ID); err != nil {
return err return err
} }
messages := utils.SMTPGetMessages() messages := utils.SMTPGetMessages()
@@ -961,12 +975,14 @@ func checkVerificationMail(message *utils.Email, user *models.User) error {
} }
func verifyMail(verificationURL string) error { func verifyMail(verificationURL string, user_id uint) error {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
router := gin.New() router := gin.New()
router.LoadHTMLGlob(filepath.Join(config.Templates.HTMLPath, "*")) router.LoadHTMLGlob(filepath.Join(config.Templates.HTMLPath, "*"))
router.GET("api/users/verify", Uc.VerifyMailHandler) expectedUrl := fmt.Sprintf("/api/users/verify/%v", user_id)
log.Printf("Expected URL: %v", expectedUrl)
router.GET("/api/users/verify/:id", Uc.VerifyMailHandler)
wv := httptest.NewRecorder() wv := httptest.NewRecorder()
cv, _ := gin.CreateTestContext(wv) cv, _ := gin.CreateTestContext(wv)
var err error var err error
@@ -1109,8 +1125,9 @@ func getTestUsers() []RegisterUserTest {
Assert: false, Assert: false,
Input: GenerateInputJSON(customizeInput(func(user models.User) models.User { Input: GenerateInputJSON(customizeInput(func(user models.User) models.User {
user.BankAccount.IBAN = "DE1234234123134" user.BankAccount.IBAN = "DE1234234123134"
user.RoleID = 0 user.RoleID = constants.Roles.Supporter
user.Email = "john.supporter@example.com" user.Email = "john.supporter@example.com"
user.Membership.SubscriptionModel.Name = constants.SupporterSubscriptionModelName
return user return user
})), })),
}, },
@@ -1121,8 +1138,9 @@ func getTestUsers() []RegisterUserTest {
Assert: true, Assert: true,
Input: GenerateInputJSON(customizeInput(func(user models.User) models.User { Input: GenerateInputJSON(customizeInput(func(user models.User) models.User {
user.BankAccount.IBAN = "" user.BankAccount.IBAN = ""
user.RoleID = 0 user.RoleID = constants.Roles.Supporter
user.Email = "john.supporter@example.com" user.Email = "john.supporter@example.com"
user.Membership.SubscriptionModel.Name = constants.SupporterSubscriptionModelName
return user return user
})), })),
}, },

View File

@@ -15,11 +15,11 @@ import (
var DB *gorm.DB var DB *gorm.DB
func Open(dbPath string, adminMail string) error { func Open(dbPath string, adminMail string) (*gorm.DB, error) {
db, err := gorm.Open(sqlite.Open(dbPath), &gorm.Config{}) db, err := gorm.Open(sqlite.Open(dbPath), &gorm.Config{})
if err != nil { if err != nil {
return err return nil, err
} }
if err := db.AutoMigrate( if err := db.AutoMigrate(
&models.User{}, &models.User{},
@@ -31,12 +31,11 @@ func Open(dbPath string, adminMail string) error {
&models.Category{}, &models.Category{},
&models.BankAccount{}); err != nil { &models.BankAccount{}); err != nil {
logger.Error.Fatalf("Couldn't create database: %v", err) logger.Error.Fatalf("Couldn't create database: %v", err)
return err return nil, err
} }
DB = db
logger.Info.Print("Opened DB") logger.Info.Print("Opened DB")
DB = db
var categoriesCount int64 var categoriesCount int64
db.Model(&models.Category{}).Count(&categoriesCount) db.Model(&models.Category{}).Count(&categoriesCount)
if categoriesCount == 0 { if categoriesCount == 0 {
@@ -44,7 +43,7 @@ func Open(dbPath string, adminMail string) error {
for _, model := range categories { for _, model := range categories {
result := db.Create(&model) result := db.Create(&model)
if result.Error != nil { if result.Error != nil {
return result.Error return nil, result.Error
} }
} }
} }
@@ -62,7 +61,7 @@ func Open(dbPath string, adminMail string) error {
if exists == 0 { if exists == 0 {
result := db.Create(&model) result := db.Create(&model)
if result.Error != nil { if result.Error != nil {
return result.Error return nil, result.Error
} }
} }
} }
@@ -72,20 +71,20 @@ func Open(dbPath string, adminMail string) error {
if userCount == 0 { if userCount == 0 {
var createdModel models.SubscriptionModel var createdModel models.SubscriptionModel
if err := db.First(&createdModel).Error; err != nil { if err := db.First(&createdModel).Error; err != nil {
return err return nil, err
} }
admin, err := createAdmin(adminMail, createdModel.ID) admin, err := createAdmin(adminMail, createdModel.ID)
if err != nil { if err != nil {
return err return nil, err
} }
result := db.Session(&gorm.Session{FullSaveAssociations: true}).Create(&admin) result := db.Session(&gorm.Session{FullSaveAssociations: true}).Create(&admin)
if result.Error != nil { if result.Error != nil {
return result.Error return nil, result.Error
} }
} }
return nil return db, nil
} }
func createSubscriptionModels() []models.SubscriptionModel { func createSubscriptionModels() []models.SubscriptionModel {
@@ -167,11 +166,11 @@ func createAdmin(userMail string, subscriptionModelID uint) (*models.User, error
//"DE49700500000008447644", //fake //"DE49700500000008447644", //fake
} }
func Close() error { func Close(db *gorm.DB) error {
logger.Info.Print("Closing DB") logger.Info.Print("Closing DB")
db, err := DB.DB() database, err := db.DB()
if err != nil { if err != nil {
return err return err
} }
return db.Close() return database.Close()
} }

View File

@@ -2,9 +2,7 @@ package middlewares
import ( import (
"GoMembership/internal/config" "GoMembership/internal/config"
"GoMembership/internal/models"
"GoMembership/internal/utils" "GoMembership/internal/utils"
customerrors "GoMembership/pkg/errors"
"GoMembership/pkg/logger" "GoMembership/pkg/logger"
"errors" "errors"
"fmt" "fmt"
@@ -34,26 +32,43 @@ func verifyAndRenewToken(tokenString string) (string, uint, error) {
return "", 0, fmt.Errorf("Authorization token is required") return "", 0, fmt.Errorf("Authorization token is required")
} }
token, claims, err := ExtractContentFrom(tokenString) token, claims, err := ExtractContentFrom(tokenString)
if err != nil {
if err != nil && !errors.Is(err, jwt.ErrTokenExpired) {
logger.Error.Printf("Couldn't parse JWT token String: %v", err) logger.Error.Printf("Couldn't parse JWT token String: %v", err)
return "", 0, err return "", 0, err
} }
sessionID := (*claims)["session_id"].(string)
userID := uint((*claims)["user_id"].(float64)) if token.Valid {
roleID := int8((*claims)["role_id"].(float64)) // token is valid, so we can return the old tokenString
return tokenString, uint((*claims)["user_id"].(float64)), nil
}
// Token is expired but valid
sessionID, ok := (*claims)["session_id"].(string)
if !ok || sessionID == "" {
return "", 0, fmt.Errorf("invalid session ID")
}
id, ok := (*claims)["user_id"]
if !ok {
return "", 0, fmt.Errorf("missing user_id claim")
}
userID := uint(id.(float64))
id, ok = (*claims)["role_id"]
if !ok {
return "", 0, fmt.Errorf("missing role_id claim")
}
roleID := int8(id.(float64))
session, ok := sessions[sessionID] session, ok := sessions[sessionID]
if !ok { if !ok {
logger.Error.Printf("session not found") logger.Error.Printf("session not found")
return "", 0, fmt.Errorf("session not found") return "", 0, fmt.Errorf("session not found")
} }
if userID != session.UserID { if userID != session.UserID {
return "", 0, fmt.Errorf("Cookie has been altered, aborting..") return "", 0, fmt.Errorf("Cookie has been altered, aborting..")
} }
if token.Valid {
// token is valid, so we can return the old tokenString
return tokenString, session.UserID, customerrors.ErrValidToken
}
if time.Now().After(sessions[sessionID].ExpiresAt) { if time.Now().After(sessions[sessionID].ExpiresAt) {
delete(sessions, sessionID) delete(sessions, sessionID)
@@ -64,8 +79,8 @@ func verifyAndRenewToken(tokenString string) (string, uint, error) {
logger.Error.Printf("Session still valid generating new token") logger.Error.Printf("Session still valid generating new token")
// Session is still valid, generate a new token // Session is still valid, generate a new token
user := models.User{ID: userID, RoleID: roleID} user := map[string]interface{}{"user_id": userID, "role_id": roleID}
newTokenString, err := GenerateToken(config.Auth.JWTSecret, &user, sessionID) newTokenString, err := GenerateToken(&config.Auth.JWTSecret, user, sessionID)
if err != nil { if err != nil {
return "", 0, err return "", 0, err
} }
@@ -89,11 +104,6 @@ func AuthMiddleware() gin.HandlerFunc {
newToken, userID, err := verifyAndRenewToken(tokenString) newToken, userID, err := verifyAndRenewToken(tokenString)
if err != nil { if err != nil {
if err == customerrors.ErrValidToken {
c.Set("user_id", uint(userID))
c.Next()
return
}
logger.Error.Printf("Token(%v) is invalid: %v\n", tokenString, err) logger.Error.Printf("Token(%v) is invalid: %v\n", tokenString, err)
c.JSON(http.StatusUnauthorized, c.JSON(http.StatusUnauthorized,
gin.H{"errors": []gin.H{{ gin.H{"errors": []gin.H{{
@@ -104,24 +114,30 @@ func AuthMiddleware() gin.HandlerFunc {
return return
} }
if newToken != tokenString {
utils.SetCookie(c, newToken) utils.SetCookie(c, newToken)
}
c.Set("user_id", uint(userID)) c.Set("user_id", uint(userID))
c.Next() c.Next()
} }
} }
func GenerateToken(jwtKey string, user *models.User, sessionID string) (string, error) { // GenerateToken generates a new JWT token with the given claims and session ID.
// "user_id": user.ID, "role_id": user.RoleID
func GenerateToken(jwtKey *string, claims map[string]interface{}, sessionID string) (string, error) {
if sessionID == "" { if sessionID == "" {
sessionID = uuid.New().String() sessionID = uuid.New().String()
} }
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{ claims["session_id"] = sessionID
"user_id": user.ID, claims["exp"] = time.Now().Add(time.Minute * 1).Unix() // Token expires in 10 Minutes
"role_id": user.RoleID, token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims(claims))
"session_id": sessionID,
"exp": time.Now().Add(time.Minute * 1).Unix(), // Token expires in 10 Minutes userID, ok := claims["user_id"].(uint)
}) if !ok {
UpdateSession(sessionID, user.ID) return "", fmt.Errorf("invalid user_id in claims")
return token.SignedString([]byte(jwtKey)) }
UpdateSession(sessionID, userID)
return token.SignedString([]byte(*jwtKey))
} }
func ExtractContentFrom(tokenString string) (*jwt.Token, *jwt.MapClaims, error) { func ExtractContentFrom(tokenString string) (*jwt.Token, *jwt.MapClaims, error) {
@@ -130,23 +146,33 @@ func ExtractContentFrom(tokenString string) (*jwt.Token, *jwt.MapClaims, error)
return []byte(config.Auth.JWTSecret), nil return []byte(config.Auth.JWTSecret), nil
}) })
if !errors.Is(err, jwt.ErrTokenExpired) && err != nil { // Handle parsing errors (excluding expiration error)
logger.Error.Printf("Error during token(%v) parsing: %#v", tokenString, err) if err != nil && !errors.Is(err, jwt.ErrTokenExpired) {
logger.Error.Printf("Error parsing token: %v", err)
return nil, nil, err return nil, nil, err
} }
// Token is expired, check if session is still valid // Ensure token is not nil (e.g., malformed tokens)
claims, ok := token.Claims.(jwt.MapClaims) if token == nil {
if !ok { logger.Error.Print("Token is nil after parsing")
logger.Error.Printf("Invalid Token Claims") return nil, nil, fmt.Errorf("invalid token")
return nil, nil, fmt.Errorf("invalid token claims")
} }
// Extract and validate claims
claims, ok := token.Claims.(jwt.MapClaims)
if !ok { if !ok {
logger.Error.Printf("invalid session_id in token") logger.Error.Print("Invalid token claims structure")
return nil, nil, fmt.Errorf("invalid session_id in token") return nil, nil, fmt.Errorf("invalid token claims format")
} }
return token, &claims, nil
// Validate required session_id claim
if _, exists := claims["session_id"]; !exists {
logger.Error.Print("Missing session_id in token claims")
return nil, nil, fmt.Errorf("missing session_id claim")
}
// Return token, claims, and original error (might be expiration)
return token, &claims, err
} }
func UpdateSession(sessionID string, userID uint) { func UpdateSession(sessionID string, userID uint) {

View File

@@ -3,7 +3,6 @@ package middlewares
import ( import (
"GoMembership/internal/config" "GoMembership/internal/config"
"GoMembership/internal/constants" "GoMembership/internal/constants"
"GoMembership/internal/models"
"GoMembership/pkg/logger" "GoMembership/pkg/logger"
"encoding/json" "encoding/json"
"log" "log"
@@ -56,8 +55,11 @@ func TestAuthMiddleware(t *testing.T) {
{ {
name: "Valid Token", name: "Valid Token",
setupAuth: func(r *http.Request) { setupAuth: func(r *http.Request) {
user := models.User{ID: 123, RoleID: constants.Roles.Member} claims := map[string]interface{}{"user_id": uint(123), "role_id": constants.Roles.Member}
token, _ := GenerateToken(config.Auth.JWTSecret, &user, "") token, err := GenerateToken(&config.Auth.JWTSecret, claims, "")
if err != nil {
t.Fatal(err)
}
r.AddCookie(&http.Cookie{Name: "jwt", Value: token}) r.AddCookie(&http.Cookie{Name: "jwt", Value: token})
}, },
expectedStatus: http.StatusOK, expectedStatus: http.StatusOK,
@@ -82,7 +84,7 @@ func TestAuthMiddleware(t *testing.T) {
setupAuth: func(r *http.Request) { setupAuth: func(r *http.Request) {
sessionID := "test-session" sessionID := "test-session"
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{ token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{
"user_id": 123, "user_id": uint(123),
"role_id": constants.Roles.Member, "role_id": constants.Roles.Member,
"session_id": sessionID, "session_id": sessionID,
"exp": time.Now().Add(-time.Hour).Unix(), // Expired 1 hour ago "exp": time.Now().Add(-time.Hour).Unix(), // Expired 1 hour ago
@@ -100,7 +102,7 @@ func TestAuthMiddleware(t *testing.T) {
setupAuth: func(r *http.Request) { setupAuth: func(r *http.Request) {
sessionID := "expired-session" sessionID := "expired-session"
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{ token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{
"user_id": 123, "user_id": uint(123),
"role_id": constants.Roles.Member, "role_id": constants.Roles.Member,
"session_id": sessionID, "session_id": sessionID,
"exp": time.Now().Add(-time.Hour).Unix(), // Expired 1 hour ago "exp": time.Now().Add(-time.Hour).Unix(), // Expired 1 hour ago
@@ -116,7 +118,7 @@ func TestAuthMiddleware(t *testing.T) {
name: "Invalid Signature", name: "Invalid Signature",
setupAuth: func(r *http.Request) { setupAuth: func(r *http.Request) {
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{ token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{
"user_id": 123, "user_id": uint(123),
"session_id": "some-session", "session_id": "some-session",
"exp": time.Now().Add(time.Hour).Unix(), "exp": time.Now().Add(time.Hour).Unix(),
}) })
@@ -130,7 +132,7 @@ func TestAuthMiddleware(t *testing.T) {
name: "Invalid Signing Method", name: "Invalid Signing Method",
setupAuth: func(r *http.Request) { setupAuth: func(r *http.Request) {
token := jwt.NewWithClaims(jwt.SigningMethodES256, jwt.MapClaims{ token := jwt.NewWithClaims(jwt.SigningMethodES256, jwt.MapClaims{
"user_id": 123, "user_id": uint(123),
"session_id": "some-session", "session_id": "some-session",
"role_id": constants.Roles.Member, "role_id": constants.Roles.Member,
"exp": time.Now().Add(time.Hour).Unix(), "exp": time.Now().Add(time.Hour).Unix(),

View File

@@ -9,7 +9,6 @@ import (
) )
func CSPMiddleware() gin.HandlerFunc { func CSPMiddleware() gin.HandlerFunc {
logger.Error.Printf("applying CSP")
return func(c *gin.Context) { return func(c *gin.Context) {
policy := "default-src 'self'; " + policy := "default-src 'self'; " +
"script-src 'self' 'unsafe-inline'" + "script-src 'self' 'unsafe-inline'" +
@@ -35,7 +34,6 @@ func CSPMiddleware() gin.HandlerFunc {
func CSPReportHandling(c *gin.Context) { func CSPReportHandling(c *gin.Context) {
var report map[string]interface{} var report map[string]interface{}
if err := c.BindJSON(&report); err != nil { if err := c.BindJSON(&report); err != nil {
logger.Error.Printf("Couldn't Bind JSON: %#v", err) logger.Error.Printf("Couldn't Bind JSON: %#v", err)
return return
} }

View File

@@ -1,13 +1,10 @@
package middlewares package middlewares
import ( import (
"GoMembership/pkg/logger"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
func SecurityHeadersMiddleware() gin.HandlerFunc { func SecurityHeadersMiddleware() gin.HandlerFunc {
logger.Error.Printf("applying headers")
return func(c *gin.Context) { return func(c *gin.Context) {
c.Header("X-Frame-Options", "DENY") c.Header("X-Frame-Options", "DENY")
c.Header("X-Content-Type-Options", "nosniff") c.Header("X-Content-Type-Options", "nosniff")

View File

@@ -6,6 +6,7 @@ import (
type Licence struct { type Licence struct {
ID uint `json:"id"` ID uint `json:"id"`
UserID uint `json:"user_id"`
CreatedAt time.Time CreatedAt time.Time
UpdatedAt time.Time UpdatedAt time.Time
Status int8 `json:"status" binding:"omitempty,number"` Status int8 `json:"status" binding:"omitempty,number"`

View File

@@ -1,11 +1,20 @@
package models package models
import ( import (
"GoMembership/internal/config"
"GoMembership/internal/constants"
"GoMembership/internal/utils"
"GoMembership/pkg/errors"
"GoMembership/pkg/logger"
"fmt" "fmt"
"slices"
"time" "time"
"github.com/alexedwards/argon2id" "github.com/alexedwards/argon2id"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause"
) )
type User struct { type User struct {
@@ -28,8 +37,7 @@ type User struct {
Consents []Consent `gorm:"constraint:OnUpdate:CASCADE"` Consents []Consent `gorm:"constraint:OnUpdate:CASCADE"`
BankAccount BankAccount `gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE;" json:"bank_account"` BankAccount BankAccount `gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE;" json:"bank_account"`
BankAccountID uint BankAccountID uint
Verification Verification `gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE;"` Verifications *[]Verification `gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE;"`
VerificationID uint
Membership Membership `gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE;" json:"membership"` Membership Membership `gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE;" json:"membership"`
MembershipID uint MembershipID uint
Licence *Licence `gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE;" json:"licence"` Licence *Licence `gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE;" json:"licence"`
@@ -69,6 +77,268 @@ func (u *User) PasswordMatches(plaintextPassword string) (bool, error) {
return argon2id.ComparePasswordAndHash(plaintextPassword, u.Password) return argon2id.ComparePasswordAndHash(plaintextPassword, u.Password)
} }
func (u *User) PasswordExists() bool {
return u.Password != ""
}
func (u *User) Delete(db *gorm.DB) error {
return db.Delete(&User{}, "id = ?", u.ID).Error
}
func (u *User) Create(db *gorm.DB) error {
return db.Transaction(func(tx *gorm.DB) error {
// Create the base User record (omit associations to handle them separately)
if err := tx.Create(u).Error; err != nil {
return err
}
// Replace associated Categories (assumes Categories already exist)
if u.Licence != nil && len(u.Licence.Categories) > 0 {
if err := tx.Model(u.Licence).Association("Categories").Replace(u.Licence.Categories); err != nil {
return err
}
}
logger.Info.Printf("user created: %#v", u.Safe())
// Preload all associations to return the fully populated User
return tx.
Preload("Membership").
Preload("Membership.SubscriptionModel").
Preload("Licence").
Preload("Licence.Categories").
First(u, u.ID).Error // Refresh the user object with all associations
})
}
func (u *User) Update(db *gorm.DB) error {
err := db.Transaction(func(tx *gorm.DB) error {
// Check if the user exists in the database
var existingUser User
logger.Info.Printf("updating user: %#v", u)
if err := tx.
Preload("Membership").
Preload("Membership.SubscriptionModel").
Preload("Licence").
Preload("Licence.Categories").
Preload("Verifications").
First(&existingUser, u.ID).Error; err != nil {
return err
}
// Update the user's main fields
result := tx.Session(&gorm.Session{FullSaveAssociations: true}).Omit("Password", "Membership", "Licence", "Verifications").Updates(u)
if result.Error != nil {
logger.Error.Printf("User update error in update user: %#v", result.Error)
return result.Error
}
if result.RowsAffected == 0 {
return errors.ErrNoRowsAffected
}
if u.Password != "" {
if err := tx.Model(&existingUser).
Update("Password", u.Password).Error; err != nil {
logger.Error.Printf("Password update error in update user: %#v", err)
return err
}
}
// Update the Membership if provided
if u.Membership.ID != 0 {
if err := tx.Model(&existingUser.Membership).Updates(u.Membership).Error; err != nil {
logger.Error.Printf("Membership update error in update user: %#v", err)
return err
}
}
if u.Licence != nil {
u.Licence.UserID = existingUser.ID
if err := tx.Save(u.Licence).Error; err != nil {
return err
}
if err := tx.Model(&existingUser).Update("LicenceID", u.Licence.ID).Error; err != nil {
return err
}
if err := tx.Model(u.Licence).Association("Categories").Replace(u.Licence.Categories); err != nil {
return err
}
}
// if u.Licence != nil {
// if existingUser.Licence == nil || existingUser.LicenceID == 0 {
// u.Licence.UserID = existingUser.ID // Ensure Licence belongs to User
// if err := tx.Create(u.Licence).Error; err != nil {
// return err
// }
// existingUser.Licence = u.Licence
// existingUser.LicenceID = u.Licence.ID
// if err := tx.Model(&existingUser).Update("LicenceID", u.Licence.ID).Error; err != nil {
// return err
// }
// }
// if err := tx.Model(existingUser.Licence).Updates(u.Licence).Error; err != nil {
// return err
// }
// // Update Categories association
// if err := tx.Model(existingUser.Licence).Association("Categories").Replace(u.Licence.Categories); err != nil {
// return err
// }
// }
if u.Verifications != nil {
if err := tx.Save(*u.Verifications).Error; err != nil {
return err
}
}
return nil
})
if err != nil {
return err
}
return db.
Preload("Membership").
Preload("Membership.SubscriptionModel").
Preload("Licence").
Preload("Licence.Categories").
Preload("Verifications").
First(&u, u.ID).Error
}
func (u *User) FromID(db *gorm.DB, userID *uint) error {
var user User
result := db.
Preload(clause.Associations).
Preload("Membership").
Preload("Membership.SubscriptionModel").
Preload("Licence").
Preload("Licence.Categories").
Preload("Verifications").
First(&user, userID)
if result.Error != nil {
if result.Error == gorm.ErrRecordNotFound {
return gorm.ErrRecordNotFound
}
return result.Error
}
*u = user
return nil
}
func (u *User) FromEmail(db *gorm.DB, email *string) error {
var user User
result := db.
Preload(clause.Associations).
Preload("Membership").
Preload("Membership.SubscriptionModel").
Preload("Licence").
Preload("Licence.Categories").
Preload("Verifications").
Where("email = ?", email).First(&user)
if result.Error != nil {
if result.Error == gorm.ErrRecordNotFound {
return gorm.ErrRecordNotFound
}
return result.Error
}
*u = user
return nil
}
func (u *User) FromContext(db *gorm.DB, c *gin.Context) error {
tokenString, err := c.Cookie("jwt")
if err != nil {
return err
}
jwtUserID, err := extractUserIDFrom(tokenString)
if err != nil {
return err
}
if err = u.FromID(db, &jwtUserID); err != nil {
return err
}
return nil
}
func (u *User) IsVerified() bool {
return u.Status > constants.DisabledStatus
}
func (u *User) HasPrivilege(privilege int8) bool {
return u.RoleID >= privilege
}
func (u *User) IsAdmin() bool {
return u.RoleID == constants.Roles.Admin
}
func (u *User) IsMember() bool {
return u.RoleID == constants.Roles.Member
}
func (u *User) IsSupporter() bool {
return u.RoleID == constants.Roles.Supporter
}
func (u *User) SetVerification(verificationType string) (*Verification, error) {
if u.Verifications == nil {
u.Verifications = new([]Verification)
}
token, err := utils.GenerateVerificationToken()
if err != nil {
return nil, err
}
v := Verification{
UserID: u.ID,
VerificationToken: token,
Type: verificationType,
}
if vi := slices.IndexFunc(*u.Verifications, func(vsl Verification) bool { return vsl.Type == v.Type }); vi > -1 {
(*u.Verifications)[vi] = v
} else {
*u.Verifications = append(*u.Verifications, v)
}
return &v, nil
}
func (u *User) GetVerification(verificationType string) (*Verification, error) {
if u.Verifications == nil {
return nil, errors.ErrNoData
}
vi := slices.IndexFunc(*u.Verifications, func(vsl Verification) bool { return vsl.Type == verificationType })
if vi == -1 {
return nil, errors.ErrNotFound
}
return &(*u.Verifications)[vi], nil
}
func (u *User) Verify(token string, verificationType string) bool {
if token == "" || verificationType == "" {
logger.Error.Printf("token or verification type are empty in user.Verify")
return false
}
vi := slices.IndexFunc(*u.Verifications, func(vsl Verification) bool {
return vsl.Type == verificationType && vsl.VerificationToken == token
})
if vi == -1 {
logger.Error.Printf("Couldn't find verification in users verifications")
return false
}
if (*u.Verifications)[vi].VerifiedAt != nil {
logger.Error.Printf("VerifiedAt is not nil, already verified?: %#v", (*u.Verifications)[vi])
return false
}
t := time.Now()
(*u.Verifications)[vi].VerifiedAt = &t
return true
}
func (u *User) Safe() map[string]interface{} { func (u *User) Safe() map[string]interface{} {
result := map[string]interface{}{ result := map[string]interface{}{
"email": u.Email, "email": u.Email,
@@ -128,3 +398,61 @@ func (u *User) Safe() map[string]interface{} {
return result return result
} }
func extractUserIDFrom(tokenString string) (uint, error) {
jwtSigningMethod := jwt.SigningMethodHS256
jwtParser := jwt.NewParser(jwt.WithValidMethods([]string{jwtSigningMethod.Alg()}))
token, err := jwtParser.Parse(tokenString, func(_ *jwt.Token) (interface{}, error) {
return []byte(config.Auth.JWTSecret), nil
})
// Handle parsing errors (excluding expiration error)
if err != nil && !errors.Is(err, jwt.ErrTokenExpired) || token == nil {
logger.Error.Printf("Error parsing token: %v", err)
return 0, err
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
logger.Error.Print("Invalid token claims structure")
return 0, fmt.Errorf("invalid token claims format")
}
// Validate required session_id claim
if _, exists := claims["session_id"]; !exists {
logger.Error.Print("Missing session_id in token claims")
return 0, fmt.Errorf("missing session_id claim")
}
// Return token, claims, and original error (might be expiration)
if _, exists := claims["session_id"]; !exists {
logger.Error.Print("Missing session_id in token claims")
return 0, fmt.Errorf("missing session_id claim")
}
id, ok := claims["user_id"]
if !ok {
return 0, fmt.Errorf("missing user_id claim")
}
return uint(id.(float64)), nil
}
func GetUsersWhere(db *gorm.DB, where map[string]interface{}) (*[]User, error) {
var users []User
result := db.
Preload(clause.Associations).
Preload("Membership").
Preload("Membership.SubscriptionModel").
Preload("Licence").
Preload("Licence.Categories").
Preload("Verifications").
Where(where).Find(&users)
if result.Error != nil {
if result.Error == gorm.ErrRecordNotFound {
return nil, gorm.ErrRecordNotFound
}
return nil, result.Error
}
return &users, nil
}

View File

@@ -1,13 +1,15 @@
package models package models
import "time" import (
"time"
)
type Verification struct { type Verification struct {
UpdatedAt time.Time UpdatedAt time.Time
CreatedAt time.Time CreatedAt time.Time
VerifiedAt *time.Time `gorm:"Default:NULL" json:"verified_at"` VerifiedAt *time.Time `json:"verified_at"`
VerificationToken string `json:"token"` VerificationToken string `json:"token"`
ID uint `gorm:"primaryKey"` ID uint `gorm:"primaryKey"`
UserID uint `gorm:"unique;" json:"user_id"` UserID uint `json:"user_id"`
Type string Type string
} }

View File

@@ -1,10 +0,0 @@
package repositories
import (
"GoMembership/internal/database"
"GoMembership/internal/models"
)
func (r *UserRepository) SetUserStatus(id uint, status uint) error {
return database.DB.Model(&models.User{}).Where("id = ?", id).Update("status", status).Error
}

View File

@@ -1,159 +0,0 @@
package repositories
import (
"gorm.io/gorm"
"GoMembership/internal/database"
"gorm.io/gorm/clause"
"GoMembership/internal/models"
"GoMembership/pkg/errors"
"GoMembership/pkg/logger"
)
type UserRepositoryInterface interface {
CreateUser(user *models.User) (uint, error)
UpdateUser(user *models.User) (*models.User, error)
GetUsers(where map[string]interface{}) (*[]models.User, error)
GetUserByEmail(email string) (*models.User, error)
IsVerified(userID *uint) (bool, error)
GetVerificationOfToken(token *string, verificationType *string) (*models.Verification, error)
SetVerificationToken(verification *models.Verification) (token string, err error)
DeleteVerification(id uint, verificationType string) error
DeleteUser(id uint) error
SetUserStatus(id uint, status uint) error
}
type UserRepository struct{}
func (ur *UserRepository) DeleteUser(id uint) error {
return database.DB.Delete(&models.User{}, "id = ?", id).Error
}
func PasswordExists(userID *uint) (bool, error) {
var user models.User
result := database.DB.Select("password").First(&user, userID)
if result.Error != nil {
return false, result.Error
}
return user.Password != "", nil
}
func (ur *UserRepository) CreateUser(user *models.User) (uint, error) {
result := database.DB.Create(user)
if result.Error != nil {
logger.Error.Printf("Create User error: %#v", result.Error)
return 0, result.Error
}
return user.ID, nil
}
func (ur *UserRepository) UpdateUser(user *models.User) (*models.User, error) {
if user == nil {
return nil, errors.ErrNoData
}
err := database.DB.Transaction(func(tx *gorm.DB) error {
// Check if the user exists in the database
var existingUser models.User
if err := tx.Preload(clause.Associations).
Preload("Membership").
Preload("Membership.SubscriptionModel").
Preload("Licence").
Preload("Licence.Categories").
First(&existingUser, user.ID).Error; err != nil {
return err
}
// Update the user's main fields
result := tx.Session(&gorm.Session{FullSaveAssociations: true}).Omit("Password").Updates(user)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return errors.ErrNoRowsAffected
}
if user.Password != "" {
if err := tx.Model(&models.User{}).
Where("id = ?", user.ID).
Update("Password", user.Password).Error; err != nil {
return err
}
}
// Update the Membership if provided
if user.Membership.ID != 0 {
if err := tx.Model(&existingUser.Membership).Updates(user.Membership).Error; err != nil {
return err
}
}
// Replace categories if Licence and Categories are provided
if user.Licence != nil {
if err := tx.Model(&user.Licence).Association("Categories").Replace(user.Licence.Categories); err != nil {
return err
}
}
return nil
})
if err != nil {
return nil, err
}
var updatedUser models.User
if err := database.DB.Preload("Licence.Categories").
Preload("Membership").
First(&updatedUser, user.ID).Error; err != nil {
return nil, err
}
return &updatedUser, nil
}
func (ur *UserRepository) GetUsers(where map[string]interface{}) (*[]models.User, error) {
var users []models.User
result := database.DB.
Preload(clause.Associations).
Preload("Membership.SubscriptionModel").
Preload("Licence.Categories").
Where(where).Find(&users)
if result.Error != nil {
if result.Error == gorm.ErrRecordNotFound {
return nil, gorm.ErrRecordNotFound
}
return nil, result.Error
}
return &users, nil
}
func GetUserByID(userID *uint) (*models.User, error) {
var user models.User
result := database.DB.
Preload(clause.Associations).
Preload("Membership").
Preload("Membership.SubscriptionModel").
Preload("Licence.Categories").
First(&user, userID)
if result.Error != nil {
if result.Error == gorm.ErrRecordNotFound {
return nil, gorm.ErrRecordNotFound
}
return nil, result.Error
}
return &user, nil
}
func (ur *UserRepository) GetUserByEmail(email string) (*models.User, error) {
var user models.User
result := database.DB.Where("email = ?", email).First(&user)
if result.Error != nil {
if result.Error == gorm.ErrRecordNotFound {
return nil, gorm.ErrRecordNotFound
}
return nil, result.Error
}
return &user, nil
}

View File

@@ -1,57 +0,0 @@
package repositories
import (
"GoMembership/internal/constants"
"GoMembership/internal/database"
"GoMembership/internal/models"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
func (ur *UserRepository) IsVerified(userID *uint) (bool, error) {
var user models.User
result := database.DB.Select("status").First(&user, userID)
if result.Error != nil {
if result.Error == gorm.ErrRecordNotFound {
return false, gorm.ErrRecordNotFound
}
return false, result.Error
}
return user.Status > constants.DisabledStatus, nil
}
func (ur *UserRepository) GetVerificationOfToken(token *string, verificationType *string) (*models.Verification, error) {
var emailVerification models.Verification
result := database.DB.Where("verification_token = ? AND type = ?", token, verificationType).First(&emailVerification)
if result.Error != nil {
if result.Error == gorm.ErrRecordNotFound {
return nil, gorm.ErrRecordNotFound
}
return nil, result.Error
}
return &emailVerification, nil
}
func (ur *UserRepository) SetVerificationToken(verification *models.Verification) (token string, err error) {
result := database.DB.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "user_id"}},
DoUpdates: clause.AssignmentColumns([]string{"verification_token", "created_at", "type"}),
}).Create(&verification)
if result.Error != nil {
return "", result.Error
}
return verification.VerificationToken, nil
}
func (ur *UserRepository) DeleteVerification(id uint, verificationType string) error {
result := database.DB.Where("user_id = ? AND type = ?", id, verificationType).Delete(&models.Verification{})
if result.Error != nil {
return result.Error
}
return nil
}

View File

@@ -8,7 +8,7 @@ import (
) )
func RegisterRoutes(router *gin.Engine, userController *controllers.UserController, membershipcontroller *controllers.MembershipController, contactController *controllers.ContactController, licenceController *controllers.LicenceController) { func RegisterRoutes(router *gin.Engine, userController *controllers.UserController, membershipcontroller *controllers.MembershipController, contactController *controllers.ContactController, licenceController *controllers.LicenceController) {
router.GET("/api/users/verify", userController.VerifyMailHandler) router.GET("/api/users/verify/:id", userController.VerifyMailHandler)
router.POST("/api/users/register", userController.RegisterUser) router.POST("/api/users/register", userController.RegisterUser)
router.POST("/api/users/contact", contactController.RelayContactRequest) router.POST("/api/users/contact", contactController.RelayContactRequest)
router.POST("/api/users/password/request-change", userController.RequestPasswordChangeHandler) router.POST("/api/users/password/request-change", userController.RequestPasswordChangeHandler)

View File

@@ -20,13 +20,14 @@ import (
"GoMembership/pkg/logger" "GoMembership/pkg/logger"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gorm.io/gorm"
) )
var shutdownChannel = make(chan struct{}) var shutdownChannel = make(chan struct{})
var srv *http.Server var srv *http.Server
// Run initializes the server configuration, sets up services and controllers, and starts the HTTP server. // Run initializes the server configuration, sets up services and controllers, and starts the HTTP server.
func Run() { func Run(db *gorm.DB) {
emailService := services.NewEmailService(config.SMTP.Host, config.SMTP.Port, config.SMTP.User, config.SMTP.Password) emailService := services.NewEmailService(config.SMTP.Host, config.SMTP.Port, config.SMTP.User, config.SMTP.Password)
var consentRepo repositories.ConsentRepositoryInterface = &repositories.ConsentRepository{} var consentRepo repositories.ConsentRepositoryInterface = &repositories.ConsentRepository{}
@@ -41,13 +42,11 @@ func Run() {
var licenceRepo repositories.LicenceInterface = &repositories.LicenceRepository{} var licenceRepo repositories.LicenceInterface = &repositories.LicenceRepository{}
licenceService := &services.LicenceService{Repo: licenceRepo} licenceService := &services.LicenceService{Repo: licenceRepo}
userService := &services.UserService{DB: db, Licences: licenceRepo}
var userRepo repositories.UserRepositoryInterface = &repositories.UserRepository{}
userService := &services.UserService{Repo: userRepo, Licences: licenceRepo}
userController := &controllers.UserController{Service: userService, EmailService: emailService, ConsentService: consentService, LicenceService: licenceService, BankAccountService: bankAccountService, MembershipService: membershipService} userController := &controllers.UserController{Service: userService, EmailService: emailService, ConsentService: consentService, LicenceService: licenceService, BankAccountService: bankAccountService, MembershipService: membershipService}
membershipController := &controllers.MembershipController{Service: *membershipService, UserController: userController} membershipController := &controllers.MembershipController{Service: membershipService, UserService: userService}
licenceController := &controllers.LicenceController{Service: *licenceService} licenceController := &controllers.LicenceController{Service: licenceService}
contactController := &controllers.ContactController{EmailService: emailService} contactController := &controllers.ContactController{EmailService: emailService}
router := gin.Default() router := gin.Default()
@@ -65,7 +64,7 @@ func Run() {
router.Use(middlewares.RateLimitMiddleware(limiter)) router.Use(middlewares.RateLimitMiddleware(limiter))
routes.RegisterRoutes(router, userController, membershipController, contactController, licenceController) routes.RegisterRoutes(router, userController, membershipController, contactController, licenceController)
validation.SetupValidators() validation.SetupValidators(db)
logger.Info.Println("Starting server on :8080") logger.Info.Println("Starting server on :8080")
srv = &http.Server{ srv = &http.Server{

View File

@@ -71,11 +71,13 @@ func (s *EmailService) SendVerificationEmail(user *models.User, token *string) e
LastName string LastName string
Token string Token string
BASEURL string BASEURL string
ID uint
}{ }{
FirstName: user.FirstName, FirstName: user.FirstName,
LastName: user.LastName, LastName: user.LastName,
Token: *token, Token: *token,
BASEURL: config.Site.BaseURL, BASEURL: config.Site.BaseURL,
ID: user.ID,
} }
subject := constants.MailVerificationSubject subject := constants.MailVerificationSubject

View File

@@ -5,7 +5,7 @@ import (
"GoMembership/internal/repositories" "GoMembership/internal/repositories"
) )
type LicenceInterface interface { type LicenceServiceInterface interface {
GetAllCategories() ([]models.Category, error) GetAllCategories() ([]models.Category, error)
} }

View File

@@ -1,21 +0,0 @@
package services
import (
"GoMembership/internal/constants"
"GoMembership/internal/models"
)
func (s *UserService) HandlePasswordChangeRequest(user *models.User) (token string, err error) {
// Deactivate user and reset Verification
if err := s.SetUserStatus(user.ID, constants.DisabledStatus); err != nil {
return "", err
}
if err := s.RevokeVerification(&user.ID, constants.VerificationTypes.Password); err != nil {
return "", err
}
// Generate a token
return s.SetVerificationToken(&user.ID, &constants.VerificationTypes.Password)
}

View File

@@ -1,5 +0,0 @@
package services
func (s *UserService) SetUserStatus(id uint, status uint) error {
return s.Repo.SetUserStatus(id, status)
}

View File

@@ -8,51 +8,74 @@ import (
"GoMembership/internal/repositories" "GoMembership/internal/repositories"
"GoMembership/pkg/errors" "GoMembership/pkg/errors"
"github.com/gin-gonic/gin"
"gorm.io/gorm" "gorm.io/gorm"
"time" "time"
) )
type UserServiceInterface interface { type UserServiceInterface interface {
RegisterUser(user *models.User) (id uint, token string, err error) Register(user *models.User) (id uint, token string, err error)
GetUserByEmail(email string) (*models.User, error) Update(user *models.User) (*models.User, error)
GetUserByID(id uint) (*models.User, error) Delete(id *uint) error
FromContext(c *gin.Context) (*models.User, error)
FromID(id *uint) (*models.User, error)
FromEmail(email *string) (*models.User, error)
GetUsers(where map[string]interface{}) (*[]models.User, error) GetUsers(where map[string]interface{}) (*[]models.User, error)
UpdateUser(user *models.User) (*models.User, error)
DeleteUser(lastname string, id uint) error
SetUserStatus(id uint, status uint) error
VerifyUser(token *string, verificationType *string) (*models.Verification, error)
SetVerificationToken(id *uint, verificationType *string) (string, error)
RevokeVerification(id *uint, verificationType string) error
HandlePasswordChangeRequest(user *models.User) (token string, err error)
} }
type UserService struct { type UserService struct {
Repo repositories.UserRepositoryInterface
Licences repositories.LicenceInterface Licences repositories.LicenceInterface
DB *gorm.DB
} }
func (service *UserService) DeleteUser(lastname string, id uint) error { func (s *UserService) FromContext(c *gin.Context) (*models.User, error) {
if id == 0 || lastname == "" { var user models.User
return errors.ErrNoData if err := user.FromContext(s.DB, c); err != nil {
return nil, err
}
return &user, nil
} }
user, err := service.GetUserByID(id) func (s *UserService) FromID(id *uint) (*models.User, error) {
if err != nil { var user models.User
if err := user.FromID(s.DB, id); err != nil {
return nil, err
}
return &user, nil
}
func (s *UserService) FromEmail(email *string) (*models.User, error) {
var user models.User
if err := user.FromEmail(s.DB, email); err != nil {
return nil, err
}
return &user, nil
}
func (s *UserService) Delete(id *uint) error {
var user models.User
if err := user.FromID(s.DB, id); err != nil {
return err return err
} }
if user == nil {
return errors.ErrUserNotFound return user.Delete(s.DB)
} }
return service.Repo.DeleteUser(id) func (s *UserService) Update(user *models.User) (*models.User, error) {
}
func (service *UserService) UpdateUser(user *models.User) (*models.User, error) { var existingUser models.User
if err := existingUser.FromID(s.DB, &user.ID); err != nil {
if user.ID == 0 { return nil, err
return nil, errors.ErrUserNotFound
} }
user.MembershipID = existingUser.MembershipID
user.Membership.ID = existingUser.Membership.ID
if existingUser.Licence != nil {
user.Licence.ID = existingUser.Licence.ID
user.LicenceID = existingUser.LicenceID
}
user.BankAccount.ID = existingUser.BankAccount.ID
user.BankAccountID = existingUser.BankAccountID
user.SetPassword(user.Password) user.SetPassword(user.Password)
@@ -64,9 +87,7 @@ func (service *UserService) UpdateUser(user *models.User) (*models.User, error)
user.Membership.SubscriptionModel = *selectedModel user.Membership.SubscriptionModel = *selectedModel
user.Membership.SubscriptionModelID = selectedModel.ID user.Membership.SubscriptionModelID = selectedModel.ID
updatedUser, err := service.Repo.UpdateUser(user) if err := user.Update(s.DB); err != nil {
if err != nil {
if err == gorm.ErrRecordNotFound { if err == gorm.ErrRecordNotFound {
return nil, errors.ErrUserNotFound return nil, errors.ErrUserNotFound
} }
@@ -75,42 +96,38 @@ func (service *UserService) UpdateUser(user *models.User) (*models.User, error)
} }
return nil, err return nil, err
} }
return user, nil
return updatedUser, nil
} }
func (service *UserService) RegisterUser(user *models.User) (id uint, token string, err error) { func (s *UserService) Register(user *models.User) (id uint, token string, err error) {
user.SetPassword(user.Password) user.SetPassword(user.Password)
selectedModel, err := repositories.GetSubscriptionByName(&user.Membership.SubscriptionModel.Name)
if err != nil {
return 0, "", errors.ErrSubscriptionNotFound
}
user.Membership.SubscriptionModel = *selectedModel
user.Membership.SubscriptionModelID = selectedModel.ID
user.Status = constants.UnverifiedStatus user.Status = constants.UnverifiedStatus
user.CreatedAt = time.Now()
user.UpdatedAt = time.Now()
user.PaymentStatus = constants.AwaitingPaymentStatus user.PaymentStatus = constants.AwaitingPaymentStatus
user.BankAccount.MandateDateSigned = time.Now() user.BankAccount.MandateDateSigned = time.Now()
id, err = service.Repo.CreateUser(user) v, err := user.SetVerification(constants.VerificationTypes.Email)
if err != nil { if err != nil {
return 0, "", err return 0, "", err
} }
token, err = service.SetVerificationToken(&id, &constants.VerificationTypes.Email) if err := user.Create(s.DB); err != nil {
if err != nil {
return 0, "", err return 0, "", err
} }
return id, token, nil
return user.ID, v.VerificationToken, nil
} }
func (service *UserService) GetUserByID(id uint) (*models.User, error) { // GetUsers returns a list of users based on the provided where clause.
return repositories.GetUserByID(&id) // if where == nil: all users are returned
} func (s *UserService) GetUsers(where map[string]interface{}) (*[]models.User, error) {
func (service *UserService) GetUserByEmail(email string) (*models.User, error) {
return service.Repo.GetUserByEmail(email)
}
func (service *UserService) GetUsers(where map[string]interface{}) (*[]models.User, error) {
if where == nil { if where == nil {
where = map[string]interface{}{} where = map[string]interface{}{}
} }
return service.Repo.GetUsers(where) return models.GetUsersWhere(s.DB, where)
} }

View File

@@ -1,59 +0,0 @@
package services
import (
"GoMembership/internal/models"
"GoMembership/internal/utils"
"GoMembership/pkg/errors"
"time"
)
func (s *UserService) SetVerificationToken(id *uint, verificationType *string) (string, error) {
token, err := utils.GenerateVerificationToken()
if err != nil {
return "", err
}
// Check if user is already verified
verified, err := s.Repo.IsVerified(id)
if err != nil {
return "", err
}
if verified {
return "", errors.ErrAlreadyVerified
}
// Prepare the Verification record
verification := models.Verification{
UserID: *id,
VerificationToken: token,
Type: *verificationType,
}
return s.Repo.SetVerificationToken(&verification)
}
func (s *UserService) RevokeVerification(id *uint, verificationType string) error {
return s.Repo.DeleteVerification(*id, verificationType)
}
func (service *UserService) VerifyUser(token *string, verificationType *string) (*models.Verification, error) {
verification, err := service.Repo.GetVerificationOfToken(token, verificationType)
if err != nil {
return nil, err
}
// Check if the user is already verified
verified, err := service.Repo.IsVerified(&verification.UserID)
if err != nil {
return nil, err
}
if verified {
return nil, errors.ErrAlreadyVerified
}
t := time.Now()
verification.VerifiedAt = &t
return verification, nil
}

View File

@@ -81,18 +81,6 @@ func HandleSubscriptionUpdateError(c *gin.Context, err error) {
} }
} }
func HandleVerifyUserError(c *gin.Context, err error) {
if err.Error() == "record not found" {
RespondWithError(c, err, "Couldn't find verification. This is most probably a outdated token.", http.StatusGone, errors.Responses.Fields.User, errors.Responses.Keys.NoAuthToken)
}
switch err {
case errors.ErrAlreadyVerified:
RespondWithError(c, err, "User already changed password", http.StatusConflict, errors.Responses.Fields.User, errors.Responses.Keys.PasswordAlreadyChanged)
default:
RespondWithError(c, err, "Couldn't verify user", http.StatusInternalServerError, errors.Responses.Fields.General, errors.Responses.Keys.InternalServerError)
}
}
func HandleDeleteUserError(c *gin.Context, err error) { func HandleDeleteUserError(c *gin.Context, err error) {
if err.Error() == "record not found" { if err.Error() == "record not found" {
RespondWithError(c, err, "Couldn't find user", http.StatusNotFound, errors.Responses.Fields.User, errors.Responses.Keys.NotFound) RespondWithError(c, err, "Couldn't find user", http.StatusNotFound, errors.Responses.Fields.User, errors.Responses.Keys.NotFound)

View File

@@ -1,15 +1,10 @@
package utils package validation
import ( import (
"GoMembership/internal/models"
"errors" "errors"
"reflect" "reflect"
) )
func HasPrivilige(user *models.User, privilige int8) bool {
return user.RoleID >= privilige
}
// FilterAllowedStructFields filters allowed fields recursively in a struct and modifies structToModify in place. // FilterAllowedStructFields filters allowed fields recursively in a struct and modifies structToModify in place.
func FilterAllowedStructFields(input interface{}, existing interface{}, allowedFields map[string]bool, prefix string) error { func FilterAllowedStructFields(input interface{}, existing interface{}, allowedFields map[string]bool, prefix string) error {
v := reflect.ValueOf(input) v := reflect.ValueOf(input)

View File

@@ -1,4 +1,4 @@
package utils package validation
import ( import (
"reflect" "reflect"

View File

@@ -2,39 +2,37 @@ package validation
import ( import (
"GoMembership/internal/models" "GoMembership/internal/models"
"GoMembership/internal/repositories"
"GoMembership/pkg/errors" "GoMembership/pkg/errors"
"GoMembership/pkg/logger" "GoMembership/pkg/logger"
"github.com/go-playground/validator/v10" "github.com/go-playground/validator/v10"
"gorm.io/gorm"
) )
func validateMembership(sl validator.StructLevel) { func validateMembership(db *gorm.DB, user *models.User, sl validator.StructLevel) {
membership := sl.Current().Interface().(models.User).Membership if user.Membership.SubscriptionModel.RequiredMembershipField != "" {
if membership.SubscriptionModel.RequiredMembershipField != "" { switch user.Membership.SubscriptionModel.RequiredMembershipField {
switch membership.SubscriptionModel.RequiredMembershipField {
case "ParentMembershipID": case "ParentMembershipID":
if err := CheckParentMembershipID(membership); err != nil { if err := CheckParentMembershipID(db, user); err != nil {
logger.Error.Printf("Error ParentMembershipValidation: %v", err.Error()) logger.Error.Printf("Error ParentMembershipValidation: %v", err.Error())
sl.ReportError(membership.ParentMembershipID, membership.SubscriptionModel.RequiredMembershipField, sl.ReportError(user.Membership.ParentMembershipID, user.Membership.SubscriptionModel.RequiredMembershipField,
"RequiredMembershipField", "invalid", "") "RequiredMembershipField", "invalid", "")
} }
default: default:
logger.Error.Printf("Error no matching RequiredMembershipField: %v", errors.ErrInvalidValue.Error()) logger.Error.Printf("Error no matching RequiredMembershipField: %v", errors.ErrInvalidValue.Error())
sl.ReportError(membership.ParentMembershipID, membership.SubscriptionModel.RequiredMembershipField, sl.ReportError(user.Membership.ParentMembershipID, user.Membership.SubscriptionModel.RequiredMembershipField,
"RequiredMembershipField", "not_implemented", "") "RequiredMembershipField", "not_implemented", "")
} }
} }
} }
func CheckParentMembershipID(membership models.Membership) error { func CheckParentMembershipID(db *gorm.DB, user *models.User) error {
if membership.ParentMembershipID == 0 { if user.Membership.ParentMembershipID == 0 {
return errors.ValErrParentIDNotSet return errors.ValErrParentIDNotSet
} else { } else {
_, err := repositories.GetUserByID(&membership.ParentMembershipID) var parent models.User
if err != nil { if err := parent.FromID(db, &user.Membership.ParentMembershipID); err != nil {
return errors.ValErrParentIDNotFound return errors.ValErrParentIDNotFound
} }
} }

View File

@@ -4,17 +4,18 @@ import (
"GoMembership/internal/models" "GoMembership/internal/models"
"github.com/gin-gonic/gin/binding" "github.com/gin-gonic/gin/binding"
"gorm.io/gorm"
"github.com/go-playground/validator/v10" "github.com/go-playground/validator/v10"
) )
func SetupValidators() { func SetupValidators(db *gorm.DB) {
if v, ok := binding.Validator.Engine().(*validator.Validate); ok { if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
// Register custom validators // Register custom validators
v.RegisterValidation("safe_content", ValidateSafeContent) v.RegisterValidation("safe_content", ValidateSafeContent)
// Register struct-level validations // Register struct-level validations
v.RegisterStructValidation(ValidateUser, models.User{}) v.RegisterStructValidation(ValidateUserFactory(db), models.User{})
v.RegisterStructValidation(ValidateSubscription, models.SubscriptionModel{}) v.RegisterStructValidation(ValidateSubscription, models.SubscriptionModel{})
} }
} }

View File

@@ -1,7 +1,6 @@
package validation package validation
import ( import (
"GoMembership/internal/constants"
"GoMembership/internal/models" "GoMembership/internal/models"
"GoMembership/internal/repositories" "GoMembership/internal/repositories"
"GoMembership/pkg/logger" "GoMembership/pkg/logger"
@@ -10,6 +9,7 @@ import (
"github.com/go-playground/validator/v10" "github.com/go-playground/validator/v10"
passwordvalidator "github.com/wagslane/go-password-validator" passwordvalidator "github.com/wagslane/go-password-validator"
"gorm.io/gorm"
) )
var passwordErrorTranslations = map[string]string{ var passwordErrorTranslations = map[string]string{
@@ -21,11 +21,15 @@ var passwordErrorTranslations = map[string]string{
"using numbers": "server.validation.numbers", "using numbers": "server.validation.numbers",
} }
func ValidateUser(sl validator.StructLevel) { func ValidateUserFactory(db *gorm.DB) validator.StructLevelFunc {
return func(sl validator.StructLevel) {
validateUser(db, sl)
}
}
func validateUser(db *gorm.DB, sl validator.StructLevel) {
user := sl.Current().Interface().(models.User) user := sl.Current().Interface().(models.User)
isSuper := user.RoleID >= constants.Roles.Admin
isSupporter := user.RoleID == constants.Roles.Supporter
// validate subscriptionModel // validate subscriptionModel
if user.Membership.SubscriptionModel.Name == "" { if user.Membership.SubscriptionModel.Name == "" {
sl.ReportError(user.Membership.SubscriptionModel.Name, "subscription.name", "name", "required", "") sl.ReportError(user.Membership.SubscriptionModel.Name, "subscription.name", "name", "required", "")
@@ -38,7 +42,7 @@ func ValidateUser(sl validator.StructLevel) {
user.Membership.SubscriptionModel = *selectedModel user.Membership.SubscriptionModel = *selectedModel
} }
} }
if isSupporter { if user.IsSupporter() {
if user.BankAccount.IBAN != "" { if user.BankAccount.IBAN != "" {
validateBankAccount(sl) validateBankAccount(sl)
} }
@@ -54,9 +58,9 @@ func ValidateUser(sl validator.StructLevel) {
if user.DateOfBirth.After(time.Now().AddDate(-18, 0, 0)) { if user.DateOfBirth.After(time.Now().AddDate(-18, 0, 0)) {
sl.ReportError(user.DateOfBirth, "user.user", "user.dateofbirth", "age", "") sl.ReportError(user.DateOfBirth, "user.user", "user.dateofbirth", "age", "")
} }
validateMembership(sl) validateMembership(db, &user, sl)
if isSuper { if user.IsAdmin() {
return return
} }

View File

@@ -18,6 +18,7 @@ type ValidationKeys struct {
NotFound string NotFound string
InUse string InUse string
UndeliveredVerificationMail string UndeliveredVerificationMail string
UserAlreadyVerified string
} }
type ValidationFields struct { type ValidationFields struct {
@@ -28,6 +29,7 @@ type ValidationFields struct {
Email string Email string
User string User string
Licences string Licences string
Verification string
} }
var ( var (
@@ -72,6 +74,7 @@ var Responses = struct {
NotFound: "server.error.not_found", NotFound: "server.error.not_found",
InUse: "server.error.in_use", InUse: "server.error.in_use",
UndeliveredVerificationMail: "server.error.undelivered_verification_mail", UndeliveredVerificationMail: "server.error.undelivered_verification_mail",
UserAlreadyVerified: "server.validation.user_already_verified",
}, },
Fields: ValidationFields{ Fields: ValidationFields{
General: "server.general", General: "server.general",
@@ -81,5 +84,10 @@ var Responses = struct {
Email: "user.email", Email: "user.email",
User: "user.user", User: "user.user",
Licences: "licence", Licences: "licence",
Verification: "verification",
}, },
} }
func Is(err error, target error) bool {
return errors.Is(err, target)
}

View File

@@ -70,7 +70,7 @@
</div> </div>
<div style="text-align: center; padding: 16px 24px 16px 24px"> <div style="text-align: center; padding: 16px 24px 16px 24px">
<a <a
href="{{.BASEURL}}/api/users/verify?token={{.Token}}" href="{{.BASEURL}}/api/users/verify/{{.ID}}?token={{.Token}}"
style=" style="
color: #ffffff; color: #ffffff;
font-size: 26px; font-size: 26px;

View File

@@ -8,7 +8,7 @@ noch Ihre Emailadresse indem Sie hier klicken:
E-Mail Adresse bestätigen E-Mail Adresse bestätigen
{{.BASEURL}}/api/users/verify?token={{.Token}} {{.BASEURL}}/api/users/verify/{{.ID}}?token={{.Token}}
Nachdem wir Ihre E-Mail Adresse bestätigen konnten, schicken wir Nachdem wir Ihre E-Mail Adresse bestätigen konnten, schicken wir
Ihnen alle weiteren Informationen zu. Wir freuen uns auf die Ihnen alle weiteren Informationen zu. Wir freuen uns auf die