diff --git a/cli/server/server.go b/cli/server/server.go index 2393cd7..447c90d 100644 --- a/cli/server/server.go +++ b/cli/server/server.go @@ -62,8 +62,9 @@ func handlerSelector() (handler http.HandlerFunc) { handler = handle.IgnoreIndex(handler) } + // If configured, apply wildcard CORS support. if config.Get.Cors { - handler = handle.AddCorsHeaders(handler) + handler = handle.AddCorsWildcardHeaders(handler) } return diff --git a/handle/handle.go b/handle/handle.go index f0d3e91..e869e3b 100644 --- a/handle/handle.go +++ b/handle/handle.go @@ -109,12 +109,13 @@ func IgnoreIndex(serve http.HandlerFunc) http.HandlerFunc { } } -func AddCorsHeaders(serve http.HandlerFunc) http.HandlerFunc { - return func(writer http.ResponseWriter, request *http.Request) { - writer.Header().Set("Access-Control-Allow-Origin", "*") - writer.Header().Set("Access-Control-Allow-Headers", "*") - - serve(writer, request) +// AddCorsWildcardHeaders wraps an HTTP request to notify client browsers that +// resources should be allowed to be retrieved by any other domain. +func AddCorsWildcardHeaders(serve http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Headers", "*") + serve(w, r) } } diff --git a/handle/handle_test.go b/handle/handle_test.go index 2ad5eb4..c165300 100644 --- a/handle/handle_test.go +++ b/handle/handle_test.go @@ -9,7 +9,6 @@ import ( "os" "path" "testing" - "strings" ) var ( @@ -460,29 +459,73 @@ func TestValidReferrer(t *testing.T) { } } -func TestAddsCorsHeaders(t *testing.T) { +func TestAddCorsWildcardHeaders(t *testing.T) { testCases := []struct { - name string - header string - value string + name string + corsEnabled bool }{ - {"Add Access-Control-Allow-Origin header", "Access-Control-Allow-Origin", "*"}, - {"Add Access-Control-Allow-Headers header", "Access-Control-Allow-Headers", "*"}, + {"CORS disabled", false}, + {"CORS enabled", true}, + } + + corsHeaders := map[string]string{ + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Headers": "*", } for _, serveFile := range serveFileFuncs { - handler := AddCorsHeaders(Basic(serveFile, baseDir)) for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - req := httptest.NewRequest("GET", "http://localhost/", nil) + var handler http.HandlerFunc + if tc.corsEnabled { + handler = AddCorsWildcardHeaders(Basic(serveFile, baseDir)) + } else { + handler = Basic(serveFile, baseDir) + } + + fullpath := "http://localhost/" + tmpFileName + req := httptest.NewRequest("GET", fullpath, nil) w := httptest.NewRecorder() handler(w, req) resp := w.Result() - headerValue := strings.Join(resp.Header[tc.header], ", ") - if headerValue != tc.value { - t.Errorf("Response header %q = %q, want %q", tc.header, headerValue, tc.value) + body, err := ioutil.ReadAll(resp.Body) + if nil != err { + t.Errorf("While reading body got %v", err) + } + contents := string(body) + if ok != resp.StatusCode { + t.Errorf( + "While retrieving %s expected status code of %d but got %d", + fullpath, ok, resp.StatusCode, + ) + } + if tmpFile != contents { + t.Errorf( + "While retrieving %s expected contents '%s' but got '%s'", + fullpath, tmpFile, contents, + ) + } + + if tc.corsEnabled { + for k, v := range corsHeaders { + if v != resp.Header.Get(k) { + t.Errorf( + "With CORS enabled expect header '%s' to return '%s' but got '%s'", + k, v, resp.Header.Get(k), + ) + } + } + } else { + for k := range corsHeaders { + if "" != resp.Header.Get(k) { + t.Errorf( + "With CORS disabled expected header '%s' to return '' but got '%s'", + k, resp.Header.Get(k), + ) + } + } } }) }