Add: session handling

This commit is contained in:
$(pass /github/name)
2024-09-20 08:25:26 +02:00
parent 31c47270ab
commit 1e68e7d390
2 changed files with 181 additions and 82 deletions

View File

@@ -2,68 +2,75 @@ package middlewares
import (
"GoMembership/internal/config"
"GoMembership/internal/models"
"GoMembership/internal/utils"
customerrors "GoMembership/pkg/errors"
"GoMembership/pkg/logger"
"errors"
"fmt"
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
)
var (
jwtKey = []byte(config.Auth.JWTSecret)
jwtSigningMethod = jwt.SigningMethodHS256
jwtParser = jwt.NewParser(jwt.WithValidMethods([]string{jwtSigningMethod.Alg()}))
)
func GenerateToken(userID int64) (string, error) {
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{
"user_id": userID,
"exp": time.Now().Add(time.Minute * 10).Unix(), // Token expires in 10 Minutes
})
logger.Error.Printf("token generated: %#v", token)
return token.SignedString(jwtKey)
type Session struct {
UserID int64
ExpiresAt time.Time
}
func verifyToken(tokenString string) (*jwt.Token, error) {
var (
sessionDuration = 5 * 24 * time.Hour
jwtSigningMethod = jwt.SigningMethodHS256
jwtParser = jwt.NewParser(jwt.WithValidMethods([]string{jwtSigningMethod.Alg()}))
sessions = make(map[string]*Session)
)
func verifyAndRenewToken(tokenString string) (string, int64, error) {
if tokenString == "" {
return nil, fmt.Errorf("Authorization token is required")
logger.Error.Printf("empty tokenstring")
return "", -1, fmt.Errorf("Authorization token is required")
}
token, err := jwtParser.Parse(tokenString, func(_ *jwt.Token) (interface{}, error) {
return jwtKey, nil
})
token, claims, err := ExtractContentFrom(tokenString)
if err != nil {
return nil, err
logger.Error.Printf("Couldn't parse JWT token String: %v", err)
return "", -1, err
}
sessionID := (*claims)["session_id"].(string)
userID := int64((*claims)["user_id"].(float64))
roleID := int8((*claims)["role_id"].(float64))
if !token.Valid {
return nil, fmt.Errorf("invalid token")
}
claims, ok := token.Claims.(jwt.MapClaims)
session, ok := sessions[sessionID]
if !ok {
return nil, fmt.Errorf("invalid token claims")
logger.Error.Printf("session not found")
return "", -1, fmt.Errorf("session not found")
}
if userID != session.UserID {
return "", -1, fmt.Errorf("Cookie has been altered, aborting..")
}
if token.Valid {
// token is valid, so we can return the old tokenString
return tokenString, session.UserID, customerrors.ErrValidToken
}
exp, ok := claims["exp"].(float64)
if !ok {
return nil, fmt.Errorf("invalid expiration claim")
if time.Now().After(sessions[sessionID].ExpiresAt) {
delete(sessions, sessionID)
logger.Error.Printf("session expired")
return "", -1, fmt.Errorf("session expired")
}
session.ExpiresAt = time.Now().Add(sessionDuration)
logger.Error.Printf("Session still valid generating new token")
// Session is still valid, generate a new token
user := models.User{ID: userID, RoleID: roleID}
newTokenString, err := GenerateToken(config.Auth.JWTSecret, &user, sessionID)
if err != nil {
return "", -1, err
}
userID, ok := claims["user_id"].(float64)
if !ok {
logger.Error.Printf("Invalid user ID: %v", userID)
return nil, fmt.Errorf("Invalid user ID")
}
if time.Now().Unix() > int64(exp) {
return nil, fmt.Errorf("token expired")
}
return token, nil
return newTokenString, session.UserID, nil
}
func AuthMiddleware() gin.HandlerFunc {
@@ -76,34 +83,89 @@ func AuthMiddleware() gin.HandlerFunc {
return
}
token, err := verifyToken(tokenString)
newToken, userID, err := verifyAndRenewToken(tokenString)
if err != nil {
logger.Error.Printf("Token is invalid: %v\n", err)
if err == customerrors.ErrValidToken {
c.Set("user_id", int64(userID))
c.Next()
return
}
logger.Error.Printf("Token(%v) is invalid: %v\n", tokenString, err)
c.JSON(http.StatusUnauthorized, gin.H{"error": "Auth token invalid"})
c.Abort()
return
}
claims, _ := token.Claims.(jwt.MapClaims)
userID, _ := claims["user_id"].(float64)
// Generate a new token
newToken, err := GenerateToken(int64(userID))
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to refresh token"})
c.Abort()
return
}
c.SetCookie(
"jwt",
newToken,
10*60, // 10 minutes
"/",
"",
true,
true,
)
c.Set("user_id", userID)
utils.SetCookie(c, newToken)
c.Set("user_id", int64(userID))
c.Next()
}
}
func GenerateToken(jwtKey string, user *models.User, sessionID string) (string, error) {
if sessionID == "" {
sessionID = uuid.New().String()
}
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{
"user_id": user.ID,
"role_id": user.RoleID,
"session_id": sessionID,
"exp": time.Now().Add(time.Minute * 1).Unix(), // Token expires in 10 Minutes
})
UpdateSession(sessionID, user.ID)
return token.SignedString([]byte(jwtKey))
}
func ExtractContentFrom(tokenString string) (*jwt.Token, *jwt.MapClaims, error) {
token, err := jwtParser.Parse(tokenString, func(_ *jwt.Token) (interface{}, error) {
return []byte(config.Auth.JWTSecret), nil
})
if !errors.Is(err, jwt.ErrTokenExpired) && err != nil {
logger.Error.Printf("Error during token(%v) parsing: %#v", tokenString, err)
return nil, nil, err
}
// Token is expired, check if session is still valid
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
logger.Error.Printf("Invalid Token Claims")
return nil, nil, fmt.Errorf("invalid token claims")
}
if !ok {
logger.Error.Printf("invalid session_id in token")
return nil, nil, fmt.Errorf("invalid session_id in token")
}
return token, &claims, nil
}
func UpdateSession(sessionID string, userID int64) {
sessions[sessionID] = &Session{
UserID: userID,
ExpiresAt: time.Now().Add(sessionDuration),
}
}
func InvalidateSession(token string) (bool, error) {
claims := jwt.MapClaims{}
_, err := jwt.ParseWithClaims(
token,
claims,
func(token *jwt.Token) (interface{}, error) {
return config.Auth.JWTSecret, nil
},
)
if err != nil {
return false, fmt.Errorf("Couldn't get JWT claims: %#v", err)
}
sessionID, ok := claims["session_id"].(string)
if !ok {
return false, fmt.Errorf("No SessionID found")
}
delete(sessions, sessionID)
return true, nil
}

