add: Login system

This commit is contained in:
$(pass /github/name)
2024-09-03 20:20:24 +02:00
parent f648b53fe1
commit 569c0acaee
8 changed files with 189 additions and 44 deletions

View File

@@ -6,7 +6,6 @@ import (
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"GoMembership/internal/config"
"GoMembership/internal/models" "GoMembership/internal/models"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -64,7 +63,6 @@ func validateSubscription(assert bool, wantDBData map[string]interface{}) error
func getBaseSubscription() MembershipData { func getBaseSubscription() MembershipData {
return MembershipData{ return MembershipData{
APIKey: config.Auth.APIKEY,
Model: models.SubscriptionModel{ Model: models.SubscriptionModel{
Name: "Just a Subscription", Name: "Just a Subscription",
Details: "A subscription detail", Details: "A subscription detail",
@@ -80,28 +78,6 @@ func customizeSubscription(customize func(MembershipData) MembershipData) Member
func getSubscriptionData() []RegisterSubscriptionTest { func getSubscriptionData() []RegisterSubscriptionTest {
return []RegisterSubscriptionTest{ return []RegisterSubscriptionTest{
{
Name: "No API Key should fail",
WantResponse: http.StatusUnauthorized,
WantDBData: map[string]interface{}{"name": "Just a Subscription"},
Assert: false,
Input: GenerateInputJSON(
customizeSubscription(func(subscription MembershipData) MembershipData {
subscription.APIKey = ""
return subscription
})),
},
{
Name: "Wrong API Key should fail",
WantResponse: http.StatusUnauthorized,
WantDBData: map[string]interface{}{"name": "Just a Subscription"},
Assert: false,
Input: GenerateInputJSON(
customizeSubscription(func(subscription MembershipData) MembershipData {
subscription.APIKey = "alskfdlkjsfjk23-dF"
return subscription
})),
},
{ {
Name: "No Details should fail", Name: "No Details should fail",
WantResponse: http.StatusNotAcceptable, WantResponse: http.StatusNotAcceptable,

View File

@@ -3,6 +3,8 @@ package controllers
import ( import (
"fmt" "fmt"
"GoMembership/internal/constants"
"GoMembership/internal/middlewares"
"GoMembership/internal/models" "GoMembership/internal/models"
"GoMembership/internal/services" "GoMembership/internal/services"
@@ -25,6 +27,51 @@ type RegistrationData struct {
User models.User `json:"user"` User models.User `json:"user"`
} }
func (uc *UserController) LoginUser(c *gin.Context) {
var input struct {
Email string `json:"email"`
Password string `json:"password"`
}
if err := c.ShouldBindJSON(&input); err != nil {
logger.Error.Printf("Couldn't decode input: %v", err.Error())
c.JSON(http.StatusBadRequest, gin.H{"error": "Couldn't decode request data"})
return
}
user, err := uc.Service.GetUserByEmail(input.Email)
if err != nil {
logger.Error.Printf("Error during user(%v) retrieval: %v\n", input.Email, err)
c.JSON(http.StatusNotFound, gin.H{"error": "Couldn't find user"})
return
}
ok, err := user.PasswordMatches(input.Password)
if err != nil {
logger.Error.Printf("Error during Password comparison: %v", err.Error())
c.JSON(http.StatusInternalServerError, gin.H{"error": "couldn't calculate match"})
return
}
if !ok {
logger.Error.Printf("Wrong Password: %v %v", user.FirstName, user.LastName)
c.JSON(http.StatusNotAcceptable, gin.H{"error": "Wrong Password"})
return
}
token, err := middlewares.GenerateToken(user.ID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate JWT token"})
return
}
c.JSON(http.StatusOK, gin.H{
"message": "Login successful",
"token": token,
})
}
func (uc *UserController) RegisterUser(c *gin.Context) { func (uc *UserController) RegisterUser(c *gin.Context) {
var regData RegistrationData var regData RegistrationData
@@ -48,6 +95,9 @@ func (uc *UserController) RegisterUser(c *gin.Context) {
} }
regData.User.Membership.SubscriptionModel = *selectedModel regData.User.Membership.SubscriptionModel = *selectedModel
// logger.Info.Printf("REGISTERING user: %#v", regData.User) // logger.Info.Printf("REGISTERING user: %#v", regData.User)
regData.User.RoleID = constants.Roles.Member
// Register User // Register User
id, token, err := uc.Service.RegisterUser(&regData.User) id, token, err := uc.Service.RegisterUser(&regData.User)
if err != nil { if err != nil {
@@ -93,8 +143,8 @@ func (uc *UserController) RegisterUser(c *gin.Context) {
// Proceed without returning error since user registration is successful // Proceed without returning error since user registration is successful
} }
c.JSON(http.StatusCreated, gin.H{ c.JSON(http.StatusCreated, gin.H{
"status": "success", "message": "Registration successuful",
"id": regData.User.ID, "id": regData.User.ID,
}) })
} }

View File

@@ -1,6 +1,7 @@
package controllers package controllers
import ( import (
"encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@@ -12,6 +13,7 @@ import (
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"GoMembership/internal/config" "GoMembership/internal/config"
"GoMembership/internal/constants" "GoMembership/internal/constants"
@@ -57,8 +59,74 @@ func TestUserController(t *testing.T) {
} }
}) })
} }
testLoginUser(t)
} }
func testLoginUser(t *testing.T) {
// This test should run after the user registration test
t.Run("LoginUser", func(t *testing.T) {
// Test cases
tests := []struct {
name string
input string
wantStatusCode int
wantToken bool
}{
{
name: "Valid login",
input: `{
"email": "john.doe@example.com",
"password": "password123"
}`,
wantStatusCode: http.StatusOK,
wantToken: true,
},
{
name: "Invalid email",
input: `{
"email": "nonexistent@example.com",
"password": "password123"
}`,
wantStatusCode: http.StatusNotFound,
wantToken: false,
},
{
name: "Invalid password",
input: `{
"email": "john.doe@example.com",
"password": "wrongpassword"
}`,
wantStatusCode: http.StatusNotAcceptable,
wantToken: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Setup
c, w, _ := GetMockedJSONContext([]byte(tt.input), "/login")
// Execute
Uc.LoginUser(c)
// Assert
assert.Equal(t, tt.wantStatusCode, w.Code)
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
assert.NoError(t, err)
if tt.wantToken {
logger.Info.Printf("Response: %#v", response)
assert.Contains(t, response, "token")
assert.NotEmpty(t, response["token"])
} else {
assert.NotContains(t, response, "token")
}
})
}
})
}
func validateUser(assert bool, wantDBData map[string]interface{}) error { func validateUser(assert bool, wantDBData map[string]interface{}) error {
users, err := Uc.Service.GetUsers(wantDBData) users, err := Uc.Service.GetUsers(wantDBData)
if err != nil { if err != nil {
@@ -223,7 +291,7 @@ func verifyMail(verificationURL string) error {
router := gin.New() router := gin.New()
router.LoadHTMLGlob(filepath.Join(config.Templates.HTMLPath, "*")) router.LoadHTMLGlob(filepath.Join(config.Templates.HTMLPath, "*"))
router.GET("/backend/verify", Uc.VerifyMailHandler) router.GET("/verify", Uc.VerifyMailHandler)
wv := httptest.NewRecorder() wv := httptest.NewRecorder()
cv, _ := gin.CreateTestContext(wv) cv, _ := gin.CreateTestContext(wv)
var err error var err error

View File

@@ -17,7 +17,7 @@ var (
jwtParser = jwt.NewParser(jwt.WithValidMethods([]string{jwtSigningMethod.Alg()})) jwtParser = jwt.NewParser(jwt.WithValidMethods([]string{jwtSigningMethod.Alg()}))
) )
func GenerateToken(userID string) (string, error) { func GenerateToken(userID int64) (string, error) {
token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{ token := jwt.NewWithClaims(jwtSigningMethod, jwt.MapClaims{
"user_id": userID, "user_id": userID,
"exp": time.Now().Add(time.Minute * 15).Unix(), // Token expires in 15 Minutes "exp": time.Now().Add(time.Minute * 15).Unix(), // Token expires in 15 Minutes

View File

@@ -19,7 +19,7 @@ func TestAuthMiddleware(t *testing.T) {
name string name string
setupAuth func(r *http.Request) setupAuth func(r *http.Request)
expectedStatus int expectedStatus int
expectedUserID string expectedUserID int64
}{ }{
{ {
name: "Valid Token", name: "Valid Token",
@@ -28,13 +28,13 @@ func TestAuthMiddleware(t *testing.T) {
r.Header.Set("Authorization", "Bearer "+token) r.Header.Set("Authorization", "Bearer "+token)
}, },
expectedStatus: http.StatusOK, expectedStatus: http.StatusOK,
expectedUserID: "user123", expectedUserID: 12,
}, },
{ {
name: "Missing Auth Header", name: "Missing Auth Header",
setupAuth: func(r *http.Request) {}, setupAuth: func(r *http.Request) {},
expectedStatus: http.StatusUnauthorized, expectedStatus: http.StatusUnauthorized,
expectedUserID: "", expectedUserID: 0,
}, },
{ {
name: "Invalid Token Format", name: "Invalid Token Format",
@@ -42,7 +42,7 @@ func TestAuthMiddleware(t *testing.T) {
r.Header.Set("Authorization", "InvalidFormat") r.Header.Set("Authorization", "InvalidFormat")
}, },
expectedStatus: http.StatusUnauthorized, expectedStatus: http.StatusUnauthorized,
expectedUserID: "", expectedUserID: 0,
}, },
{ {
name: "Expired Token", name: "Expired Token",
@@ -55,7 +55,7 @@ func TestAuthMiddleware(t *testing.T) {
r.Header.Set("Authorization", "Bearer "+tokenString) r.Header.Set("Authorization", "Bearer "+tokenString)
}, },
expectedStatus: http.StatusUnauthorized, expectedStatus: http.StatusUnauthorized,
expectedUserID: "", expectedUserID: 0,
}, },
{ {
name: "Invalid Signature", name: "Invalid Signature",
@@ -68,13 +68,13 @@ func TestAuthMiddleware(t *testing.T) {
r.Header.Set("Authorization", "Bearer "+tokenString) r.Header.Set("Authorization", "Bearer "+tokenString)
}, },
expectedStatus: http.StatusUnauthorized, expectedStatus: http.StatusUnauthorized,
expectedUserID: "", expectedUserID: 0,
}, },
{ {
name: "Missing Auth Header", name: "Missing Auth Header",
setupAuth: func(r *http.Request) {}, setupAuth: func(r *http.Request) {},
expectedStatus: http.StatusUnauthorized, expectedStatus: http.StatusUnauthorized,
expectedUserID: "", expectedUserID: 0,
}, },
} }
@@ -88,7 +88,7 @@ func TestAuthMiddleware(t *testing.T) {
if exists { if exists {
c.JSON(http.StatusOK, gin.H{"user_id": userID}) c.JSON(http.StatusOK, gin.H{"user_id": userID})
} else { } else {
c.JSON(http.StatusUnauthorized, gin.H{"user_id": ""}) c.JSON(http.StatusUnauthorized, gin.H{"user_id": 0})
} }
}) })

View File

@@ -2,20 +2,21 @@ package models
import ( import (
"GoMembership/internal/constants" "GoMembership/internal/constants"
"gorm.io/gorm"
"time" "time"
"github.com/alexedwards/argon2id"
"gorm.io/gorm"
) )
type User struct { type User struct {
UpdatedAt time.Time UpdatedAt time.Time
DateOfBirth time.Time `gorm:"not null" json:"date_of_birth" validate:"required,age"` DateOfBirth time.Time `gorm:"not null" json:"date_of_birth" validate:"required,age"`
CreatedAt time.Time CreatedAt time.Time
Salt *string `json:"-"`
Company string `json:"company" validate:"omitempty,omitnil"` Company string `json:"company" validate:"omitempty,omitnil"`
Phone string `json:"phone" validate:"omitempty,omitnil"` Phone string `json:"phone" validate:"omitempty,omitnil"`
Notes *string `json:"notes"` Notes *string `json:"notes"`
FirstName string `gorm:"not null" json:"first_name" validate:"required"` FirstName string `gorm:"not null" json:"first_name" validate:"required"`
Password string `json:"password"` Password string `json:"password" required_unless=RoleID 0`
Email string `gorm:"unique;not null" json:"email" validate:"required,email"` Email string `gorm:"unique;not null" json:"email" validate:"required,email"`
LastName string `gorm:"not null" json:"last_name" validate:"required"` LastName string `gorm:"not null" json:"last_name" validate:"required"`
ProfilePicture string `json:"profile_picture" validate:"omitempty,omitnil,image"` ProfilePicture string `json:"profile_picture" validate:"omitempty,omitnil,image"`
@@ -40,3 +41,12 @@ func (u *User) BeforeCreate(tx *gorm.DB) (err error) {
} }
return return
} }
func (u *User) PasswordMatches(plaintextPassword string) (bool, error) {
match, err := argon2id.ComparePasswordAndHash(plaintextPassword, u.Password)
if err != nil {
return false, err
}
return match, nil
}

View File

@@ -18,8 +18,8 @@ type UserRepositoryInterface interface {
CreateUser(user *models.User) (int64, error) CreateUser(user *models.User) (int64, error)
UpdateUser(userID int64, user *models.User) error UpdateUser(userID int64, user *models.User) error
GetUsers(where map[string]interface{}) (*[]models.User, error) GetUsers(where map[string]interface{}) (*[]models.User, error)
FindUserByID(id int64) (*models.User, error) GetUserByID(id int64) (*models.User, error)
FindUserByEmail(email string) (*models.User, error) GetUserByEmail(email string) (*models.User, error)
SetVerificationToken(user *models.User, token *string) (int64, error) SetVerificationToken(user *models.User, token *string) (int64, error)
IsVerified(userID *int64) (bool, error) IsVerified(userID *int64) (bool, error)
VerifyUserOfToken(token *string) (*models.User, error) VerifyUserOfToken(token *string) (*models.User, error)
@@ -70,7 +70,7 @@ func (ur *UserRepository) GetUsers(where map[string]interface{}) (*[]models.User
return &users, nil return &users, nil
} }
func (ur *UserRepository) FindUserByID(id int64) (*models.User, error) { func (ur *UserRepository) GetUserByID(id int64) (*models.User, error) {
var user models.User var user models.User
result := database.DB. result := database.DB.
Preload("Consents"). Preload("Consents").
@@ -88,7 +88,7 @@ func (ur *UserRepository) FindUserByID(id int64) (*models.User, error) {
return &user, nil return &user, nil
} }
func (ur *UserRepository) FindUserByEmail(email string) (*models.User, error) { func (ur *UserRepository) GetUserByEmail(email string) (*models.User, error) {
var user models.User var user models.User
result := database.DB.Where("email = ?", email).First(&user) result := database.DB.Where("email = ?", email).First(&user)
if result.Error != nil { if result.Error != nil {
@@ -127,7 +127,7 @@ func (ur *UserRepository) VerifyUserOfToken(token *string) (*models.User, error)
if err != nil { if err != nil {
return nil, err return nil, err
} }
user, err := ur.FindUserByID(emailVerification.UserID) user, err := ur.GetUserByID(emailVerification.UserID)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -69,6 +69,47 @@ func (service *UserService) RegisterUser(user *models.User) (int64, string, erro
func (service *UserService) FindUserByEmail(email string) (*models.User, error) { func (service *UserService) FindUserByEmail(email string) (*models.User, error) {
return service.Repo.FindUserByEmail(email) return service.Repo.FindUserByEmail(email)
func (service *UserService) Update(user *models.User) (int64, string, error) {
if err := validateRegistrationData(user); err != nil {
return http.StatusNotAcceptable, "", err
}
if user.Password == "" && user.RoleID != constants.Roles.Member {
return http.StatusNotAcceptable, "", fmt.Errorf("No password provided")
}
hash, err := utils.HashPassword(user.Password)
if err != nil {
return http.StatusInternalServerError, "", err
}
user.Password = hash
user.Status = constants.UnverifiedStatus
user.CreatedAt = time.Now()
user.UpdatedAt = time.Now()
id, err := service.Repo.CreateUser(user)
if err != nil && strings.Contains(err.Error(), "UNIQUE constraint failed") {
return http.StatusConflict, "", err
} else if err != nil {
return http.StatusInternalServerError, "", err
}
user.ID = id
token, err := utils.GenerateVerificationToken()
if err != nil {
return http.StatusInternalServerError, "", err
}
logger.Info.Printf("TOKEN: %v", token)
_, err = service.Repo.SetVerificationToken(user, &token)
if err != nil {
return http.StatusInternalServerError, "", err
}
return id, token, nil
} }
func (service *UserService) GetUsers(where map[string]interface{}) (*[]models.User, error) { func (service *UserService) GetUsers(where map[string]interface{}) (*[]models.User, error) {