package middlewares import ( "GoMembership/internal/config" "GoMembership/internal/models" "GoMembership/internal/utils" customerrors "GoMembership/pkg/errors" "GoMembership/pkg/logger" "errors" "fmt" "net/http" "time" "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" ) type Session struct { UserID int64 ExpiresAt time.Time } var ( sessionDuration = 5 * 24 * time.Hour jwtSigningMethod = jwt.SigningMethodHS256 jwtParser = jwt.NewParser(jwt.WithValidMethods([]string{jwtSigningMethod.Alg()})) sessions = make(map[string]*Session) ) func verifyAndRenewToken(tokenString string) (string, int64, error) { if tokenString == "" { logger.Error.Printf("empty tokenstring") return "", -1, fmt.Errorf("Authorization token is required") } token, claims, err := ExtractContentFrom(tokenString) if err != nil { logger.Error.Printf("Couldn't parse JWT token String: %v", err) return "", -1, err } sessionID := (*claims)["session_id"].(string) userID := int64((*claims)["user_id"].(float64)) roleID := int8((*claims)["role_id"].(float64)) session, ok := sessions[sessionID] if !ok { logger.Error.Printf("session not found") return "", -1, fmt.Errorf("session not found") } if userID != session.UserID { return "", -1, fmt.Errorf("Cookie has been altered, aborting..") } if token.Valid { // token is valid, so we can return the old tokenString return tokenString, session.UserID, customerrors.ErrValidToken } if time.Now().After(sessions[sessionID].ExpiresAt) { delete(sessions, sessionID) logger.Error.Printf("session expired") return "", -1, fmt.Errorf("session expired") } session.ExpiresAt = time.Now().Add(sessionDuration) logger.Error.Printf("Session still valid generating new token") // Session is still valid, generate a new token user := models.User{ID: userID, RoleID: roleID} newTokenString, err := GenerateToken(config.Auth.JWTSecret, &user, sessionID) if err != nil { return "", -1, err } return newTokenString, session.UserID, nil } func AuthMiddleware() gin.HandlerFunc { return func(c *gin.Context) { tokenString, err := c.Cookie("jwt") if err != nil { logger.Error.Printf("No Auth token: %v\n", err) c.JSON(http.StatusUnauthorized, gin.H{"error": "No Auth token"}) c.Abort() return } newToken, userID, err := verifyAndRenewToken(tokenString) if err != nil { if err == customerrors.ErrValidToken { c.Set("user_id", int64(userID)) c.Next() return } logger.Error.Printf("Token(%v) is invalid: %v\n", tokenString, err) c.JSON(http.StatusUnauthorized, gin.H{"error": "Auth token invalid"}) c.Abort() return } utils.SetCookie(c, newToken) c.Set("user_id", int64(userID)) c.Next() } } func GenerateToken(jwtKey string, user *models.User, sessionID string) (string, error) { if sessionID == "" { sessionID = uuid.New().String() } token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{ "user_id": user.ID, "role_id": user.RoleID, "session_id": sessionID, "exp": time.Now().Add(time.Minute * 1).Unix(), // Token expires in 10 Minutes }) UpdateSession(sessionID, user.ID) return token.SignedString([]byte(jwtKey)) } func ExtractContentFrom(tokenString string) (*jwt.Token, *jwt.MapClaims, error) { token, err := jwtParser.Parse(tokenString, func(_ *jwt.Token) (interface{}, error) { return []byte(config.Auth.JWTSecret), nil }) if !errors.Is(err, jwt.ErrTokenExpired) && err != nil { logger.Error.Printf("Error during token(%v) parsing: %#v", tokenString, err) return nil, nil, err } // Token is expired, check if session is still valid claims, ok := token.Claims.(jwt.MapClaims) if !ok { logger.Error.Printf("Invalid Token Claims") return nil, nil, fmt.Errorf("invalid token claims") } if !ok { logger.Error.Printf("invalid session_id in token") return nil, nil, fmt.Errorf("invalid session_id in token") } return token, &claims, nil } func UpdateSession(sessionID string, userID int64) { sessions[sessionID] = &Session{ UserID: userID, ExpiresAt: time.Now().Add(sessionDuration), } } func InvalidateSession(token string) (bool, error) { claims := jwt.MapClaims{} _, err := jwt.ParseWithClaims( token, claims, func(token *jwt.Token) (interface{}, error) { return config.Auth.JWTSecret, nil }, ) if err != nil { return false, fmt.Errorf("Couldn't get JWT claims: %#v", err) } sessionID, ok := claims["session_id"].(string) if !ok { return false, fmt.Errorf("No SessionID found") } delete(sessions, sessionID) return true, nil }