View File

@@ -2,6 +2,8 @@ package middlewares
import (
"GoMembership/internal/config"
"GoMembership/internal/constants"
"GoMembership/internal/models"
"GoMembership/pkg/logger"
"encoding/json"
"log"
@@ -17,9 +19,7 @@ import (
"github.com/stretchr/testify/assert"
)
func TestAuthMiddleware(t *testing.T) {
gin.SetMode(gin.TestMode)
func setupTestEnvironment() {
cwd, err := os.Getwd()
if err != nil {
log.Fatalf("Failed to get current working directory: %v", err)
@@ -39,17 +39,25 @@ func TestAuthMiddleware(t *testing.T) {
}
config.LoadConfig()
logger.Info.Printf("Config: %#v", config.CFG)
}
func TestAuthMiddleware(t *testing.T) {
gin.SetMode(gin.TestMode)
setupTestEnvironment()
tests := []struct {
name string
setupAuth func(r *http.Request)
expectedStatus int
expectedUserID int64
name string
setupAuth func(r *http.Request)
expectedStatus int
expectNewCookie bool
expectedUserID int64
}{
{
name: "Valid Token",
setupAuth: func(r *http.Request) {
token, _ := GenerateToken(123)
user := models.User{ID: 123, RoleID: constants.Roles.Member}
token, _ := GenerateToken(config.Auth.JWTSecret, &user, "")
r.AddCookie(&http.Cookie{Name: "jwt", Value: token})
},
expectedStatus: http.StatusOK,
@@ -70,14 +78,36 @@ func TestAuthMiddleware(t *testing.T) {
expectedUserID: 0,
},
{
name: "Expired Token",
name: "Expired Token with Valid Session",
setupAuth: func(r *http.Request) {
sessionID := "test-session"
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{
"user_id": 123,
"exp": time.Now().Add(-time.Hour).Unix(), // Expired 1 hour ago
"user_id": 123,
"role_id": constants.Roles.Member,
"session_id": sessionID,
"exp": time.Now().Add(-time.Hour).Unix(), // Expired 1 hour ago
})
tokenString, _ := token.SignedString(jwtKey)
tokenString, _ := token.SignedString([]byte(config.Auth.JWTSecret))
r.AddCookie(&http.Cookie{Name: "jwt", Value: tokenString})
UpdateSession(sessionID, 123) // Add a valid session
},
expectedStatus: http.StatusOK,
expectNewCookie: true,
expectedUserID: 123,
},
{
name: "Expired Token with Expired Session",
setupAuth: func(r *http.Request) {
sessionID := "expired-session"
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{
"user_id": 123,
"role_id": constants.Roles.Member,
"session_id": sessionID,
"exp": time.Now().Add(-time.Hour).Unix(), // Expired 1 hour ago
})
tokenString, _ := token.SignedString([]byte(config.Auth.JWTSecret))
r.AddCookie(&http.Cookie{Name: "jwt", Value: tokenString})
// Don't add a session, simulating an expired session
},
expectedStatus: http.StatusUnauthorized,
expectedUserID: 0,
@@ -86,8 +116,9 @@ func TestAuthMiddleware(t *testing.T) {
name: "Invalid Signature",
setupAuth: func(r *http.Request) {
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{
"user_id": 123,
"exp": time.Now().Add(time.Hour).Unix(),
"user_id": 123,
"session_id": "some-session",
"exp": time.Now().Add(time.Hour).Unix(),
})
tokenString, _ := token.SignedString([]byte("wrong_secret"))
r.AddCookie(&http.Cookie{Name: "jwt", Value: tokenString})
@@ -99,8 +130,10 @@ func TestAuthMiddleware(t *testing.T) {
name: "Invalid Signing Method",
setupAuth: func(r *http.Request) {
token := jwt.NewWithClaims(jwt.SigningMethodES256, jwt.MapClaims{
"user_id": 123,
"exp": time.Now().Add(time.Hour).Unix(),
"user_id": 123,
"session_id": "some-session",
"role_id": constants.Roles.Member,
"exp": time.Now().Add(time.Hour).Unix(),
})
tokenString, _ := token.SignedString([]byte(config.Auth.JWTSecret))
r.AddCookie(&http.Cookie{Name: "jwt", Value: tokenString})
@@ -143,9 +176,13 @@ func TestAuthMiddleware(t *testing.T) {
// Check if a new cookie was set
cookies := w.Result().Cookies()
assert.GreaterOrEqual(t, len(cookies), 1)
assert.Equal(t, "jwt", cookies[0].Name)
assert.NotEmpty(t, cookies[0].Value)
if tt.expectNewCookie {
assert.GreaterOrEqual(t, len(cookies), 1)
assert.Equal(t, "jwt", cookies[0].Name)
assert.NotEmpty(t, cookies[0].Value)
} else {
assert.Equal(t, 0, len(cookies), "Unexpected cookie set")
}
} else {
assert.Equal(t, 0, len(w.Result().Cookies()))
}