chg: auth handling to jwt cookies
This commit is contained in:
@@ -3,8 +3,8 @@ package middlewares
|
||||
import (
|
||||
"GoMembership/internal/config"
|
||||
"GoMembership/pkg/logger"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -20,9 +20,10 @@ var (
|
||||
func GenerateToken(userID int64) (string, error) {
|
||||
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{
|
||||
"user_id": userID,
|
||||
"exp": time.Now().Add(time.Minute * 15).Unix(), // Token expires in 15 Minutes
|
||||
"exp": time.Now().Add(time.Minute * 10).Unix(), // Token expires in 10 Minutes
|
||||
})
|
||||
|
||||
logger.Error.Printf("token generated: %#v", token)
|
||||
return token.SignedString(jwtKey)
|
||||
}
|
||||
|
||||
@@ -37,45 +38,71 @@ func verifyToken(tokenString string) (*jwt.Token, error) {
|
||||
|
||||
return token, nil
|
||||
}
|
||||
func GetUserIDFromContext(c *gin.Context) (int64, error) {
|
||||
|
||||
tokenString, err := c.Cookie("jwt")
|
||||
if err != nil {
|
||||
|
||||
logger.Error.Printf("Error getting cookie: %v\n", err)
|
||||
return 0, err
|
||||
}
|
||||
if tokenString == "" {
|
||||
logger.Error.Printf("Token is empty: %v\n", err)
|
||||
return 0, fmt.Errorf("Authorization token is required")
|
||||
}
|
||||
|
||||
token, err := verifyToken(tokenString)
|
||||
if err != nil || !token.Valid {
|
||||
|
||||
logger.Error.Printf("Token is invalid: %v\n", err)
|
||||
return 0, fmt.Errorf("Token not valid!")
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
|
||||
logger.Error.Printf("claims userid: %v", claims["user_id"].(float64))
|
||||
if !ok {
|
||||
logger.Error.Printf("Invalid Token claims")
|
||||
return 0, fmt.Errorf("Invalid token claims")
|
||||
}
|
||||
userID, ok := claims["user_id"].(float64)
|
||||
|
||||
if !ok {
|
||||
logger.Error.Printf("Invalid user ID: %v", userID)
|
||||
return 0, fmt.Errorf("Invalid user ID")
|
||||
}
|
||||
|
||||
return int64(userID), nil
|
||||
}
|
||||
|
||||
func AuthMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Authorization header is required"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
bearerToken := strings.Split(authHeader, " ")
|
||||
if len(bearerToken) != 2 {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token format"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
tokenString := bearerToken[1]
|
||||
|
||||
token, err := verifyToken(tokenString)
|
||||
userID, err := GetUserIDFromContext(c)
|
||||
if err != nil {
|
||||
if err == jwt.ErrTokenSignatureInvalid {
|
||||
logger.Error.Printf("JWT NULL ATTACK: %#v", err)
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token signing method"})
|
||||
} else {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"})
|
||||
}
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
|
||||
userID := claims["user_id"].(string)
|
||||
c.Set("user_id", userID)
|
||||
c.Next()
|
||||
} else {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token claims"})
|
||||
// Generate a new token
|
||||
newToken, err := GenerateToken(int64(userID))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to refresh token"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.SetCookie(
|
||||
"jwt",
|
||||
newToken,
|
||||
10*60, // 10 minutes
|
||||
"/",
|
||||
"",
|
||||
true,
|
||||
true,
|
||||
)
|
||||
c.Set("user_id", userID)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,11 +24,11 @@ func TestAuthMiddleware(t *testing.T) {
|
||||
{
|
||||
name: "Valid Token",
|
||||
setupAuth: func(r *http.Request) {
|
||||
token, _ := GenerateToken("user123")
|
||||
r.Header.Set("Authorization", "Bearer "+token)
|
||||
token, _ := GenerateToken(123)
|
||||
r.AddCookie(&http.Cookie{Name: "jwt", Value: token})
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedUserID: 12,
|
||||
expectedUserID: 123,
|
||||
},
|
||||
{
|
||||
name: "Missing Auth Header",
|
||||
@@ -52,7 +52,7 @@ func TestAuthMiddleware(t *testing.T) {
|
||||
"exp": time.Now().Add(-time.Hour).Unix(), // Expired 1 hour ago
|
||||
})
|
||||
tokenString, _ := token.SignedString(jwtKey)
|
||||
r.Header.Set("Authorization", "Bearer "+tokenString)
|
||||
r.AddCookie(&http.Cookie{Name: "jwt", Value: tokenString})
|
||||
},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedUserID: 0,
|
||||
@@ -65,17 +65,11 @@ func TestAuthMiddleware(t *testing.T) {
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
})
|
||||
tokenString, _ := token.SignedString([]byte("wrong_secret"))
|
||||
r.Header.Set("Authorization", "Bearer "+tokenString)
|
||||
r.AddCookie(&http.Cookie{Name: "jwt", Value: tokenString})
|
||||
},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedUserID: 0,
|
||||
},
|
||||
{
|
||||
name: "Missing Auth Header",
|
||||
setupAuth: func(r *http.Request) {},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedUserID: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -100,11 +94,20 @@ func TestAuthMiddleware(t *testing.T) {
|
||||
|
||||
assert.Equal(t, tt.expectedStatus, w.Code)
|
||||
|
||||
var response map[string]string
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
assert.NoError(t, err)
|
||||
if tt.expectedStatus == http.StatusOK {
|
||||
var response map[string]int64
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expectedUserID, response["user_id"])
|
||||
|
||||
assert.Equal(t, tt.expectedUserID, response["user_id"])
|
||||
// Check if a new cookie was set
|
||||
cookies := w.Result().Cookies()
|
||||
assert.GreaterOrEqual(t, len(cookies), 1)
|
||||
assert.Equal(t, "jwt", cookies[0].Name)
|
||||
assert.NotEmpty(t, cookies[0].Value)
|
||||
} else {
|
||||
assert.Equal(t, 0, len(w.Result().Cookies()))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user