mirror of
https://github.com/halverneus/static-file-server.git
synced 2024-11-24 09:05:30 +00:00
Added comments and renamed function to annotate that it will apply the wildcard to CORS (rather than list allowed domains). Modified CORS unit test.
This commit is contained in:
parent
0d80a644ff
commit
9a679fb2e9
@ -62,8 +62,9 @@ func handlerSelector() (handler http.HandlerFunc) {
|
||||
handler = handle.IgnoreIndex(handler)
|
||||
}
|
||||
|
||||
// If configured, apply wildcard CORS support.
|
||||
if config.Get.Cors {
|
||||
handler = handle.AddCorsHeaders(handler)
|
||||
handler = handle.AddCorsWildcardHeaders(handler)
|
||||
}
|
||||
|
||||
return
|
||||
|
@ -109,12 +109,13 @@ func IgnoreIndex(serve http.HandlerFunc) http.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func AddCorsHeaders(serve http.HandlerFunc) http.HandlerFunc {
|
||||
return func(writer http.ResponseWriter, request *http.Request) {
|
||||
writer.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
writer.Header().Set("Access-Control-Allow-Headers", "*")
|
||||
|
||||
serve(writer, request)
|
||||
// AddCorsWildcardHeaders wraps an HTTP request to notify client browsers that
|
||||
// resources should be allowed to be retrieved by any other domain.
|
||||
func AddCorsWildcardHeaders(serve http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "*")
|
||||
serve(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -9,7 +9,6 @@ import (
|
||||
"os"
|
||||
"path"
|
||||
"testing"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -460,29 +459,73 @@ func TestValidReferrer(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddsCorsHeaders(t *testing.T) {
|
||||
func TestAddCorsWildcardHeaders(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
header string
|
||||
value string
|
||||
name string
|
||||
corsEnabled bool
|
||||
}{
|
||||
{"Add Access-Control-Allow-Origin header", "Access-Control-Allow-Origin", "*"},
|
||||
{"Add Access-Control-Allow-Headers header", "Access-Control-Allow-Headers", "*"},
|
||||
{"CORS disabled", false},
|
||||
{"CORS enabled", true},
|
||||
}
|
||||
|
||||
corsHeaders := map[string]string{
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"Access-Control-Allow-Headers": "*",
|
||||
}
|
||||
|
||||
for _, serveFile := range serveFileFuncs {
|
||||
handler := AddCorsHeaders(Basic(serveFile, baseDir))
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "http://localhost/", nil)
|
||||
var handler http.HandlerFunc
|
||||
if tc.corsEnabled {
|
||||
handler = AddCorsWildcardHeaders(Basic(serveFile, baseDir))
|
||||
} else {
|
||||
handler = Basic(serveFile, baseDir)
|
||||
}
|
||||
|
||||
fullpath := "http://localhost/" + tmpFileName
|
||||
req := httptest.NewRequest("GET", fullpath, nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler(w, req)
|
||||
|
||||
resp := w.Result()
|
||||
headerValue := strings.Join(resp.Header[tc.header], ", ")
|
||||
if headerValue != tc.value {
|
||||
t.Errorf("Response header %q = %q, want %q", tc.header, headerValue, tc.value)
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if nil != err {
|
||||
t.Errorf("While reading body got %v", err)
|
||||
}
|
||||
contents := string(body)
|
||||
if ok != resp.StatusCode {
|
||||
t.Errorf(
|
||||
"While retrieving %s expected status code of %d but got %d",
|
||||
fullpath, ok, resp.StatusCode,
|
||||
)
|
||||
}
|
||||
if tmpFile != contents {
|
||||
t.Errorf(
|
||||
"While retrieving %s expected contents '%s' but got '%s'",
|
||||
fullpath, tmpFile, contents,
|
||||
)
|
||||
}
|
||||
|
||||
if tc.corsEnabled {
|
||||
for k, v := range corsHeaders {
|
||||
if v != resp.Header.Get(k) {
|
||||
t.Errorf(
|
||||
"With CORS enabled expect header '%s' to return '%s' but got '%s'",
|
||||
k, v, resp.Header.Get(k),
|
||||
)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for k := range corsHeaders {
|
||||
if "" != resp.Header.Get(k) {
|
||||
t.Errorf(
|
||||
"With CORS disabled expected header '%s' to return '' but got '%s'",
|
||||
k, resp.Header.Get(k),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user