add: CORS
This commit is contained in:
21
internal/middlewares/cors.go
Normal file
21
internal/middlewares/cors.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package middlewares
|
||||
|
||||
import (
|
||||
"GoMembership/internal/config"
|
||||
"GoMembership/pkg/logger"
|
||||
|
||||
"github.com/gin-contrib/cors"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func CORSMiddleware() gin.HandlerFunc {
|
||||
logger.Info.Print("Applying CORS")
|
||||
return cors.New(cors.Config{
|
||||
AllowOrigins: []string{config.BaseURL}, // Add your frontend URL(s)
|
||||
AllowMethods: []string{"GET", "POST"}, // "PUT", "PATCH", "DELETE", "OPTIONS"},
|
||||
AllowHeaders: []string{"Origin", "Content-Type", "Accept", "Authorization", "X-Requested-With"},
|
||||
// ExposeHeaders: []string{"Content-Length"},
|
||||
AllowCredentials: true,
|
||||
MaxAge: 12 * 60 * 60, // 12 hours
|
||||
})
|
||||
}
|
||||
104
internal/middlewares/cors_test.go
Normal file
104
internal/middlewares/cors_test.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package middlewares
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"GoMembership/internal/config"
|
||||
"GoMembership/pkg/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
Host = "127.0.0.1"
|
||||
Port int = 2525
|
||||
)
|
||||
|
||||
func TestCORSMiddleware(t *testing.T) {
|
||||
cwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to get current working directory: %v", err)
|
||||
}
|
||||
|
||||
configFilePath := filepath.Join(cwd, "..", "..", "configs", "config.json")
|
||||
templateHTMLPath := filepath.Join(cwd, "..", "..", "templates", "html")
|
||||
templateMailPath := filepath.Join(cwd, "..", "..", "templates", "email")
|
||||
|
||||
if err := os.Setenv("TEMPLATE_MAIL_PATH", templateMailPath); err != nil {
|
||||
log.Fatalf("Error setting environment variable: %v", err)
|
||||
}
|
||||
if err := os.Setenv("TEMPLATE_HTML_PATH", templateHTMLPath); err != nil {
|
||||
log.Fatalf("Error setting environment variable: %v", err)
|
||||
}
|
||||
if err := os.Setenv("CONFIG_FILE_PATH", configFilePath); err != nil {
|
||||
log.Fatalf("Error setting environment variable: %v", err)
|
||||
}
|
||||
if err := os.Setenv("SMTP_HOST", Host); err != nil {
|
||||
log.Fatalf("Error setting environment variable: %v", err)
|
||||
}
|
||||
if err := os.Setenv("SMTP_PORT", strconv.Itoa(Port)); err != nil {
|
||||
log.Fatalf("Error setting environment variable: %v", err)
|
||||
}
|
||||
if err := os.Setenv("BASE_URL", "http://"+Host+":2525"); err != nil {
|
||||
log.Fatalf("Error setting environment variable: %v", err)
|
||||
}
|
||||
// Load your configuration
|
||||
config.LoadConfig()
|
||||
|
||||
// Create a gin router with the CORS middleware
|
||||
router := gin.New()
|
||||
router.Use(CORSMiddleware())
|
||||
|
||||
// Add a simple handler
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.String(200, "test")
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
origin string
|
||||
expectedStatus int
|
||||
expectedHeaders map[string]string
|
||||
}{
|
||||
{
|
||||
name: "Allowed origin",
|
||||
origin: config.BaseURL,
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedHeaders: map[string]string{
|
||||
"Access-Control-Allow-Origin": config.BaseURL,
|
||||
"Content-Type": "text/plain; charset=utf-8",
|
||||
"Access-Control-Allow-Credentials": "true",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Disallowed origin",
|
||||
origin: "http://evil.com",
|
||||
expectedStatus: http.StatusForbidden,
|
||||
expectedHeaders: map[string]string{
|
||||
"Access-Control-Allow-Origin": "",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("Origin", tt.origin)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, tt.expectedStatus, w.Code)
|
||||
logger.Info.Printf("Recieved Headers: %#v", w.Header())
|
||||
for key, value := range tt.expectedHeaders {
|
||||
assert.Equal(t, value, w.Header().Get(key))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
"GoMembership/internal/config"
|
||||
"GoMembership/internal/controllers"
|
||||
"GoMembership/internal/middlewares"
|
||||
"GoMembership/internal/repositories"
|
||||
|
||||
// "GoMembership/internal/middlewares"
|
||||
@@ -50,8 +51,9 @@ func Run() {
|
||||
router.Static(config.Templates.StaticPath, "./style")
|
||||
// Load HTML templates
|
||||
router.LoadHTMLGlob(filepath.Join(config.Templates.HTMLPath, "*"))
|
||||
|
||||
router.Use(gin.Logger())
|
||||
// router.Use(middlewares.LoggerMiddleware())
|
||||
router.Use(middlewares.CORSMiddleware())
|
||||
|
||||
routes.RegisterRoutes(router, userController, membershipController, contactController)
|
||||
// create subrouter for teh authenticated area /account
|
||||
|
||||
Reference in New Issue
Block a user