Files
GoMembership/internal/middlewares/cors_test.go
$(pass /github/name) 682c50574b add: CORS
2024-08-22 11:27:27 +02:00

105 lines
2.8 KiB
Go

package middlewares
import (
"log"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strconv"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"GoMembership/internal/config"
"GoMembership/pkg/logger"
)
const (
Host = "127.0.0.1"
Port int = 2525
)
func TestCORSMiddleware(t *testing.T) {
cwd, err := os.Getwd()
if err != nil {
log.Fatalf("Failed to get current working directory: %v", err)
}
configFilePath := filepath.Join(cwd, "..", "..", "configs", "config.json")
templateHTMLPath := filepath.Join(cwd, "..", "..", "templates", "html")
templateMailPath := filepath.Join(cwd, "..", "..", "templates", "email")
if err := os.Setenv("TEMPLATE_MAIL_PATH", templateMailPath); err != nil {
log.Fatalf("Error setting environment variable: %v", err)
}
if err := os.Setenv("TEMPLATE_HTML_PATH", templateHTMLPath); err != nil {
log.Fatalf("Error setting environment variable: %v", err)
}
if err := os.Setenv("CONFIG_FILE_PATH", configFilePath); err != nil {
log.Fatalf("Error setting environment variable: %v", err)
}
if err := os.Setenv("SMTP_HOST", Host); err != nil {
log.Fatalf("Error setting environment variable: %v", err)
}
if err := os.Setenv("SMTP_PORT", strconv.Itoa(Port)); err != nil {
log.Fatalf("Error setting environment variable: %v", err)
}
if err := os.Setenv("BASE_URL", "http://"+Host+":2525"); err != nil {
log.Fatalf("Error setting environment variable: %v", err)
}
// Load your configuration
config.LoadConfig()
// Create a gin router with the CORS middleware
router := gin.New()
router.Use(CORSMiddleware())
// Add a simple handler
router.GET("/test", func(c *gin.Context) {
c.String(200, "test")
})
tests := []struct {
name string
origin string
expectedStatus int
expectedHeaders map[string]string
}{
{
name: "Allowed origin",
origin: config.BaseURL,
expectedStatus: http.StatusOK,
expectedHeaders: map[string]string{
"Access-Control-Allow-Origin": config.BaseURL,
"Content-Type": "text/plain; charset=utf-8",
"Access-Control-Allow-Credentials": "true",
},
},
{
name: "Disallowed origin",
origin: "http://evil.com",
expectedStatus: http.StatusForbidden,
expectedHeaders: map[string]string{
"Access-Control-Allow-Origin": "",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
req.Header.Set("Origin", tt.origin)
router.ServeHTTP(w, req)
assert.Equal(t, tt.expectedStatus, w.Code)
logger.Info.Printf("Recieved Headers: %#v", w.Header())
for key, value := range tt.expectedHeaders {
assert.Equal(t, value, w.Header().Get(key))
}
})
}
}