backend moved to separate directory

backend: deleted the old structure
This commit is contained in:
Alex
2025-02-28 08:52:04 +01:00
parent ad599ae3f4
commit 2ffd1f439f
88 changed files with 112 additions and 9 deletions

View File

@@ -0,0 +1,31 @@
package middlewares
import (
"crypto/subtle"
"net/http"
"github.com/gin-gonic/gin"
"GoMembership/internal/config"
)
func APIKeyMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
clientAPIKey := c.GetHeader("X-API-Key")
if clientAPIKey == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "API key is missing"})
c.Abort()
return
}
// Using subtle.ConstantTimeCompare to mitigate timing attacks
if subtle.ConstantTimeCompare([]byte(clientAPIKey), []byte(config.Auth.APIKEY)) != 1 {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid API key"})
c.Abort()
return
}
c.Next()
}
}

View File

@@ -0,0 +1,61 @@
package middlewares
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"GoMembership/internal/config"
)
func TestAPIKeyMiddleware(t *testing.T) {
// Set up a test API key
testAPIKey := "test-api-key-12345"
config.Auth.APIKEY = testAPIKey
// Set Gin to Test Mode
gin.SetMode(gin.TestMode)
// Tests table
tests := []struct {
name string
apiKey string
wantStatus int
}{
{"Valid API Key", testAPIKey, http.StatusOK},
{"Missing API Key", "", http.StatusUnauthorized},
{"Invalid API Key", "wrong-key", http.StatusUnauthorized},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Set up a new test router and handler
router := gin.New()
router.Use(APIKeyMiddleware())
router.GET("/test", func(c *gin.Context) {
c.Status(http.StatusOK)
})
// Create a test request
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
if tt.apiKey != "" {
req.Header.Set("X-API-Key", tt.apiKey)
}
// Serve the request
router.ServeHTTP(w, req)
// Assert the response
assert.Equal(t, tt.wantStatus, w.Code)
// Additional assertions for specific cases
if tt.wantStatus == http.StatusUnauthorized {
assert.Contains(t, w.Body.String(), "API key")
}
})
}
}

View File

