mirror of
https://github.com/halverneus/static-file-server.git
synced 2024-11-13 21:55:31 +00:00
557 lines
15 KiB
Go
557 lines
15 KiB
Go
package handle
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"errors"
|
|
"io/ioutil"
|
|
"log"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"path"
|
|
"testing"
|
|
)
|
|
|
|
var (
|
|
baseDir = "tmp/"
|
|
subDir = "sub/"
|
|
subDeepDir = "sub/deep/"
|
|
tmpIndexName = "index.html"
|
|
tmpFileName = "file.txt"
|
|
tmpBadName = "bad.txt"
|
|
tmpSubIndexName = "sub/index.html"
|
|
tmpSubFileName = "sub/file.txt"
|
|
tmpSubBadName = "sub/bad.txt"
|
|
tmpSubDeepIndexName = "sub/deep/index.html"
|
|
tmpSubDeepFileName = "sub/deep/file.txt"
|
|
tmpSubDeepBadName = "sub/deep/bad.txt"
|
|
|
|
tmpIndex = "Space: the final frontier"
|
|
tmpFile = "These are the voyages of the starship Enterprise."
|
|
tmpSubIndex = "Its continuing mission:"
|
|
tmpSubFile = "To explore strange new worlds"
|
|
tmpSubDeepIndex = "To seek out new life and new civilizations"
|
|
tmpSubDeepFile = "To boldly go where no one has gone before"
|
|
|
|
nothing = ""
|
|
ok = http.StatusOK
|
|
missing = http.StatusNotFound
|
|
redirect = http.StatusMovedPermanently
|
|
notFound = "404 page not found\n"
|
|
|
|
files = map[string]string{
|
|
baseDir + tmpIndexName: tmpIndex,
|
|
baseDir + tmpFileName: tmpFile,
|
|
baseDir + tmpSubIndexName: tmpSubIndex,
|
|
baseDir + tmpSubFileName: tmpSubFile,
|
|
baseDir + tmpSubDeepIndexName: tmpSubDeepIndex,
|
|
baseDir + tmpSubDeepFileName: tmpSubDeepFile,
|
|
}
|
|
|
|
serveFileFuncs = []FileServerFunc{
|
|
http.ServeFile,
|
|
WithLogging(http.ServeFile),
|
|
}
|
|
)
|
|
|
|
func TestMain(m *testing.M) {
|
|
code := func(m *testing.M) int {
|
|
if err := setup(); nil != err {
|
|
log.Fatalf("While setting up test got: %v\n", err)
|
|
}
|
|
defer teardown()
|
|
return m.Run()
|
|
}(m)
|
|
os.Exit(code)
|
|
}
|
|
|
|
func setup() (err error) {
|
|
for filename, contents := range files {
|
|
if err = os.MkdirAll(path.Dir(filename), 0700); nil != err {
|
|
return
|
|
}
|
|
if err = ioutil.WriteFile(
|
|
filename,
|
|
[]byte(contents),
|
|
0600,
|
|
); nil != err {
|
|
return
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
func teardown() (err error) {
|
|
return os.RemoveAll("tmp")
|
|
}
|
|
|
|
func TestSetMinimumTLSVersion(t *testing.T) {
|
|
testCases := []struct {
|
|
name string
|
|
value uint16
|
|
expected uint16
|
|
}{
|
|
{"Too low", tls.VersionTLS10 - 1, tls.VersionTLS10},
|
|
{"Lower bounds", tls.VersionTLS10, tls.VersionTLS10},
|
|
{"Upper bounds", tls.VersionTLS13, tls.VersionTLS13},
|
|
{"Too high", tls.VersionTLS13 + 1, tls.VersionTLS13},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
SetMinimumTLSVersion(tc.value)
|
|
if tc.expected != minTLSVersion {
|
|
t.Errorf("Expected %d but got %d", tc.expected, minTLSVersion)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestWithReferrers(t *testing.T) {
|
|
forbidden := http.StatusForbidden
|
|
|
|
ok1 := "http://valid.com"
|
|
ok2 := "https://valid.com"
|
|
ok3 := "http://localhost"
|
|
bad := "http://other.pl"
|
|
|
|
var noRefer []string
|
|
emptyRefer := []string{}
|
|
onlyNoRefer := []string{""}
|
|
refer := []string{ok1, ok2, ok3}
|
|
noWithRefer := []string{"", ok1, ok2, ok3}
|
|
|
|
testCases := []struct {
|
|
name string
|
|
refers []string
|
|
refer string
|
|
code int
|
|
}{
|
|
{"Nil refer list", noRefer, bad, ok},
|
|
{"Empty refer list", emptyRefer, bad, ok},
|
|
{"Unassigned allowed & unassigned", onlyNoRefer, "", ok},
|
|
{"Unassigned allowed & assigned", onlyNoRefer, bad, forbidden},
|
|
{"Whitelist with unassigned", refer, "", forbidden},
|
|
{"Whitelist with bad", refer, bad, forbidden},
|
|
{"Whitelist with ok1", refer, ok1, ok},
|
|
{"Whitelist with ok2", refer, ok2, ok},
|
|
{"Whitelist with ok3", refer, ok3, ok},
|
|
{"Whitelist and none with unassigned", noWithRefer, "", ok},
|
|
{"Whitelist with bad", noWithRefer, bad, forbidden},
|
|
{"Whitelist with ok1", noWithRefer, ok1, ok},
|
|
{"Whitelist with ok2", noWithRefer, ok2, ok},
|
|
{"Whitelist with ok3", noWithRefer, ok3, ok},
|
|
}
|
|
|
|
success := func(w http.ResponseWriter, r *http.Request, name string) {
|
|
defer r.Body.Close()
|
|
w.WriteHeader(ok)
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
handler := WithReferrers(success, tc.refers)
|
|
|
|
fullpath := "http://localhost/" + tmpIndexName
|
|
req := httptest.NewRequest("GET", fullpath, nil)
|
|
req.Header.Add("Referer", tc.refer)
|
|
w := httptest.NewRecorder()
|
|
|
|
handler(w, req, "")
|
|
|
|
resp := w.Result()
|
|
_, err := ioutil.ReadAll(resp.Body)
|
|
if nil != err {
|
|
t.Errorf("While reading body got %v", err)
|
|
}
|
|
if tc.code != resp.StatusCode {
|
|
t.Errorf(
|
|
"With referer '%s' in '%v' expected status code %d but got %d",
|
|
tc.refer, tc.refers, tc.code, resp.StatusCode,
|
|
)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestBasicWithAndWithoutLogging(t *testing.T) {
|
|
referer := "http://localhost"
|
|
noReferer := ""
|
|
testCases := []struct {
|
|
name string
|
|
path string
|
|
code int
|
|
refer string
|
|
contents string
|
|
}{
|
|
{"Good base dir", "", ok, referer, tmpIndex},
|
|
{"Good base index", tmpIndexName, redirect, referer, nothing},
|
|
{"Good base file", tmpFileName, ok, referer, tmpFile},
|
|
{"Bad base file", tmpBadName, missing, referer, notFound},
|
|
{"Good subdir dir", subDir, ok, referer, tmpSubIndex},
|
|
{"Good subdir index", tmpSubIndexName, redirect, referer, nothing},
|
|
{"Good subdir file", tmpSubFileName, ok, referer, tmpSubFile},
|
|
{"Good base dir", "", ok, noReferer, tmpIndex},
|
|
{"Good base index", tmpIndexName, redirect, noReferer, nothing},
|
|
{"Good base file", tmpFileName, ok, noReferer, tmpFile},
|
|
{"Bad base file", tmpBadName, missing, noReferer, notFound},
|
|
{"Good subdir dir", subDir, ok, noReferer, tmpSubIndex},
|
|
{"Good subdir index", tmpSubIndexName, redirect, noReferer, nothing},
|
|
{"Good subdir file", tmpSubFileName, ok, noReferer, tmpSubFile},
|
|
}
|
|
|
|
for _, serveFile := range serveFileFuncs {
|
|
handler := Basic(serveFile, baseDir)
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
fullpath := "http://localhost/" + tc.path
|
|
req := httptest.NewRequest("GET", fullpath, nil)
|
|
req.Header.Add("Referer", tc.refer)
|
|
w := httptest.NewRecorder()
|
|
|
|
handler(w, req)
|
|
|
|
resp := w.Result()
|
|
body, err := ioutil.ReadAll(resp.Body)
|
|
if nil != err {
|
|
t.Errorf("While reading body got %v", err)
|
|
}
|
|
contents := string(body)
|
|
if tc.code != resp.StatusCode {
|
|
t.Errorf(
|
|
"While retrieving %s expected status code of %d but got %d",
|
|
fullpath, tc.code, resp.StatusCode,
|
|
)
|
|
}
|
|
if tc.contents != contents {
|
|
t.Errorf(
|
|
"While retrieving %s expected contents '%s' but got '%s'",
|
|
fullpath, tc.contents, contents,
|
|
)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestPrefix(t *testing.T) {
|
|
prefix := "/my/prefix/path/"
|
|
|
|
testCases := []struct {
|
|
name string
|
|
path string
|
|
code int
|
|
contents string
|
|
}{
|
|
{"Good base dir", prefix, ok, tmpIndex},
|
|
{"Good base index", prefix + tmpIndexName, redirect, nothing},
|
|
{"Good base file", prefix + tmpFileName, ok, tmpFile},
|
|
{"Bad base file", prefix + tmpBadName, missing, notFound},
|
|
{"Good subdir dir", prefix + subDir, ok, tmpSubIndex},
|
|
{"Good subdir index", prefix + tmpSubIndexName, redirect, nothing},
|
|
{"Good subdir file", prefix + tmpSubFileName, ok, tmpSubFile},
|
|
{"Unknown prefix", tmpFileName, missing, notFound},
|
|
}
|
|
|
|
for _, serveFile := range serveFileFuncs {
|
|
handler := Prefix(serveFile, baseDir, prefix)
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
fullpath := "http://localhost" + tc.path
|
|
req := httptest.NewRequest("GET", fullpath, nil)
|
|
w := httptest.NewRecorder()
|
|
|
|
handler(w, req)
|
|
|
|
resp := w.Result()
|
|
body, err := ioutil.ReadAll(resp.Body)
|
|
if nil != err {
|
|
t.Errorf("While reading body got %v", err)
|
|
}
|
|
contents := string(body)
|
|
if tc.code != resp.StatusCode {
|
|
t.Errorf(
|
|
"While retrieving %s expected status code of %d but got %d",
|
|
fullpath, tc.code, resp.StatusCode,
|
|
)
|
|
}
|
|
if tc.contents != contents {
|
|
t.Errorf(
|
|
"While retrieving %s expected contents '%s' but got '%s'",
|
|
fullpath, tc.contents, contents,
|
|
)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestIgnoreIndex(t *testing.T) {
|
|
testCases := []struct {
|
|
name string
|
|
path string
|
|
code int
|
|
contents string
|
|
}{
|
|
{"Good base dir", "", missing, notFound},
|
|
{"Good base index", tmpIndexName, redirect, nothing},
|
|
{"Good base file", tmpFileName, ok, tmpFile},
|
|
{"Bad base file", tmpBadName, missing, notFound},
|
|
{"Good subdir dir", subDir, missing, notFound},
|
|
{"Good subdir index", tmpSubIndexName, redirect, nothing},
|
|
{"Good subdir file", tmpSubFileName, ok, tmpSubFile},
|
|
}
|
|
|
|
for _, serveFile := range serveFileFuncs {
|
|
handler := IgnoreIndex(Basic(serveFile, baseDir))
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
fullpath := "http://localhost/" + tc.path
|
|
req := httptest.NewRequest("GET", fullpath, nil)
|
|
w := httptest.NewRecorder()
|
|
|
|
handler(w, req)
|
|
|
|
resp := w.Result()
|
|
body, err := ioutil.ReadAll(resp.Body)
|
|
if nil != err {
|
|
t.Errorf("While reading body got %v", err)
|
|
}
|
|
contents := string(body)
|
|
if tc.code != resp.StatusCode {
|
|
t.Errorf(
|
|
"While retrieving %s expected status code of %d but got %d",
|
|
fullpath, tc.code, resp.StatusCode,
|
|
)
|
|
}
|
|
if tc.contents != contents {
|
|
t.Errorf(
|
|
"While retrieving %s expected contents '%s' but got '%s'",
|
|
fullpath, tc.contents, contents,
|
|
)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestListening(t *testing.T) {
|
|
// Choose values for testing.
|
|
called := false
|
|
testBinding := "host:port"
|
|
testError := errors.New("random problem")
|
|
|
|
// Create an empty placeholder router function.
|
|
handler := func(http.ResponseWriter, *http.Request) {}
|
|
|
|
// Override setHandler so that multiple calls to 'http.HandleFunc' doesn't
|
|
// panic.
|
|
setHandler = func(string, func(http.ResponseWriter, *http.Request)) {}
|
|
|
|
// Override listenAndServe with a function with more introspection and
|
|
// control than 'http.ListenAndServe'.
|
|
listenAndServe = func(
|
|
binding string, handler http.Handler,
|
|
) error {
|
|
if testBinding != binding {
|
|
t.Errorf(
|
|
"While serving expected binding of %s but got %s",
|
|
testBinding, binding,
|
|
)
|
|
}
|
|
called = !called
|
|
if called {
|
|
return nil
|
|
}
|
|
return testError
|
|
}
|
|
|
|
// Perform test.
|
|
listener := Listening()
|
|
if err := listener(testBinding, handler); nil != err {
|
|
t.Errorf("While serving first expected nil error but got %v", err)
|
|
}
|
|
if err := listener(testBinding, handler); nil == err {
|
|
t.Errorf(
|
|
"While serving second got nil while expecting %v", testError,
|
|
)
|
|
}
|
|
}
|
|
|
|
func TestTLSListening(t *testing.T) {
|
|
// Choose values for testing.
|
|
called := false
|
|
testBinding := "host:port"
|
|
testTLSCert := "test/file.pem"
|
|
testTLSKey := "test/file.key"
|
|
testError := errors.New("random problem")
|
|
|
|
// Create an empty placeholder router function.
|
|
handler := func(http.ResponseWriter, *http.Request) {}
|
|
|
|
// Override setHandler so that multiple calls to 'http.HandleFunc' doesn't
|
|
// panic.
|
|
setHandler = func(string, func(http.ResponseWriter, *http.Request)) {}
|
|
|
|
// Override listenAndServeTLS with a function with more introspection and
|
|
// control than 'http.ListenAndServeTLS'.
|
|
listenAndServeTLS = func(
|
|
binding, tlsCert, tlsKey string, handler http.Handler,
|
|
) error {
|
|
if testBinding != binding {
|
|
t.Errorf(
|
|
"While serving TLS expected binding of %s but got %s",
|
|
testBinding, binding,
|
|
)
|
|
}
|
|
if testTLSCert != tlsCert {
|
|
t.Errorf(
|
|
"While serving TLS expected TLS cert of %s but got %s",
|
|
testTLSCert, tlsCert,
|
|
)
|
|
}
|
|
if testTLSKey != tlsKey {
|
|
t.Errorf(
|
|
"While serving TLS expected TLS key of %s but got %s",
|
|
testTLSKey, tlsKey,
|
|
)
|
|
}
|
|
called = !called
|
|
if called {
|
|
return nil
|
|
}
|
|
return testError
|
|
}
|
|
|
|
// Perform test.
|
|
listener := TLSListening(testTLSCert, testTLSKey)
|
|
if err := listener(testBinding, handler); nil != err {
|
|
t.Errorf("While serving first TLS expected nil error but got %v", err)
|
|
}
|
|
if err := listener(testBinding, handler); nil == err {
|
|
t.Errorf(
|
|
"While serving second TLS got nil while expecting %v", testError,
|
|
)
|
|
}
|
|
}
|
|
|
|
func TestValidReferrer(t *testing.T) {
|
|
ok1 := "http://valid.com"
|
|
ok2 := "https://valid.com"
|
|
ok3 := "http://localhost"
|
|
bad := "http://other.pl"
|
|
|
|
var noRefer []string
|
|
emptyRefer := []string{}
|
|
onlyNoRefer := []string{""}
|
|
refer := []string{ok1, ok2, ok3}
|
|
noWithRefer := []string{"", ok1, ok2, ok3}
|
|
|
|
testCases := []struct {
|
|
name string
|
|
refers []string
|
|
refer string
|
|
result bool
|
|
}{
|
|
{"Nil refer list", noRefer, bad, true},
|
|
{"Empty refer list", emptyRefer, bad, true},
|
|
{"Unassigned allowed & unassigned", onlyNoRefer, "", true},
|
|
{"Unassigned allowed & assigned", onlyNoRefer, bad, false},
|
|
{"Whitelist with unassigned", refer, "", false},
|
|
{"Whitelist with bad", refer, bad, false},
|
|
{"Whitelist with ok1", refer, ok1, true},
|
|
{"Whitelist with ok2", refer, ok2, true},
|
|
{"Whitelist with ok3", refer, ok3, true},
|
|
{"Whitelist and none with unassigned", noWithRefer, "", true},
|
|
{"Whitelist with bad", noWithRefer, bad, false},
|
|
{"Whitelist with ok1", noWithRefer, ok1, true},
|
|
{"Whitelist with ok2", noWithRefer, ok2, true},
|
|
{"Whitelist with ok3", noWithRefer, ok3, true},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
result := validReferrer(tc.refers, tc.refer)
|
|
if result != tc.result {
|
|
t.Errorf(
|
|
"With referrers of '%v' and a value of '%s' expected %t but got %t",
|
|
tc.refers, tc.refer, tc.result, result,
|
|
)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAddCorsWildcardHeaders(t *testing.T) {
|
|
testCases := []struct {
|
|
name string
|
|
corsEnabled bool
|
|
}{
|
|
{"CORS disabled", false},
|
|
{"CORS enabled", true},
|
|
}
|
|
|
|
corsHeaders := map[string]string{
|
|
"Access-Control-Allow-Origin": "*",
|
|
"Access-Control-Allow-Headers": "*",
|
|
}
|
|
|
|
for _, serveFile := range serveFileFuncs {
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
var handler http.HandlerFunc
|
|
if tc.corsEnabled {
|
|
handler = AddCorsWildcardHeaders(Basic(serveFile, baseDir))
|
|
} else {
|
|
handler = Basic(serveFile, baseDir)
|
|
}
|
|
|
|
fullpath := "http://localhost/" + tmpFileName
|
|
req := httptest.NewRequest("GET", fullpath, nil)
|
|
w := httptest.NewRecorder()
|
|
|
|
handler(w, req)
|
|
|
|
resp := w.Result()
|
|
body, err := ioutil.ReadAll(resp.Body)
|
|
if nil != err {
|
|
t.Errorf("While reading body got %v", err)
|
|
}
|
|
contents := string(body)
|
|
if ok != resp.StatusCode {
|
|
t.Errorf(
|
|
"While retrieving %s expected status code of %d but got %d",
|
|
fullpath, ok, resp.StatusCode,
|
|
)
|
|
}
|
|
if tmpFile != contents {
|
|
t.Errorf(
|
|
"While retrieving %s expected contents '%s' but got '%s'",
|
|
fullpath, tmpFile, contents,
|
|
)
|
|
}
|
|
|
|
if tc.corsEnabled {
|
|
for k, v := range corsHeaders {
|
|
if v != resp.Header.Get(k) {
|
|
t.Errorf(
|
|
"With CORS enabled expect header '%s' to return '%s' but got '%s'",
|
|
k, v, resp.Header.Get(k),
|
|
)
|
|
}
|
|
}
|
|
} else {
|
|
for k := range corsHeaders {
|
|
if "" != resp.Header.Get(k) {
|
|
t.Errorf(
|
|
"With CORS disabled expected header '%s' to return '' but got '%s'",
|
|
k, resp.Header.Get(k),
|
|
)
|
|
}
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
}
|