diff --git a/configs/config.template.json b/configs/config.template.json index b2a5c80..0dc9e8d 100644 --- a/configs/config.template.json +++ b/configs/config.template.json @@ -22,5 +22,11 @@ "recipients": { "ContactForm": "contacts@server.com", "UserRegistration": "registration@server.com" + }, + "security": { + "RateLimits": { + "Limit": 1, + "Burst": 60 + } } } diff --git a/go.mod b/go.mod index a285db8..3b7b48c 100644 --- a/go.mod +++ b/go.mod @@ -51,6 +51,7 @@ require ( golang.org/x/net v0.27.0 // indirect golang.org/x/sys v0.22.0 // indirect golang.org/x/text v0.16.0 // indirect + golang.org/x/time v0.6.0 google.golang.org/protobuf v1.34.2 // indirect gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index 1a314ca..1e811f5 100644 --- a/go.sum +++ b/go.sum @@ -98,6 +98,8 @@ golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= +golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U= +golang.org/x/time v0.6.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= diff --git a/internal/config/config.go b/internal/config/config.go index f28ffaf..ec49fbd 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -48,6 +48,12 @@ type RecipientsConfig struct { UserRegistration string `json:"UserRegistration" envconfig:"RECIPIENT_USER_REGISTRATION"` } +type SecurityConfig struct { + Ratelimits struct { + Limit int `json:"Limit" default:"1" envconfig:"RATE_LIMIT"` + Burst int `json:"Burst" default:"60" envconfig:"BURST_LIMIT"` + } `json:"RateLimits"` +} type Config struct { Auth AuthenticationConfig `json:"auth"` Templates TemplateConfig `json:"templates"` @@ -56,6 +62,7 @@ type Config struct { BaseURL string `json:"BaseUrl" envconfig:"BASE_URL"` DB DatabaseConfig `json:"db"` SMTP SMTPConfig `json:"smtp"` + Security SecurityConfig `json:"security"` } var ( @@ -67,6 +74,7 @@ var ( Templates TemplateConfig SMTP SMTPConfig Recipients RecipientsConfig + Security SecurityConfig ) // LoadConfig initializes the configuration by reading from a file and environment variables. @@ -94,6 +102,7 @@ func LoadConfig() { SMTP = CFG.SMTP BaseURL = CFG.BaseURL Recipients = CFG.Recipients + Security = CFG.Security } // readFile reads the configuration from the specified file path into the provided Config struct. diff --git a/internal/middlewares/rate_limit.go b/internal/middlewares/rate_limit.go new file mode 100644 index 0000000..cced867 --- /dev/null +++ b/internal/middlewares/rate_limit.go @@ -0,0 +1,83 @@ +package middlewares + +import ( + "GoMembership/pkg/logger" + "net/http" + "sync" + + "github.com/gin-gonic/gin" + "golang.org/x/time/rate" +) + +type IPRateLimiter struct { + ips map[string]*rate.Limiter + mu *sync.RWMutex + r rate.Limit + b int +} + +func NewIPRateLimiter(r int, b int) *IPRateLimiter { + return &IPRateLimiter{ + ips: make(map[string]*rate.Limiter), + mu: &sync.RWMutex{}, + r: rate.Limit(r), + b: b, + } +} + +func (i *IPRateLimiter) GetLimiter(ip string) *rate.Limiter { + i.mu.Lock() + defer i.mu.Unlock() + + limiter, exists := i.ips[ip] + if !exists { + limiter = rate.NewLimiter(i.r, i.b) + i.ips[ip] = limiter + } + + return limiter +} + +// func RateLimitMiddleware() gin.HandlerFunc { +// if iPLimiter == nil { + +// iPLimiter := NewIPRateLimiter( +// rate.Limit(config.Security.Ratelimits.Limit), +// config.Security.Ratelimits.Burst) +// } + +// return func(c *gin.Context) { +// ip := c.ClientIP() +// l := iPLimiter.GetLimiter(ip) +// if !l.Allow() { +// c.JSON(http.StatusTooManyRequests, gin.H{ +// "error": "Too many requests", +// }) +// c.Abort() +// return +// } +// c.Next() +// } +// } + +func RateLimitMiddleware(limiter *IPRateLimiter) gin.HandlerFunc { + logger.Info.Printf("Limiter with Limit: %v, Burst: %v", limiter.r, limiter.b) + return func(c *gin.Context) { + if limiter == nil { + logger.Error.Println("Limiter missing") + c.AbortWithStatus(http.StatusInternalServerError) + return + } + + ip := c.ClientIP() + l := limiter.GetLimiter(ip) + if !l.Allow() { + c.JSON(http.StatusTooManyRequests, gin.H{ + "error": "Too many requests", + }) + c.Abort() + return + } + c.Next() + } +} diff --git a/internal/middlewares/rate_limit_test.go b/internal/middlewares/rate_limit_test.go new file mode 100644 index 0000000..c285e4b --- /dev/null +++ b/internal/middlewares/rate_limit_test.go @@ -0,0 +1,143 @@ +package middlewares + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/gin-gonic/gin" +) + +var ( + limit = 1 + burst = 60 +) + +func TestRateLimitMiddleware(t *testing.T) { + + gin.SetMode(gin.TestMode) + + // Create a new rate limiter that allows 2 requests per second with a burst of 4 + + // Create a new Gin router with the rate limit middleware + router := setupRouter() + + // Test cases + tests := []struct { + name string + requests int + expectedStatus int + sleep time.Duration + }{ + {"Allow first request", 1, http.StatusOK, 0}, + {"Allow up to burst limit", burst - 1, http.StatusOK, 0}, + {"Block after burst limit", burst + 20, http.StatusTooManyRequests, 0}, + {"Allow after rate limit replenishes", 1, http.StatusOK, time.Second}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + time.Sleep(tt.sleep) + var status int + for i := 0; i < tt.requests; i++ { + status = makeRequest(router, "192.168.0.2") + + } + if status != tt.expectedStatus { + t.Errorf("Expected status %d, got %d", tt.expectedStatus, status) + } + }) + } +} + +func TestIPRateLimiter(t *testing.T) { + limiter := NewIPRateLimiter(limit, burst) + + limiter1 := limiter.GetLimiter("127.0.0.1") + limiter2 := limiter.GetLimiter("192.168.0.1") + + if limiter1 == limiter2 { + t.Error("Expected different limiters for different IPs") + } + + limiter3 := limiter.GetLimiter("127.0.0.1") + if limiter1 != limiter3 { + t.Error("Expected the same limiter for the same IP") + } +} + +func TestDifferentRateLimits(t *testing.T) { + testCases := []struct { + name string + duration time.Duration + requests int + expectedRequests int + ip string + }{ + {"Low rate", 5 * time.Second, burst + 5*limit, burst + 5*limit, "192.168.23.3"}, + {"Low rate with limiting", 4 * time.Second, burst + 4*limit + 4, burst + 4*limit, "192.168.23.4"}, + {"High rate", 1 * time.Second, burst + 5*limit, burst + limit, "192.168.23.5"}, + {"Fractional rate", 10 * time.Second, (burst + limit) / 2, (burst + limit) / 2, "192.168.23.6"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + router := setupRouter() + testRateLimit(t, router, tc.duration, tc.requests, tc.expectedRequests, tc.ip) + }) + } +} + +func testRateLimit(t *testing.T, router *gin.Engine, duration time.Duration, requests int, expectedRequests int, ip string) { + start := time.Now() + successCount := 0 + totalRequests := 0 + + t.Logf("Sleeping for: %v", time.Duration(duration.Nanoseconds()/int64(requests))) + for time.Since(start) < duration { + status := makeRequest(router, ip) + totalRequests++ + + if status == http.StatusOK { + successCount++ + } + + time.Sleep(time.Duration(duration.Nanoseconds() / int64(requests))) + } + + actualDuration := time.Since(start) + + t.Logf("limit: %v, burst: %v", limit, burst) + t.Logf("Test duration: %v", actualDuration) + t.Logf("Successful requests: %d", successCount) + t.Logf("Expected successful requests: %d", expectedRequests) + t.Logf("Total requests: %d", totalRequests) + + if successCount < int(expectedRequests)-4 || successCount > int(expectedRequests)+4 { + t.Errorf("Expected around %d successful requests, got %d", expectedRequests, successCount) + } + + if requests-expectedRequests != 0 && totalRequests <= successCount { + t.Errorf("Expected some requests to be rate limited") + } +} + +func setupRouter() *gin.Engine { + gin.SetMode(gin.TestMode) + router := gin.New() + limiter := NewIPRateLimiter(limit, burst) + router.Use(RateLimitMiddleware(limiter)) + router.GET("/test", func(c *gin.Context) { + c.String(http.StatusOK, "success") + }) + return router +} + +func makeRequest(router *gin.Engine, ip string) int { + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + req.Header.Set("X-Forwarded-For", ip) // Set a consistent IP + router.ServeHTTP(w, req) + return w.Code +} diff --git a/internal/server/server.go b/internal/server/server.go index 5035aa1..613d23c 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -14,7 +14,6 @@ import ( "GoMembership/internal/middlewares" "GoMembership/internal/repositories" - // "GoMembership/internal/middlewares" "GoMembership/internal/routes" "GoMembership/internal/services" "GoMembership/pkg/logger" @@ -46,6 +45,7 @@ func Run() { membershipController := &controllers.MembershipController{Service: *membershipService} contactController := &controllers.ContactController{EmailService: emailService} + router := gin.Default() // gin.SetMode(gin.ReleaseMode) router.Static(config.Templates.StaticPath, "./style") @@ -55,6 +55,8 @@ func Run() { router.Use(gin.Logger()) router.Use(middlewares.CORSMiddleware()) + limiter := middlewares.NewIPRateLimiter(config.Security.Ratelimits.Limit, config.Security.Ratelimits.Burst) + router.Use(middlewares.RateLimitMiddleware(limiter)) routes.RegisterRoutes(router, userController, membershipController, contactController) // create subrouter for teh authenticated area /account // also pthprefix matches everything below /account