@@ -0,0 +1,179 @@
package middlewares
import (
"GoMembership/internal/config"
"GoMembership/internal/models"
"GoMembership/internal/utils"
customerrors "GoMembership/pkg/errors"
"GoMembership/pkg/logger"
"errors"
"fmt"
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
)
type Session struct {
UserID uint
ExpiresAt time.Time
}
var (
sessionDuration = 5 * 24 * time.Hour
jwtSigningMethod = jwt.SigningMethodHS256
jwtParser = jwt.NewParser(jwt.WithValidMethods([]string{jwtSigningMethod.Alg()}))
sessions = make(map[string]*Session)
)
func verifyAndRenewToken(tokenString string) (string, uint, error) {
if tokenString == "" {
logger.Error.Printf("empty tokenstring")
return "", 0, fmt.Errorf("Authorization token is required")
}
token, claims, err := ExtractContentFrom(tokenString)
if err != nil {
logger.Error.Printf("Couldn't parse JWT token String: %v", err)
return "", 0, err
}
sessionID := (*claims)["session_id"].(string)
userID := uint((*claims)["user_id"].(float64))
roleID := int8((*claims)["role_id"].(float64))
session, ok := sessions[sessionID]
if !ok {
logger.Error.Printf("session not found")
return "", 0, fmt.Errorf("session not found")
}
if userID != session.UserID {
return "", 0, fmt.Errorf("Cookie has been altered, aborting..")
}
if token.Valid {
// token is valid, so we can return the old tokenString
return tokenString, session.UserID, customerrors.ErrValidToken
}
if time.Now().After(sessions[sessionID].ExpiresAt) {
delete(sessions, sessionID)
logger.Error.Printf("session expired")
return "", 0, fmt.Errorf("session expired")
}
session.ExpiresAt = time.Now().Add(sessionDuration)
logger.Error.Printf("Session still valid generating new token")
// Session is still valid, generate a new token
user := models.User{ID: userID, RoleID: roleID}
newTokenString, err := GenerateToken(config.Auth.JWTSecret, &user, sessionID)
if err != nil {
return "", 0, err
}
return newTokenString, session.UserID, nil
}
func AuthMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
tokenString, err := c.Cookie("jwt")
if err != nil {
logger.Error.Printf("No Auth token: %v\n", err)
c.JSON(http.StatusUnauthorized,
gin.H{"errors": []gin.H{{
"field": "general",
"key": "server.error.no_auth_token",
}}})
c.Abort()
return
}
newToken, userID, err := verifyAndRenewToken(tokenString)
if err != nil {
if err == customerrors.ErrValidToken {
c.Set("user_id", uint(userID))
c.Next()
return
}
logger.Error.Printf("Token(%v) is invalid: %v\n", tokenString, err)
c.JSON(http.StatusUnauthorized,
gin.H{"errors": []gin.H{{
"field": "general",
"key": "server.error.no_auth_token",
}}})
c.Abort()
return
}
utils.SetCookie(c, newToken)
c.Set("user_id", uint(userID))
c.Next()
}
}
func GenerateToken(jwtKey string, user *models.User, sessionID string) (string, error) {
if sessionID == "" {
sessionID = uuid.New().String()
}
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{
"user_id": user.ID,
"role_id": user.RoleID,
"session_id": sessionID,
"exp": time.Now().Add(time.Minute * 1).Unix(), // Token expires in 10 Minutes
})
UpdateSession(sessionID, user.ID)
return token.SignedString([]byte(jwtKey))
}
func ExtractContentFrom(tokenString string) (*jwt.Token, *jwt.MapClaims, error) {
token, err := jwtParser.Parse(tokenString, func(_ *jwt.Token) (interface{}, error) {
return []byte(config.Auth.JWTSecret), nil
})
if !errors.Is(err, jwt.ErrTokenExpired) && err != nil {
logger.Error.Printf("Error during token(%v) parsing: %#v", tokenString, err)
return nil, nil, err
}
// Token is expired, check if session is still valid
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
logger.Error.Printf("Invalid Token Claims")
return nil, nil, fmt.Errorf("invalid token claims")
}
if !ok {
logger.Error.Printf("invalid session_id in token")
return nil, nil, fmt.Errorf("invalid session_id in token")
}
return token, &claims, nil
}
func UpdateSession(sessionID string, userID uint) {
sessions[sessionID] = &Session{
UserID: userID,
ExpiresAt: time.Now().Add(sessionDuration),
}
}
func InvalidateSession(token string) (bool, error) {
claims := jwt.MapClaims{}
_, err := jwt.ParseWithClaims(
token,
claims,
func(token *jwt.Token) (interface{}, error) {
return config.Auth.JWTSecret, nil
},
)
if err != nil {
return false, fmt.Errorf("Couldn't get JWT claims: %#v", err)
}
sessionID, ok := claims["session_id"].(string)
if !ok {
return false, fmt.Errorf("No SessionID found")
}
delete(sessions, sessionID)
return true, nil
}

View File

