refactored auth.go & tests

This commit is contained in:
Alex
2025-03-11 20:52:11 +01:00
parent ca99e28433
commit e60aaa1d69
2 changed files with 72 additions and 44 deletions

View File

@@ -2,9 +2,7 @@ package middlewares
import ( import (
"GoMembership/internal/config" "GoMembership/internal/config"
"GoMembership/internal/models"
"GoMembership/internal/utils" "GoMembership/internal/utils"
customerrors "GoMembership/pkg/errors"
"GoMembership/pkg/logger" "GoMembership/pkg/logger"
"errors" "errors"
"fmt" "fmt"
@@ -34,26 +32,43 @@ func verifyAndRenewToken(tokenString string) (string, uint, error) {
return "", 0, fmt.Errorf("Authorization token is required") return "", 0, fmt.Errorf("Authorization token is required")
} }
token, claims, err := ExtractContentFrom(tokenString) token, claims, err := ExtractContentFrom(tokenString)
if err != nil {
if err != nil && !errors.Is(err, jwt.ErrTokenExpired) {
logger.Error.Printf("Couldn't parse JWT token String: %v", err) logger.Error.Printf("Couldn't parse JWT token String: %v", err)
return "", 0, err return "", 0, err
} }
sessionID := (*claims)["session_id"].(string)
userID := uint((*claims)["user_id"].(float64)) if token.Valid {
roleID := int8((*claims)["role_id"].(float64)) // token is valid, so we can return the old tokenString
return tokenString, uint((*claims)["user_id"].(float64)), nil
}
// Token is expired but valid
sessionID, ok := (*claims)["session_id"].(string)
if !ok || sessionID == "" {
return "", 0, fmt.Errorf("invalid session ID")
}
id, ok := (*claims)["user_id"]
if !ok {
return "", 0, fmt.Errorf("missing user_id claim")
}
userID := uint(id.(float64))
id, ok = (*claims)["role_id"]
if !ok {
return "", 0, fmt.Errorf("missing role_id claim")
}
roleID := int8(id.(float64))
session, ok := sessions[sessionID] session, ok := sessions[sessionID]
if !ok { if !ok {
logger.Error.Printf("session not found") logger.Error.Printf("session not found")
return "", 0, fmt.Errorf("session not found") return "", 0, fmt.Errorf("session not found")
} }
if userID != session.UserID { if userID != session.UserID {
return "", 0, fmt.Errorf("Cookie has been altered, aborting..") return "", 0, fmt.Errorf("Cookie has been altered, aborting..")
} }
if token.Valid {
// token is valid, so we can return the old tokenString
return tokenString, session.UserID, customerrors.ErrValidToken
}
if time.Now().After(sessions[sessionID].ExpiresAt) { if time.Now().After(sessions[sessionID].ExpiresAt) {
delete(sessions, sessionID) delete(sessions, sessionID)
@@ -64,8 +79,8 @@ func verifyAndRenewToken(tokenString string) (string, uint, error) {
logger.Error.Printf("Session still valid generating new token") logger.Error.Printf("Session still valid generating new token")
// Session is still valid, generate a new token // Session is still valid, generate a new token
user := models.User{ID: userID, RoleID: roleID} user := map[string]interface{}{"user_id": userID, "role_id": roleID}
newTokenString, err := GenerateToken(config.Auth.JWTSecret, &user, sessionID) newTokenString, err := GenerateToken(&config.Auth.JWTSecret, user, sessionID)
if err != nil { if err != nil {
return "", 0, err return "", 0, err
} }
@@ -89,11 +104,6 @@ func AuthMiddleware() gin.HandlerFunc {
newToken, userID, err := verifyAndRenewToken(tokenString) newToken, userID, err := verifyAndRenewToken(tokenString)
if err != nil { if err != nil {
if err == customerrors.ErrValidToken {
c.Set("user_id", uint(userID))
c.Next()
return
}
logger.Error.Printf("Token(%v) is invalid: %v\n", tokenString, err) logger.Error.Printf("Token(%v) is invalid: %v\n", tokenString, err)
c.JSON(http.StatusUnauthorized, c.JSON(http.StatusUnauthorized,
gin.H{"errors": []gin.H{{ gin.H{"errors": []gin.H{{
@@ -104,24 +114,30 @@ func AuthMiddleware() gin.HandlerFunc {
return return
} }
utils.SetCookie(c, newToken) if newToken != tokenString {
utils.SetCookie(c, newToken)
}
c.Set("user_id", uint(userID)) c.Set("user_id", uint(userID))
c.Next() c.Next()
} }
} }
func GenerateToken(jwtKey string, user *models.User, sessionID string) (string, error) { // GenerateToken generates a new JWT token with the given claims and session ID.
// "user_id": user.ID, "role_id": user.RoleID
func GenerateToken(jwtKey *string, claims map[string]interface{}, sessionID string) (string, error) {
if sessionID == "" { if sessionID == "" {
sessionID = uuid.New().String() sessionID = uuid.New().String()
} }
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{ claims["session_id"] = sessionID
"user_id": user.ID, claims["exp"] = time.Now().Add(time.Minute * 1).Unix() // Token expires in 10 Minutes
"role_id": user.RoleID, token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims(claims))
"session_id": sessionID,
"exp": time.Now().Add(time.Minute * 1).Unix(), // Token expires in 10 Minutes userID, ok := claims["user_id"].(uint)
}) if !ok {
UpdateSession(sessionID, user.ID) return "", fmt.Errorf("invalid user_id in claims")
return token.SignedString([]byte(jwtKey)) }
UpdateSession(sessionID, userID)
return token.SignedString([]byte(*jwtKey))
} }
func ExtractContentFrom(tokenString string) (*jwt.Token, *jwt.MapClaims, error) { func ExtractContentFrom(tokenString string) (*jwt.Token, *jwt.MapClaims, error) {
@@ -130,23 +146,33 @@ func ExtractContentFrom(tokenString string) (*jwt.Token, *jwt.MapClaims, error)
return []byte(config.Auth.JWTSecret), nil return []byte(config.Auth.JWTSecret), nil
}) })
if !errors.Is(err, jwt.ErrTokenExpired) && err != nil { // Handle parsing errors (excluding expiration error)
logger.Error.Printf("Error during token(%v) parsing: %#v", tokenString, err) if err != nil && !errors.Is(err, jwt.ErrTokenExpired) {
logger.Error.Printf("Error parsing token: %v", err)
return nil, nil, err return nil, nil, err
} }
// Token is expired, check if session is still valid // Ensure token is not nil (e.g., malformed tokens)
claims, ok := token.Claims.(jwt.MapClaims) if token == nil {
if !ok { logger.Error.Print("Token is nil after parsing")
logger.Error.Printf("Invalid Token Claims") return nil, nil, fmt.Errorf("invalid token")
return nil, nil, fmt.Errorf("invalid token claims")
} }
// Extract and validate claims
claims, ok := token.Claims.(jwt.MapClaims)
if !ok { if !ok {
logger.Error.Printf("invalid session_id in token") logger.Error.Print("Invalid token claims structure")
return nil, nil, fmt.Errorf("invalid session_id in token") return nil, nil, fmt.Errorf("invalid token claims format")
} }
return token, &claims, nil
// Validate required session_id claim
if _, exists := claims["session_id"]; !exists {
logger.Error.Print("Missing session_id in token claims")
return nil, nil, fmt.Errorf("missing session_id claim")
}
// Return token, claims, and original error (might be expiration)
return token, &claims, err
} }
func UpdateSession(sessionID string, userID uint) { func UpdateSession(sessionID string, userID uint) {

View File

@@ -3,7 +3,6 @@ package middlewares
import ( import (
"GoMembership/internal/config" "GoMembership/internal/config"
"GoMembership/internal/constants" "GoMembership/internal/constants"
"GoMembership/internal/models"
"GoMembership/pkg/logger" "GoMembership/pkg/logger"
"encoding/json" "encoding/json"
"log" "log"
@@ -56,8 +55,11 @@ func TestAuthMiddleware(t *testing.T) {
{ {
name: "Valid Token", name: "Valid Token",
setupAuth: func(r *http.Request) { setupAuth: func(r *http.Request) {
user := models.User{ID: 123, RoleID: constants.Roles.Member} claims := map[string]interface{}{"user_id": uint(123), "role_id": constants.Roles.Member}
token, _ := GenerateToken(config.Auth.JWTSecret, &user, "") token, err := GenerateToken(&config.Auth.JWTSecret, claims, "")
if err != nil {
t.Fatal(err)
}
r.AddCookie(&http.Cookie{Name: "jwt", Value: token}) r.AddCookie(&http.Cookie{Name: "jwt", Value: token})
}, },
expectedStatus: http.StatusOK, expectedStatus: http.StatusOK,
@@ -82,7 +84,7 @@ func TestAuthMiddleware(t *testing.T) {
setupAuth: func(r *http.Request) { setupAuth: func(r *http.Request) {
sessionID := "test-session" sessionID := "test-session"
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{ token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{
"user_id": 123, "user_id": uint(123),
"role_id": constants.Roles.Member, "role_id": constants.Roles.Member,
"session_id": sessionID, "session_id": sessionID,
"exp": time.Now().Add(-time.Hour).Unix(), // Expired 1 hour ago "exp": time.Now().Add(-time.Hour).Unix(), // Expired 1 hour ago
@@ -100,7 +102,7 @@ func TestAuthMiddleware(t *testing.T) {
setupAuth: func(r *http.Request) { setupAuth: func(r *http.Request) {
sessionID := "expired-session" sessionID := "expired-session"
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{ token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{
"user_id": 123, "user_id": uint(123),
"role_id": constants.Roles.Member, "role_id": constants.Roles.Member,
"session_id": sessionID, "session_id": sessionID,
"exp": time.Now().Add(-time.Hour).Unix(), // Expired 1 hour ago "exp": time.Now().Add(-time.Hour).Unix(), // Expired 1 hour ago
@@ -116,7 +118,7 @@ func TestAuthMiddleware(t *testing.T) {
name: "Invalid Signature", name: "Invalid Signature",
setupAuth: func(r *http.Request) { setupAuth: func(r *http.Request) {
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{ token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{
"user_id": 123, "user_id": uint(123),
"session_id": "some-session", "session_id": "some-session",
"exp": time.Now().Add(time.Hour).Unix(), "exp": time.Now().Add(time.Hour).Unix(),
}) })
@@ -130,7 +132,7 @@ func TestAuthMiddleware(t *testing.T) {
name: "Invalid Signing Method", name: "Invalid Signing Method",
setupAuth: func(r *http.Request) { setupAuth: func(r *http.Request) {
token := jwt.NewWithClaims(jwt.SigningMethodES256, jwt.MapClaims{ token := jwt.NewWithClaims(jwt.SigningMethodES256, jwt.MapClaims{
"user_id": 123, "user_id": uint(123),
"session_id": "some-session", "session_id": "some-session",
"role_id": constants.Roles.Member, "role_id": constants.Roles.Member,
"exp": time.Now().Add(time.Hour).Unix(), "exp": time.Now().Add(time.Hour).Unix(),