first working server

This commit is contained in:
$(pass /github/name)
2024-07-03 09:40:45 +02:00
parent 9bd8d48243
commit 6d34d99835
20 changed files with 340 additions and 128 deletions

View File

@@ -1,38 +1,69 @@
package config
import (
"GoMembership/internal/utils"
"GoMembership/pkg/logger"
"encoding/json"
"log"
"os"
"path/filepath"
"sync"
)
type DatabaseConfig struct {
DBPath string `json:"DBPath"`
}
type Config struct {
DB DatabaseConfig `json:"db"`
type AuthenticationConfig struct {
JWTSecret string
CSRFSecret string
}
type Config struct {
DB DatabaseConfig `json:"db"`
Auth AuthenticationConfig
}
var (
pConfig Config
once sync.Once
loaded bool
)
func LoadConfig() *Config {
path, err := os.Getwd()
if err != nil {
log.Fatalf("could not get working directory: %v", err)
logger.Error.Fatalf("could not get working directory: %v", err)
}
configFile, err := os.Open(filepath.Join(path, "configs", "config.json"))
if err != nil {
log.Fatalf("could not open config file: %v", err)
logger.Error.Fatalf("could not open config file: %v", err)
}
defer configFile.Close()
decoder := json.NewDecoder(configFile)
config := &Config{}
err = decoder.Decode(config)
// pConfig = &Config{}
err = decoder.Decode(&pConfig)
if err != nil {
log.Fatalf("could not decode config file: %v", err)
logger.Error.Fatalf("could not decode config file: %v", err)
}
if !loaded {
once.Do(
func() {
csrfSecret, err := utils.GenerateRandomString(32)
if err != nil {
logger.Error.Fatalf("could not generate CSRF secret: %v", err)
}
jwtSecret, err := utils.GenerateRandomString(32)
if err != nil {
logger.Error.Fatalf("could not generate JWT secret: %v", err)
}
pConfig.Auth.JWTSecret = jwtSecret
pConfig.Auth.CSRFSecret = csrfSecret
loaded = true
})
}
return config
return &pConfig
}

View File

@@ -7,6 +7,7 @@ import (
// "github.com/gorilla/mux"
"net/http"
// "strconv"
"GoMembership/pkg/logger"
)
type UserController struct {
@@ -18,13 +19,17 @@ func NewUserController(service services.UserService) *UserController {
}
func (uc *UserController) RegisterUser(w http.ResponseWriter, r *http.Request) {
logger.Info.Println("registering user")
var user models.User
if err := json.NewDecoder(r.Body).Decode(&user); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
logger.Error.Printf("Couldn't decode Userdata: %v", err)
return
}
if err := uc.service.RegisterUser(&user); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
logger.Error.Printf("Couldn't register User: %v", err)
return
}
w.WriteHeader(http.StatusCreated)
@@ -41,8 +46,8 @@ func (uc *UserController) RegisterUser(w http.ResponseWriter, r *http.Request) {
return
}
json.NewEncoder(w).Encode(user)
} */
}
*/
/* func (uc *UserController) GetUserID(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
id, err := strconv.Atoi(vars["id"])

View File

@@ -2,8 +2,8 @@ package database
import (
"GoMembership/internal/config"
"GoMembership/pkg/logger"
"database/sql"
"log"
"os"
_ "github.com/mattn/go-sqlite3"
@@ -29,13 +29,21 @@ func initializeDB(dbPath string, schemaPath string) error {
func Connect() *sql.DB {
cfg := config.LoadConfig()
dsn := cfg.DB.DBPath
db, err := sql.Open("sqlite3", dsn)
_, err := os.Stat(cfg.DB.DBPath)
if os.IsNotExist(err) {
initErr := initializeDB(cfg.DB.DBPath, "internal/database/schema.sql")
if initErr != nil {
logger.Error.Fatalf("Couldn't create database: %v", initErr)
}
logger.Info.Println("Created new database")
}
db, err := sql.Open("sqlite3", cfg.DB.DBPath)
if err != nil {
log.Fatal(err)
logger.Error.Fatal(err)
}
if err := db.Ping(); err != nil {
log.Fatal(err)
logger.Error.Fatal(err)
}
return db
}

View File

@@ -1,13 +1,17 @@
package middlewares
import (
"net/http"
"net/http"
)
func AuthMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Authentication logic here
next.ServeHTTP(w, r)
})
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
token := r.Header.Get("Authorization")
if token != "your-secret-token" {
http.Error(w, "Forbidden", http.StatusForbidden)
return
}
next.ServeHTTP(w, r)
})
}

View File

@@ -0,0 +1,114 @@
package middlewares
import (
"GoMembership/internal/config"
"GoMembership/internal/utils"
"GoMembership/pkg/logger"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"net/http"
"strings"
)
// GenerateCSRFToken generates HMAC-signed CSRF token
func GenerateCSRFToken(sessionID string, secretKey string) string {
// Create message to be signed (e.g., combining sessionID with some random value)
randomString, err := utils.GenerateRandomString(8)
if err != nil {
logger.Error.Fatalf("Could not create random string: %v", err)
return ""
}
message := sessionID + "!" + randomString
// Create HMAC hash using SHA-256
h := hmac.New(sha256.New, []byte(secretKey))
h.Write([]byte(message))
signature := h.Sum(nil)
// Encode signature and message into a CSRF token
csrfToken := base64.StdEncoding.EncodeToString(signature) + "." + message
return csrfToken
}
func ComputeHMAC(message string, secretKey string) []byte {
h := hmac.New(sha256.New, []byte(secretKey))
h.Write([]byte(message))
return h.Sum(nil)
}
// CSRFMiddleware verifies HMAC-signed CSRF token
func CSRFMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodGet || r.Method == http.MethodHead || r.Method == http.MethodOptions {
next.ServeHTTP(w, r)
return
}
csrfSecret := config.LoadConfig().Auth.CSRFSecret
// Retrieve CSRF token from request (e.g., from cookie, header, or form data)
csrfToken := r.Header.Get("X-CSRF-Token")
// Extract signature and message from CSRF token
parts := strings.SplitN(csrfToken, ".", 2)
if len(parts) != 2 {
http.Error(w, "Invalid CSRF token", http.StatusForbidden)
return
}
receivedSignature := parts[0]
receivedMessage := parts[1]
// Compute HMAC using the received message and the CSRF secret key
computedSignature := ComputeHMAC(receivedMessage, csrfSecret)
// Compare computed HMAC with received signature
if !hmac.Equal([]byte(receivedSignature), computedSignature) {
http.Error(w, "CSRF Token validation failed", http.StatusForbidden)
return
}
// CSRF token is valid, proceed to the next handler
next.ServeHTTP(w, r)
})
}
func GenerateCSRFTokenHandler(w http.ResponseWriter, r *http.Request) {
// Simulate getting session ID from authenticated session
sessionID := "exampleSessionID123"
// Generate HMAC-signed CSRF token
csrfToken := GenerateCSRFToken(sessionID, config.LoadConfig().Auth.CSRFSecret)
// Set CSRF token in a cookie (example)
http.SetCookie(w, &http.Cookie{
Name: "csrf_token",
Value: csrfToken,
Path: "/",
HttpOnly: true,
Secure: true,
})
}
/* func GenerateCSRFTokenHandler() http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
token, err := GenerateCSRFToken()
if err != nil {
http.Error(w, "Could not generate CSRF token", http.StatusInternalServerError)
return
}
// Set CSRF token in cookie
http.SetCookie(w, &http.Cookie{
Name: "csrf_token",
Value: token,
Path: "/",
})
logger.Info.Printf("generated token: %v", token)
// Return CSRF token in response
w.Header().Set("X-CSRF-Token", token)
w.WriteHeader(http.StatusOK)
})
} */