@@ -0,0 +1,191 @@
package middlewares
import (
"GoMembership/internal/config"
"GoMembership/internal/constants"
"GoMembership/internal/models"
"GoMembership/pkg/logger"
"encoding/json"
"log"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert"
)
func setupTestEnvironment() {
cwd, err := os.Getwd()
if err != nil {
log.Fatalf("Failed to get current working directory: %v", err)
}
configFilePath := filepath.Join(cwd, "..", "..", "configs", "config.json")
templateHTMLPath := filepath.Join(cwd, "..", "..", "templates", "html")
templateMailPath := filepath.Join(cwd, "..", "..", "templates", "email")
if err := os.Setenv("TEMPLATE_MAIL_PATH", templateMailPath); err != nil {
log.Fatalf("Error setting environment variable: %v", err)
}
if err := os.Setenv("TEMPLATE_HTML_PATH", templateHTMLPath); err != nil {
log.Fatalf("Error setting environment variable: %v", err)
}
if err := os.Setenv("CONFIG_FILE_PATH", configFilePath); err != nil {
log.Fatalf("Error setting environment variable: %v", err)
}
config.LoadConfig()
logger.Info.Printf("Config: %#v", config.CFG)
}
func TestAuthMiddleware(t *testing.T) {
gin.SetMode(gin.TestMode)
setupTestEnvironment()
tests := []struct {
name string
setupAuth func(r *http.Request)
expectedStatus int
expectNewCookie bool
expectedUserID uint
}{
{
name: "Valid Token",
setupAuth: func(r *http.Request) {
user := models.User{ID: 123, RoleID: constants.Roles.Member}
token, _ := GenerateToken(config.Auth.JWTSecret, &user, "")
r.AddCookie(&http.Cookie{Name: "jwt", Value: token})
},
expectedStatus: http.StatusOK,
expectedUserID: 123,
},
{
name: "Missing Cookie",
setupAuth: func(r *http.Request) {},
expectedStatus: http.StatusUnauthorized,
expectedUserID: 0,
},
{
name: "Invalid Token",
setupAuth: func(r *http.Request) {
r.AddCookie(&http.Cookie{Name: "jwt", Value: "InvalidToken"})
},
expectedStatus: http.StatusUnauthorized,
expectedUserID: 0,
},
{
name: "Expired Token with Valid Session",
setupAuth: func(r *http.Request) {
sessionID := "test-session"
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{
"user_id": 123,
"role_id": constants.Roles.Member,
"session_id": sessionID,
"exp": time.Now().Add(-time.Hour).Unix(), // Expired 1 hour ago
})
tokenString, _ := token.SignedString([]byte(config.Auth.JWTSecret))
r.AddCookie(&http.Cookie{Name: "jwt", Value: tokenString})
UpdateSession(sessionID, 123) // Add a valid session
},
expectedStatus: http.StatusOK,
expectNewCookie: true,
expectedUserID: 123,
},
{
name: "Expired Token with Expired Session",
setupAuth: func(r *http.Request) {
sessionID := "expired-session"
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{
"user_id": 123,
"role_id": constants.Roles.Member,
"session_id": sessionID,
"exp": time.Now().Add(-time.Hour).Unix(), // Expired 1 hour ago
})
tokenString, _ := token.SignedString([]byte(config.Auth.JWTSecret))
r.AddCookie(&http.Cookie{Name: "jwt", Value: tokenString})
// Don't add a session, simulating an expired session
},
expectedStatus: http.StatusUnauthorized,
expectedUserID: 0,
},
{
name: "Invalid Signature",
setupAuth: func(r *http.Request) {
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{
"user_id": 123,
"session_id": "some-session",
"exp": time.Now().Add(time.Hour).Unix(),
})
tokenString, _ := token.SignedString([]byte("wrong_secret"))
r.AddCookie(&http.Cookie{Name: "jwt", Value: tokenString})
},
expectedStatus: http.StatusUnauthorized,
expectedUserID: 0,
},
{
name: "Invalid Signing Method",
setupAuth: func(r *http.Request) {
token := jwt.NewWithClaims(jwt.SigningMethodES256, jwt.MapClaims{
"user_id": 123,
"session_id": "some-session",
"role_id": constants.Roles.Member,
"exp": time.Now().Add(time.Hour).Unix(),
})
tokenString, _ := token.SignedString([]byte(config.Auth.JWTSecret))
r.AddCookie(&http.Cookie{Name: "jwt", Value: tokenString})
},
expectedStatus: http.StatusUnauthorized,
expectedUserID: 0,
},
}
for _, tt := range tests {
logger.Error.Print("==============================================================")
logger.Error.Printf("Testing : %v", tt.name)
logger.Error.Print("==============================================================")
t.Run(tt.name, func(t *testing.T) {
// Setup
r := gin.New()
r.Use(AuthMiddleware())
r.GET("/test", func(c *gin.Context) {
userID, exists := c.Get("user_id")
if exists {
c.JSON(http.StatusOK, gin.H{"user_id": userID})
} else {
c.JSON(http.StatusUnauthorized, gin.H{"user_id": 0})
}
})
req, _ := http.NewRequest(http.MethodGet, "/test", nil)
tt.setupAuth(req)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, tt.expectedStatus, w.Code)
if tt.expectedStatus == http.StatusOK {
var response map[string]uint
err := json.Unmarshal(w.Body.Bytes(), &response)
assert.NoError(t, err)
assert.Equal(t, tt.expectedUserID, response["user_id"])
// Check if a new cookie was set
cookies := w.Result().Cookies()
if tt.expectNewCookie {
assert.GreaterOrEqual(t, len(cookies), 1)
assert.Equal(t, "jwt", cookies[0].Name)
assert.NotEmpty(t, cookies[0].Value)
} else {
assert.Equal(t, 0, len(cookies), "Unexpected cookie set")
}
} else {
assert.Equal(t, 0, len(w.Result().Cookies()))
}
})
}
}

View File

