Files
GoMembership/internal/middlewares/auth.go
2024-09-29 20:58:42 +02:00

172 lines
4.5 KiB
Go

package middlewares
import (
"GoMembership/internal/config"
"GoMembership/internal/models"
"GoMembership/internal/utils"
customerrors "GoMembership/pkg/errors"
"GoMembership/pkg/logger"
"errors"
"fmt"
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
)
type Session struct {
UserID uint
ExpiresAt time.Time
}
var (
sessionDuration = 5 * 24 * time.Hour
jwtSigningMethod = jwt.SigningMethodHS256
jwtParser = jwt.NewParser(jwt.WithValidMethods([]string{jwtSigningMethod.Alg()}))
sessions = make(map[string]*Session)
)
func verifyAndRenewToken(tokenString string) (string, uint, error) {
if tokenString == "" {
logger.Error.Printf("empty tokenstring")
return "", 0, fmt.Errorf("Authorization token is required")
}
token, claims, err := ExtractContentFrom(tokenString)
if err != nil {
logger.Error.Printf("Couldn't parse JWT token String: %v", err)
return "", 0, err
}
sessionID := (*claims)["session_id"].(string)
userID := uint((*claims)["user_id"].(float64))
roleID := int8((*claims)["role_id"].(float64))
session, ok := sessions[sessionID]
if !ok {
logger.Error.Printf("session not found")
return "", 0, fmt.Errorf("session not found")
}
if userID != session.UserID {
return "", 0, fmt.Errorf("Cookie has been altered, aborting..")
}
if token.Valid {
// token is valid, so we can return the old tokenString
return tokenString, session.UserID, customerrors.ErrValidToken
}
if time.Now().After(sessions[sessionID].ExpiresAt) {
delete(sessions, sessionID)
logger.Error.Printf("session expired")
return "", 0, fmt.Errorf("session expired")
}
session.ExpiresAt = time.Now().Add(sessionDuration)
logger.Error.Printf("Session still valid generating new token")
// Session is still valid, generate a new token
user := models.User{ID: userID, RoleID: roleID}
newTokenString, err := GenerateToken(config.Auth.JWTSecret, &user, sessionID)
if err != nil {
return "", 0, err
}
return newTokenString, session.UserID, nil
}
func AuthMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
tokenString, err := c.Cookie("jwt")
if err != nil {
logger.Error.Printf("No Auth token: %v\n", err)
c.JSON(http.StatusUnauthorized, gin.H{"error": "No Auth token"})
c.Abort()
return
}
newToken, userID, err := verifyAndRenewToken(tokenString)
if err != nil {
if err == customerrors.ErrValidToken {
c.Set("user_id", uint(userID))
c.Next()
return
}
logger.Error.Printf("Token(%v) is invalid: %v\n", tokenString, err)
c.JSON(http.StatusUnauthorized, gin.H{"error": "Auth token invalid"})
c.Abort()
return
}
utils.SetCookie(c, newToken)
c.Set("user_id", uint(userID))
c.Next()
}
}
func GenerateToken(jwtKey string, user *models.User, sessionID string) (string, error) {
if sessionID == "" {
sessionID = uuid.New().String()
}
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{
"user_id": user.ID,
"role_id": user.RoleID,
"session_id": sessionID,
"exp": time.Now().Add(time.Minute * 1).Unix(), // Token expires in 10 Minutes
})
UpdateSession(sessionID, user.ID)
return token.SignedString([]byte(jwtKey))
}
func ExtractContentFrom(tokenString string) (*jwt.Token, *jwt.MapClaims, error) {
token, err := jwtParser.Parse(tokenString, func(_ *jwt.Token) (interface{}, error) {
return []byte(config.Auth.JWTSecret), nil
})
if !errors.Is(err, jwt.ErrTokenExpired) && err != nil {
logger.Error.Printf("Error during token(%v) parsing: %#v", tokenString, err)
return nil, nil, err
}
// Token is expired, check if session is still valid
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
logger.Error.Printf("Invalid Token Claims")
return nil, nil, fmt.Errorf("invalid token claims")
}
if !ok {
logger.Error.Printf("invalid session_id in token")
return nil, nil, fmt.Errorf("invalid session_id in token")
}
return token, &claims, nil
}
func UpdateSession(sessionID string, userID uint) {
sessions[sessionID] = &Session{
UserID: userID,
ExpiresAt: time.Now().Add(sessionDuration),
}
}
func InvalidateSession(token string) (bool, error) {
claims := jwt.MapClaims{}
_, err := jwt.ParseWithClaims(
token,
claims,
func(token *jwt.Token) (interface{}, error) {
return config.Auth.JWTSecret, nil
},
)
if err != nil {
return false, fmt.Errorf("Couldn't get JWT claims: %#v", err)
}
sessionID, ok := claims["session_id"].(string)
if !ok {
return false, fmt.Errorf("No SessionID found")
}
delete(sessions, sessionID)
return true, nil
}