diff --git a/internal/middlewares/auth.go b/internal/middlewares/auth.go index 57a6311..522d38f 100644 --- a/internal/middlewares/auth.go +++ b/internal/middlewares/auth.go @@ -3,8 +3,8 @@ package middlewares import ( "GoMembership/internal/config" "GoMembership/pkg/logger" + "fmt" "net/http" - "strings" "time" "github.com/gin-gonic/gin" @@ -20,9 +20,10 @@ var ( func GenerateToken(userID int64) (string, error) { token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{ "user_id": userID, - "exp": time.Now().Add(time.Minute * 15).Unix(), // Token expires in 15 Minutes + "exp": time.Now().Add(time.Minute * 10).Unix(), // Token expires in 10 Minutes }) + logger.Error.Printf("token generated: %#v", token) return token.SignedString(jwtKey) } @@ -37,45 +38,71 @@ func verifyToken(tokenString string) (*jwt.Token, error) { return token, nil } +func GetUserIDFromContext(c *gin.Context) (int64, error) { + + tokenString, err := c.Cookie("jwt") + if err != nil { + + logger.Error.Printf("Error getting cookie: %v\n", err) + return 0, err + } + if tokenString == "" { + logger.Error.Printf("Token is empty: %v\n", err) + return 0, fmt.Errorf("Authorization token is required") + } + + token, err := verifyToken(tokenString) + if err != nil || !token.Valid { + + logger.Error.Printf("Token is invalid: %v\n", err) + return 0, fmt.Errorf("Token not valid!") + } + + claims, ok := token.Claims.(jwt.MapClaims) + + logger.Error.Printf("claims userid: %v", claims["user_id"].(float64)) + if !ok { + logger.Error.Printf("Invalid Token claims") + return 0, fmt.Errorf("Invalid token claims") + } + userID, ok := claims["user_id"].(float64) + + if !ok { + logger.Error.Printf("Invalid user ID: %v", userID) + return 0, fmt.Errorf("Invalid user ID") + } + + return int64(userID), nil +} func AuthMiddleware() gin.HandlerFunc { return func(c *gin.Context) { - authHeader := c.GetHeader("Authorization") - if authHeader == "" { - c.JSON(http.StatusUnauthorized, gin.H{"error": "Authorization header is required"}) - c.Abort() - return - } - bearerToken := strings.Split(authHeader, " ") - if len(bearerToken) != 2 { - c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token format"}) - c.Abort() - return - } - - tokenString := bearerToken[1] - - token, err := verifyToken(tokenString) + userID, err := GetUserIDFromContext(c) if err != nil { - if err == jwt.ErrTokenSignatureInvalid { - logger.Error.Printf("JWT NULL ATTACK: %#v", err) - c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token signing method"}) - } else { - c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"}) - } + c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()}) c.Abort() return } - if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid { - userID := claims["user_id"].(string) - c.Set("user_id", userID) - c.Next() - } else { - c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token claims"}) + // Generate a new token + newToken, err := GenerateToken(int64(userID)) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to refresh token"}) c.Abort() return } + + c.SetCookie( + "jwt", + newToken, + 10*60, // 10 minutes + "/", + "", + true, + true, + ) + c.Set("user_id", userID) + c.Next() } } diff --git a/internal/middlewares/auth_test.go b/internal/middlewares/auth_test.go index 531354c..7f503f1 100644 --- a/internal/middlewares/auth_test.go +++ b/internal/middlewares/auth_test.go @@ -24,11 +24,11 @@ func TestAuthMiddleware(t *testing.T) { { name: "Valid Token", setupAuth: func(r *http.Request) { - token, _ := GenerateToken("user123") - r.Header.Set("Authorization", "Bearer "+token) + token, _ := GenerateToken(123) + r.AddCookie(&http.Cookie{Name: "jwt", Value: token}) }, expectedStatus: http.StatusOK, - expectedUserID: 12, + expectedUserID: 123, }, { name: "Missing Auth Header", @@ -52,7 +52,7 @@ func TestAuthMiddleware(t *testing.T) { "exp": time.Now().Add(-time.Hour).Unix(), // Expired 1 hour ago }) tokenString, _ := token.SignedString(jwtKey) - r.Header.Set("Authorization", "Bearer "+tokenString) + r.AddCookie(&http.Cookie{Name: "jwt", Value: tokenString}) }, expectedStatus: http.StatusUnauthorized, expectedUserID: 0, @@ -65,17 +65,11 @@ func TestAuthMiddleware(t *testing.T) { "exp": time.Now().Add(time.Hour).Unix(), }) tokenString, _ := token.SignedString([]byte("wrong_secret")) - r.Header.Set("Authorization", "Bearer "+tokenString) + r.AddCookie(&http.Cookie{Name: "jwt", Value: tokenString}) }, expectedStatus: http.StatusUnauthorized, expectedUserID: 0, }, - { - name: "Missing Auth Header", - setupAuth: func(r *http.Request) {}, - expectedStatus: http.StatusUnauthorized, - expectedUserID: 0, - }, } for _, tt := range tests { @@ -100,11 +94,20 @@ func TestAuthMiddleware(t *testing.T) { assert.Equal(t, tt.expectedStatus, w.Code) - var response map[string]string - err := json.Unmarshal(w.Body.Bytes(), &response) - assert.NoError(t, err) + if tt.expectedStatus == http.StatusOK { + var response map[string]int64 + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + assert.Equal(t, tt.expectedUserID, response["user_id"]) - assert.Equal(t, tt.expectedUserID, response["user_id"]) + // Check if a new cookie was set + cookies := w.Result().Cookies() + assert.GreaterOrEqual(t, len(cookies), 1) + assert.Equal(t, "jwt", cookies[0].Name) + assert.NotEmpty(t, cookies[0].Value) + } else { + assert.Equal(t, 0, len(w.Result().Cookies())) + } }) } }