@@ -0,0 +1,22 @@
package middlewares
import (
"GoMembership/internal/config"
"GoMembership/pkg/logger"
"strings"
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
)
func CORSMiddleware() gin.HandlerFunc {
logger.Info.Print("Applying CORS")
return cors.New(cors.Config{
AllowOrigins: strings.Split(config.Site.AllowOrigins, ","),
AllowMethods: []string{"GET", "POST", "PATCH", "PUT", "OPTIONS"},
AllowHeaders: []string{"Origin", "Content-Type", "Accept", "Authorization", "X-Requested-With", "X-CSRF-Token"},
ExposeHeaders: []string{"Content-Length"},
AllowCredentials: true,
MaxAge: 12 * 60 * 60, // 12 hours
})
}

View File

@@ -0,0 +1,104 @@
package middlewares
import (
"log"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strconv"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"GoMembership/internal/config"
"GoMembership/pkg/logger"
)
const (
Host = "127.0.0.1"
Port int = 2525
)
func TestCORSMiddleware(t *testing.T) {
cwd, err := os.Getwd()
if err != nil {
log.Fatalf("Failed to get current working directory: %v", err)
}
configFilePath := filepath.Join(cwd, "..", "..", "configs", "config.json")
templateHTMLPath := filepath.Join(cwd, "..", "..", "templates", "html")
templateMailPath := filepath.Join(cwd, "..", "..", "templates", "email")
if err := os.Setenv("TEMPLATE_MAIL_PATH", templateMailPath); err != nil {
log.Fatalf("Error setting environment variable: %v", err)
}
if err := os.Setenv("TEMPLATE_HTML_PATH", templateHTMLPath); err != nil {
log.Fatalf("Error setting environment variable: %v", err)
}
if err := os.Setenv("CONFIG_FILE_PATH", configFilePath); err != nil {
log.Fatalf("Error setting environment variable: %v", err)
}
if err := os.Setenv("SMTP_HOST", Host); err != nil {
log.Fatalf("Error setting environment variable: %v", err)
}
if err := os.Setenv("SMTP_PORT", strconv.Itoa(Port)); err != nil {
log.Fatalf("Error setting environment variable: %v", err)
}
if err := os.Setenv("BASE_URL", "http://"+Host+":2525"); err != nil {
log.Fatalf("Error setting environment variable: %v", err)
}
// Load your configuration
config.LoadConfig()
// Create a gin router with the CORS middleware
router := gin.New()
router.Use(CORSMiddleware())
// Add a simple handler
router.GET("/test", func(c *gin.Context) {
c.String(200, "test")
})
tests := []struct {
name string
origin string
expectedStatus int
expectedHeaders map[string]string
}{
{
name: "Allowed origin",
origin: config.Site.AllowOrigins,
expectedStatus: http.StatusOK,
expectedHeaders: map[string]string{
"Access-Control-Allow-Origin": config.Site.AllowOrigins,
"Content-Type": "text/plain; charset=utf-8",
"Access-Control-Allow-Credentials": "true",
},
},
{
name: "Disallowed origin",
origin: "http://evil.com",
expectedStatus: http.StatusForbidden,
expectedHeaders: map[string]string{
"Access-Control-Allow-Origin": "",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
req.Header.Set("Origin", tt.origin)
router.ServeHTTP(w, req)
assert.Equal(t, tt.expectedStatus, w.Code)
logger.Info.Printf("Recieved Headers: %#v", w.Header())
for key, value := range tt.expectedHeaders {
assert.Equal(t, value, w.Header().Get(key))
}
})
}
}

View File

@@ -0,0 +1,44 @@
package middlewares
import (
"GoMembership/internal/config"
"GoMembership/pkg/logger"
"net/http"
"github.com/gin-gonic/gin"
)
func CSPMiddleware() gin.HandlerFunc {
logger.Error.Printf("applying CSP")
return func(c *gin.Context) {
policy := "default-src 'self'; " +
"script-src 'self' 'unsafe-inline'" +
"style-src 'self' 'unsafe-inline'" +
"img-src 'self'" +
"font-src 'self'" +
"connect-src 'self'; " +
"frame-ancestors 'none'; " +
"form-action 'self'; " +
"base-uri 'self'; " +
"upgrade-insecure-requests;"
if config.Env == "development" {
policy += " report-uri /csp-report;"
c.Header("Content-Security-Policy-Report-Only", policy)
} else {
c.Header("Content-Security-Policy", policy)
}
c.Next()
}
}
func CSPReportHandling(c *gin.Context) {
var report map[string]interface{}
if err := c.BindJSON(&report); err != nil {
logger.Error.Printf("Couldn't Bind JSON: %#v", err)
return
}
logger.Info.Printf("CSP Violation: %+v", report)
c.Status(http.StatusNoContent)
}

