diff --git a/internal/controllers/membershipController.go b/internal/controllers/membershipController.go index 1df8931..07c3f35 100644 --- a/internal/controllers/membershipController.go +++ b/internal/controllers/membershipController.go @@ -84,10 +84,9 @@ func (mc *MembershipController) UpdateHandler(c *gin.Context) { return } - // Register Subscription - logger.Info.Printf("Registering subscription %v", regData.Subscription.Name) - // id, err := mc.Service.UpdateSubscription(®Data.Subscription) - id := 1 + // update Subscription + logger.Info.Printf("Updating subscription %v", regData.Subscription.Name) + id, err := mc.Service.UpdateSubscription(®Data.Subscription) if err != nil { logger.Error.Printf("Couldn't update Membershipmodel: %v", err) if strings.Contains(err.Error(), "UNIQUE constraint failed") { @@ -103,6 +102,33 @@ func (mc *MembershipController) UpdateHandler(c *gin.Context) { "id": id, }) } + +func (mc *MembershipController) DeleteSubscription(c *gin.Context) { + var membershipdata MembershipData + requestUser, err := mc.UserController.ExtractUserFromContext(c) + if err != nil { + utils.RespondWithError(c, err, "Error extracting user from context in subscription UpdateHandler", http.StatusBadRequest, "general", "server.validation.no_auth_tokenw") + return + } + + if !utils.HasPrivilige(requestUser, constants.Priviliges.Update) { + utils.RespondWithError(c, errors.ErrNotAuthorized, "Not allowed to update subscription", http.StatusForbidden, "user", "server.error.unauthorized") + return + } + + if err := c.ShouldBindJSON(&membershipdata); err != nil { + utils.HandleValidationError(c, err) + return + } + + if err := mc.Service.DeleteSubscription(&membershipdata.Subscription); err != nil { + utils.RespondWithError(c, err, "Error during subscription Deletion", http.StatusExpectationFailed, "subscription", "server.error.not_possible") + return + } + + c.JSON(http.StatusOK, gin.H{"message": "Subscription deleted successfully"}) +} + func (mc *MembershipController) GetSubscriptions(c *gin.Context) { subscriptions, err := mc.Service.GetSubscriptions(nil) if err != nil { diff --git a/internal/controllers/membershipController_test.go b/internal/controllers/membershipController_test.go index 2686290..ca2de7d 100644 --- a/internal/controllers/membershipController_test.go +++ b/internal/controllers/membershipController_test.go @@ -21,6 +21,22 @@ type RegisterSubscriptionTest struct { Assert bool } +type UpdateSubscriptionTest struct { + WantDBData map[string]interface{} + Input string + Name string + WantResponse int + Assert bool +} + +type DeleteSubscriptionTest struct { + WantDBData map[string]interface{} + Input string + Name string + WantResponse int + Assert bool +} + type MockUserController struct { UserController // Embed the UserController } @@ -44,7 +60,7 @@ func setupMockAuth() { func testMembershipController(t *testing.T) { setupMockAuth() - tests := getSubscriptionData() + tests := getSubscriptionRegistrationData() for _, tt := range tests { logger.Error.Print("==============================================================") logger.Error.Printf("MembershipController : %v", tt.Name) @@ -55,6 +71,28 @@ func testMembershipController(t *testing.T) { } }) } + updateTests := getSubscriptionUpdateData() + for _, tt := range updateTests { + logger.Error.Print("==============================================================") + logger.Error.Printf("Update SubscriptionData : %v", tt.Name) + logger.Error.Print("==============================================================") + t.Run(tt.Name, func(t *testing.T) { + if err := runSingleTest(&tt); err != nil { + t.Errorf("Test failed: %v", err.Error()) + } + }) + } + deleteTests := getSubscriptionDeleteData() + for _, tt := range deleteTests { + logger.Error.Print("==============================================================") + logger.Error.Printf("Delete SubscriptionData : %v", tt.Name) + logger.Error.Print("==============================================================") + t.Run(tt.Name, func(t *testing.T) { + if err := runSingleTest(&tt); err != nil { + t.Errorf("Test failed: %v", err.Error()) + } + }) + } } func (rt *RegisterSubscriptionTest) SetupContext() (*gin.Context, *httptest.ResponseRecorder, *gin.Engine) { @@ -87,6 +125,44 @@ func validateSubscription(assert bool, wantDBData map[string]interface{}) error return nil } +func (ut *UpdateSubscriptionTest) SetupContext() (*gin.Context, *httptest.ResponseRecorder, *gin.Engine) { + return GetMockedJSONContext([]byte(ut.Input), "api/subscription/upsert") +} + +func (ut *UpdateSubscriptionTest) RunHandler(c *gin.Context, router *gin.Engine) { + Mc.UpdateHandler(c) +} + +func (ut *UpdateSubscriptionTest) ValidateResponse(w *httptest.ResponseRecorder) error { + if w.Code != ut.WantResponse { + return fmt.Errorf("Didn't get the expected response code: got: %v; expected: %v", w.Code, ut.WantResponse) + } + return nil +} + +func (ut *UpdateSubscriptionTest) ValidateResult() error { + return validateSubscription(ut.Assert, ut.WantDBData) +} + +func (dt *DeleteSubscriptionTest) SetupContext() (*gin.Context, *httptest.ResponseRecorder, *gin.Engine) { + return GetMockedJSONContext([]byte(dt.Input), "api/subscription/delete") +} + +func (dt *DeleteSubscriptionTest) RunHandler(c *gin.Context, router *gin.Engine) { + Mc.DeleteSubscription(c) +} + +func (dt *DeleteSubscriptionTest) ValidateResponse(w *httptest.ResponseRecorder) error { + if w.Code != dt.WantResponse { + return fmt.Errorf("Didn't get the expected response code: got: %v; expected: %v", w.Code, dt.WantResponse) + } + return nil +} + +func (dt *DeleteSubscriptionTest) ValidateResult() error { + return validateSubscription(dt.Assert, dt.WantDBData) +} + func getBaseSubscription() MembershipData { return MembershipData{ // APIKey: config.Auth.APIKEY, @@ -103,7 +179,7 @@ func customizeSubscription(customize func(MembershipData) MembershipData) Member return customize(subscription) } -func getSubscriptionData() []RegisterSubscriptionTest { +func getSubscriptionRegistrationData() []RegisterSubscriptionTest { return []RegisterSubscriptionTest{ { Name: "Missing details should fail", @@ -169,3 +245,137 @@ func getSubscriptionData() []RegisterSubscriptionTest { }, } } + +func getSubscriptionUpdateData() []UpdateSubscriptionTest { + return []UpdateSubscriptionTest{ + { + Name: "Modified Monthly Fee, should fail", + WantResponse: http.StatusNotAcceptable, + WantDBData: map[string]interface{}{"name": "Premium", "monthly_fee": "12"}, + Assert: true, + Input: GenerateInputJSON( + customizeSubscription(func(subscription MembershipData) MembershipData { + subscription.Subscription.MonthlyFee = 123.0 + return subscription + })), + }, + { + Name: "Missing ID, should fail", + WantResponse: http.StatusNotAcceptable, + WantDBData: map[string]interface{}{"name": "Premium"}, + Assert: true, + Input: GenerateInputJSON( + customizeSubscription(func(subscription MembershipData) MembershipData { + subscription.Subscription.ID = 0 + return subscription + })), + }, + { + Name: "Modified Hourly Rate, should fail", + WantResponse: http.StatusNotAcceptable, + WantDBData: map[string]interface{}{"name": "Premium", "hourly_rate": "14"}, + Assert: true, + Input: GenerateInputJSON( + customizeSubscription(func(subscription MembershipData) MembershipData { + subscription.Subscription.HourlyRate = 3254.0 + return subscription + })), + }, + { + Name: "IncludedPerYear changed, should fail", + WantResponse: http.StatusNotAcceptable, + WantDBData: map[string]interface{}{"name": "Premium", "included_per_year": "0"}, + Assert: true, + Input: GenerateInputJSON( + customizeSubscription(func(subscription MembershipData) MembershipData { + subscription.Subscription.IncludedPerYear = 9873.0 + return subscription + })), + }, + { + Name: "IncludedPerMonth changed, should fail", + WantResponse: http.StatusNotAcceptable, + WantDBData: map[string]interface{}{"name": "Premium", "included_per_month": "1"}, + Assert: true, + Input: GenerateInputJSON( + customizeSubscription(func(subscription MembershipData) MembershipData { + subscription.Subscription.IncludedPerMonth = 23415.0 + return subscription + })), + }, + { + Name: "Update non-existent subscription should fail", + WantResponse: http.StatusNotAcceptable, + WantDBData: map[string]interface{}{"name": "NonExistentSubscription"}, + Assert: false, + Input: GenerateInputJSON( + customizeSubscription(func(subscription MembershipData) MembershipData { + subscription.Subscription.Name = "NonExistentSubscription" + return subscription + })), + }, + { + Name: "Correct Update should pass", + WantResponse: http.StatusAccepted, + WantDBData: map[string]interface{}{"name": "Premium", "details": "Altered Details"}, + Assert: true, + Input: GenerateInputJSON( + customizeSubscription(func(subscription MembershipData) MembershipData { + subscription.Subscription.Details = "Altered Details" + subscription.Subscription.Conditions = "Some Condition" + subscription.Subscription.IncludedPerYear = 0 + subscription.Subscription.IncludedPerMonth = 1 + return subscription + })), + }, + } +} + +func getSubscriptionDeleteData() []DeleteSubscriptionTest { + return []DeleteSubscriptionTest{ + { + Name: "Delete non-existent subscription should fail", + WantResponse: http.StatusExpectationFailed, + WantDBData: map[string]interface{}{"name": "NonExistentSubscription"}, + Assert: false, + Input: GenerateInputJSON( + customizeSubscription(func(subscription MembershipData) MembershipData { + subscription.Subscription.Name = "NonExistentSubscription" + return subscription + })), + }, + { + Name: "Delete subscription without name should fail", + WantResponse: http.StatusBadRequest, + WantDBData: map[string]interface{}{"name": ""}, + Assert: false, + Input: GenerateInputJSON( + customizeSubscription(func(subscription MembershipData) MembershipData { + subscription.Subscription.Name = "" + return subscription + })), + }, + { + Name: "Delete subscription with users should fail", + WantResponse: http.StatusExpectationFailed, + WantDBData: map[string]interface{}{"name": "Basic"}, + Assert: true, + Input: GenerateInputJSON( + customizeSubscription(func(subscription MembershipData) MembershipData { + subscription.Subscription.Name = "Basic" + return subscription + })), + }, + { + Name: "Delete valid subscription should succeed", + WantResponse: http.StatusOK, + WantDBData: map[string]interface{}{"name": "Premium"}, + Assert: false, + Input: GenerateInputJSON( + customizeSubscription(func(subscription MembershipData) MembershipData { + subscription.Subscription.Name = "Premium" + return subscription + })), + }, + } +} diff --git a/internal/controllers/user_controller.go b/internal/controllers/user_controller.go index ba789e9..08c86bb 100644 --- a/internal/controllers/user_controller.go +++ b/internal/controllers/user_controller.go @@ -212,7 +212,7 @@ func (uc *UserController) RegisterUser(c *gin.Context) { } logger.Info.Printf("Registering user %v", regData.User.Email) - selectedModel, err := uc.MembershipService.GetModelByName(®Data.User.Membership.SubscriptionModel.Name) + selectedModel, err := uc.MembershipService.GetSubscriptionByName(®Data.User.Membership.SubscriptionModel.Name) if err != nil { utils.RespondWithError(c, err, "Error in Registeruser, couldn't get selected model", http.StatusNotFound, "subscription_model", "server.validation.subscription_model_not_found") return diff --git a/internal/repositories/subscription_model_repository.go b/internal/repositories/subscription_model_repository.go index 8616227..eb2259e 100644 --- a/internal/repositories/subscription_model_repository.go +++ b/internal/repositories/subscription_model_repository.go @@ -10,8 +10,11 @@ import ( type SubscriptionModelsRepositoryInterface interface { CreateSubscriptionModel(subscriptionModel *models.SubscriptionModel) (uint, error) - GetMembershipModelNames() ([]string, error) + UpdateSubscription(subscription *models.SubscriptionModel) (*models.SubscriptionModel, error) + GetSubscriptionModelNames() ([]string, error) GetSubscriptions(where map[string]interface{}) (*[]models.SubscriptionModel, error) + // GetUsersBySubscription(id uint) (*[]models.SubscriptionModel, error) + DeleteSubscription(subscription *models.SubscriptionModel) error } type SubscriptionModelsRepository struct{} @@ -25,15 +28,34 @@ func (sr *SubscriptionModelsRepository) CreateSubscriptionModel(subscriptionMode return subscriptionModel.ID, nil } -func GetModelByName(modelname *string) (*models.SubscriptionModel, error) { +func (sr *SubscriptionModelsRepository) UpdateSubscription(subscription *models.SubscriptionModel) (*models.SubscriptionModel, error) { + + result := database.DB.Model(&models.SubscriptionModel{ID: subscription.ID}).Updates(subscription) + if result.Error != nil { + return nil, result.Error + } + return subscription, nil +} + +func (sr *SubscriptionModelsRepository) DeleteSubscription(subscription *models.SubscriptionModel) error { + + result := database.DB.Delete(&models.SubscriptionModel{}, subscription.ID) + if result.Error != nil { + return result.Error + } + return nil +} + +func GetSubscriptionByName(modelname *string) (*models.SubscriptionModel, error) { var model models.SubscriptionModel - if err := database.DB.Where("name = ?", modelname).First(&model).Error; err != nil { - return nil, err + result := database.DB.Where("name = ?", modelname).First(&model) + if result.Error != nil { + return nil, result.Error } return &model, nil } -func (sr *SubscriptionModelsRepository) GetMembershipModelNames() ([]string, error) { +func (sr *SubscriptionModelsRepository) GetSubscriptionModelNames() ([]string, error) { var names []string if err := database.DB.Model(&models.SubscriptionModel{}).Pluck("name", &names).Error; err != nil { return []string{}, err @@ -52,3 +74,24 @@ func (sr *SubscriptionModelsRepository) GetSubscriptions(where map[string]interf } return &subscriptions, nil } + +func GetUsersBySubscription(subscriptionID uint) (*[]models.User, error) { + var users []models.User + + err := database.DB.Preload("Membership"). + Preload("Membership.SubscriptionModel"). + Preload("BankAccount"). + Preload("Licence"). + Preload("Licence.Categories"). + Joins("JOIN memberships ON users.membership_id = memberships.id"). + Joins("JOIN subscription_models ON memberships.subscription_model_id = subscription_models.id"). + Where("subscription_models.id = ?", subscriptionID). + Find(&users).Error + + if err != nil { + return nil, err + } + + return &users, nil + +} diff --git a/internal/services/membership_service.go b/internal/services/membership_service.go index 85aa128..e9db8d8 100644 --- a/internal/services/membership_service.go +++ b/internal/services/membership_service.go @@ -5,14 +5,17 @@ import ( "GoMembership/internal/models" "GoMembership/internal/repositories" + "GoMembership/pkg/errors" ) type MembershipServiceInterface interface { RegisterMembership(membership *models.Membership) (uint, error) FindMembershipByUserID(userID uint) (*models.Membership, error) RegisterSubscription(subscription *models.SubscriptionModel) (uint, error) - GetMembershipModelNames() ([]string, error) - GetModelByName(modelname *string) (*models.SubscriptionModel, error) + UpdateSubscription(subscription *models.SubscriptionModel) (*models.SubscriptionModel, error) + DeleteSubscription(subscription *models.SubscriptionModel) error + GetSubscriptionModelNames() ([]string, error) + GetSubscriptionByName(modelname *string) (*models.SubscriptionModel, error) GetSubscriptions(where map[string]interface{}) (*[]models.SubscriptionModel, error) } @@ -26,6 +29,48 @@ func (service *MembershipService) RegisterMembership(membership *models.Membersh return service.Repo.CreateMembership(membership) } +func (service *MembershipService) UpdateSubscription(subscription *models.SubscriptionModel) (*models.SubscriptionModel, error) { + + existingSubscription, err := repositories.GetSubscriptionByName(&subscription.Name) + if err != nil { + return nil, err + } + if existingSubscription == nil { + return nil, errors.ErrSubscriptionNotFound + } + if existingSubscription.MonthlyFee != subscription.MonthlyFee || + existingSubscription.HourlyRate != subscription.HourlyRate || + existingSubscription.Conditions != subscription.Conditions || + existingSubscription.IncludedPerYear != subscription.IncludedPerYear || + existingSubscription.IncludedPerMonth != subscription.IncludedPerMonth { + return nil, errors.ErrInvalidSubscriptionData + } + subscription.ID = existingSubscription.ID + return service.SubscriptionRepo.UpdateSubscription(subscription) + +} + +func (service *MembershipService) DeleteSubscription(subscription *models.SubscriptionModel) error { + exists, err := repositories.GetSubscriptionByName(&subscription.Name) + if err != nil { + return err + } + if exists == nil { + return errors.ErrNotFound + } + + subscription.ID = exists.ID + usersInSubscription, err := repositories.GetUsersBySubscription(subscription.ID) + + if err != nil { + return err + } + if len(*usersInSubscription) > 0 { + return errors.ErrSubscriptionInUse + } + return service.SubscriptionRepo.DeleteSubscription(subscription) +} + func (service *MembershipService) FindMembershipByUserID(userID uint) (*models.Membership, error) { return service.Repo.FindMembershipByUserID(userID) } @@ -35,12 +80,12 @@ func (service *MembershipService) RegisterSubscription(subscription *models.Subs return service.SubscriptionRepo.CreateSubscriptionModel(subscription) } -func (service *MembershipService) GetMembershipModelNames() ([]string, error) { - return service.SubscriptionRepo.GetMembershipModelNames() +func (service *MembershipService) GetSubscriptionModelNames() ([]string, error) { + return service.SubscriptionRepo.GetSubscriptionModelNames() } -func (service *MembershipService) GetModelByName(modelname *string) (*models.SubscriptionModel, error) { - return repositories.GetModelByName(modelname) +func (service *MembershipService) GetSubscriptionByName(modelname *string) (*models.SubscriptionModel, error) { + return repositories.GetSubscriptionByName(modelname) } func (service *MembershipService) GetSubscriptions(where map[string]interface{}) (*[]models.SubscriptionModel, error) { diff --git a/internal/services/user_service.go b/internal/services/user_service.go index af1f802..b5a648f 100644 --- a/internal/services/user_service.go +++ b/internal/services/user_service.go @@ -59,7 +59,7 @@ func (service *UserService) UpdateUser(user *models.User) (*models.User, error) } // Validate subscription model - selectedModel, err := repositories.GetModelByName(&user.Membership.SubscriptionModel.Name) + selectedModel, err := repositories.GetSubscriptionByName(&user.Membership.SubscriptionModel.Name) if err != nil { return nil, errors.ErrSubscriptionNotFound } diff --git a/internal/validation/subscription_validation.go b/internal/validation/subscription_validation.go index 7ba6ed3..a4a2cc2 100644 --- a/internal/validation/subscription_validation.go +++ b/internal/validation/subscription_validation.go @@ -38,7 +38,7 @@ func ValidateSubscription(sl validator.StructLevel) { } } else { // This is a nested probably user struct. We are only checking if the model exists - existingSubscription, err := repositories.GetModelByName(&subscription.Name) + existingSubscription, err := repositories.GetSubscriptionByName(&subscription.Name) if err != nil || existingSubscription == nil { sl.ReportError(subscription.Name, "Subscription_Name", "name", "exists", "") } diff --git a/internal/validation/user_validation.go b/internal/validation/user_validation.go index a05f18a..7c927e4 100644 --- a/internal/validation/user_validation.go +++ b/internal/validation/user_validation.go @@ -30,7 +30,7 @@ func validateUser(sl validator.StructLevel) { if user.Membership.SubscriptionModel.Name == "" { sl.ReportError(user.Membership.SubscriptionModel.Name, "SubscriptionModel.Name", "name", "required", "") } else { - selectedModel, err := repositories.GetModelByName(&user.Membership.SubscriptionModel.Name) + selectedModel, err := repositories.GetSubscriptionByName(&user.Membership.SubscriptionModel.Name) if err != nil { logger.Error.Printf("Error finding subscription model for user %v: %v", user.Email, err) sl.ReportError(user.Membership.SubscriptionModel.Name, "SubscriptionModel.Name", "name", "invalid", "")