diff --git a/internal/middlewares/auth.go b/internal/middlewares/auth.go index 7051e76..bdad303 100644 --- a/internal/middlewares/auth.go +++ b/internal/middlewares/auth.go @@ -2,68 +2,75 @@ package middlewares import ( "GoMembership/internal/config" + "GoMembership/internal/models" + "GoMembership/internal/utils" + customerrors "GoMembership/pkg/errors" "GoMembership/pkg/logger" + "errors" "fmt" "net/http" "time" "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" ) -var ( - jwtKey = []byte(config.Auth.JWTSecret) - jwtSigningMethod = jwt.SigningMethodHS256 - jwtParser = jwt.NewParser(jwt.WithValidMethods([]string{jwtSigningMethod.Alg()})) -) - -func GenerateToken(userID int64) (string, error) { - token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{ - "user_id": userID, - "exp": time.Now().Add(time.Minute * 10).Unix(), // Token expires in 10 Minutes - }) - - logger.Error.Printf("token generated: %#v", token) - return token.SignedString(jwtKey) +type Session struct { + UserID int64 + ExpiresAt time.Time } -func verifyToken(tokenString string) (*jwt.Token, error) { +var ( + sessionDuration = 5 * 24 * time.Hour + jwtSigningMethod = jwt.SigningMethodHS256 + jwtParser = jwt.NewParser(jwt.WithValidMethods([]string{jwtSigningMethod.Alg()})) + sessions = make(map[string]*Session) +) + +func verifyAndRenewToken(tokenString string) (string, int64, error) { if tokenString == "" { - return nil, fmt.Errorf("Authorization token is required") + logger.Error.Printf("empty tokenstring") + return "", -1, fmt.Errorf("Authorization token is required") } - token, err := jwtParser.Parse(tokenString, func(_ *jwt.Token) (interface{}, error) { - return jwtKey, nil - }) - + token, claims, err := ExtractContentFrom(tokenString) if err != nil { - return nil, err + logger.Error.Printf("Couldn't parse JWT token String: %v", err) + return "", -1, err } + sessionID := (*claims)["session_id"].(string) + userID := int64((*claims)["user_id"].(float64)) + roleID := int8((*claims)["role_id"].(float64)) - if !token.Valid { - return nil, fmt.Errorf("invalid token") - } - - claims, ok := token.Claims.(jwt.MapClaims) + session, ok := sessions[sessionID] if !ok { - return nil, fmt.Errorf("invalid token claims") + logger.Error.Printf("session not found") + return "", -1, fmt.Errorf("session not found") + } + if userID != session.UserID { + return "", -1, fmt.Errorf("Cookie has been altered, aborting..") + } + if token.Valid { + // token is valid, so we can return the old tokenString + return tokenString, session.UserID, customerrors.ErrValidToken } - exp, ok := claims["exp"].(float64) - if !ok { - return nil, fmt.Errorf("invalid expiration claim") + if time.Now().After(sessions[sessionID].ExpiresAt) { + delete(sessions, sessionID) + logger.Error.Printf("session expired") + return "", -1, fmt.Errorf("session expired") + } + session.ExpiresAt = time.Now().Add(sessionDuration) + + logger.Error.Printf("Session still valid generating new token") + // Session is still valid, generate a new token + user := models.User{ID: userID, RoleID: roleID} + newTokenString, err := GenerateToken(config.Auth.JWTSecret, &user, sessionID) + if err != nil { + return "", -1, err } - userID, ok := claims["user_id"].(float64) - if !ok { - logger.Error.Printf("Invalid user ID: %v", userID) - return nil, fmt.Errorf("Invalid user ID") - } - - if time.Now().Unix() > int64(exp) { - return nil, fmt.Errorf("token expired") - } - - return token, nil + return newTokenString, session.UserID, nil } func AuthMiddleware() gin.HandlerFunc { @@ -76,34 +83,89 @@ func AuthMiddleware() gin.HandlerFunc { return } - token, err := verifyToken(tokenString) + newToken, userID, err := verifyAndRenewToken(tokenString) if err != nil { - logger.Error.Printf("Token is invalid: %v\n", err) + if err == customerrors.ErrValidToken { + c.Set("user_id", int64(userID)) + c.Next() + return + } + logger.Error.Printf("Token(%v) is invalid: %v\n", tokenString, err) c.JSON(http.StatusUnauthorized, gin.H{"error": "Auth token invalid"}) c.Abort() return } - claims, _ := token.Claims.(jwt.MapClaims) - userID, _ := claims["user_id"].(float64) - // Generate a new token - newToken, err := GenerateToken(int64(userID)) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to refresh token"}) - c.Abort() - return - } - - c.SetCookie( - "jwt", - newToken, - 10*60, // 10 minutes - "/", - "", - true, - true, - ) - c.Set("user_id", userID) + utils.SetCookie(c, newToken) + c.Set("user_id", int64(userID)) c.Next() } } + +func GenerateToken(jwtKey string, user *models.User, sessionID string) (string, error) { + if sessionID == "" { + sessionID = uuid.New().String() + } + token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{ + "user_id": user.ID, + "role_id": user.RoleID, + "session_id": sessionID, + "exp": time.Now().Add(time.Minute * 1).Unix(), // Token expires in 10 Minutes + }) + UpdateSession(sessionID, user.ID) + return token.SignedString([]byte(jwtKey)) +} + +func ExtractContentFrom(tokenString string) (*jwt.Token, *jwt.MapClaims, error) { + + token, err := jwtParser.Parse(tokenString, func(_ *jwt.Token) (interface{}, error) { + return []byte(config.Auth.JWTSecret), nil + }) + + if !errors.Is(err, jwt.ErrTokenExpired) && err != nil { + logger.Error.Printf("Error during token(%v) parsing: %#v", tokenString, err) + return nil, nil, err + } + + // Token is expired, check if session is still valid + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + logger.Error.Printf("Invalid Token Claims") + return nil, nil, fmt.Errorf("invalid token claims") + } + + if !ok { + logger.Error.Printf("invalid session_id in token") + return nil, nil, fmt.Errorf("invalid session_id in token") + } + return token, &claims, nil +} + +func UpdateSession(sessionID string, userID int64) { + sessions[sessionID] = &Session{ + UserID: userID, + ExpiresAt: time.Now().Add(sessionDuration), + } +} + +func InvalidateSession(token string) (bool, error) { + claims := jwt.MapClaims{} + _, err := jwt.ParseWithClaims( + token, + claims, + func(token *jwt.Token) (interface{}, error) { + return config.Auth.JWTSecret, nil + }, + ) + if err != nil { + return false, fmt.Errorf("Couldn't get JWT claims: %#v", err) + } + + sessionID, ok := claims["session_id"].(string) + if !ok { + return false, fmt.Errorf("No SessionID found") + } + + delete(sessions, sessionID) + return true, nil +} diff --git a/internal/middlewares/auth_test.go b/internal/middlewares/auth_test.go index 05cec57..a6c9a71 100644 --- a/internal/middlewares/auth_test.go +++ b/internal/middlewares/auth_test.go @@ -2,6 +2,8 @@ package middlewares import ( "GoMembership/internal/config" + "GoMembership/internal/constants" + "GoMembership/internal/models" "GoMembership/pkg/logger" "encoding/json" "log" @@ -17,9 +19,7 @@ import ( "github.com/stretchr/testify/assert" ) -func TestAuthMiddleware(t *testing.T) { - gin.SetMode(gin.TestMode) - +func setupTestEnvironment() { cwd, err := os.Getwd() if err != nil { log.Fatalf("Failed to get current working directory: %v", err) @@ -39,17 +39,25 @@ func TestAuthMiddleware(t *testing.T) { } config.LoadConfig() logger.Info.Printf("Config: %#v", config.CFG) +} + +func TestAuthMiddleware(t *testing.T) { + gin.SetMode(gin.TestMode) + + setupTestEnvironment() tests := []struct { - name string - setupAuth func(r *http.Request) - expectedStatus int - expectedUserID int64 + name string + setupAuth func(r *http.Request) + expectedStatus int + expectNewCookie bool + expectedUserID int64 }{ { name: "Valid Token", setupAuth: func(r *http.Request) { - token, _ := GenerateToken(123) + user := models.User{ID: 123, RoleID: constants.Roles.Member} + token, _ := GenerateToken(config.Auth.JWTSecret, &user, "") r.AddCookie(&http.Cookie{Name: "jwt", Value: token}) }, expectedStatus: http.StatusOK, @@ -70,14 +78,36 @@ func TestAuthMiddleware(t *testing.T) { expectedUserID: 0, }, { - name: "Expired Token", + name: "Expired Token with Valid Session", setupAuth: func(r *http.Request) { + sessionID := "test-session" token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{ - "user_id": 123, - "exp": time.Now().Add(-time.Hour).Unix(), // Expired 1 hour ago + "user_id": 123, + "role_id": constants.Roles.Member, + "session_id": sessionID, + "exp": time.Now().Add(-time.Hour).Unix(), // Expired 1 hour ago }) - tokenString, _ := token.SignedString(jwtKey) + tokenString, _ := token.SignedString([]byte(config.Auth.JWTSecret)) r.AddCookie(&http.Cookie{Name: "jwt", Value: tokenString}) + UpdateSession(sessionID, 123) // Add a valid session + }, + expectedStatus: http.StatusOK, + expectNewCookie: true, + expectedUserID: 123, + }, + { + name: "Expired Token with Expired Session", + setupAuth: func(r *http.Request) { + sessionID := "expired-session" + token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{ + "user_id": 123, + "role_id": constants.Roles.Member, + "session_id": sessionID, + "exp": time.Now().Add(-time.Hour).Unix(), // Expired 1 hour ago + }) + tokenString, _ := token.SignedString([]byte(config.Auth.JWTSecret)) + r.AddCookie(&http.Cookie{Name: "jwt", Value: tokenString}) + // Don't add a session, simulating an expired session }, expectedStatus: http.StatusUnauthorized, expectedUserID: 0, @@ -86,8 +116,9 @@ func TestAuthMiddleware(t *testing.T) { name: "Invalid Signature", setupAuth: func(r *http.Request) { token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{ - "user_id": 123, - "exp": time.Now().Add(time.Hour).Unix(), + "user_id": 123, + "session_id": "some-session", + "exp": time.Now().Add(time.Hour).Unix(), }) tokenString, _ := token.SignedString([]byte("wrong_secret")) r.AddCookie(&http.Cookie{Name: "jwt", Value: tokenString}) @@ -99,8 +130,10 @@ func TestAuthMiddleware(t *testing.T) { name: "Invalid Signing Method", setupAuth: func(r *http.Request) { token := jwt.NewWithClaims(jwt.SigningMethodES256, jwt.MapClaims{ - "user_id": 123, - "exp": time.Now().Add(time.Hour).Unix(), + "user_id": 123, + "session_id": "some-session", + "role_id": constants.Roles.Member, + "exp": time.Now().Add(time.Hour).Unix(), }) tokenString, _ := token.SignedString([]byte(config.Auth.JWTSecret)) r.AddCookie(&http.Cookie{Name: "jwt", Value: tokenString}) @@ -143,9 +176,13 @@ func TestAuthMiddleware(t *testing.T) { // Check if a new cookie was set cookies := w.Result().Cookies() - assert.GreaterOrEqual(t, len(cookies), 1) - assert.Equal(t, "jwt", cookies[0].Name) - assert.NotEmpty(t, cookies[0].Value) + if tt.expectNewCookie { + assert.GreaterOrEqual(t, len(cookies), 1) + assert.Equal(t, "jwt", cookies[0].Name) + assert.NotEmpty(t, cookies[0].Value) + } else { + assert.Equal(t, 0, len(cookies), "Unexpected cookie set") + } } else { assert.Equal(t, 0, len(w.Result().Cookies())) }