Added comments and renamed function to annotate that it will apply the wildcard to CORS (rather than list allowed domains). Modified CORS unit test.

pull/37/head
Jeromy Streets 5 years ago
parent 0d80a644ff
commit 9a679fb2e9
  1. 3
      cli/server/server.go
  2. 13
      handle/handle.go
  3. 67
      handle/handle_test.go

@ -62,8 +62,9 @@ func handlerSelector() (handler http.HandlerFunc) {
handler = handle.IgnoreIndex(handler) handler = handle.IgnoreIndex(handler)
} }
// If configured, apply wildcard CORS support.
if config.Get.Cors { if config.Get.Cors {
handler = handle.AddCorsHeaders(handler) handler = handle.AddCorsWildcardHeaders(handler)
} }
return return

@ -109,12 +109,13 @@ func IgnoreIndex(serve http.HandlerFunc) http.HandlerFunc {
} }
} }
func AddCorsHeaders(serve http.HandlerFunc) http.HandlerFunc { // AddCorsWildcardHeaders wraps an HTTP request to notify client browsers that
return func(writer http.ResponseWriter, request *http.Request) { // resources should be allowed to be retrieved by any other domain.
writer.Header().Set("Access-Control-Allow-Origin", "*") func AddCorsWildcardHeaders(serve http.HandlerFunc) http.HandlerFunc {
writer.Header().Set("Access-Control-Allow-Headers", "*") return func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
serve(writer, request) w.Header().Set("Access-Control-Allow-Headers", "*")
serve(w, r)
} }
} }

@ -9,7 +9,6 @@ import (
"os" "os"
"path" "path"
"testing" "testing"
"strings"
) )
var ( var (
@ -460,29 +459,73 @@ func TestValidReferrer(t *testing.T) {
} }
} }
func TestAddsCorsHeaders(t *testing.T) { func TestAddCorsWildcardHeaders(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
header string corsEnabled bool
value string
}{ }{
{"Add Access-Control-Allow-Origin header", "Access-Control-Allow-Origin", "*"}, {"CORS disabled", false},
{"Add Access-Control-Allow-Headers header", "Access-Control-Allow-Headers", "*"}, {"CORS enabled", true},
}
corsHeaders := map[string]string{
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Headers": "*",
} }
for _, serveFile := range serveFileFuncs { for _, serveFile := range serveFileFuncs {
handler := AddCorsHeaders(Basic(serveFile, baseDir))
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "http://localhost/", nil) var handler http.HandlerFunc
if tc.corsEnabled {
handler = AddCorsWildcardHeaders(Basic(serveFile, baseDir))
} else {
handler = Basic(serveFile, baseDir)
}
fullpath := "http://localhost/" + tmpFileName
req := httptest.NewRequest("GET", fullpath, nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler(w, req) handler(w, req)
resp := w.Result() resp := w.Result()
headerValue := strings.Join(resp.Header[tc.header], ", ") body, err := ioutil.ReadAll(resp.Body)
if headerValue != tc.value { if nil != err {
t.Errorf("Response header %q = %q, want %q", tc.header, headerValue, tc.value) t.Errorf("While reading body got %v", err)
}
contents := string(body)
if ok != resp.StatusCode {
t.Errorf(
"While retrieving %s expected status code of %d but got %d",
fullpath, ok, resp.StatusCode,
)
}
if tmpFile != contents {
t.Errorf(
"While retrieving %s expected contents '%s' but got '%s'",
fullpath, tmpFile, contents,
)
}
if tc.corsEnabled {
for k, v := range corsHeaders {
if v != resp.Header.Get(k) {
t.Errorf(
"With CORS enabled expect header '%s' to return '%s' but got '%s'",
k, v, resp.Header.Get(k),
)
}
}
} else {
for k := range corsHeaders {
if "" != resp.Header.Get(k) {
t.Errorf(
"With CORS disabled expected header '%s' to return '' but got '%s'",
k, resp.Header.Get(k),
)
}
}
} }
}) })
} }

Loading…
Cancel
Save