mirror of
https://github.com/halverneus/static-file-server.git
synced 2024-11-24 09:05:30 +00:00
Added unit testing to referrer handler. Moved private function to bottom. Added function comments and removed commented-out code.
This commit is contained in:
parent
3b5291c19e
commit
eaa2bf565d
@ -1,6 +1,7 @@
|
||||
package handle
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
@ -26,44 +27,21 @@ type ListenerFunc func(string, http.HandlerFunc) error
|
||||
// requesting client.
|
||||
type FileServerFunc func(http.ResponseWriter, *http.Request, string)
|
||||
|
||||
func validReferrer(s []string, e string) bool {
|
||||
if (s == nil) {
|
||||
// log.Printf("No referrers specified, all fine.")
|
||||
return true
|
||||
}
|
||||
|
||||
// log.Printf("Checking referrers " + strings.Join(s, ",") + " against " + e)
|
||||
|
||||
for _, a := range s {
|
||||
// Handle blank HTTP Referer header, if configured
|
||||
if (a == "") {
|
||||
if (e == "") {
|
||||
// log.Printf("No referrer in request. Allowing.");
|
||||
return true;
|
||||
}
|
||||
// Continue loop (all strings start with "")
|
||||
continue;
|
||||
}
|
||||
|
||||
// Compare header with allowed prefixes
|
||||
if strings.HasPrefix(e, a) {
|
||||
// log.Printf(strings.Join([]string{ "Referrer match", e, a }, " "));
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// WithReferrers returns a function that evaluates the HTTP 'Referer' header
|
||||
// value and returns HTTP error 403 if the value is not found in the whitelist.
|
||||
// If one of the whitelisted referrers are an empty string, then it is allowed
|
||||
// for the 'Referer' HTTP header key to not be set.
|
||||
func WithReferrers(serveFile FileServerFunc, referrers []string) FileServerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request, name string) {
|
||||
if (validReferrer(referrers, r.Referer())) {
|
||||
// log.Printf("Serving file.")
|
||||
serveFile(w, r, name)
|
||||
} else {
|
||||
// log.Printf(strings.Join([]string{"Invalid referrer", r.Referer(), "Not in", strings.Join(referrers, ",")}, " "))
|
||||
http.Error(w, strings.Join([]string{ "Invalid source", r.Referer() }, " "), 403)
|
||||
return
|
||||
if !validReferrer(referrers, r.Referer()) {
|
||||
http.Error(
|
||||
w,
|
||||
fmt.Sprintf("Invalid source '%s'", r.Referer()),
|
||||
http.StatusForbidden,
|
||||
)
|
||||
return
|
||||
}
|
||||
serveFile(w, r, name)
|
||||
}
|
||||
}
|
||||
|
||||
@ -72,7 +50,7 @@ func WithReferrers(serveFile FileServerFunc, referrers []string) FileServerFunc
|
||||
func WithLogging(serveFile FileServerFunc) FileServerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request, name string) {
|
||||
log.Printf(
|
||||
"REQ from %s: %s %s %s%s -> %s\n",
|
||||
"REQ from '%s': %s %s %s%s -> %s\n",
|
||||
r.Referer(),
|
||||
r.Method,
|
||||
r.Proto,
|
||||
@ -132,3 +110,29 @@ func TLSListening(tlsCert, tlsKey string) ListenerFunc {
|
||||
return listenAndServeTLS(binding, tlsCert, tlsKey, nil)
|
||||
}
|
||||
}
|
||||
|
||||
// validReferrer returns true if the passed referrer can be resolved by the
|
||||
// passed list of referrers.
|
||||
func validReferrer(s []string, e string) bool {
|
||||
// Whitelisted referer list is empty. All requests are allowed.
|
||||
if 0 == len(s) {
|
||||
return true
|
||||
}
|
||||
|
||||
for _, a := range s {
|
||||
// Handle blank HTTP Referer header, if configured
|
||||
if a == "" {
|
||||
if e == "" {
|
||||
return true
|
||||
}
|
||||
// Continue loop (all strings start with "")
|
||||
continue
|
||||
}
|
||||
|
||||
// Compare header with allowed prefixes
|
||||
if strings.HasPrefix(e, a) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
@ -84,6 +84,73 @@ func teardown() (err error) {
|
||||
return os.RemoveAll("tmp")
|
||||
}
|
||||
|
||||
func TestWithReferrers(t *testing.T) {
|
||||
forbidden := http.StatusForbidden
|
||||
|
||||
ok1 := "http://valid.com"
|
||||
ok2 := "https://valid.com"
|
||||
ok3 := "http://localhost"
|
||||
bad := "http://other.pl"
|
||||
|
||||
var noRefer []string
|
||||
emptyRefer := []string{}
|
||||
onlyNoRefer := []string{""}
|
||||
refer := []string{ok1, ok2, ok3}
|
||||
noWithRefer := []string{"", ok1, ok2, ok3}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
refers []string
|
||||
refer string
|
||||
code int
|
||||
}{
|
||||
{"Nil refer list", noRefer, bad, ok},
|
||||
{"Empty refer list", emptyRefer, bad, ok},
|
||||
{"Unassigned allowed & unassigned", onlyNoRefer, "", ok},
|
||||
{"Unassigned allowed & assigned", onlyNoRefer, bad, forbidden},
|
||||
{"Whitelist with unassigned", refer, "", forbidden},
|
||||
{"Whitelist with bad", refer, bad, forbidden},
|
||||
{"Whitelist with ok1", refer, ok1, ok},
|
||||
{"Whitelist with ok2", refer, ok2, ok},
|
||||
{"Whitelist with ok3", refer, ok3, ok},
|
||||
{"Whitelist and none with unassigned", noWithRefer, "", ok},
|
||||
{"Whitelist with bad", noWithRefer, bad, forbidden},
|
||||
{"Whitelist with ok1", noWithRefer, ok1, ok},
|
||||
{"Whitelist with ok2", noWithRefer, ok2, ok},
|
||||
{"Whitelist with ok3", noWithRefer, ok3, ok},
|
||||
}
|
||||
|
||||
success := func(w http.ResponseWriter, r *http.Request, name string) {
|
||||
defer r.Body.Close()
|
||||
w.WriteHeader(ok)
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
handler := WithReferrers(success, tc.refers)
|
||||
|
||||
fullpath := "http://localhost/" + tmpIndexName
|
||||
req := httptest.NewRequest("GET", fullpath, nil)
|
||||
req.Header.Add("Referer", tc.refer)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler(w, req, "")
|
||||
|
||||
resp := w.Result()
|
||||
_, err := ioutil.ReadAll(resp.Body)
|
||||
if nil != err {
|
||||
t.Errorf("While reading body got %v", err)
|
||||
}
|
||||
if tc.code != resp.StatusCode {
|
||||
t.Errorf(
|
||||
"With referer '%s' in '%v' expected status code %d but got %d",
|
||||
tc.refer, tc.refers, tc.code, resp.StatusCode,
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBasicWithAndWithoutLogging(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
@ -333,3 +400,50 @@ func TestTLSListening(t *testing.T) {
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidReferrer(t *testing.T) {
|
||||
ok1 := "http://valid.com"
|
||||
ok2 := "https://valid.com"
|
||||
ok3 := "http://localhost"
|
||||
bad := "http://other.pl"
|
||||
|
||||
var noRefer []string
|
||||
emptyRefer := []string{}
|
||||
onlyNoRefer := []string{""}
|
||||
refer := []string{ok1, ok2, ok3}
|
||||
noWithRefer := []string{"", ok1, ok2, ok3}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
refers []string
|
||||
refer string
|
||||
result bool
|
||||
}{
|
||||
{"Nil refer list", noRefer, bad, true},
|
||||
{"Empty refer list", emptyRefer, bad, true},
|
||||
{"Unassigned allowed & unassigned", onlyNoRefer, "", true},
|
||||
{"Unassigned allowed & assigned", onlyNoRefer, bad, false},
|
||||
{"Whitelist with unassigned", refer, "", false},
|
||||
{"Whitelist with bad", refer, bad, false},
|
||||
{"Whitelist with ok1", refer, ok1, true},
|
||||
{"Whitelist with ok2", refer, ok2, true},
|
||||
{"Whitelist with ok3", refer, ok3, true},
|
||||
{"Whitelist and none with unassigned", noWithRefer, "", true},
|
||||
{"Whitelist with bad", noWithRefer, bad, false},
|
||||
{"Whitelist with ok1", noWithRefer, ok1, true},
|
||||
{"Whitelist with ok2", noWithRefer, ok2, true},
|
||||
{"Whitelist with ok3", noWithRefer, ok3, true},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := validReferrer(tc.refers, tc.refer)
|
||||
if result != tc.result {
|
||||
t.Errorf(
|
||||
"With referrers of '%v' and a value of '%s' expected %t but got %t",
|
||||
tc.refers, tc.refer, tc.result, result,
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user