diff --git a/internal/config/config.go b/internal/config/config.go index 47dc9d5..157ca33 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -11,6 +11,7 @@ import ( "encoding/json" "os" "path/filepath" + "strings" "github.com/kelseyhightower/envconfig" @@ -60,6 +61,7 @@ type Config struct { Recipients RecipientsConfig `json:"recipients"` ConfigFilePath string `json:"config_file_path" envconfig:"CONFIG_FILE_PATH"` BaseURL string `json:"BaseUrl" envconfig:"BASE_URL"` + Env string `json:"Environment" default:"development" envconfig:"ENV"` DB DatabaseConfig `json:"db"` SMTP SMTPConfig `json:"smtp"` Security SecurityConfig `json:"security"` @@ -74,8 +76,15 @@ var ( Templates TemplateConfig SMTP SMTPConfig Recipients RecipientsConfig + Env string Security SecurityConfig ) +var environmentOptions map[string]bool = map[string]bool{ + "development": true, + "production": true, + "dev": true, + "prod": true, +} // LoadConfig initializes the configuration by reading from a file and environment variables. // It also generates JWT and CSRF secrets. Returns a Config pointer or an error if any step fails. @@ -95,7 +104,11 @@ func LoadConfig() { } CFG.Auth.JWTSecret = jwtSecret CFG.Auth.CSRFSecret = csrfSecret - + if environmentOptions[CFG.Env] && strings.Contains("development", CFG.Env) { + CFG.Env = "development" + } else { + CFG.Env = "production" + } Auth = CFG.Auth DB = CFG.DB Templates = CFG.Templates @@ -103,6 +116,7 @@ func LoadConfig() { BaseURL = CFG.BaseURL Recipients = CFG.Recipients Security = CFG.Security + Env = CFG.Env logger.Info.Printf("Config loaded: %#v", CFG) } diff --git a/internal/middlewares/csp.go b/internal/middlewares/csp.go new file mode 100644 index 0000000..ddd435b --- /dev/null +++ b/internal/middlewares/csp.go @@ -0,0 +1,43 @@ +package middlewares + +import ( + "GoMembership/internal/config" + "GoMembership/pkg/logger" + "net/http" + + "github.com/gin-gonic/gin" +) + +func CSPMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + policy := "default-src 'self'; " + + "script-src 'self' 'unsafe-inline'" + + "style-src 'self' 'unsafe-inline'" + + "img-src 'self'" + + "font-src 'self'" + + "connect-src 'self'; " + + "frame-ancestors 'none'; " + + "form-action 'self'; " + + "base-uri 'self'; " + + "upgrade-insecure-requests;" + + if config.Env == "development" { + policy += " report-uri /csp-report;" + c.Header("Content-Security-Policy-Report-Only", policy) + } else { + c.Header("Content-Security-Policy", policy) + } + c.Next() + } +} + +func CSPReportHandling(c *gin.Context) { + var report map[string]interface{} + if err := c.BindJSON(&report); err != nil { + + logger.Error.Printf("Couldn't Bind JSON: %#v", err) + return + } + logger.Info.Printf("CSP Violation: %+v", report) + c.Status(http.StatusNoContent) +} diff --git a/internal/middlewares/csp_test.go b/internal/middlewares/csp_test.go new file mode 100644 index 0000000..55a7351 --- /dev/null +++ b/internal/middlewares/csp_test.go @@ -0,0 +1,81 @@ +package middlewares + +import ( + "GoMembership/internal/config" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" +) + +func TestCSPMiddleware(t *testing.T) { + // Save the current environment and restore it after the test + originalEnv := config.Env + + tests := []struct { + name string + environment string + expectedHeader string + expectedPolicy string + }{ + { + name: "Development Environment", + environment: "development", + expectedHeader: "Content-Security-Policy-Report-Only", + expectedPolicy: "default-src 'self'; " + + "script-src 'self' 'unsafe-inline'" + + "style-src 'self' 'unsafe-inline'" + + "img-src 'self'" + + "font-src 'self'" + + "connect-src 'self'; " + + "frame-ancestors 'none'; " + + "form-action 'self'; " + + "base-uri 'self'; " + + "upgrade-insecure-requests; report-uri /csp-report;", + }, + { + name: "Production Environment", + environment: "production", + expectedHeader: "Content-Security-Policy", + expectedPolicy: "default-src 'self'; " + + "script-src 'self' 'unsafe-inline'" + + "style-src 'self' 'unsafe-inline'" + + "img-src 'self'" + + "font-src 'self'" + + "connect-src 'self'; " + + "frame-ancestors 'none'; " + + "form-action 'self'; " + + "base-uri 'self'; " + + "upgrade-insecure-requests;", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Set up the test environment + config.Env = tt.environment + + // Create a new Gin router with the middleware + gin.SetMode(gin.TestMode) + router := gin.New() + router.Use(CSPMiddleware()) + router.GET("/test", func(c *gin.Context) { + c.String(http.StatusOK, "test") + }) + + // Create a test request + req, _ := http.NewRequest("GET", "/test", nil) + w := httptest.NewRecorder() + + // Serve the request + router.ServeHTTP(w, req) + + // Check the response + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, tt.expectedPolicy, w.Header().Get(tt.expectedHeader)) + }) + } + config.Env = originalEnv +} diff --git a/internal/routes/routes.go b/internal/routes/routes.go index 7a99eaa..7aa9640 100644 --- a/internal/routes/routes.go +++ b/internal/routes/routes.go @@ -2,6 +2,7 @@ package routes import ( "GoMembership/internal/controllers" + "GoMembership/internal/middlewares" "github.com/gin-gonic/gin" ) @@ -13,4 +14,5 @@ func RegisterRoutes(router *gin.Engine, userController *controllers.UserControll router.POST("/backend/api/contact", contactController.RelayContactRequest) // router.HandleFunc("/login", userController.LoginUser).Methods("POST") + router.POST("/csp-report", middlewares.CSPReportHandling) } diff --git a/internal/server/server.go b/internal/server/server.go index 613d23c..e565e09 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -54,9 +54,11 @@ func Run() { router.Use(gin.Logger()) router.Use(middlewares.CORSMiddleware()) + router.Use(middlewares.CSPMiddleware()) limiter := middlewares.NewIPRateLimiter(config.Security.Ratelimits.Limit, config.Security.Ratelimits.Burst) router.Use(middlewares.RateLimitMiddleware(limiter)) + routes.RegisterRoutes(router, userController, membershipController, contactController) // create subrouter for teh authenticated area /account // also pthprefix matches everything below /account