frontend: disabled button while processing password reset
This commit is contained in:
31
go-backend/internal/middlewares/api.go
Normal file
31
go-backend/internal/middlewares/api.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package middlewares
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"GoMembership/internal/config"
|
||||
)
|
||||
|
||||
func APIKeyMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
clientAPIKey := c.GetHeader("X-API-Key")
|
||||
|
||||
if clientAPIKey == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "API key is missing"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Using subtle.ConstantTimeCompare to mitigate timing attacks
|
||||
if subtle.ConstantTimeCompare([]byte(clientAPIKey), []byte(config.Auth.APIKEY)) != 1 {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid API key"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
61
go-backend/internal/middlewares/api_test.go
Normal file
61
go-backend/internal/middlewares/api_test.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package middlewares
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"GoMembership/internal/config"
|
||||
)
|
||||
|
||||
func TestAPIKeyMiddleware(t *testing.T) {
|
||||
// Set up a test API key
|
||||
testAPIKey := "test-api-key-12345"
|
||||
config.Auth.APIKEY = testAPIKey
|
||||
|
||||
// Set Gin to Test Mode
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// Tests table
|
||||
tests := []struct {
|
||||
name string
|
||||
apiKey string
|
||||
wantStatus int
|
||||
}{
|
||||
{"Valid API Key", testAPIKey, http.StatusOK},
|
||||
{"Missing API Key", "", http.StatusUnauthorized},
|
||||
{"Invalid API Key", "wrong-key", http.StatusUnauthorized},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Set up a new test router and handler
|
||||
router := gin.New()
|
||||
router.Use(APIKeyMiddleware())
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
// Create a test request
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/test", nil)
|
||||
if tt.apiKey != "" {
|
||||
req.Header.Set("X-API-Key", tt.apiKey)
|
||||
}
|
||||
|
||||
// Serve the request
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// Assert the response
|
||||
assert.Equal(t, tt.wantStatus, w.Code)
|
||||
|
||||
// Additional assertions for specific cases
|
||||
if tt.wantStatus == http.StatusUnauthorized {
|
||||
assert.Contains(t, w.Body.String(), "API key")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
179
go-backend/internal/middlewares/auth.go
Normal file
179
go-backend/internal/middlewares/auth.go
Normal file
@@ -0,0 +1,179 @@
|
||||
package middlewares
|
||||
|
||||
import (
|
||||
"GoMembership/internal/config"
|
||||
"GoMembership/internal/models"
|
||||
"GoMembership/internal/utils"
|
||||
customerrors "GoMembership/pkg/errors"
|
||||
"GoMembership/pkg/logger"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type Session struct {
|
||||
UserID uint
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
var (
|
||||
sessionDuration = 5 * 24 * time.Hour
|
||||
jwtSigningMethod = jwt.SigningMethodHS256
|
||||
jwtParser = jwt.NewParser(jwt.WithValidMethods([]string{jwtSigningMethod.Alg()}))
|
||||
sessions = make(map[string]*Session)
|
||||
)
|
||||
|
||||
func verifyAndRenewToken(tokenString string) (string, uint, error) {
|
||||
if tokenString == "" {
|
||||
logger.Error.Printf("empty tokenstring")
|
||||
return "", 0, fmt.Errorf("Authorization token is required")
|
||||
}
|
||||
token, claims, err := ExtractContentFrom(tokenString)
|
||||
if err != nil {
|
||||
logger.Error.Printf("Couldn't parse JWT token String: %v", err)
|
||||
return "", 0, err
|
||||
}
|
||||
sessionID := (*claims)["session_id"].(string)
|
||||
userID := uint((*claims)["user_id"].(float64))
|
||||
roleID := int8((*claims)["role_id"].(float64))
|
||||
|
||||
session, ok := sessions[sessionID]
|
||||
if !ok {
|
||||
logger.Error.Printf("session not found")
|
||||
return "", 0, fmt.Errorf("session not found")
|
||||
}
|
||||
if userID != session.UserID {
|
||||
return "", 0, fmt.Errorf("Cookie has been altered, aborting..")
|
||||
}
|
||||
if token.Valid {
|
||||
// token is valid, so we can return the old tokenString
|
||||
return tokenString, session.UserID, customerrors.ErrValidToken
|
||||
}
|
||||
|
||||
if time.Now().After(sessions[sessionID].ExpiresAt) {
|
||||
delete(sessions, sessionID)
|
||||
logger.Error.Printf("session expired")
|
||||
return "", 0, fmt.Errorf("session expired")
|
||||
}
|
||||
session.ExpiresAt = time.Now().Add(sessionDuration)
|
||||
|
||||
logger.Error.Printf("Session still valid generating new token")
|
||||
// Session is still valid, generate a new token
|
||||
user := models.User{ID: userID, RoleID: roleID}
|
||||
newTokenString, err := GenerateToken(config.Auth.JWTSecret, &user, sessionID)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
|
||||
return newTokenString, session.UserID, nil
|
||||
}
|
||||
|
||||
func AuthMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
tokenString, err := c.Cookie("jwt")
|
||||
if err != nil {
|
||||
logger.Error.Printf("No Auth token: %v\n", err)
|
||||
c.JSON(http.StatusUnauthorized,
|
||||
gin.H{"errors": []gin.H{{
|
||||
"field": "general",
|
||||
"key": "server.error.no_auth_token",
|
||||
}}})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
newToken, userID, err := verifyAndRenewToken(tokenString)
|
||||
if err != nil {
|
||||
if err == customerrors.ErrValidToken {
|
||||
c.Set("user_id", uint(userID))
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
logger.Error.Printf("Token(%v) is invalid: %v\n", tokenString, err)
|
||||
c.JSON(http.StatusUnauthorized,
|
||||
gin.H{"errors": []gin.H{{
|
||||
"field": "general",
|
||||
"key": "server.error.no_auth_token",
|
||||
}}})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
utils.SetCookie(c, newToken)
|
||||
c.Set("user_id", uint(userID))
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func GenerateToken(jwtKey string, user *models.User, sessionID string) (string, error) {
|
||||
if sessionID == "" {
|
||||
sessionID = uuid.New().String()
|
||||
}
|
||||
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{
|
||||
"user_id": user.ID,
|
||||
"role_id": user.RoleID,
|
||||
"session_id": sessionID,
|
||||
"exp": time.Now().Add(time.Minute * 1).Unix(), // Token expires in 10 Minutes
|
||||
})
|
||||
UpdateSession(sessionID, user.ID)
|
||||
return token.SignedString([]byte(jwtKey))
|
||||
}
|
||||
|
||||
func ExtractContentFrom(tokenString string) (*jwt.Token, *jwt.MapClaims, error) {
|
||||
|
||||
token, err := jwtParser.Parse(tokenString, func(_ *jwt.Token) (interface{}, error) {
|
||||
return []byte(config.Auth.JWTSecret), nil
|
||||
})
|
||||
|
||||
if !errors.Is(err, jwt.ErrTokenExpired) && err != nil {
|
||||
logger.Error.Printf("Error during token(%v) parsing: %#v", tokenString, err)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Token is expired, check if session is still valid
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
logger.Error.Printf("Invalid Token Claims")
|
||||
return nil, nil, fmt.Errorf("invalid token claims")
|
||||
}
|
||||
|
||||
if !ok {
|
||||
logger.Error.Printf("invalid session_id in token")
|
||||
return nil, nil, fmt.Errorf("invalid session_id in token")
|
||||
}
|
||||
return token, &claims, nil
|
||||
}
|
||||
|
||||
func UpdateSession(sessionID string, userID uint) {
|
||||
sessions[sessionID] = &Session{
|
||||
UserID: userID,
|
||||
ExpiresAt: time.Now().Add(sessionDuration),
|
||||
}
|
||||
}
|
||||
|
||||
func InvalidateSession(token string) (bool, error) {
|
||||
claims := jwt.MapClaims{}
|
||||
_, err := jwt.ParseWithClaims(
|
||||
token,
|
||||
claims,
|
||||
func(token *jwt.Token) (interface{}, error) {
|
||||
return config.Auth.JWTSecret, nil
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("Couldn't get JWT claims: %#v", err)
|
||||
}
|
||||
|
||||
sessionID, ok := claims["session_id"].(string)
|
||||
if !ok {
|
||||
return false, fmt.Errorf("No SessionID found")
|
||||
}
|
||||
|
||||
delete(sessions, sessionID)
|
||||
return true, nil
|
||||
}
|
||||
191
go-backend/internal/middlewares/auth_test.go
Normal file
191
go-backend/internal/middlewares/auth_test.go
Normal file
@@ -0,0 +1,191 @@
|
||||
package middlewares
|
||||
|
||||
import (
|
||||
"GoMembership/internal/config"
|
||||
"GoMembership/internal/constants"
|
||||
"GoMembership/internal/models"
|
||||
"GoMembership/pkg/logger"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func setupTestEnvironment() {
|
||||
cwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to get current working directory: %v", err)
|
||||
}
|
||||
configFilePath := filepath.Join(cwd, "..", "..", "configs", "config.json")
|
||||
templateHTMLPath := filepath.Join(cwd, "..", "..", "templates", "html")
|
||||
templateMailPath := filepath.Join(cwd, "..", "..", "templates", "email")
|
||||
|
||||
if err := os.Setenv("TEMPLATE_MAIL_PATH", templateMailPath); err != nil {
|
||||
log.Fatalf("Error setting environment variable: %v", err)
|
||||
}
|
||||
if err := os.Setenv("TEMPLATE_HTML_PATH", templateHTMLPath); err != nil {
|
||||
log.Fatalf("Error setting environment variable: %v", err)
|
||||
}
|
||||
if err := os.Setenv("CONFIG_FILE_PATH", configFilePath); err != nil {
|
||||
log.Fatalf("Error setting environment variable: %v", err)
|
||||
}
|
||||
config.LoadConfig()
|
||||
logger.Info.Printf("Config: %#v", config.CFG)
|
||||
}
|
||||
|
||||
func TestAuthMiddleware(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
setupTestEnvironment()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupAuth func(r *http.Request)
|
||||
expectedStatus int
|
||||
expectNewCookie bool
|
||||
expectedUserID uint
|
||||
}{
|
||||
{
|
||||
name: "Valid Token",
|
||||
setupAuth: func(r *http.Request) {
|
||||
user := models.User{ID: 123, RoleID: constants.Roles.Member}
|
||||
token, _ := GenerateToken(config.Auth.JWTSecret, &user, "")
|
||||
r.AddCookie(&http.Cookie{Name: "jwt", Value: token})
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedUserID: 123,
|
||||
},
|
||||
{
|
||||
name: "Missing Cookie",
|
||||
setupAuth: func(r *http.Request) {},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedUserID: 0,
|
||||
},
|
||||
{
|
||||
name: "Invalid Token",
|
||||
setupAuth: func(r *http.Request) {
|
||||
r.AddCookie(&http.Cookie{Name: "jwt", Value: "InvalidToken"})
|
||||
},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedUserID: 0,
|
||||
},
|
||||
{
|
||||
name: "Expired Token with Valid Session",
|
||||
setupAuth: func(r *http.Request) {
|
||||
sessionID := "test-session"
|
||||
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{
|
||||
"user_id": 123,
|
||||
"role_id": constants.Roles.Member,
|
||||
"session_id": sessionID,
|
||||
"exp": time.Now().Add(-time.Hour).Unix(), // Expired 1 hour ago
|
||||
})
|
||||
tokenString, _ := token.SignedString([]byte(config.Auth.JWTSecret))
|
||||
r.AddCookie(&http.Cookie{Name: "jwt", Value: tokenString})
|
||||
UpdateSession(sessionID, 123) // Add a valid session
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectNewCookie: true,
|
||||
expectedUserID: 123,
|
||||
},
|
||||
{
|
||||
name: "Expired Token with Expired Session",
|
||||
setupAuth: func(r *http.Request) {
|
||||
sessionID := "expired-session"
|
||||
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{
|
||||
"user_id": 123,
|
||||
"role_id": constants.Roles.Member,
|
||||
"session_id": sessionID,
|
||||
"exp": time.Now().Add(-time.Hour).Unix(), // Expired 1 hour ago
|
||||
})
|
||||
tokenString, _ := token.SignedString([]byte(config.Auth.JWTSecret))
|
||||
r.AddCookie(&http.Cookie{Name: "jwt", Value: tokenString})
|
||||
// Don't add a session, simulating an expired session
|
||||
},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedUserID: 0,
|
||||
},
|
||||
{
|
||||
name: "Invalid Signature",
|
||||
setupAuth: func(r *http.Request) {
|
||||
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{
|
||||
"user_id": 123,
|
||||
"session_id": "some-session",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
})
|
||||
tokenString, _ := token.SignedString([]byte("wrong_secret"))
|
||||
r.AddCookie(&http.Cookie{Name: "jwt", Value: tokenString})
|
||||
},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedUserID: 0,
|
||||
},
|
||||
{
|
||||
name: "Invalid Signing Method",
|
||||
setupAuth: func(r *http.Request) {
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodES256, jwt.MapClaims{
|
||||
"user_id": 123,
|
||||
"session_id": "some-session",
|
||||
"role_id": constants.Roles.Member,
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
})
|
||||
tokenString, _ := token.SignedString([]byte(config.Auth.JWTSecret))
|
||||
r.AddCookie(&http.Cookie{Name: "jwt", Value: tokenString})
|
||||
},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedUserID: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
logger.Error.Print("==============================================================")
|
||||
logger.Error.Printf("Testing : %v", tt.name)
|
||||
logger.Error.Print("==============================================================")
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Setup
|
||||
r := gin.New()
|
||||
r.Use(AuthMiddleware())
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
userID, exists := c.Get("user_id")
|
||||
if exists {
|
||||
c.JSON(http.StatusOK, gin.H{"user_id": userID})
|
||||
} else {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"user_id": 0})
|
||||
}
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, "/test", nil)
|
||||
tt.setupAuth(req)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, tt.expectedStatus, w.Code)
|
||||
|
||||
if tt.expectedStatus == http.StatusOK {
|
||||
var response map[string]uint
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expectedUserID, response["user_id"])
|
||||
|
||||
// Check if a new cookie was set
|
||||
cookies := w.Result().Cookies()
|
||||
if tt.expectNewCookie {
|
||||
assert.GreaterOrEqual(t, len(cookies), 1)
|
||||
assert.Equal(t, "jwt", cookies[0].Name)
|
||||
assert.NotEmpty(t, cookies[0].Value)
|
||||
} else {
|
||||
assert.Equal(t, 0, len(cookies), "Unexpected cookie set")
|
||||
}
|
||||
} else {
|
||||
assert.Equal(t, 0, len(w.Result().Cookies()))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
22
go-backend/internal/middlewares/cors.go
Normal file
22
go-backend/internal/middlewares/cors.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package middlewares
|
||||
|
||||
import (
|
||||
"GoMembership/internal/config"
|
||||
"GoMembership/pkg/logger"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-contrib/cors"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func CORSMiddleware() gin.HandlerFunc {
|
||||
logger.Info.Print("Applying CORS")
|
||||
return cors.New(cors.Config{
|
||||
AllowOrigins: strings.Split(config.Site.AllowOrigins, ","),
|
||||
AllowMethods: []string{"GET", "POST", "PATCH", "PUT", "OPTIONS"},
|
||||
AllowHeaders: []string{"Origin", "Content-Type", "Accept", "Authorization", "X-Requested-With", "X-CSRF-Token"},
|
||||
ExposeHeaders: []string{"Content-Length"},
|
||||
AllowCredentials: true,
|
||||
MaxAge: 12 * 60 * 60, // 12 hours
|
||||
})
|
||||
}
|
||||
104
go-backend/internal/middlewares/cors_test.go
Normal file
104
go-backend/internal/middlewares/cors_test.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package middlewares
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"GoMembership/internal/config"
|
||||
"GoMembership/pkg/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
Host = "127.0.0.1"
|
||||
Port int = 2525
|
||||
)
|
||||
|
||||
func TestCORSMiddleware(t *testing.T) {
|
||||
cwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to get current working directory: %v", err)
|
||||
}
|
||||
|
||||
configFilePath := filepath.Join(cwd, "..", "..", "configs", "config.json")
|
||||
templateHTMLPath := filepath.Join(cwd, "..", "..", "templates", "html")
|
||||
templateMailPath := filepath.Join(cwd, "..", "..", "templates", "email")
|
||||
|
||||
if err := os.Setenv("TEMPLATE_MAIL_PATH", templateMailPath); err != nil {
|
||||
log.Fatalf("Error setting environment variable: %v", err)
|
||||
}
|
||||
if err := os.Setenv("TEMPLATE_HTML_PATH", templateHTMLPath); err != nil {
|
||||
log.Fatalf("Error setting environment variable: %v", err)
|
||||
}
|
||||
if err := os.Setenv("CONFIG_FILE_PATH", configFilePath); err != nil {
|
||||
log.Fatalf("Error setting environment variable: %v", err)
|
||||
}
|
||||
if err := os.Setenv("SMTP_HOST", Host); err != nil {
|
||||
log.Fatalf("Error setting environment variable: %v", err)
|
||||
}
|
||||
if err := os.Setenv("SMTP_PORT", strconv.Itoa(Port)); err != nil {
|
||||
log.Fatalf("Error setting environment variable: %v", err)
|
||||
}
|
||||
if err := os.Setenv("BASE_URL", "http://"+Host+":2525"); err != nil {
|
||||
log.Fatalf("Error setting environment variable: %v", err)
|
||||
}
|
||||
// Load your configuration
|
||||
config.LoadConfig()
|
||||
|
||||
// Create a gin router with the CORS middleware
|
||||
router := gin.New()
|
||||
router.Use(CORSMiddleware())
|
||||
|
||||
// Add a simple handler
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.String(200, "test")
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
origin string
|
||||
expectedStatus int
|
||||
expectedHeaders map[string]string
|
||||
}{
|
||||
{
|
||||
name: "Allowed origin",
|
||||
origin: config.Site.AllowOrigins,
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedHeaders: map[string]string{
|
||||
"Access-Control-Allow-Origin": config.Site.AllowOrigins,
|
||||
"Content-Type": "text/plain; charset=utf-8",
|
||||
"Access-Control-Allow-Credentials": "true",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Disallowed origin",
|
||||
origin: "http://evil.com",
|
||||
expectedStatus: http.StatusForbidden,
|
||||
expectedHeaders: map[string]string{
|
||||
"Access-Control-Allow-Origin": "",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("Origin", tt.origin)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, tt.expectedStatus, w.Code)
|
||||
logger.Info.Printf("Recieved Headers: %#v", w.Header())
|
||||
for key, value := range tt.expectedHeaders {
|
||||
assert.Equal(t, value, w.Header().Get(key))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
44
go-backend/internal/middlewares/csp.go
Normal file
44
go-backend/internal/middlewares/csp.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package middlewares
|
||||
|
||||
import (
|
||||
"GoMembership/internal/config"
|
||||
"GoMembership/pkg/logger"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func CSPMiddleware() gin.HandlerFunc {
|
||||
logger.Error.Printf("applying CSP")
|
||||
return func(c *gin.Context) {
|
||||
policy := "default-src 'self'; " +
|
||||
"script-src 'self' 'unsafe-inline'" +
|
||||
"style-src 'self' 'unsafe-inline'" +
|
||||
"img-src 'self'" +
|
||||
"font-src 'self'" +
|
||||
"connect-src 'self'; " +
|
||||
"frame-ancestors 'none'; " +
|
||||
"form-action 'self'; " +
|
||||
"base-uri 'self'; " +
|
||||
"upgrade-insecure-requests;"
|
||||
|
||||
if config.Env == "development" {
|
||||
policy += " report-uri /csp-report;"
|
||||
c.Header("Content-Security-Policy-Report-Only", policy)
|
||||
} else {
|
||||
c.Header("Content-Security-Policy", policy)
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func CSPReportHandling(c *gin.Context) {
|
||||
var report map[string]interface{}
|
||||
if err := c.BindJSON(&report); err != nil {
|
||||
|
||||
logger.Error.Printf("Couldn't Bind JSON: %#v", err)
|
||||
return
|
||||
}
|
||||
logger.Info.Printf("CSP Violation: %+v", report)
|
||||
c.Status(http.StatusNoContent)
|
||||
}
|
||||
81
go-backend/internal/middlewares/csp_test.go
Normal file
81
go-backend/internal/middlewares/csp_test.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package middlewares
|
||||
|
||||
import (
|
||||
"GoMembership/internal/config"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCSPMiddleware(t *testing.T) {
|
||||
// Save the current environment and restore it after the test
|
||||
originalEnv := config.Env
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
environment string
|
||||
expectedHeader string
|
||||
expectedPolicy string
|
||||
}{
|
||||
{
|
||||
name: "Development Environment",
|
||||
environment: "development",
|
||||
expectedHeader: "Content-Security-Policy-Report-Only",
|
||||
expectedPolicy: "default-src 'self'; " +
|
||||
"script-src 'self' 'unsafe-inline'" +
|
||||
"style-src 'self' 'unsafe-inline'" +
|
||||
"img-src 'self'" +
|
||||
"font-src 'self'" +
|
||||
"connect-src 'self'; " +
|
||||
"frame-ancestors 'none'; " +
|
||||
"form-action 'self'; " +
|
||||
"base-uri 'self'; " +
|
||||
"upgrade-insecure-requests; report-uri /csp-report;",
|
||||
},
|
||||
{
|
||||
name: "Production Environment",
|
||||
environment: "production",
|
||||
expectedHeader: "Content-Security-Policy",
|
||||
expectedPolicy: "default-src 'self'; " +
|
||||
"script-src 'self' 'unsafe-inline'" +
|
||||
"style-src 'self' 'unsafe-inline'" +
|
||||
"img-src 'self'" +
|
||||
"font-src 'self'" +
|
||||
"connect-src 'self'; " +
|
||||
"frame-ancestors 'none'; " +
|
||||
"form-action 'self'; " +
|
||||
"base-uri 'self'; " +
|
||||
"upgrade-insecure-requests;",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Set up the test environment
|
||||
config.Env = tt.environment
|
||||
|
||||
// Create a new Gin router with the middleware
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(CSPMiddleware())
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "test")
|
||||
})
|
||||
|
||||
// Create a test request
|
||||
req, _ := http.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Serve the request
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// Check the response
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, tt.expectedPolicy, w.Header().Get(tt.expectedHeader))
|
||||
})
|
||||
}
|
||||
config.Env = originalEnv
|
||||
}
|
||||
21
go-backend/internal/middlewares/headers.go
Normal file
21
go-backend/internal/middlewares/headers.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package middlewares
|
||||
|
||||
import (
|
||||
"GoMembership/pkg/logger"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func SecurityHeadersMiddleware() gin.HandlerFunc {
|
||||
logger.Error.Printf("applying headers")
|
||||
return func(c *gin.Context) {
|
||||
c.Header("X-Frame-Options", "DENY")
|
||||
c.Header("X-Content-Type-Options", "nosniff")
|
||||
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
|
||||
c.Header("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
|
||||
c.Header("X-XSS-Protection", "1; mode=block")
|
||||
c.Header("Feature-Policy", "geolocation 'none'; midi 'none'; sync-xhr 'none'; microphone 'none'; camera 'none'; magnetometer 'none'; gyroscope 'none'; speaker 'none'; fullscreen 'self'; payment 'none'")
|
||||
c.Header("Permissions-Policy", "geolocation=(), midi=(), sync-xhr=(), microphone=(), camera=(), magnetometer=(), gyroscope=(), fullscreen=(self), payment=()")
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
83
go-backend/internal/middlewares/rate_limit.go
Normal file
83
go-backend/internal/middlewares/rate_limit.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package middlewares
|
||||
|
||||
import (
|
||||
"GoMembership/pkg/logger"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
type IPRateLimiter struct {
|
||||
ips map[string]*rate.Limiter
|
||||
mu *sync.RWMutex
|
||||
r rate.Limit
|
||||
b int
|
||||
}
|
||||
|
||||
func NewIPRateLimiter(r int, b int) *IPRateLimiter {
|
||||
return &IPRateLimiter{
|
||||
ips: make(map[string]*rate.Limiter),
|
||||
mu: &sync.RWMutex{},
|
||||
r: rate.Limit(r),
|
||||
b: b,
|
||||
}
|
||||
}
|
||||
|
||||
func (i *IPRateLimiter) GetLimiter(ip string) *rate.Limiter {
|
||||
i.mu.Lock()
|
||||
defer i.mu.Unlock()
|
||||
|
||||
limiter, exists := i.ips[ip]
|
||||
if !exists {
|
||||
limiter = rate.NewLimiter(i.r, i.b)
|
||||
i.ips[ip] = limiter
|
||||
}
|
||||
|
||||
return limiter
|
||||
}
|
||||
|
||||
// func RateLimitMiddleware() gin.HandlerFunc {
|
||||
// if iPLimiter == nil {
|
||||
|
||||
// iPLimiter := NewIPRateLimiter(
|
||||
// rate.Limit(config.Security.Ratelimits.Limit),
|
||||
// config.Security.Ratelimits.Burst)
|
||||
// }
|
||||
|
||||
// return func(c *gin.Context) {
|
||||
// ip := c.ClientIP()
|
||||
// l := iPLimiter.GetLimiter(ip)
|
||||
// if !l.Allow() {
|
||||
// c.JSON(http.StatusTooManyRequests, gin.H{
|
||||
// "error": "Too many requests",
|
||||
// })
|
||||
// c.Abort()
|
||||
// return
|
||||
// }
|
||||
// c.Next()
|
||||
// }
|
||||
// }
|
||||
|
||||
func RateLimitMiddleware(limiter *IPRateLimiter) gin.HandlerFunc {
|
||||
logger.Info.Printf("Limiter with Limit: %v, Burst: %v", limiter.r, limiter.b)
|
||||
return func(c *gin.Context) {
|
||||
if limiter == nil {
|
||||
logger.Error.Println("Limiter missing")
|
||||
c.AbortWithStatus(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
ip := c.ClientIP()
|
||||
l := limiter.GetLimiter(ip)
|
||||
if !l.Allow() {
|
||||
c.JSON(http.StatusTooManyRequests, gin.H{
|
||||
"error": "Too many requests",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
143
go-backend/internal/middlewares/rate_limit_test.go
Normal file
143
go-backend/internal/middlewares/rate_limit_test.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package middlewares
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
var (
|
||||
limit = 1
|
||||
burst = 60
|
||||
)
|
||||
|
||||
func TestRateLimitMiddleware(t *testing.T) {
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// Create a new rate limiter that allows 2 requests per second with a burst of 4
|
||||
|
||||
// Create a new Gin router with the rate limit middleware
|
||||
router := setupRouter()
|
||||
|
||||
// Test cases
|
||||
tests := []struct {
|
||||
name string
|
||||
requests int
|
||||
expectedStatus int
|
||||
sleep time.Duration
|
||||
}{
|
||||
{"Allow first request", 1, http.StatusOK, 0},
|
||||
{"Allow up to burst limit", burst - 1, http.StatusOK, 0},
|
||||
{"Block after burst limit", burst + 20, http.StatusTooManyRequests, 0},
|
||||
{"Allow after rate limit replenishes", 1, http.StatusOK, time.Second},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
time.Sleep(tt.sleep)
|
||||
var status int
|
||||
for i := 0; i < tt.requests; i++ {
|
||||
status = makeRequest(router, "192.168.0.2")
|
||||
|
||||
}
|
||||
if status != tt.expectedStatus {
|
||||
t.Errorf("Expected status %d, got %d", tt.expectedStatus, status)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPRateLimiter(t *testing.T) {
|
||||
limiter := NewIPRateLimiter(limit, burst)
|
||||
|
||||
limiter1 := limiter.GetLimiter("127.0.0.1")
|
||||
limiter2 := limiter.GetLimiter("192.168.0.1")
|
||||
|
||||
if limiter1 == limiter2 {
|
||||
t.Error("Expected different limiters for different IPs")
|
||||
}
|
||||
|
||||
limiter3 := limiter.GetLimiter("127.0.0.1")
|
||||
if limiter1 != limiter3 {
|
||||
t.Error("Expected the same limiter for the same IP")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDifferentRateLimits(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
duration time.Duration
|
||||
requests int
|
||||
expectedRequests int
|
||||
ip string
|
||||
}{
|
||||
{"Low rate", 5 * time.Second, burst + 5*limit, burst + 5*limit, "192.168.23.3"},
|
||||
{"Low rate with limiting", 4 * time.Second, burst + 4*limit + 4, burst + 4*limit, "192.168.23.4"},
|
||||
{"High rate", 1 * time.Second, burst + 5*limit, burst + limit, "192.168.23.5"},
|
||||
{"Fractional rate", 10 * time.Second, (burst + limit) / 2, (burst + limit) / 2, "192.168.23.6"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
router := setupRouter()
|
||||
testRateLimit(t, router, tc.duration, tc.requests, tc.expectedRequests, tc.ip)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testRateLimit(t *testing.T, router *gin.Engine, duration time.Duration, requests int, expectedRequests int, ip string) {
|
||||
start := time.Now()
|
||||
successCount := 0
|
||||
totalRequests := 0
|
||||
|
||||
t.Logf("Sleeping for: %v", time.Duration(duration.Nanoseconds()/int64(requests)))
|
||||
for time.Since(start) < duration {
|
||||
status := makeRequest(router, ip)
|
||||
totalRequests++
|
||||
|
||||
if status == http.StatusOK {
|
||||
successCount++
|
||||
}
|
||||
|
||||
time.Sleep(time.Duration(duration.Nanoseconds() / int64(requests)))
|
||||
}
|
||||
|
||||
actualDuration := time.Since(start)
|
||||
|
||||
t.Logf("limit: %v, burst: %v", limit, burst)
|
||||
t.Logf("Test duration: %v", actualDuration)
|
||||
t.Logf("Successful requests: %d", successCount)
|
||||
t.Logf("Expected successful requests: %d", expectedRequests)
|
||||
t.Logf("Total requests: %d", totalRequests)
|
||||
|
||||
if successCount < int(expectedRequests)-4 || successCount > int(expectedRequests)+4 {
|
||||
t.Errorf("Expected around %d successful requests, got %d", expectedRequests, successCount)
|
||||
}
|
||||
|
||||
if requests-expectedRequests != 0 && totalRequests <= successCount {
|
||||
t.Errorf("Expected some requests to be rate limited")
|
||||
}
|
||||
}
|
||||
|
||||
func setupRouter() *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
limiter := NewIPRateLimiter(limit, burst)
|
||||
router.Use(RateLimitMiddleware(limiter))
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "success")
|
||||
})
|
||||
return router
|
||||
}
|
||||
|
||||
func makeRequest(router *gin.Engine, ip string) int {
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("X-Forwarded-For", ip) // Set a consistent IP
|
||||
router.ServeHTTP(w, req)
|
||||
return w.Code
|
||||
}
|
||||
Reference in New Issue
Block a user