Add: session handling
This commit is contained in:
@@ -2,68 +2,75 @@ package middlewares
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"GoMembership/internal/config"
|
"GoMembership/internal/config"
|
||||||
|
"GoMembership/internal/models"
|
||||||
|
"GoMembership/internal/utils"
|
||||||
|
customerrors "GoMembership/pkg/errors"
|
||||||
"GoMembership/pkg/logger"
|
"GoMembership/pkg/logger"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
type Session struct {
|
||||||
jwtKey = []byte(config.Auth.JWTSecret)
|
UserID int64
|
||||||
jwtSigningMethod = jwt.SigningMethodHS256
|
ExpiresAt time.Time
|
||||||
jwtParser = jwt.NewParser(jwt.WithValidMethods([]string{jwtSigningMethod.Alg()}))
|
|
||||||
)
|
|
||||||
|
|
||||||
func GenerateToken(userID int64) (string, error) {
|
|
||||||
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{
|
|
||||||
"user_id": userID,
|
|
||||||
"exp": time.Now().Add(time.Minute * 10).Unix(), // Token expires in 10 Minutes
|
|
||||||
})
|
|
||||||
|
|
||||||
logger.Error.Printf("token generated: %#v", token)
|
|
||||||
return token.SignedString(jwtKey)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func verifyToken(tokenString string) (*jwt.Token, error) {
|
var (
|
||||||
|
sessionDuration = 5 * 24 * time.Hour
|
||||||
|
jwtSigningMethod = jwt.SigningMethodHS256
|
||||||
|
jwtParser = jwt.NewParser(jwt.WithValidMethods([]string{jwtSigningMethod.Alg()}))
|
||||||
|
sessions = make(map[string]*Session)
|
||||||
|
)
|
||||||
|
|
||||||
|
func verifyAndRenewToken(tokenString string) (string, int64, error) {
|
||||||
if tokenString == "" {
|
if tokenString == "" {
|
||||||
return nil, fmt.Errorf("Authorization token is required")
|
logger.Error.Printf("empty tokenstring")
|
||||||
|
return "", -1, fmt.Errorf("Authorization token is required")
|
||||||
}
|
}
|
||||||
token, err := jwtParser.Parse(tokenString, func(_ *jwt.Token) (interface{}, error) {
|
token, claims, err := ExtractContentFrom(tokenString)
|
||||||
return jwtKey, nil
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
logger.Error.Printf("Couldn't parse JWT token String: %v", err)
|
||||||
|
return "", -1, err
|
||||||
}
|
}
|
||||||
|
sessionID := (*claims)["session_id"].(string)
|
||||||
|
userID := int64((*claims)["user_id"].(float64))
|
||||||
|
roleID := int8((*claims)["role_id"].(float64))
|
||||||
|
|
||||||
if !token.Valid {
|
session, ok := sessions[sessionID]
|
||||||
return nil, fmt.Errorf("invalid token")
|
|
||||||
}
|
|
||||||
|
|
||||||
claims, ok := token.Claims.(jwt.MapClaims)
|
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("invalid token claims")
|
logger.Error.Printf("session not found")
|
||||||
|
return "", -1, fmt.Errorf("session not found")
|
||||||
|
}
|
||||||
|
if userID != session.UserID {
|
||||||
|
return "", -1, fmt.Errorf("Cookie has been altered, aborting..")
|
||||||
|
}
|
||||||
|
if token.Valid {
|
||||||
|
// token is valid, so we can return the old tokenString
|
||||||
|
return tokenString, session.UserID, customerrors.ErrValidToken
|
||||||
}
|
}
|
||||||
|
|
||||||
exp, ok := claims["exp"].(float64)
|
if time.Now().After(sessions[sessionID].ExpiresAt) {
|
||||||
if !ok {
|
delete(sessions, sessionID)
|
||||||
return nil, fmt.Errorf("invalid expiration claim")
|
logger.Error.Printf("session expired")
|
||||||
|
return "", -1, fmt.Errorf("session expired")
|
||||||
|
}
|
||||||
|
session.ExpiresAt = time.Now().Add(sessionDuration)
|
||||||
|
|
||||||
|
logger.Error.Printf("Session still valid generating new token")
|
||||||
|
// Session is still valid, generate a new token
|
||||||
|
user := models.User{ID: userID, RoleID: roleID}
|
||||||
|
newTokenString, err := GenerateToken(config.Auth.JWTSecret, &user, sessionID)
|
||||||
|
if err != nil {
|
||||||
|
return "", -1, err
|
||||||
}
|
}
|
||||||
|
|
||||||
userID, ok := claims["user_id"].(float64)
|
return newTokenString, session.UserID, nil
|
||||||
if !ok {
|
|
||||||
logger.Error.Printf("Invalid user ID: %v", userID)
|
|
||||||
return nil, fmt.Errorf("Invalid user ID")
|
|
||||||
}
|
|
||||||
|
|
||||||
if time.Now().Unix() > int64(exp) {
|
|
||||||
return nil, fmt.Errorf("token expired")
|
|
||||||
}
|
|
||||||
|
|
||||||
return token, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func AuthMiddleware() gin.HandlerFunc {
|
func AuthMiddleware() gin.HandlerFunc {
|
||||||
@@ -76,34 +83,89 @@ func AuthMiddleware() gin.HandlerFunc {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
token, err := verifyToken(tokenString)
|
newToken, userID, err := verifyAndRenewToken(tokenString)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error.Printf("Token is invalid: %v\n", err)
|
if err == customerrors.ErrValidToken {
|
||||||
|
c.Set("user_id", int64(userID))
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logger.Error.Printf("Token(%v) is invalid: %v\n", tokenString, err)
|
||||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Auth token invalid"})
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "Auth token invalid"})
|
||||||
c.Abort()
|
c.Abort()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
claims, _ := token.Claims.(jwt.MapClaims)
|
|
||||||
userID, _ := claims["user_id"].(float64)
|
|
||||||
|
|
||||||
// Generate a new token
|
utils.SetCookie(c, newToken)
|
||||||
newToken, err := GenerateToken(int64(userID))
|
c.Set("user_id", int64(userID))
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to refresh token"})
|
|
||||||
c.Abort()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.SetCookie(
|
|
||||||
"jwt",
|
|
||||||
newToken,
|
|
||||||
10*60, // 10 minutes
|
|
||||||
"/",
|
|
||||||
"",
|
|
||||||
true,
|
|
||||||
true,
|
|
||||||
)
|
|
||||||
c.Set("user_id", userID)
|
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GenerateToken(jwtKey string, user *models.User, sessionID string) (string, error) {
|
||||||
|
if sessionID == "" {
|
||||||
|
sessionID = uuid.New().String()
|
||||||
|
}
|
||||||
|
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{
|
||||||
|
"user_id": user.ID,
|
||||||
|
"role_id": user.RoleID,
|
||||||
|
"session_id": sessionID,
|
||||||
|
"exp": time.Now().Add(time.Minute * 1).Unix(), // Token expires in 10 Minutes
|
||||||
|
})
|
||||||
|
UpdateSession(sessionID, user.ID)
|
||||||
|
return token.SignedString([]byte(jwtKey))
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExtractContentFrom(tokenString string) (*jwt.Token, *jwt.MapClaims, error) {
|
||||||
|
|
||||||
|
token, err := jwtParser.Parse(tokenString, func(_ *jwt.Token) (interface{}, error) {
|
||||||
|
return []byte(config.Auth.JWTSecret), nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if !errors.Is(err, jwt.ErrTokenExpired) && err != nil {
|
||||||
|
logger.Error.Printf("Error during token(%v) parsing: %#v", tokenString, err)
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Token is expired, check if session is still valid
|
||||||
|
claims, ok := token.Claims.(jwt.MapClaims)
|
||||||
|
if !ok {
|
||||||
|
logger.Error.Printf("Invalid Token Claims")
|
||||||
|
return nil, nil, fmt.Errorf("invalid token claims")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
logger.Error.Printf("invalid session_id in token")
|
||||||
|
return nil, nil, fmt.Errorf("invalid session_id in token")
|
||||||
|
}
|
||||||
|
return token, &claims, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func UpdateSession(sessionID string, userID int64) {
|
||||||
|
sessions[sessionID] = &Session{
|
||||||
|
UserID: userID,
|
||||||
|
ExpiresAt: time.Now().Add(sessionDuration),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func InvalidateSession(token string) (bool, error) {
|
||||||
|
claims := jwt.MapClaims{}
|
||||||
|
_, err := jwt.ParseWithClaims(
|
||||||
|
token,
|
||||||
|
claims,
|
||||||
|
func(token *jwt.Token) (interface{}, error) {
|
||||||
|
return config.Auth.JWTSecret, nil
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("Couldn't get JWT claims: %#v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionID, ok := claims["session_id"].(string)
|
||||||
|
if !ok {
|
||||||
|
return false, fmt.Errorf("No SessionID found")
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(sessions, sessionID)
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ package middlewares
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"GoMembership/internal/config"
|
"GoMembership/internal/config"
|
||||||
|
"GoMembership/internal/constants"
|
||||||
|
"GoMembership/internal/models"
|
||||||
"GoMembership/pkg/logger"
|
"GoMembership/pkg/logger"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"log"
|
"log"
|
||||||
@@ -17,9 +19,7 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestAuthMiddleware(t *testing.T) {
|
func setupTestEnvironment() {
|
||||||
gin.SetMode(gin.TestMode)
|
|
||||||
|
|
||||||
cwd, err := os.Getwd()
|
cwd, err := os.Getwd()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Failed to get current working directory: %v", err)
|
log.Fatalf("Failed to get current working directory: %v", err)
|
||||||
@@ -39,17 +39,25 @@ func TestAuthMiddleware(t *testing.T) {
|
|||||||
}
|
}
|
||||||
config.LoadConfig()
|
config.LoadConfig()
|
||||||
logger.Info.Printf("Config: %#v", config.CFG)
|
logger.Info.Printf("Config: %#v", config.CFG)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthMiddleware(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
setupTestEnvironment()
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
setupAuth func(r *http.Request)
|
setupAuth func(r *http.Request)
|
||||||
expectedStatus int
|
expectedStatus int
|
||||||
expectedUserID int64
|
expectNewCookie bool
|
||||||
|
expectedUserID int64
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Valid Token",
|
name: "Valid Token",
|
||||||
setupAuth: func(r *http.Request) {
|
setupAuth: func(r *http.Request) {
|
||||||
token, _ := GenerateToken(123)
|
user := models.User{ID: 123, RoleID: constants.Roles.Member}
|
||||||
|
token, _ := GenerateToken(config.Auth.JWTSecret, &user, "")
|
||||||
r.AddCookie(&http.Cookie{Name: "jwt", Value: token})
|
r.AddCookie(&http.Cookie{Name: "jwt", Value: token})
|
||||||
},
|
},
|
||||||
expectedStatus: http.StatusOK,
|
expectedStatus: http.StatusOK,
|
||||||
@@ -70,14 +78,36 @@ func TestAuthMiddleware(t *testing.T) {
|
|||||||
expectedUserID: 0,
|
expectedUserID: 0,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Expired Token",
|
name: "Expired Token with Valid Session",
|
||||||
setupAuth: func(r *http.Request) {
|
setupAuth: func(r *http.Request) {
|
||||||
|
sessionID := "test-session"
|
||||||
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{
|
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{
|
||||||
"user_id": 123,
|
"user_id": 123,
|
||||||
"exp": time.Now().Add(-time.Hour).Unix(), // Expired 1 hour ago
|
"role_id": constants.Roles.Member,
|
||||||
|
"session_id": sessionID,
|
||||||
|
"exp": time.Now().Add(-time.Hour).Unix(), // Expired 1 hour ago
|
||||||
})
|
})
|
||||||
tokenString, _ := token.SignedString(jwtKey)
|
tokenString, _ := token.SignedString([]byte(config.Auth.JWTSecret))
|
||||||
r.AddCookie(&http.Cookie{Name: "jwt", Value: tokenString})
|
r.AddCookie(&http.Cookie{Name: "jwt", Value: tokenString})
|
||||||
|
UpdateSession(sessionID, 123) // Add a valid session
|
||||||
|
},
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectNewCookie: true,
|
||||||
|
expectedUserID: 123,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Expired Token with Expired Session",
|
||||||
|
setupAuth: func(r *http.Request) {
|
||||||
|
sessionID := "expired-session"
|
||||||
|
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{
|
||||||
|
"user_id": 123,
|
||||||
|
"role_id": constants.Roles.Member,
|
||||||
|
"session_id": sessionID,
|
||||||
|
"exp": time.Now().Add(-time.Hour).Unix(), // Expired 1 hour ago
|
||||||
|
})
|
||||||
|
tokenString, _ := token.SignedString([]byte(config.Auth.JWTSecret))
|
||||||
|
r.AddCookie(&http.Cookie{Name: "jwt", Value: tokenString})
|
||||||
|
// Don't add a session, simulating an expired session
|
||||||
},
|
},
|
||||||
expectedStatus: http.StatusUnauthorized,
|
expectedStatus: http.StatusUnauthorized,
|
||||||
expectedUserID: 0,
|
expectedUserID: 0,
|
||||||
@@ -86,8 +116,9 @@ func TestAuthMiddleware(t *testing.T) {
|
|||||||
name: "Invalid Signature",
|
name: "Invalid Signature",
|
||||||
setupAuth: func(r *http.Request) {
|
setupAuth: func(r *http.Request) {
|
||||||
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{
|
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{
|
||||||
"user_id": 123,
|
"user_id": 123,
|
||||||
"exp": time.Now().Add(time.Hour).Unix(),
|
"session_id": "some-session",
|
||||||
|
"exp": time.Now().Add(time.Hour).Unix(),
|
||||||
})
|
})
|
||||||
tokenString, _ := token.SignedString([]byte("wrong_secret"))
|
tokenString, _ := token.SignedString([]byte("wrong_secret"))
|
||||||
r.AddCookie(&http.Cookie{Name: "jwt", Value: tokenString})
|
r.AddCookie(&http.Cookie{Name: "jwt", Value: tokenString})
|
||||||
@@ -99,8 +130,10 @@ func TestAuthMiddleware(t *testing.T) {
|
|||||||
name: "Invalid Signing Method",
|
name: "Invalid Signing Method",
|
||||||
setupAuth: func(r *http.Request) {
|
setupAuth: func(r *http.Request) {
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodES256, jwt.MapClaims{
|
token := jwt.NewWithClaims(jwt.SigningMethodES256, jwt.MapClaims{
|
||||||
"user_id": 123,
|
"user_id": 123,
|
||||||
"exp": time.Now().Add(time.Hour).Unix(),
|
"session_id": "some-session",
|
||||||
|
"role_id": constants.Roles.Member,
|
||||||
|
"exp": time.Now().Add(time.Hour).Unix(),
|
||||||
})
|
})
|
||||||
tokenString, _ := token.SignedString([]byte(config.Auth.JWTSecret))
|
tokenString, _ := token.SignedString([]byte(config.Auth.JWTSecret))
|
||||||
r.AddCookie(&http.Cookie{Name: "jwt", Value: tokenString})
|
r.AddCookie(&http.Cookie{Name: "jwt", Value: tokenString})
|
||||||
@@ -143,9 +176,13 @@ func TestAuthMiddleware(t *testing.T) {
|
|||||||
|
|
||||||
// Check if a new cookie was set
|
// Check if a new cookie was set
|
||||||
cookies := w.Result().Cookies()
|
cookies := w.Result().Cookies()
|
||||||
assert.GreaterOrEqual(t, len(cookies), 1)
|
if tt.expectNewCookie {
|
||||||
assert.Equal(t, "jwt", cookies[0].Name)
|
assert.GreaterOrEqual(t, len(cookies), 1)
|
||||||
assert.NotEmpty(t, cookies[0].Value)
|
assert.Equal(t, "jwt", cookies[0].Name)
|
||||||
|
assert.NotEmpty(t, cookies[0].Value)
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, 0, len(cookies), "Unexpected cookie set")
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
assert.Equal(t, 0, len(w.Result().Cookies()))
|
assert.Equal(t, 0, len(w.Result().Cookies()))
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user