View File

@@ -0,0 +1,17 @@
package middlewares
import (
"GoMembership/pkg/logger"
"net/http"
"time"
)
// LoggerMiddleware logs each incoming HTTP request
func LoggerMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
next.ServeHTTP(w, r)
logger.Info.Printf("%s %s %s", r.Method, r.RequestURI, time.Since(start))
})
}

View File

@@ -9,7 +9,8 @@ type User struct {
FirstName string `json:"first_name"`
LastName string `json:"last_name"`
Email string `json:"email"`
Password string `json:"password"`
Password string `json:"-"`
Salt string `json:"-"`
IBAN string `json:"iban"`
BIC string `json:"bic"`
MandateReference string `json:"mandate_reference"`

View File

@@ -9,7 +9,7 @@ import (
type UserRepository interface {
CreateUser(user *models.User) error
FindUserByID(id int) (*models.User, error)
// FindUserByEmail(email string) (*models.User, error)
FindUserByEmail(email string) (*models.User, error)
}
type userRepository struct {
@@ -20,16 +20,29 @@ func NewUserRepository(db *sql.DB) UserRepository {
return &userRepository{db}
}
func (r *userRepository) CreateUser(user *models.User) error {
query := "INSERT INTO users (first_name, last_name, email, password, iban, bic, mandate_reference, mandate_date_signed, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"
_, err := r.db.Exec(query, user.FirstName, user.LastName, user.Email, user.Password, user.CreatedAt, user.UpdatedAt)
func (repo *userRepository) CreateUser(user *models.User) error {
query := "INSERT INTO users (first_name, last_name, email, password, salt, iban, bic, mandate_reference, mandate_date_signed, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"
_, err := repo.db.Exec(query, user.FirstName, user.LastName, user.Email, user.Password, user.Salt, user.IBAN, user.BIC, user.MandateReference, user.MandateDateSigned, user.CreatedAt, user.UpdatedAt)
return err
}
func (r *userRepository) FindUserByID(id int) (*models.User, error) {
func (repo *userRepository) FindUserByID(id int) (*models.User, error) {
var user models.User
query := "SELECT id, first_name, last_name, email, iban, bic, mandate_reference FROM users WHERE id = ?"
err := r.db.QueryRow(query, id).Scan(&user.ID, &user.FirstName, &user.LastName, &user.Email, &user.IBAN, &user.BIC, &user.MandateReference)
err := repo.db.QueryRow(query, id).Scan(&user.ID, &user.FirstName, &user.LastName, &user.Email, &user.IBAN, &user.BIC, &user.MandateReference)
if err != nil {
if err == sql.ErrNoRows {
return nil, errors.ErrUserNotFound
}
return nil, err
}
return &user, nil
}
func (repo *userRepository) FindUserByEmail(email string) (*models.User, error) {
var user models.User
query := "SELECT id, first_name, last_name, email, iban, bic, mandate_reference FROM users WHERE email = ?"
err := repo.db.QueryRow(query, email).Scan(&user.ID, &user.FirstName, &user.LastName, &user.Email, &user.IBAN, &user.BIC, &user.MandateReference)
if err != nil {
if err == sql.ErrNoRows {
return nil, errors.ErrUserNotFound

View File

@@ -1,13 +1,17 @@
package routes
import (
"net/http"
"GoMembership/internal/controllers"
"GoMembership/internal/controllers"
// "GoMembership/internal/middlewares"
"GoMembership/pkg/logger"
// "net/http"
"github.com/gorilla/mux"
"github.com/gorilla/mux"
)
func RegisterRoutes(router *mux.Router, userController *controllers.UserController) {
router.HandleFunc("/register", userController.RegisterUser).Methods("POST")
}
logger.Info.Println("Registering /register route")
router.HandleFunc("/register", userController.RegisterUser).Methods("POST")
// router.HandleFunc("/login", userController.LoginUser).Methods("POST")
}

View File

@@ -1,20 +1,19 @@
package server
import (
// "GoMembership/internal/config"
"GoMembership/internal/controllers"
"GoMembership/internal/database"
"GoMembership/internal/middlewares"
"GoMembership/internal/repositories"
"GoMembership/internal/routes"
"GoMembership/internal/services"
"log"
"GoMembership/pkg/logger"
"net/http"
"github.com/gorilla/mux"
)
func Run() {
// cfg := config.LoadConfig()
db := database.Connect()
defer db.Close()
@@ -23,10 +22,20 @@ func Run() {
userController := controllers.NewUserController(userService)
router := mux.NewRouter()
routes.RegisterRoutes(router, userController)
// router.Handle("/csrf-token", middlewares.GenerateCSRFTokenHandler()).Methods("GET")
log.Println("Starting server on :8080")
// Apply CSRF middleware
// router.Use(middlewares.CSRFMiddleware)
router.Use(middlewares.LoggerMiddleware)
routes.RegisterRoutes(router, userController)
// create subrouter for teh authenticated area /account
// also pthprefix matches everything below /account
// accountRouter := router.PathPrefix("/account").Subrouter()
// accountRouter.Use(middlewares.AuthMiddleware)
logger.Info.Println("Starting server on :8080")
if err := http.ListenAndServe(":8080", router); err != nil {
log.Fatalf("could not start server: %v", err)
logger.Error.Fatalf("could not start server: %v", err)
}
}

View File

@@ -3,12 +3,16 @@ package services
import (
"GoMembership/internal/models"
"GoMembership/internal/repositories"
// "GoMembership/pkg/errors"
"crypto/rand"
"encoding/base64"
"golang.org/x/crypto/bcrypt"
"time"
)
type UserService interface {
RegisterUser(user *models.User) error
// AuthenticateUser(email, password string) (*models.User, error)
}
type userService struct {
@@ -19,8 +23,14 @@ func NewUserService(repo repositories.UserRepository) UserService {
return &userService{repo}
}
func (s *userService) RegisterUser(user *models.User) error {
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(user.Password), bcrypt.DefaultCost)
func (service *userService) RegisterUser(user *models.User) error {
salt := make([]byte, 16)
if _, err := rand.Read(salt); err != nil {
return err
}
user.Salt = base64.StdEncoding.EncodeToString(salt)
hashedPassword, err := HashPassword(user.Password, user.Salt)
if err != nil {
return err
}
@@ -28,5 +38,38 @@ func (s *userService) RegisterUser(user *models.User) error {
user.CreatedAt = time.Now()
user.UpdatedAt = time.Now()
user.MandateDateSigned = time.Now()
return s.repo.CreateUser(user)
return service.repo.CreateUser(user)
}
func HashPassword(password string, salt string) (string, error) {
saltedPassword := password + salt
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(saltedPassword), bcrypt.DefaultCost)
if err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(hashedPassword), nil
}
/* func (s *userService) AuthenticateUser(email, password string) (*models.User, error) {
user, err := s.repo.FindUserByEmail(email)
if err != nil {
return nil, errors.ErrUserNotFound
}
if !verifyPassword(password, user.Password, user.Salt) {
return nil, errors.ErrInvalidCredentials
}
return user, nil
}
*/
/* func verifyPassword(password string, storedPassword string, salt string) bool {
saltedPassword := password + salt
decodedStoredPassword, err := base64.StdEncoding.DecodeString(storedPassword)
if err != nil {
return false
}
err = bcrypt.CompareHashAndPassword([]byte(decodedStoredPassword), []byte(saltedPassword))
return err == nil
} */

15
internal/utils/crypto.go Normal file
View File

@@ -0,0 +1,15 @@
package utils
import (
"crypto/rand"
"encoding/base64"
)
func GenerateRandomString(length int) (string, error) {
bytes := make([]byte, length)
_, err := rand.Read(bytes)
if err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(bytes), nil
}