package middlewares import ( "log" "net/http" "net/http/httptest" "os" "path/filepath" "strconv" "testing" "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" "GoMembership/internal/config" "GoMembership/pkg/logger" ) const ( Host = "127.0.0.1" Port int = 2525 ) func TestCORSMiddleware(t *testing.T) { cwd, err := os.Getwd() if err != nil { log.Fatalf("Failed to get current working directory: %v", err) } configFilePath := filepath.Join(cwd, "..", "..", "configs", "config.json") templateHTMLPath := filepath.Join(cwd, "..", "..", "templates", "html") templateMailPath := filepath.Join(cwd, "..", "..", "templates", "email") if err := os.Setenv("TEMPLATE_MAIL_PATH", templateMailPath); err != nil { log.Fatalf("Error setting environment variable: %v", err) } if err := os.Setenv("TEMPLATE_HTML_PATH", templateHTMLPath); err != nil { log.Fatalf("Error setting environment variable: %v", err) } if err := os.Setenv("CONFIG_FILE_PATH", configFilePath); err != nil { log.Fatalf("Error setting environment variable: %v", err) } if err := os.Setenv("SMTP_HOST", Host); err != nil { log.Fatalf("Error setting environment variable: %v", err) } if err := os.Setenv("SMTP_PORT", strconv.Itoa(Port)); err != nil { log.Fatalf("Error setting environment variable: %v", err) } if err := os.Setenv("BASE_URL", "http://"+Host+":2525"); err != nil { log.Fatalf("Error setting environment variable: %v", err) } // Load your configuration config.LoadConfig() // Create a gin router with the CORS middleware router := gin.New() router.Use(CORSMiddleware()) // Add a simple handler router.GET("/test", func(c *gin.Context) { c.String(200, "test") }) tests := []struct { name string origin string expectedStatus int expectedHeaders map[string]string }{ { name: "Allowed origin", origin: config.Site.AllowOrigins, expectedStatus: http.StatusOK, expectedHeaders: map[string]string{ "Access-Control-Allow-Origin": config.Site.AllowOrigins, "Content-Type": "text/plain; charset=utf-8", "Access-Control-Allow-Credentials": "true", }, }, { name: "Disallowed origin", origin: "http://evil.com", expectedStatus: http.StatusForbidden, expectedHeaders: map[string]string{ "Access-Control-Allow-Origin": "", }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { w := httptest.NewRecorder() req, _ := http.NewRequest("GET", "/test", nil) req.Header.Set("Origin", tt.origin) router.ServeHTTP(w, req) assert.Equal(t, tt.expectedStatus, w.Code) logger.Info.Printf("Recieved Headers: %#v", w.Header()) for key, value := range tt.expectedHeaders { assert.Equal(t, value, w.Header().Get(key)) } }) } }