|
|
|
@ -9,6 +9,7 @@ import ( |
|
|
|
|
"os" |
|
|
|
|
"path" |
|
|
|
|
"testing" |
|
|
|
|
"strings" |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
var ( |
|
|
|
@ -458,3 +459,32 @@ func TestValidReferrer(t *testing.T) { |
|
|
|
|
}) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func TestAddsCorsHeaders(t *testing.T) { |
|
|
|
|
testCases := []struct { |
|
|
|
|
name string |
|
|
|
|
header string |
|
|
|
|
value string |
|
|
|
|
}{ |
|
|
|
|
{"Add Access-Control-Allow-Origin header", "Access-Control-Allow-Origin", "*"}, |
|
|
|
|
{"Add Access-Control-Allow-Headers header", "Access-Control-Allow-Headers", "*"}, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
for _, serveFile := range serveFileFuncs { |
|
|
|
|
handler := AddCorsHeaders(Basic(serveFile, baseDir)) |
|
|
|
|
for _, tc := range testCases { |
|
|
|
|
t.Run(tc.name, func(t *testing.T) { |
|
|
|
|
req := httptest.NewRequest("GET", "http://localhost/", nil) |
|
|
|
|
w := httptest.NewRecorder() |
|
|
|
|
|
|
|
|
|
handler(w, req) |
|
|
|
|
|
|
|
|
|
resp := w.Result() |
|
|
|
|
headerValue := strings.Join(resp.Header[tc.header], ", ") |
|
|
|
|
if headerValue != tc.value { |
|
|
|
|
t.Errorf("Response header %q = %q, want %q", tc.header, headerValue, tc.value) |
|
|
|
|
} |
|
|
|
|
}) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|