package middlewares import ( "GoMembership/internal/config" "GoMembership/internal/utils" "GoMembership/pkg/logger" "errors" "fmt" "net/http" "time" "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" ) type Session struct { UserID uint ExpiresAt time.Time } var ( sessionDuration = 5 * 24 * time.Hour jwtSigningMethod = jwt.SigningMethodHS256 jwtParser = jwt.NewParser(jwt.WithValidMethods([]string{jwtSigningMethod.Alg()})) sessions = make(map[string]*Session) ) func verifyAndRenewToken(tokenString string) (string, uint, error) { if tokenString == "" { logger.Error.Printf("empty tokenstring") return "", 0, fmt.Errorf("Authorization token is required") } token, claims, err := ExtractContentFrom(tokenString) if err != nil && !errors.Is(err, jwt.ErrTokenExpired) { logger.Error.Printf("Couldn't parse JWT token String: %v", err) return "", 0, err } if token.Valid { // token is valid, so we can return the old tokenString return tokenString, uint((*claims)["user_id"].(float64)), nil } // Token is expired but valid sessionID, ok := (*claims)["session_id"].(string) if !ok || sessionID == "" { return "", 0, fmt.Errorf("invalid session ID") } id, ok := (*claims)["user_id"] if !ok { return "", 0, fmt.Errorf("missing user_id claim") } userID := uint(id.(float64)) id, ok = (*claims)["role_id"] if !ok { return "", 0, fmt.Errorf("missing role_id claim") } roleID := int8(id.(float64)) session, ok := sessions[sessionID] if !ok { logger.Error.Printf("session not found") return "", 0, fmt.Errorf("session not found") } if userID != session.UserID { return "", 0, fmt.Errorf("Cookie has been altered, aborting..") } if time.Now().After(sessions[sessionID].ExpiresAt) { delete(sessions, sessionID) logger.Error.Printf("session expired") return "", 0, fmt.Errorf("session expired") } session.ExpiresAt = time.Now().Add(sessionDuration) logger.Error.Printf("Session still valid generating new token") // Session is still valid, generate a new token user := map[string]interface{}{"user_id": userID, "role_id": roleID} newTokenString, err := GenerateToken(&config.Auth.JWTSecret, user, sessionID) if err != nil { return "", 0, err } return newTokenString, session.UserID, nil } func AuthMiddleware() gin.HandlerFunc { return func(c *gin.Context) { tokenString, err := c.Cookie("jwt") if err != nil { logger.Error.Printf("No Auth token: %v\n", err) c.JSON(http.StatusUnauthorized, gin.H{"errors": []gin.H{{ "field": "server.general", "key": "server.error.no_auth_token", }}}) c.Abort() return } newToken, userID, err := verifyAndRenewToken(tokenString) if err != nil { logger.Error.Printf("Token(%v) is invalid: %v\n", tokenString, err) c.JSON(http.StatusUnauthorized, gin.H{"errors": []gin.H{{ "field": "server.general", "key": "server.error.no_auth_token", }}}) c.Abort() return } if newToken != tokenString { utils.SetCookie(c, newToken) } c.Set("user_id", uint(userID)) c.Next() } } // GenerateToken generates a new JWT token with the given claims and session ID. // "user_id": user.ID, "role_id": user.RoleID func GenerateToken(jwtKey *string, claims map[string]interface{}, sessionID string) (string, error) { if sessionID == "" { sessionID = uuid.New().String() } claims["session_id"] = sessionID claims["exp"] = time.Now().Add(time.Minute * 1).Unix() // Token expires in 10 Minutes token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims(claims)) userID, ok := claims["user_id"].(uint) if !ok { return "", fmt.Errorf("invalid user_id in claims") } UpdateSession(sessionID, userID) return token.SignedString([]byte(*jwtKey)) } func ExtractContentFrom(tokenString string) (*jwt.Token, *jwt.MapClaims, error) { token, err := jwtParser.Parse(tokenString, func(_ *jwt.Token) (interface{}, error) { return []byte(config.Auth.JWTSecret), nil }) // Handle parsing errors (excluding expiration error) if err != nil && !errors.Is(err, jwt.ErrTokenExpired) { logger.Error.Printf("Error parsing token: %v", err) return nil, nil, err } // Ensure token is not nil (e.g., malformed tokens) if token == nil { logger.Error.Print("Token is nil after parsing") return nil, nil, fmt.Errorf("invalid token") } // Extract and validate claims claims, ok := token.Claims.(jwt.MapClaims) if !ok { logger.Error.Print("Invalid token claims structure") return nil, nil, fmt.Errorf("invalid token claims format") } // Validate required session_id claim if _, exists := claims["session_id"]; !exists { logger.Error.Print("Missing session_id in token claims") return nil, nil, fmt.Errorf("missing session_id claim") } // Return token, claims, and original error (might be expiration) return token, &claims, err } func UpdateSession(sessionID string, userID uint) { sessions[sessionID] = &Session{ UserID: userID, ExpiresAt: time.Now().Add(sessionDuration), } } func InvalidateSession(token string) (bool, error) { claims := jwt.MapClaims{} _, err := jwt.ParseWithClaims( token, claims, func(token *jwt.Token) (interface{}, error) { return config.Auth.JWTSecret, nil }, ) if err != nil { return false, fmt.Errorf("Couldn't get JWT claims: %#v", err) } sessionID, ok := claims["session_id"].(string) if !ok { return false, fmt.Errorf("No SessionID found") } delete(sessions, sessionID) return true, nil }