diff --git a/.gitignore b/.gitignore index 3a9c6c6..6c2fc61 100644 --- a/.gitignore +++ b/.gitignore @@ -41,8 +41,9 @@ go.work !README.md !LICENSE +# all template files: +!*.template* # !Makefile # ...even if they are in subdirectories !*/ - diff --git a/configs/config.template.json b/configs/config.template.json new file mode 100644 index 0000000..c1d3346 --- /dev/null +++ b/configs/config.template.json @@ -0,0 +1,13 @@ +{ + "db": + { + "DBPath": "data/db.sqlite3" + }, + "smtp": { + "server": "mail.server.com", + "user": "username", + "password": "password", + "port": 465, + "mailtype": "html" + } +} diff --git a/go.mod b/go.mod index 20a147f..54319cd 100644 --- a/go.mod +++ b/go.mod @@ -3,10 +3,8 @@ module GoMembership go 1.22.2 require ( - github.com/go-sql-driver/mysql v1.8.1 + github.com/dgrijalva/jwt-go v3.2.0+incompatible github.com/gorilla/mux v1.8.1 github.com/mattn/go-sqlite3 v1.14.22 golang.org/x/crypto v0.24.0 ) - -require filippo.io/edwards25519 v1.1.0 // indirect diff --git a/go.sum b/go.sum index 2708fe6..6a11972 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,5 @@ -filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= -filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= -github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= -github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= +github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM= +github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= diff --git a/internal/config/config.go b/internal/config/config.go index b47def0..c135c4c 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,38 +1,69 @@ package config import ( + "GoMembership/internal/utils" + "GoMembership/pkg/logger" "encoding/json" - "log" "os" "path/filepath" + "sync" ) type DatabaseConfig struct { DBPath string `json:"DBPath"` } -type Config struct { - DB DatabaseConfig `json:"db"` +type AuthenticationConfig struct { + JWTSecret string + CSRFSecret string } +type Config struct { + DB DatabaseConfig `json:"db"` + Auth AuthenticationConfig +} + +var ( + pConfig Config + once sync.Once + loaded bool +) + func LoadConfig() *Config { path, err := os.Getwd() if err != nil { - log.Fatalf("could not get working directory: %v", err) + logger.Error.Fatalf("could not get working directory: %v", err) } configFile, err := os.Open(filepath.Join(path, "configs", "config.json")) if err != nil { - log.Fatalf("could not open config file: %v", err) + logger.Error.Fatalf("could not open config file: %v", err) } defer configFile.Close() decoder := json.NewDecoder(configFile) - config := &Config{} - err = decoder.Decode(config) + // pConfig = &Config{} + err = decoder.Decode(&pConfig) if err != nil { - log.Fatalf("could not decode config file: %v", err) + logger.Error.Fatalf("could not decode config file: %v", err) + } + if !loaded { + once.Do( + func() { + csrfSecret, err := utils.GenerateRandomString(32) + if err != nil { + logger.Error.Fatalf("could not generate CSRF secret: %v", err) + } + + jwtSecret, err := utils.GenerateRandomString(32) + if err != nil { + logger.Error.Fatalf("could not generate JWT secret: %v", err) + } + pConfig.Auth.JWTSecret = jwtSecret + pConfig.Auth.CSRFSecret = csrfSecret + loaded = true + }) } - return config + return &pConfig } diff --git a/internal/controllers/user_controller.go b/internal/controllers/user_controller.go index 3ef74d5..70b3d57 100644 --- a/internal/controllers/user_controller.go +++ b/internal/controllers/user_controller.go @@ -7,6 +7,7 @@ import ( // "github.com/gorilla/mux" "net/http" // "strconv" + "GoMembership/pkg/logger" ) type UserController struct { @@ -18,13 +19,17 @@ func NewUserController(service services.UserService) *UserController { } func (uc *UserController) RegisterUser(w http.ResponseWriter, r *http.Request) { + + logger.Info.Println("registering user") var user models.User if err := json.NewDecoder(r.Body).Decode(&user); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) + logger.Error.Printf("Couldn't decode Userdata: %v", err) return } if err := uc.service.RegisterUser(&user); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) + logger.Error.Printf("Couldn't register User: %v", err) return } w.WriteHeader(http.StatusCreated) @@ -41,8 +46,8 @@ func (uc *UserController) RegisterUser(w http.ResponseWriter, r *http.Request) { return } json.NewEncoder(w).Encode(user) -} */ - +} +*/ /* func (uc *UserController) GetUserID(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) id, err := strconv.Atoi(vars["id"]) diff --git a/internal/database/db.go b/internal/database/db.go index 1f03ff9..a02bf0b 100644 --- a/internal/database/db.go +++ b/internal/database/db.go @@ -2,8 +2,8 @@ package database import ( "GoMembership/internal/config" + "GoMembership/pkg/logger" "database/sql" - "log" "os" _ "github.com/mattn/go-sqlite3" @@ -29,13 +29,21 @@ func initializeDB(dbPath string, schemaPath string) error { func Connect() *sql.DB { cfg := config.LoadConfig() - dsn := cfg.DB.DBPath - db, err := sql.Open("sqlite3", dsn) + _, err := os.Stat(cfg.DB.DBPath) + if os.IsNotExist(err) { + initErr := initializeDB(cfg.DB.DBPath, "internal/database/schema.sql") + if initErr != nil { + logger.Error.Fatalf("Couldn't create database: %v", initErr) + } + logger.Info.Println("Created new database") + } + + db, err := sql.Open("sqlite3", cfg.DB.DBPath) if err != nil { - log.Fatal(err) + logger.Error.Fatal(err) } if err := db.Ping(); err != nil { - log.Fatal(err) + logger.Error.Fatal(err) } return db } diff --git a/internal/middlewares/auth_middleware.go b/internal/middlewares/auth_middleware.go index bbcffa7..b272758 100644 --- a/internal/middlewares/auth_middleware.go +++ b/internal/middlewares/auth_middleware.go @@ -1,13 +1,17 @@ package middlewares import ( - "net/http" + "net/http" ) func AuthMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Authentication logic here - next.ServeHTTP(w, r) - }) + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + token := r.Header.Get("Authorization") + if token != "your-secret-token" { + http.Error(w, "Forbidden", http.StatusForbidden) + return + } + next.ServeHTTP(w, r) + }) } - diff --git a/internal/middlewares/csrf_middleware.go b/internal/middlewares/csrf_middleware.go new file mode 100644 index 0000000..217ca19 --- /dev/null +++ b/internal/middlewares/csrf_middleware.go @@ -0,0 +1,114 @@ +package middlewares + +import ( + "GoMembership/internal/config" + "GoMembership/internal/utils" + "GoMembership/pkg/logger" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "net/http" + "strings" +) + +// GenerateCSRFToken generates HMAC-signed CSRF token +func GenerateCSRFToken(sessionID string, secretKey string) string { + // Create message to be signed (e.g., combining sessionID with some random value) + randomString, err := utils.GenerateRandomString(8) + if err != nil { + logger.Error.Fatalf("Could not create random string: %v", err) + return "" + } + + message := sessionID + "!" + randomString + + // Create HMAC hash using SHA-256 + h := hmac.New(sha256.New, []byte(secretKey)) + h.Write([]byte(message)) + signature := h.Sum(nil) + + // Encode signature and message into a CSRF token + csrfToken := base64.StdEncoding.EncodeToString(signature) + "." + message + return csrfToken +} + +func ComputeHMAC(message string, secretKey string) []byte { + h := hmac.New(sha256.New, []byte(secretKey)) + h.Write([]byte(message)) + return h.Sum(nil) +} + +// CSRFMiddleware verifies HMAC-signed CSRF token +func CSRFMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet || r.Method == http.MethodHead || r.Method == http.MethodOptions { + next.ServeHTTP(w, r) + return + } + csrfSecret := config.LoadConfig().Auth.CSRFSecret + // Retrieve CSRF token from request (e.g., from cookie, header, or form data) + csrfToken := r.Header.Get("X-CSRF-Token") + + // Extract signature and message from CSRF token + parts := strings.SplitN(csrfToken, ".", 2) + if len(parts) != 2 { + http.Error(w, "Invalid CSRF token", http.StatusForbidden) + return + } + receivedSignature := parts[0] + receivedMessage := parts[1] + + // Compute HMAC using the received message and the CSRF secret key + computedSignature := ComputeHMAC(receivedMessage, csrfSecret) + + // Compare computed HMAC with received signature + if !hmac.Equal([]byte(receivedSignature), computedSignature) { + http.Error(w, "CSRF Token validation failed", http.StatusForbidden) + return + } + + // CSRF token is valid, proceed to the next handler + next.ServeHTTP(w, r) + }) +} + +func GenerateCSRFTokenHandler(w http.ResponseWriter, r *http.Request) { + // Simulate getting session ID from authenticated session + sessionID := "exampleSessionID123" + + // Generate HMAC-signed CSRF token + csrfToken := GenerateCSRFToken(sessionID, config.LoadConfig().Auth.CSRFSecret) + + // Set CSRF token in a cookie (example) + http.SetCookie(w, &http.Cookie{ + Name: "csrf_token", + Value: csrfToken, + Path: "/", + HttpOnly: true, + Secure: true, + }) +} + +/* func GenerateCSRFTokenHandler() http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + token, err := GenerateCSRFToken() + if err != nil { + http.Error(w, "Could not generate CSRF token", http.StatusInternalServerError) + return + } + + // Set CSRF token in cookie + http.SetCookie(w, &http.Cookie{ + Name: "csrf_token", + Value: token, + Path: "/", + }) + + logger.Info.Printf("generated token: %v", token) + // Return CSRF token in response + w.Header().Set("X-CSRF-Token", token) + w.WriteHeader(http.StatusOK) + }) +} */ diff --git a/internal/middlewares/logger_middleware.go b/internal/middlewares/logger_middleware.go new file mode 100644 index 0000000..bd09348 --- /dev/null +++ b/internal/middlewares/logger_middleware.go @@ -0,0 +1,17 @@ +package middlewares + +import ( + "GoMembership/pkg/logger" + "net/http" + "time" +) + +// LoggerMiddleware logs each incoming HTTP request +func LoggerMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + next.ServeHTTP(w, r) + logger.Info.Printf("%s %s %s", r.Method, r.RequestURI, time.Since(start)) + }) +} diff --git a/internal/models/user.go b/internal/models/user.go index a25fc11..eaff4a4 100644 --- a/internal/models/user.go +++ b/internal/models/user.go @@ -9,7 +9,8 @@ type User struct { FirstName string `json:"first_name"` LastName string `json:"last_name"` Email string `json:"email"` - Password string `json:"password"` + Password string `json:"-"` + Salt string `json:"-"` IBAN string `json:"iban"` BIC string `json:"bic"` MandateReference string `json:"mandate_reference"` diff --git a/internal/repositories/user_repository.go b/internal/repositories/user_repository.go index e22a32e..dd9eb2c 100644 --- a/internal/repositories/user_repository.go +++ b/internal/repositories/user_repository.go @@ -9,7 +9,7 @@ import ( type UserRepository interface { CreateUser(user *models.User) error FindUserByID(id int) (*models.User, error) - // FindUserByEmail(email string) (*models.User, error) + FindUserByEmail(email string) (*models.User, error) } type userRepository struct { @@ -20,16 +20,29 @@ func NewUserRepository(db *sql.DB) UserRepository { return &userRepository{db} } -func (r *userRepository) CreateUser(user *models.User) error { - query := "INSERT INTO users (first_name, last_name, email, password, iban, bic, mandate_reference, mandate_date_signed, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)" - _, err := r.db.Exec(query, user.FirstName, user.LastName, user.Email, user.Password, user.CreatedAt, user.UpdatedAt) +func (repo *userRepository) CreateUser(user *models.User) error { + query := "INSERT INTO users (first_name, last_name, email, password, salt, iban, bic, mandate_reference, mandate_date_signed, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)" + _, err := repo.db.Exec(query, user.FirstName, user.LastName, user.Email, user.Password, user.Salt, user.IBAN, user.BIC, user.MandateReference, user.MandateDateSigned, user.CreatedAt, user.UpdatedAt) return err } -func (r *userRepository) FindUserByID(id int) (*models.User, error) { +func (repo *userRepository) FindUserByID(id int) (*models.User, error) { var user models.User query := "SELECT id, first_name, last_name, email, iban, bic, mandate_reference FROM users WHERE id = ?" - err := r.db.QueryRow(query, id).Scan(&user.ID, &user.FirstName, &user.LastName, &user.Email, &user.IBAN, &user.BIC, &user.MandateReference) + err := repo.db.QueryRow(query, id).Scan(&user.ID, &user.FirstName, &user.LastName, &user.Email, &user.IBAN, &user.BIC, &user.MandateReference) + if err != nil { + if err == sql.ErrNoRows { + return nil, errors.ErrUserNotFound + } + return nil, err + } + return &user, nil +} + +func (repo *userRepository) FindUserByEmail(email string) (*models.User, error) { + var user models.User + query := "SELECT id, first_name, last_name, email, iban, bic, mandate_reference FROM users WHERE email = ?" + err := repo.db.QueryRow(query, email).Scan(&user.ID, &user.FirstName, &user.LastName, &user.Email, &user.IBAN, &user.BIC, &user.MandateReference) if err != nil { if err == sql.ErrNoRows { return nil, errors.ErrUserNotFound diff --git a/internal/routes/routes.go b/internal/routes/routes.go index ffbb3c1..f02558e 100644 --- a/internal/routes/routes.go +++ b/internal/routes/routes.go @@ -1,13 +1,17 @@ package routes import ( - "net/http" - "GoMembership/internal/controllers" + "GoMembership/internal/controllers" + // "GoMembership/internal/middlewares" + "GoMembership/pkg/logger" + // "net/http" - "github.com/gorilla/mux" + "github.com/gorilla/mux" ) func RegisterRoutes(router *mux.Router, userController *controllers.UserController) { - router.HandleFunc("/register", userController.RegisterUser).Methods("POST") -} + logger.Info.Println("Registering /register route") + router.HandleFunc("/register", userController.RegisterUser).Methods("POST") + // router.HandleFunc("/login", userController.LoginUser).Methods("POST") +} diff --git a/internal/server/server.go b/internal/server/server.go index d15b404..732a2a8 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -1,20 +1,19 @@ package server import ( - // "GoMembership/internal/config" "GoMembership/internal/controllers" "GoMembership/internal/database" + "GoMembership/internal/middlewares" "GoMembership/internal/repositories" "GoMembership/internal/routes" "GoMembership/internal/services" - "log" + "GoMembership/pkg/logger" "net/http" "github.com/gorilla/mux" ) func Run() { - // cfg := config.LoadConfig() db := database.Connect() defer db.Close() @@ -23,10 +22,20 @@ func Run() { userController := controllers.NewUserController(userService) router := mux.NewRouter() - routes.RegisterRoutes(router, userController) + // router.Handle("/csrf-token", middlewares.GenerateCSRFTokenHandler()).Methods("GET") - log.Println("Starting server on :8080") + // Apply CSRF middleware + // router.Use(middlewares.CSRFMiddleware) + router.Use(middlewares.LoggerMiddleware) + + routes.RegisterRoutes(router, userController) + // create subrouter for teh authenticated area /account + // also pthprefix matches everything below /account + // accountRouter := router.PathPrefix("/account").Subrouter() + // accountRouter.Use(middlewares.AuthMiddleware) + + logger.Info.Println("Starting server on :8080") if err := http.ListenAndServe(":8080", router); err != nil { - log.Fatalf("could not start server: %v", err) + logger.Error.Fatalf("could not start server: %v", err) } } diff --git a/internal/services/user_service.go b/internal/services/user_service.go index bf2e685..4b4fe5c 100644 --- a/internal/services/user_service.go +++ b/internal/services/user_service.go @@ -3,12 +3,16 @@ package services import ( "GoMembership/internal/models" "GoMembership/internal/repositories" + // "GoMembership/pkg/errors" + "crypto/rand" + "encoding/base64" "golang.org/x/crypto/bcrypt" "time" ) type UserService interface { RegisterUser(user *models.User) error + // AuthenticateUser(email, password string) (*models.User, error) } type userService struct { @@ -19,8 +23,14 @@ func NewUserService(repo repositories.UserRepository) UserService { return &userService{repo} } -func (s *userService) RegisterUser(user *models.User) error { - hashedPassword, err := bcrypt.GenerateFromPassword([]byte(user.Password), bcrypt.DefaultCost) +func (service *userService) RegisterUser(user *models.User) error { + salt := make([]byte, 16) + if _, err := rand.Read(salt); err != nil { + return err + } + user.Salt = base64.StdEncoding.EncodeToString(salt) + + hashedPassword, err := HashPassword(user.Password, user.Salt) if err != nil { return err } @@ -28,5 +38,38 @@ func (s *userService) RegisterUser(user *models.User) error { user.CreatedAt = time.Now() user.UpdatedAt = time.Now() user.MandateDateSigned = time.Now() - return s.repo.CreateUser(user) + return service.repo.CreateUser(user) } + +func HashPassword(password string, salt string) (string, error) { + saltedPassword := password + salt + hashedPassword, err := bcrypt.GenerateFromPassword([]byte(saltedPassword), bcrypt.DefaultCost) + if err != nil { + return "", err + } + return base64.StdEncoding.EncodeToString(hashedPassword), nil +} + +/* func (s *userService) AuthenticateUser(email, password string) (*models.User, error) { + user, err := s.repo.FindUserByEmail(email) + if err != nil { + return nil, errors.ErrUserNotFound + } + + if !verifyPassword(password, user.Password, user.Salt) { + return nil, errors.ErrInvalidCredentials + } + + return user, nil +} +*/ +/* func verifyPassword(password string, storedPassword string, salt string) bool { + + saltedPassword := password + salt + decodedStoredPassword, err := base64.StdEncoding.DecodeString(storedPassword) + if err != nil { + return false + } + err = bcrypt.CompareHashAndPassword([]byte(decodedStoredPassword), []byte(saltedPassword)) + return err == nil +} */ diff --git a/internal/utils/crypto.go b/internal/utils/crypto.go new file mode 100644 index 0000000..abd6066 --- /dev/null +++ b/internal/utils/crypto.go @@ -0,0 +1,15 @@ +package utils + +import ( + "crypto/rand" + "encoding/base64" +) + +func GenerateRandomString(length int) (string, error) { + bytes := make([]byte, length) + _, err := rand.Read(bytes) + if err != nil { + return "", err + } + return base64.URLEncoding.EncodeToString(bytes), nil +} diff --git a/pkg/errors/errors.go b/pkg/errors/errors.go index ef9e07c..52d8455 100644 --- a/pkg/errors/errors.go +++ b/pkg/errors/errors.go @@ -3,8 +3,7 @@ package errors import "errors" var ( - ErrUserNotFound = errors.New("user not found") - ErrInvalidEmail = errors.New("invalid email") - // Add other custom errors here + ErrUserNotFound = errors.New("user not found") + ErrInvalidEmail = errors.New("invalid email") + ErrInvalidCredentials = errors.New("invalid credentials: unauthorized") ) - diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go index 4b7b1c8..dd06150 100644 --- a/pkg/logger/logger.go +++ b/pkg/logger/logger.go @@ -1,24 +1,23 @@ package logger import ( - "log" - "os" + "log" + "os" ) var ( - Info *log.Logger - Warning *log.Logger - Error *log.Logger + Info *log.Logger + Warning *log.Logger + Error *log.Logger ) func init() { - file, err := os.OpenFile("app.log", os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) - if err != nil { - log.Fatal(err) - } + file, err := os.OpenFile("gomember.log", os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) + if err != nil { + log.Fatal(err) + } - Info = log.New(file, "INFO: ", log.Ldate|log.Ltime|log.Lshortfile) - Warning = log.New(file, "WARNING: ", log.Ldate|log.Ltime|log.Lshortfile) - Error = log.New(file, "ERROR: ", log.Ldate|log.Ltime|log.Lshortfile) + Info = log.New(file, "INFO: ", log.Ldate|log.Ltime|log.Lshortfile) + Warning = log.New(file, "WARNING: ", log.Ldate|log.Ltime|log.Lshortfile) + Error = log.New(file, "ERROR: ", log.Ldate|log.Ltime|log.Lshortfile) } - diff --git a/src/go.mod b/src/go.mod deleted file mode 100644 index 74edfe1..0000000 --- a/src/go.mod +++ /dev/null @@ -1,3 +0,0 @@ -module git.stoelti.land/Alex/GoMembership - -go 1.22.4 diff --git a/src/main.go b/src/main.go deleted file mode 100644 index 58e050a..0000000 --- a/src/main.go +++ /dev/null @@ -1,57 +0,0 @@ -// main.go - -package main - -import ( - "strconv" - - "github.com/astaxie/beego" -) - - -func main() { - /* This would match routes like the following: - /sum/3/5 - /product/6/23 - ... - */ - beego.Router("/:operation/:num1:int/:num2:int", &mainController{}) - beego.Run() -} - -type mainController struct { - beego.Controller -} - - -func (c *mainController) Get() { - - //Obtain the values of the route parameters defined in the route above - operation := c.Ctx.Input.Param(":operation") - num1, _ := strconv.Atoi(c.Ctx.Input.Param(":num1")) - num2, _ := strconv.Atoi(c.Ctx.Input.Param(":num2")) - - //Set the values for use in the template - c.Data["operation"] = operation - c.Data["num1"] = num1 - c.Data["num2"] = num2 - c.TplName = "result.html" - - // Perform the calculation depending on the 'operation' route parameter - switch operation { - case "sum": - c.Data["result"] = add(num1, num2) - case "product": - c.Data["result"] = multiply(num1, num2) - default: - c.TplName = "invalid-route.html" - } -} - -func add(n1, n2 int) int { - return n1 + n2 -} - -func multiply(n1, n2 int) int { - return n1 * n2 -}