add: Rate limiting & tests
This commit is contained in:
@@ -48,6 +48,12 @@ type RecipientsConfig struct {
|
||||
UserRegistration string `json:"UserRegistration" envconfig:"RECIPIENT_USER_REGISTRATION"`
|
||||
}
|
||||
|
||||
type SecurityConfig struct {
|
||||
Ratelimits struct {
|
||||
Limit int `json:"Limit" default:"1" envconfig:"RATE_LIMIT"`
|
||||
Burst int `json:"Burst" default:"60" envconfig:"BURST_LIMIT"`
|
||||
} `json:"RateLimits"`
|
||||
}
|
||||
type Config struct {
|
||||
Auth AuthenticationConfig `json:"auth"`
|
||||
Templates TemplateConfig `json:"templates"`
|
||||
@@ -56,6 +62,7 @@ type Config struct {
|
||||
BaseURL string `json:"BaseUrl" envconfig:"BASE_URL"`
|
||||
DB DatabaseConfig `json:"db"`
|
||||
SMTP SMTPConfig `json:"smtp"`
|
||||
Security SecurityConfig `json:"security"`
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -67,6 +74,7 @@ var (
|
||||
Templates TemplateConfig
|
||||
SMTP SMTPConfig
|
||||
Recipients RecipientsConfig
|
||||
Security SecurityConfig
|
||||
)
|
||||
|
||||
// LoadConfig initializes the configuration by reading from a file and environment variables.
|
||||
@@ -94,6 +102,7 @@ func LoadConfig() {
|
||||
SMTP = CFG.SMTP
|
||||
BaseURL = CFG.BaseURL
|
||||
Recipients = CFG.Recipients
|
||||
Security = CFG.Security
|
||||
}
|
||||
|
||||
// readFile reads the configuration from the specified file path into the provided Config struct.
|
||||
|
||||
83
internal/middlewares/rate_limit.go
Normal file
83
internal/middlewares/rate_limit.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package middlewares
|
||||
|
||||
import (
|
||||
"GoMembership/pkg/logger"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
type IPRateLimiter struct {
|
||||
ips map[string]*rate.Limiter
|
||||
mu *sync.RWMutex
|
||||
r rate.Limit
|
||||
b int
|
||||
}
|
||||
|
||||
func NewIPRateLimiter(r int, b int) *IPRateLimiter {
|
||||
return &IPRateLimiter{
|
||||
ips: make(map[string]*rate.Limiter),
|
||||
mu: &sync.RWMutex{},
|
||||
r: rate.Limit(r),
|
||||
b: b,
|
||||
}
|
||||
}
|
||||
|
||||
func (i *IPRateLimiter) GetLimiter(ip string) *rate.Limiter {
|
||||
i.mu.Lock()
|
||||
defer i.mu.Unlock()
|
||||
|
||||
limiter, exists := i.ips[ip]
|
||||
if !exists {
|
||||
limiter = rate.NewLimiter(i.r, i.b)
|
||||
i.ips[ip] = limiter
|
||||
}
|
||||
|
||||
return limiter
|
||||
}
|
||||
|
||||
// func RateLimitMiddleware() gin.HandlerFunc {
|
||||
// if iPLimiter == nil {
|
||||
|
||||
// iPLimiter := NewIPRateLimiter(
|
||||
// rate.Limit(config.Security.Ratelimits.Limit),
|
||||
// config.Security.Ratelimits.Burst)
|
||||
// }
|
||||
|
||||
// return func(c *gin.Context) {
|
||||
// ip := c.ClientIP()
|
||||
// l := iPLimiter.GetLimiter(ip)
|
||||
// if !l.Allow() {
|
||||
// c.JSON(http.StatusTooManyRequests, gin.H{
|
||||
// "error": "Too many requests",
|
||||
// })
|
||||
// c.Abort()
|
||||
// return
|
||||
// }
|
||||
// c.Next()
|
||||
// }
|
||||
// }
|
||||
|
||||
func RateLimitMiddleware(limiter *IPRateLimiter) gin.HandlerFunc {
|
||||
logger.Info.Printf("Limiter with Limit: %v, Burst: %v", limiter.r, limiter.b)
|
||||
return func(c *gin.Context) {
|
||||
if limiter == nil {
|
||||
logger.Error.Println("Limiter missing")
|
||||
c.AbortWithStatus(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
ip := c.ClientIP()
|
||||
l := limiter.GetLimiter(ip)
|
||||
if !l.Allow() {
|
||||
c.JSON(http.StatusTooManyRequests, gin.H{
|
||||
"error": "Too many requests",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
143
internal/middlewares/rate_limit_test.go
Normal file
143
internal/middlewares/rate_limit_test.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package middlewares
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
var (
|
||||
limit = 1
|
||||
burst = 60
|
||||
)
|
||||
|
||||
func TestRateLimitMiddleware(t *testing.T) {
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// Create a new rate limiter that allows 2 requests per second with a burst of 4
|
||||
|
||||
// Create a new Gin router with the rate limit middleware
|
||||
router := setupRouter()
|
||||
|
||||
// Test cases
|
||||
tests := []struct {
|
||||
name string
|
||||
requests int
|
||||
expectedStatus int
|
||||
sleep time.Duration
|
||||
}{
|
||||
{"Allow first request", 1, http.StatusOK, 0},
|
||||
{"Allow up to burst limit", burst - 1, http.StatusOK, 0},
|
||||
{"Block after burst limit", burst + 20, http.StatusTooManyRequests, 0},
|
||||
{"Allow after rate limit replenishes", 1, http.StatusOK, time.Second},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
time.Sleep(tt.sleep)
|
||||
var status int
|
||||
for i := 0; i < tt.requests; i++ {
|
||||
status = makeRequest(router, "192.168.0.2")
|
||||
|
||||
}
|
||||
if status != tt.expectedStatus {
|
||||
t.Errorf("Expected status %d, got %d", tt.expectedStatus, status)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPRateLimiter(t *testing.T) {
|
||||
limiter := NewIPRateLimiter(limit, burst)
|
||||
|
||||
limiter1 := limiter.GetLimiter("127.0.0.1")
|
||||
limiter2 := limiter.GetLimiter("192.168.0.1")
|
||||
|
||||
if limiter1 == limiter2 {
|
||||
t.Error("Expected different limiters for different IPs")
|
||||
}
|
||||
|
||||
limiter3 := limiter.GetLimiter("127.0.0.1")
|
||||
if limiter1 != limiter3 {
|
||||
t.Error("Expected the same limiter for the same IP")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDifferentRateLimits(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
duration time.Duration
|
||||
requests int
|
||||
expectedRequests int
|
||||
ip string
|
||||
}{
|
||||
{"Low rate", 5 * time.Second, burst + 5*limit, burst + 5*limit, "192.168.23.3"},
|
||||
{"Low rate with limiting", 4 * time.Second, burst + 4*limit + 4, burst + 4*limit, "192.168.23.4"},
|
||||
{"High rate", 1 * time.Second, burst + 5*limit, burst + limit, "192.168.23.5"},
|
||||
{"Fractional rate", 10 * time.Second, (burst + limit) / 2, (burst + limit) / 2, "192.168.23.6"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
router := setupRouter()
|
||||
testRateLimit(t, router, tc.duration, tc.requests, tc.expectedRequests, tc.ip)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testRateLimit(t *testing.T, router *gin.Engine, duration time.Duration, requests int, expectedRequests int, ip string) {
|
||||
start := time.Now()
|
||||
successCount := 0
|
||||
totalRequests := 0
|
||||
|
||||
t.Logf("Sleeping for: %v", time.Duration(duration.Nanoseconds()/int64(requests)))
|
||||
for time.Since(start) < duration {
|
||||
status := makeRequest(router, ip)
|
||||
totalRequests++
|
||||
|
||||
if status == http.StatusOK {
|
||||
successCount++
|
||||
}
|
||||
|
||||
time.Sleep(time.Duration(duration.Nanoseconds() / int64(requests)))
|
||||
}
|
||||
|
||||
actualDuration := time.Since(start)
|
||||
|
||||
t.Logf("limit: %v, burst: %v", limit, burst)
|
||||
t.Logf("Test duration: %v", actualDuration)
|
||||
t.Logf("Successful requests: %d", successCount)
|
||||
t.Logf("Expected successful requests: %d", expectedRequests)
|
||||
t.Logf("Total requests: %d", totalRequests)
|
||||
|
||||
if successCount < int(expectedRequests)-4 || successCount > int(expectedRequests)+4 {
|
||||
t.Errorf("Expected around %d successful requests, got %d", expectedRequests, successCount)
|
||||
}
|
||||
|
||||
if requests-expectedRequests != 0 && totalRequests <= successCount {
|
||||
t.Errorf("Expected some requests to be rate limited")
|
||||
}
|
||||
}
|
||||
|
||||
func setupRouter() *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
limiter := NewIPRateLimiter(limit, burst)
|
||||
router.Use(RateLimitMiddleware(limiter))
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "success")
|
||||
})
|
||||
return router
|
||||
}
|
||||
|
||||
func makeRequest(router *gin.Engine, ip string) int {
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("X-Forwarded-For", ip) // Set a consistent IP
|
||||
router.ServeHTTP(w, req)
|
||||
return w.Code
|
||||
}
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
"GoMembership/internal/middlewares"
|
||||
"GoMembership/internal/repositories"
|
||||
|
||||
// "GoMembership/internal/middlewares"
|
||||
"GoMembership/internal/routes"
|
||||
"GoMembership/internal/services"
|
||||
"GoMembership/pkg/logger"
|
||||
@@ -46,6 +45,7 @@ func Run() {
|
||||
membershipController := &controllers.MembershipController{Service: *membershipService}
|
||||
|
||||
contactController := &controllers.ContactController{EmailService: emailService}
|
||||
|
||||
router := gin.Default()
|
||||
// gin.SetMode(gin.ReleaseMode)
|
||||
router.Static(config.Templates.StaticPath, "./style")
|
||||
@@ -55,6 +55,8 @@ func Run() {
|
||||
router.Use(gin.Logger())
|
||||
router.Use(middlewares.CORSMiddleware())
|
||||
|
||||
limiter := middlewares.NewIPRateLimiter(config.Security.Ratelimits.Limit, config.Security.Ratelimits.Burst)
|
||||
router.Use(middlewares.RateLimitMiddleware(limiter))
|
||||
routes.RegisterRoutes(router, userController, membershipController, contactController)
|
||||
// create subrouter for teh authenticated area /account
|
||||
// also pthprefix matches everything below /account
|
||||
|
||||
Reference in New Issue
Block a user