Added comments and renamed function to annotate that it will apply the wildcard to CORS (rather than list allowed domains). Modified CORS unit test.

This commit is contained in:
Jeromy Streets 2019-10-16 09:33:25 -07:00
parent 0d80a644ff
commit 9a679fb2e9
3 changed files with 64 additions and 19 deletions

View File

@ -62,8 +62,9 @@ func handlerSelector() (handler http.HandlerFunc) {
handler = handle.IgnoreIndex(handler)
}
// If configured, apply wildcard CORS support.
if config.Get.Cors {
handler = handle.AddCorsHeaders(handler)
handler = handle.AddCorsWildcardHeaders(handler)
}
return

View File

@ -109,12 +109,13 @@ func IgnoreIndex(serve http.HandlerFunc) http.HandlerFunc {
}
}
func AddCorsHeaders(serve http.HandlerFunc) http.HandlerFunc {
return func(writer http.ResponseWriter, request *http.Request) {
writer.Header().Set("Access-Control-Allow-Origin", "*")
writer.Header().Set("Access-Control-Allow-Headers", "*")
serve(writer, request)
// AddCorsWildcardHeaders wraps an HTTP request to notify client browsers that
// resources should be allowed to be retrieved by any other domain.
func AddCorsWildcardHeaders(serve http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Headers", "*")
serve(w, r)
}
}

View File

@ -9,7 +9,6 @@ import (
"os"
"path"
"testing"
"strings"
)
var (
@ -460,29 +459,73 @@ func TestValidReferrer(t *testing.T) {
}
}
func TestAddsCorsHeaders(t *testing.T) {
func TestAddCorsWildcardHeaders(t *testing.T) {
testCases := []struct {
name string
header string
value string
name string
corsEnabled bool
}{
{"Add Access-Control-Allow-Origin header", "Access-Control-Allow-Origin", "*"},
{"Add Access-Control-Allow-Headers header", "Access-Control-Allow-Headers", "*"},
{"CORS disabled", false},
{"CORS enabled", true},
}
corsHeaders := map[string]string{
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Headers": "*",
}
for _, serveFile := range serveFileFuncs {
handler := AddCorsHeaders(Basic(serveFile, baseDir))
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "http://localhost/", nil)
var handler http.HandlerFunc
if tc.corsEnabled {
handler = AddCorsWildcardHeaders(Basic(serveFile, baseDir))
} else {
handler = Basic(serveFile, baseDir)
}
fullpath := "http://localhost/" + tmpFileName
req := httptest.NewRequest("GET", fullpath, nil)
w := httptest.NewRecorder()
handler(w, req)
resp := w.Result()
headerValue := strings.Join(resp.Header[tc.header], ", ")
if headerValue != tc.value {
t.Errorf("Response header %q = %q, want %q", tc.header, headerValue, tc.value)
body, err := ioutil.ReadAll(resp.Body)
if nil != err {
t.Errorf("While reading body got %v", err)
}
contents := string(body)
if ok != resp.StatusCode {
t.Errorf(
"While retrieving %s expected status code of %d but got %d",
fullpath, ok, resp.StatusCode,
)
}
if tmpFile != contents {
t.Errorf(
"While retrieving %s expected contents '%s' but got '%s'",
fullpath, tmpFile, contents,
)
}
if tc.corsEnabled {
for k, v := range corsHeaders {
if v != resp.Header.Get(k) {
t.Errorf(
"With CORS enabled expect header '%s' to return '%s' but got '%s'",
k, v, resp.Header.Get(k),
)
}
}
} else {
for k := range corsHeaders {
if "" != resp.Header.Get(k) {
t.Errorf(
"With CORS disabled expected header '%s' to return '' but got '%s'",
k, resp.Header.Get(k),
)
}
}
}
})
}