Files
2025-03-11 20:52:11 +01:00

206 lines
5.4 KiB
Go

package middlewares
import (
"GoMembership/internal/config"
"GoMembership/internal/utils"
"GoMembership/pkg/logger"
"errors"
"fmt"
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
)
type Session struct {
UserID uint
ExpiresAt time.Time
}
var (
sessionDuration = 5 * 24 * time.Hour
jwtSigningMethod = jwt.SigningMethodHS256
jwtParser = jwt.NewParser(jwt.WithValidMethods([]string{jwtSigningMethod.Alg()}))
sessions = make(map[string]*Session)
)
func verifyAndRenewToken(tokenString string) (string, uint, error) {
if tokenString == "" {
logger.Error.Printf("empty tokenstring")
return "", 0, fmt.Errorf("Authorization token is required")
}
token, claims, err := ExtractContentFrom(tokenString)
if err != nil && !errors.Is(err, jwt.ErrTokenExpired) {
logger.Error.Printf("Couldn't parse JWT token String: %v", err)
return "", 0, err
}
if token.Valid {
// token is valid, so we can return the old tokenString
return tokenString, uint((*claims)["user_id"].(float64)), nil
}
// Token is expired but valid
sessionID, ok := (*claims)["session_id"].(string)
if !ok || sessionID == "" {
return "", 0, fmt.Errorf("invalid session ID")
}
id, ok := (*claims)["user_id"]
if !ok {
return "", 0, fmt.Errorf("missing user_id claim")
}
userID := uint(id.(float64))
id, ok = (*claims)["role_id"]
if !ok {
return "", 0, fmt.Errorf("missing role_id claim")
}
roleID := int8(id.(float64))
session, ok := sessions[sessionID]
if !ok {
logger.Error.Printf("session not found")
return "", 0, fmt.Errorf("session not found")
}
if userID != session.UserID {
return "", 0, fmt.Errorf("Cookie has been altered, aborting..")
}
if time.Now().After(sessions[sessionID].ExpiresAt) {
delete(sessions, sessionID)
logger.Error.Printf("session expired")
return "", 0, fmt.Errorf("session expired")
}
session.ExpiresAt = time.Now().Add(sessionDuration)
logger.Error.Printf("Session still valid generating new token")
// Session is still valid, generate a new token
user := map[string]interface{}{"user_id": userID, "role_id": roleID}
newTokenString, err := GenerateToken(&config.Auth.JWTSecret, user, sessionID)
if err != nil {
return "", 0, err
}
return newTokenString, session.UserID, nil
}
func AuthMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
tokenString, err := c.Cookie("jwt")
if err != nil {
logger.Error.Printf("No Auth token: %v\n", err)
c.JSON(http.StatusUnauthorized,
gin.H{"errors": []gin.H{{
"field": "server.general",
"key": "server.error.no_auth_token",
}}})
c.Abort()
return
}
newToken, userID, err := verifyAndRenewToken(tokenString)
if err != nil {
logger.Error.Printf("Token(%v) is invalid: %v\n", tokenString, err)
c.JSON(http.StatusUnauthorized,
gin.H{"errors": []gin.H{{
"field": "server.general",
"key": "server.error.no_auth_token",
}}})
c.Abort()
return
}
if newToken != tokenString {
utils.SetCookie(c, newToken)
}
c.Set("user_id", uint(userID))
c.Next()
}
}
// GenerateToken generates a new JWT token with the given claims and session ID.
// "user_id": user.ID, "role_id": user.RoleID
func GenerateToken(jwtKey *string, claims map[string]interface{}, sessionID string) (string, error) {
if sessionID == "" {
sessionID = uuid.New().String()
}
claims["session_id"] = sessionID
claims["exp"] = time.Now().Add(time.Minute * 1).Unix() // Token expires in 10 Minutes
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims(claims))
userID, ok := claims["user_id"].(uint)
if !ok {
return "", fmt.Errorf("invalid user_id in claims")
}
UpdateSession(sessionID, userID)
return token.SignedString([]byte(*jwtKey))
}
func ExtractContentFrom(tokenString string) (*jwt.Token, *jwt.MapClaims, error) {
token, err := jwtParser.Parse(tokenString, func(_ *jwt.Token) (interface{}, error) {
return []byte(config.Auth.JWTSecret), nil
})
// Handle parsing errors (excluding expiration error)
if err != nil && !errors.Is(err, jwt.ErrTokenExpired) {
logger.Error.Printf("Error parsing token: %v", err)
return nil, nil, err
}
// Ensure token is not nil (e.g., malformed tokens)
if token == nil {
logger.Error.Print("Token is nil after parsing")
return nil, nil, fmt.Errorf("invalid token")
}
// Extract and validate claims
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
logger.Error.Print("Invalid token claims structure")
return nil, nil, fmt.Errorf("invalid token claims format")
}
// Validate required session_id claim
if _, exists := claims["session_id"]; !exists {
logger.Error.Print("Missing session_id in token claims")
return nil, nil, fmt.Errorf("missing session_id claim")
}
// Return token, claims, and original error (might be expiration)
return token, &claims, err
}
func UpdateSession(sessionID string, userID uint) {
sessions[sessionID] = &Session{
UserID: userID,
ExpiresAt: time.Now().Add(sessionDuration),
}
}
func InvalidateSession(token string) (bool, error) {
claims := jwt.MapClaims{}
_, err := jwt.ParseWithClaims(
token,
claims,
func(token *jwt.Token) (interface{}, error) {
return config.Auth.JWTSecret, nil
},
)
if err != nil {
return false, fmt.Errorf("Couldn't get JWT claims: %#v", err)
}
sessionID, ok := claims["session_id"].(string)
if !ok {
return false, fmt.Errorf("No SessionID found")
}
delete(sessions, sessionID)
return true, nil
}