|
|
|
@ -9,7 +9,6 @@ import ( |
|
|
|
|
"os" |
|
|
|
|
"path" |
|
|
|
|
"testing" |
|
|
|
|
"strings" |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
var ( |
|
|
|
@ -460,29 +459,73 @@ func TestValidReferrer(t *testing.T) { |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func TestAddsCorsHeaders(t *testing.T) { |
|
|
|
|
func TestAddCorsWildcardHeaders(t *testing.T) { |
|
|
|
|
testCases := []struct { |
|
|
|
|
name string |
|
|
|
|
header string |
|
|
|
|
value string |
|
|
|
|
name string |
|
|
|
|
corsEnabled bool |
|
|
|
|
}{ |
|
|
|
|
{"Add Access-Control-Allow-Origin header", "Access-Control-Allow-Origin", "*"}, |
|
|
|
|
{"Add Access-Control-Allow-Headers header", "Access-Control-Allow-Headers", "*"}, |
|
|
|
|
{"CORS disabled", false}, |
|
|
|
|
{"CORS enabled", true}, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
corsHeaders := map[string]string{ |
|
|
|
|
"Access-Control-Allow-Origin": "*", |
|
|
|
|
"Access-Control-Allow-Headers": "*", |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
for _, serveFile := range serveFileFuncs { |
|
|
|
|
handler := AddCorsHeaders(Basic(serveFile, baseDir)) |
|
|
|
|
for _, tc := range testCases { |
|
|
|
|
t.Run(tc.name, func(t *testing.T) { |
|
|
|
|
req := httptest.NewRequest("GET", "http://localhost/", nil) |
|
|
|
|
var handler http.HandlerFunc |
|
|
|
|
if tc.corsEnabled { |
|
|
|
|
handler = AddCorsWildcardHeaders(Basic(serveFile, baseDir)) |
|
|
|
|
} else { |
|
|
|
|
handler = Basic(serveFile, baseDir) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
fullpath := "http://localhost/" + tmpFileName |
|
|
|
|
req := httptest.NewRequest("GET", fullpath, nil) |
|
|
|
|
w := httptest.NewRecorder() |
|
|
|
|
|
|
|
|
|
handler(w, req) |
|
|
|
|
|
|
|
|
|
resp := w.Result() |
|
|
|
|
headerValue := strings.Join(resp.Header[tc.header], ", ") |
|
|
|
|
if headerValue != tc.value { |
|
|
|
|
t.Errorf("Response header %q = %q, want %q", tc.header, headerValue, tc.value) |
|
|
|
|
body, err := ioutil.ReadAll(resp.Body) |
|
|
|
|
if nil != err { |
|
|
|
|
t.Errorf("While reading body got %v", err) |
|
|
|
|
} |
|
|
|
|
contents := string(body) |
|
|
|
|
if ok != resp.StatusCode { |
|
|
|
|
t.Errorf( |
|
|
|
|
"While retrieving %s expected status code of %d but got %d", |
|
|
|
|
fullpath, ok, resp.StatusCode, |
|
|
|
|
) |
|
|
|
|
} |
|
|
|
|
if tmpFile != contents { |
|
|
|
|
t.Errorf( |
|
|
|
|
"While retrieving %s expected contents '%s' but got '%s'", |
|
|
|
|
fullpath, tmpFile, contents, |
|
|
|
|
) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if tc.corsEnabled { |
|
|
|
|
for k, v := range corsHeaders { |
|
|
|
|
if v != resp.Header.Get(k) { |
|
|
|
|
t.Errorf( |
|
|
|
|
"With CORS enabled expect header '%s' to return '%s' but got '%s'", |
|
|
|
|
k, v, resp.Header.Get(k), |
|
|
|
|
) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} else { |
|
|
|
|
for k := range corsHeaders { |
|
|
|
|
if "" != resp.Header.Get(k) { |
|
|
|
|
t.Errorf( |
|
|
|
|
"With CORS disabled expected header '%s' to return '' but got '%s'", |
|
|
|
|
k, resp.Header.Get(k), |
|
|
|
|
) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
}) |
|
|
|
|
} |
|
|
|
|