diff --git a/cmd/membership/main.go b/cmd/membership/main.go index 8fecd5c..97188c1 100644 --- a/cmd/membership/main.go +++ b/cmd/membership/main.go @@ -1,12 +1,54 @@ package main import ( + "GoMembership/internal/config" + "GoMembership/internal/database" "GoMembership/internal/server" "GoMembership/pkg/logger" + "context" + "os" + "os/signal" + "syscall" + "time" ) func main() { logger.Info.Println("startup...") - server.Run() + + config.LoadConfig() + logger.Info.Printf("Config loaded: %#v", config.CFG) + + err := database.Open(config.DB.Path) + if err != nil { + logger.Error.Fatalf("Couldn't init database: %v", err) + } + + defer func() { + if err := database.Close(); err != nil { + logger.Error.Fatalf("Failed to close database: %v", err) + } + }() + + go server.Run() + + gracefulShutdown() +} + +func gracefulShutdown() { + // Create a channel to listen for OS signals + stop := make(chan os.Signal, 1) + signal.Notify(stop, os.Interrupt, syscall.SIGTERM) + + // Block until a signal is received + <-stop + logger.Info.Println("Received shutdown signal") + + // Create a context with a timeout for the shutdown process + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Call the server's shutdown function + server.Shutdown(ctx) + logger.Info.Println("Server gracefully stopped") } diff --git a/internal/controllers/controllers_test.go b/internal/controllers/controllers_test.go index 19d5615..7043011 100644 --- a/internal/controllers/controllers_test.go +++ b/internal/controllers/controllers_test.go @@ -44,7 +44,7 @@ var ( func TestSuite(t *testing.T) { _ = deleteTestDB("test.db") - if err := database.InitDB("test.db"); err != nil { + if err := database.Open("test.db"); err != nil { log.Fatalf("Failed to create DB: %#v", err) } diff --git a/internal/controllers/membershipController.go b/internal/controllers/membershipController.go index f944d30..c25b711 100644 --- a/internal/controllers/membershipController.go +++ b/internal/controllers/membershipController.go @@ -29,15 +29,15 @@ func (mc *MembershipController) RegisterSubscription(c *gin.Context) { } logger.Info.Printf("Using API key: %v", config.Auth.APIKEY) - if regData.APIKey == "" { logger.Error.Println("API Key is missing") - c.JSON(http.StatusBadRequest, "API Key is missing") + c.JSON(http.StatusUnauthorized, "API Key is missing") return } + if regData.APIKey != config.Auth.APIKEY { logger.Error.Printf("API Key not valid: %v", regData.APIKey) - c.JSON(http.StatusExpectationFailed, "API Key is missing") + c.JSON(http.StatusUnauthorized, "API Key is missing") return } diff --git a/internal/controllers/membershipController_test.go b/internal/controllers/membershipController_test.go index c234f59..8f836ff 100644 --- a/internal/controllers/membershipController_test.go +++ b/internal/controllers/membershipController_test.go @@ -80,6 +80,28 @@ func customizeSubscription(customize func(MembershipData) MembershipData) Member func getSubscriptionData() []RegisterSubscriptionTest { return []RegisterSubscriptionTest{ + { + Name: "No API Key should fail", + WantResponse: http.StatusUnauthorized, + WantDBData: map[string]interface{}{"name": "Just a Subscription"}, + Assert: false, + Input: GenerateInputJSON( + customizeSubscription(func(subscription MembershipData) MembershipData { + subscription.APIKey = "" + return subscription + })), + }, + { + Name: "Wrong API Key should fail", + WantResponse: http.StatusUnauthorized, + WantDBData: map[string]interface{}{"name": "Just a Subscription"}, + Assert: false, + Input: GenerateInputJSON( + customizeSubscription(func(subscription MembershipData) MembershipData { + subscription.APIKey = "alskfdlkjsfjk23-dF" + return subscription + })), + }, { Name: "No Details should fail", WantResponse: http.StatusNotAcceptable, diff --git a/internal/database/db.go b/internal/database/db.go index 6daad24..07d8aab 100644 --- a/internal/database/db.go +++ b/internal/database/db.go @@ -3,13 +3,14 @@ package database import ( "GoMembership/internal/models" "GoMembership/pkg/logger" + "gorm.io/driver/sqlite" "gorm.io/gorm" ) var DB *gorm.DB -func InitDB(dbPath string) error { +func Open(dbPath string) error { db, err := gorm.Open(sqlite.Open(dbPath), &gorm.Config{}) if err != nil { @@ -26,5 +27,16 @@ func InitDB(dbPath string) error { return err } DB = db + + logger.Info.Print("Opened DB") return nil } + +func Close() error { + logger.Info.Print("Closing DB") + db, err := DB.DB() + if err != nil { + return err + } + return db.Close() +} diff --git a/internal/middlewares/csrf_middleware.go b/internal/middlewares/csrf_middleware.go index 847fe71..5637aeb 100644 --- a/internal/middlewares/csrf_middleware.go +++ b/internal/middlewares/csrf_middleware.go @@ -8,7 +8,7 @@ import ( "strings" "GoMembership/internal/config" - "GoMembership/internal/server" + // "GoMembership/internal/server" "GoMembership/internal/utils" "GoMembership/pkg/logger" ) diff --git a/internal/server/server.go b/internal/server/server.go index 8573d50..fb9b69d 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -4,12 +4,12 @@ package server import ( + "context" "net/http" "path/filepath" "GoMembership/internal/config" "GoMembership/internal/controllers" - "GoMembership/internal/database" "GoMembership/internal/repositories" // "GoMembership/internal/middlewares" @@ -20,15 +20,10 @@ import ( "github.com/gin-gonic/gin" ) +var shutdownChannel = make(chan struct{}) + // Run initializes the server configuration, sets up services and controllers, and starts the HTTP server. func Run() { - config.LoadConfig() - logger.Info.Printf("Config loaded: %#v", config.CFG) - - err := database.InitDB(config.DB.Path) - if err != nil { - logger.Error.Fatalf("Couldn't init database: %v", err) - } emailService := services.NewEmailService(config.SMTP.Host, config.SMTP.Port, config.SMTP.User, config.SMTP.Password) var consentRepo repositories.ConsentRepositoryInterface = &repositories.ConsentRepository{} @@ -63,7 +58,23 @@ func Run() { // accountRouter.Use(middlewares.AuthMiddleware) logger.Info.Println("Starting server on :8080") - if err := http.ListenAndServe(":8080", router); err != nil { - logger.Error.Fatalf("could not start server: %v", err) - } + go func() { + if err := http.ListenAndServe(":8080", router); err != nil && err != http.ErrServerClosed { + logger.Error.Fatalf("could not start server: %v", err) + } + }() + // Wait for the shutdown signal + <-shutdownChannel +} + +func Shutdown(ctx context.Context) { + // Signal the server to stop + close(shutdownChannel) + + // Optionally wait for a timeout or other cleanup operations + // ctx can be used to manage shutdown timeout or cleanup tasks + // select { + // case <-ctx.Done(): + // // handle context cancellation if needed + // } }