View File

@@ -0,0 +1,81 @@
package middlewares
import (
"GoMembership/internal/config"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
)
func TestCSPMiddleware(t *testing.T) {
// Save the current environment and restore it after the test
originalEnv := config.Env
tests := []struct {
name string
environment string
expectedHeader string
expectedPolicy string
}{
{
name: "Development Environment",
environment: "development",
expectedHeader: "Content-Security-Policy-Report-Only",
expectedPolicy: "default-src 'self'; " +
"script-src 'self' 'unsafe-inline'" +
"style-src 'self' 'unsafe-inline'" +
"img-src 'self'" +
"font-src 'self'" +
"connect-src 'self'; " +
"frame-ancestors 'none'; " +
"form-action 'self'; " +
"base-uri 'self'; " +
"upgrade-insecure-requests; report-uri /csp-report;",
},
{
name: "Production Environment",
environment: "production",
expectedHeader: "Content-Security-Policy",
expectedPolicy: "default-src 'self'; " +
"script-src 'self' 'unsafe-inline'" +
"style-src 'self' 'unsafe-inline'" +
"img-src 'self'" +
"font-src 'self'" +
"connect-src 'self'; " +
"frame-ancestors 'none'; " +
"form-action 'self'; " +
"base-uri 'self'; " +
"upgrade-insecure-requests;",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Set up the test environment
config.Env = tt.environment
// Create a new Gin router with the middleware
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(CSPMiddleware())
router.GET("/test", func(c *gin.Context) {
c.String(http.StatusOK, "test")
})
// Create a test request
req, _ := http.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
// Serve the request
router.ServeHTTP(w, req)
// Check the response
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, tt.expectedPolicy, w.Header().Get(tt.expectedHeader))
})
}
config.Env = originalEnv
}

View File

