From f648b53fe16f700482a0c1614a043643f262ea94 Mon Sep 17 00:00:00 2001 From: "$(pass /github/name)" <$(pass /github/email)> Date: Mon, 2 Sep 2024 22:20:58 +0200 Subject: [PATCH] added jwt auth --- internal/middlewares/auth.go | 81 +++++++++++++++ internal/middlewares/auth_middleware.go | 17 ---- internal/middlewares/auth_test.go | 110 ++++++++++++++++++++ internal/middlewares/csrf_middleware.go | 116 ---------------------- internal/middlewares/logger_middleware.go | 33 ------ 5 files changed, 191 insertions(+), 166 deletions(-) create mode 100644 internal/middlewares/auth.go delete mode 100644 internal/middlewares/auth_middleware.go create mode 100644 internal/middlewares/auth_test.go delete mode 100644 internal/middlewares/csrf_middleware.go delete mode 100644 internal/middlewares/logger_middleware.go diff --git a/internal/middlewares/auth.go b/internal/middlewares/auth.go new file mode 100644 index 0000000..30e6233 --- /dev/null +++ b/internal/middlewares/auth.go @@ -0,0 +1,81 @@ +package middlewares + +import ( + "GoMembership/internal/config" + "GoMembership/pkg/logger" + "net/http" + "strings" + "time" + + "github.com/gin-gonic/gin" + "github.com/golang-jwt/jwt/v5" +) + +var ( + jwtKey = []byte(config.Auth.JWTSecret) + jwtSigningMethod = jwt.SigningMethodHS256 + jwtParser = jwt.NewParser(jwt.WithValidMethods([]string{jwtSigningMethod.Alg()})) +) + +func GenerateToken(userID string) (string, error) { + token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{ + "user_id": userID, + "exp": time.Now().Add(time.Minute * 15).Unix(), // Token expires in 15 Minutes + }) + + return token.SignedString(jwtKey) +} + +func verifyToken(tokenString string) (*jwt.Token, error) { + token, err := jwtParser.Parse(tokenString, func(_ *jwt.Token) (interface{}, error) { + return jwtKey, nil + }) + + if err != nil { + return nil, err + } + + return token, nil +} + +func AuthMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + authHeader := c.GetHeader("Authorization") + if authHeader == "" { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Authorization header is required"}) + c.Abort() + return + } + + bearerToken := strings.Split(authHeader, " ") + if len(bearerToken) != 2 { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token format"}) + c.Abort() + return + } + + tokenString := bearerToken[1] + + token, err := verifyToken(tokenString) + if err != nil { + if err == jwt.ErrTokenSignatureInvalid { + logger.Error.Printf("JWT NULL ATTACK: %#v", err) + c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token signing method"}) + } else { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"}) + } + c.Abort() + return + } + + if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid { + userID := claims["user_id"].(string) + c.Set("user_id", userID) + c.Next() + } else { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token claims"}) + c.Abort() + return + } + } +} diff --git a/internal/middlewares/auth_middleware.go b/internal/middlewares/auth_middleware.go deleted file mode 100644 index b272758..0000000 --- a/internal/middlewares/auth_middleware.go +++ /dev/null @@ -1,17 +0,0 @@ -package middlewares - -import ( - "net/http" -) - -func AuthMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - token := r.Header.Get("Authorization") - if token != "your-secret-token" { - http.Error(w, "Forbidden", http.StatusForbidden) - return - } - next.ServeHTTP(w, r) - }) -} diff --git a/internal/middlewares/auth_test.go b/internal/middlewares/auth_test.go new file mode 100644 index 0000000..6d5cb73 --- /dev/null +++ b/internal/middlewares/auth_test.go @@ -0,0 +1,110 @@ +package middlewares + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" +) + +func TestAuthMiddleware(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + setupAuth func(r *http.Request) + expectedStatus int + expectedUserID string + }{ + { + name: "Valid Token", + setupAuth: func(r *http.Request) { + token, _ := GenerateToken("user123") + r.Header.Set("Authorization", "Bearer "+token) + }, + expectedStatus: http.StatusOK, + expectedUserID: "user123", + }, + { + name: "Missing Auth Header", + setupAuth: func(r *http.Request) {}, + expectedStatus: http.StatusUnauthorized, + expectedUserID: "", + }, + { + name: "Invalid Token Format", + setupAuth: func(r *http.Request) { + r.Header.Set("Authorization", "InvalidFormat") + }, + expectedStatus: http.StatusUnauthorized, + expectedUserID: "", + }, + { + name: "Expired Token", + setupAuth: func(r *http.Request) { + token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{ + "user_id": "user123", + "exp": time.Now().Add(-time.Hour).Unix(), // Expired 1 hour ago + }) + tokenString, _ := token.SignedString(jwtKey) + r.Header.Set("Authorization", "Bearer "+tokenString) + }, + expectedStatus: http.StatusUnauthorized, + expectedUserID: "", + }, + { + name: "Invalid Signature", + setupAuth: func(r *http.Request) { + token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{ + "user_id": "user123", + "exp": time.Now().Add(time.Hour).Unix(), + }) + tokenString, _ := token.SignedString([]byte("wrong_secret")) + r.Header.Set("Authorization", "Bearer "+tokenString) + }, + expectedStatus: http.StatusUnauthorized, + expectedUserID: "", + }, + { + name: "Missing Auth Header", + setupAuth: func(r *http.Request) {}, + expectedStatus: http.StatusUnauthorized, + expectedUserID: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Setup + r := gin.New() + r.Use(AuthMiddleware()) + r.GET("/test", func(c *gin.Context) { + userID, exists := c.Get("user_id") + if exists { + c.JSON(http.StatusOK, gin.H{"user_id": userID}) + } else { + c.JSON(http.StatusUnauthorized, gin.H{"user_id": ""}) + } + }) + + req, _ := http.NewRequest(http.MethodGet, "/test", nil) + tt.setupAuth(req) + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + assert.Equal(t, tt.expectedStatus, w.Code) + + var response map[string]string + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + + assert.Equal(t, tt.expectedUserID, response["user_id"]) + }) + } +} diff --git a/internal/middlewares/csrf_middleware.go b/internal/middlewares/csrf_middleware.go deleted file mode 100644 index 5637aeb..0000000 --- a/internal/middlewares/csrf_middleware.go +++ /dev/null @@ -1,116 +0,0 @@ -package middlewares - -import ( - "crypto/hmac" - "crypto/sha256" - "encoding/base64" - "net/http" - "strings" - - "GoMembership/internal/config" - // "GoMembership/internal/server" - "GoMembership/internal/utils" - "GoMembership/pkg/logger" -) - -// GenerateCSRFToken generates HMAC-signed CSRF token -func GenerateCSRFToken(sessionID string, secretKey string) string { - // Create message to be signed (e.g., combining sessionID with some random value) - randomString, err := utils.GenerateRandomString(8) - if err != nil { - logger.Error.Fatalf("Could not create random string: %v", err) - return "" - } - - message := sessionID + "!" + randomString - - // Create HMAC hash using SHA-256 - h := hmac.New(sha256.New, []byte(secretKey)) - h.Write([]byte(message)) - signature := h.Sum(nil) - - // Encode signature and message into a CSRF token - csrfToken := base64.StdEncoding.EncodeToString(signature) + "." + message - return csrfToken -} - -func ComputeHMAC(message string, secretKey string) []byte { - h := hmac.New(sha256.New, []byte(secretKey)) - h.Write([]byte(message)) - return h.Sum(nil) -} - -// CSRFMiddleware verifies HMAC-signed CSRF token -func CSRFMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - if r.Method == http.MethodGet || r.Method == http.MethodHead || r.Method == http.MethodOptions { - next.ServeHTTP(w, r) - return - } - csrfSecret := config.Auth.CSRFSecret - // Retrieve CSRF token from request (e.g., from cookie, header, or form data) - csrfToken := r.Header.Get("X-CSRF-Token") - - // Extract signature and message from CSRF token - parts := strings.SplitN(csrfToken, ".", 2) - if len(parts) != 2 { - http.Error(w, "Invalid CSRF token", http.StatusForbidden) - return - } - receivedSignature := parts[0] - receivedMessage := parts[1] - - // Compute HMAC using the received message and the CSRF secret key - computedSignature := ComputeHMAC(receivedMessage, csrfSecret) - - // Compare computed HMAC with received signature - if !hmac.Equal([]byte(receivedSignature), computedSignature) { - http.Error(w, "CSRF Token validation failed", http.StatusForbidden) - return - } - - // CSRF token is valid, proceed to the next handler - next.ServeHTTP(w, r) - }) -} - -func GenerateCSRFTokenHandler(w http.ResponseWriter, r *http.Request) { - // Simulate getting session ID from authenticated session - sessionID := "exampleSessionID123" - - // Generate HMAC-signed CSRF token - csrfToken := GenerateCSRFToken(sessionID, config.Auth.CSRFSecret) - - // Set CSRF token in a cookie (example) - http.SetCookie(w, &http.Cookie{ - Name: "csrf_token", - Value: csrfToken, - Path: "/", - HttpOnly: true, - Secure: true, - }) -} - -/* func GenerateCSRFTokenHandler() http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - token, err := GenerateCSRFToken() - if err != nil { - http.Error(w, "Could not generate CSRF token", http.StatusInternalServerError) - return - } - - // Set CSRF token in cookie - http.SetCookie(w, &http.Cookie{ - Name: "csrf_token", - Value: token, - Path: "/", - }) - - logger.Info.Printf("generated token: %v", token) - // Return CSRF token in response - w.Header().Set("X-CSRF-Token", token) - w.WriteHeader(http.StatusOK) - }) -} */ diff --git a/internal/middlewares/logger_middleware.go b/internal/middlewares/logger_middleware.go deleted file mode 100644 index 294377b..0000000 --- a/internal/middlewares/logger_middleware.go +++ /dev/null @@ -1,33 +0,0 @@ -package middlewares - -import ( - "time" - - "GoMembership/pkg/logger" - "github.com/gin-gonic/gin" -) - -// LoggerMiddleware logs the incoming requests. -func LoggerMiddleware() gin.HandlerFunc { - return func(c *gin.Context) { - startTime := time.Now() - - // Process the request - c.Next() - - // Calculate the latency - latency := time.Since(startTime) - - // Get the status code - statusCode := c.Writer.Status() - - // Log the details - logger.Info.Printf("| %3d | %13v | %15s | %-7s %#v\n", - statusCode, - latency, - c.ClientIP(), - c.Request.Method, - c.Request.URL.Path, - ) - } -}