@@ -0,0 +1,21 @@
package middlewares
import (
"GoMembership/pkg/logger"
"github.com/gin-gonic/gin"
)
func SecurityHeadersMiddleware() gin.HandlerFunc {
logger.Error.Printf("applying headers")
return func(c *gin.Context) {
c.Header("X-Frame-Options", "DENY")
c.Header("X-Content-Type-Options", "nosniff")
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
c.Header("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
c.Header("X-XSS-Protection", "1; mode=block")
c.Header("Feature-Policy", "geolocation 'none'; midi 'none'; sync-xhr 'none'; microphone 'none'; camera 'none'; magnetometer 'none'; gyroscope 'none'; speaker 'none'; fullscreen 'self'; payment 'none'")
c.Header("Permissions-Policy", "geolocation=(), midi=(), sync-xhr=(), microphone=(), camera=(), magnetometer=(), gyroscope=(), fullscreen=(self), payment=()")
c.Next()
}
}

View File

@@ -0,0 +1,83 @@
package middlewares
import (
"GoMembership/pkg/logger"
"net/http"
"sync"
"github.com/gin-gonic/gin"
"golang.org/x/time/rate"
)
type IPRateLimiter struct {
ips map[string]*rate.Limiter
mu *sync.RWMutex
r rate.Limit
b int
}
func NewIPRateLimiter(r int, b int) *IPRateLimiter {
return &IPRateLimiter{
ips: make(map[string]*rate.Limiter),
mu: &sync.RWMutex{},
r: rate.Limit(r),
b: b,
}
}
func (i *IPRateLimiter) GetLimiter(ip string) *rate.Limiter {
i.mu.Lock()
defer i.mu.Unlock()
limiter, exists := i.ips[ip]
if !exists {
limiter = rate.NewLimiter(i.r, i.b)
i.ips[ip] = limiter
}
return limiter
}
// func RateLimitMiddleware() gin.HandlerFunc {
// if iPLimiter == nil {
// iPLimiter := NewIPRateLimiter(
// rate.Limit(config.Security.Ratelimits.Limit),
// config.Security.Ratelimits.Burst)
// }
// return func(c *gin.Context) {
// ip := c.ClientIP()
// l := iPLimiter.GetLimiter(ip)
// if !l.Allow() {
// c.JSON(http.StatusTooManyRequests, gin.H{
// "error": "Too many requests",
// })
// c.Abort()
// return
// }
// c.Next()
// }
// }
func RateLimitMiddleware(limiter *IPRateLimiter) gin.HandlerFunc {
logger.Info.Printf("Limiter with Limit: %v, Burst: %v", limiter.r, limiter.b)
return func(c *gin.Context) {
if limiter == nil {
logger.Error.Println("Limiter missing")
c.AbortWithStatus(http.StatusInternalServerError)
return
}
ip := c.ClientIP()
l := limiter.GetLimiter(ip)
if !l.Allow() {
c.JSON(http.StatusTooManyRequests, gin.H{
"error": "Too many requests",
})
c.Abort()
return
}
c.Next()
}
}

View File

@@ -0,0 +1,143 @@
package middlewares
import (
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/gin-gonic/gin"
)
var (
limit = 1
burst = 60
)
func TestRateLimitMiddleware(t *testing.T) {
gin.SetMode(gin.TestMode)
// Create a new rate limiter that allows 2 requests per second with a burst of 4
// Create a new Gin router with the rate limit middleware
router := setupRouter()
// Test cases
tests := []struct {
name string
requests int
expectedStatus int
sleep time.Duration
}{
{"Allow first request", 1, http.StatusOK, 0},
{"Allow up to burst limit", burst - 1, http.StatusOK, 0},
{"Block after burst limit", burst + 20, http.StatusTooManyRequests, 0},
{"Allow after rate limit replenishes", 1, http.StatusOK, time.Second},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
time.Sleep(tt.sleep)
var status int
for i := 0; i < tt.requests; i++ {
status = makeRequest(router, "192.168.0.2")
}
if status != tt.expectedStatus {
t.Errorf("Expected status %d, got %d", tt.expectedStatus, status)
}
})
}
}
func TestIPRateLimiter(t *testing.T) {
limiter := NewIPRateLimiter(limit, burst)
limiter1 := limiter.GetLimiter("127.0.0.1")
limiter2 := limiter.GetLimiter("192.168.0.1")
if limiter1 == limiter2 {
t.Error("Expected different limiters for different IPs")
}
limiter3 := limiter.GetLimiter("127.0.0.1")
if limiter1 != limiter3 {
t.Error("Expected the same limiter for the same IP")
}
}
func TestDifferentRateLimits(t *testing.T) {
testCases := []struct {
name string
duration time.Duration
requests int
expectedRequests int
ip string
}{
{"Low rate", 5 * time.Second, burst + 5*limit, burst + 5*limit, "192.168.23.3"},
{"Low rate with limiting", 4 * time.Second, burst + 4*limit + 4, burst + 4*limit, "192.168.23.4"},
{"High rate", 1 * time.Second, burst + 5*limit, burst + limit, "192.168.23.5"},
{"Fractional rate", 10 * time.Second, (burst + limit) / 2, (burst + limit) / 2, "192.168.23.6"},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
router := setupRouter()
testRateLimit(t, router, tc.duration, tc.requests, tc.expectedRequests, tc.ip)
})
}
}
func testRateLimit(t *testing.T, router *gin.Engine, duration time.Duration, requests int, expectedRequests int, ip string) {
start := time.Now()
successCount := 0
totalRequests := 0
t.Logf("Sleeping for: %v", time.Duration(duration.Nanoseconds()/int64(requests)))
for time.Since(start) < duration {
status := makeRequest(router, ip)
totalRequests++
if status == http.StatusOK {
successCount++
}
time.Sleep(time.Duration(duration.Nanoseconds() / int64(requests)))
}
actualDuration := time.Since(start)
t.Logf("limit: %v, burst: %v", limit, burst)
t.Logf("Test duration: %v", actualDuration)
t.Logf("Successful requests: %d", successCount)
t.Logf("Expected successful requests: %d", expectedRequests)
t.Logf("Total requests: %d", totalRequests)
if successCount < int(expectedRequests)-4 || successCount > int(expectedRequests)+4 {
t.Errorf("Expected around %d successful requests, got %d", expectedRequests, successCount)
}
if requests-expectedRequests != 0 && totalRequests <= successCount {
t.Errorf("Expected some requests to be rate limited")
}
}
func setupRouter() *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
limiter := NewIPRateLimiter(limit, burst)
router.Use(RateLimitMiddleware(limiter))
router.GET("/test", func(c *gin.Context) {
c.String(http.StatusOK, "success")
})
return router
}
func makeRequest(router *gin.Engine, ip string) int {
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
req.Header.Set("X-Forwarded-For", ip) // Set a consistent IP
router.ServeHTTP(w, req)
return w.Code
}