Added unit testing to referrer handler. Moved private function to bottom. Added function comments and removed commented-out code.

This commit is contained in:
Jeromy Streets 2019-01-21 21:49:58 -08:00
parent 3b5291c19e
commit eaa2bf565d
2 changed files with 154 additions and 36 deletions

View File

@ -1,6 +1,7 @@
package handle
import (
"fmt"
"log"
"net/http"
"strings"
@ -26,44 +27,21 @@ type ListenerFunc func(string, http.HandlerFunc) error
// requesting client.
type FileServerFunc func(http.ResponseWriter, *http.Request, string)
func validReferrer(s []string, e string) bool {
if (s == nil) {
// log.Printf("No referrers specified, all fine.")
return true
}
// log.Printf("Checking referrers " + strings.Join(s, ",") + " against " + e)
for _, a := range s {
// Handle blank HTTP Referer header, if configured
if (a == "") {
if (e == "") {
// log.Printf("No referrer in request. Allowing.");
return true;
}
// Continue loop (all strings start with "")
continue;
}
// Compare header with allowed prefixes
if strings.HasPrefix(e, a) {
// log.Printf(strings.Join([]string{ "Referrer match", e, a }, " "));
return true
}
}
return false
}
// WithReferrers returns a function that evaluates the HTTP 'Referer' header
// value and returns HTTP error 403 if the value is not found in the whitelist.
// If one of the whitelisted referrers are an empty string, then it is allowed
// for the 'Referer' HTTP header key to not be set.
func WithReferrers(serveFile FileServerFunc, referrers []string) FileServerFunc {
return func(w http.ResponseWriter, r *http.Request, name string) {
if (validReferrer(referrers, r.Referer())) {
// log.Printf("Serving file.")
serveFile(w, r, name)
} else {
// log.Printf(strings.Join([]string{"Invalid referrer", r.Referer(), "Not in", strings.Join(referrers, ",")}, " "))
http.Error(w, strings.Join([]string{ "Invalid source", r.Referer() }, " "), 403)
return
if !validReferrer(referrers, r.Referer()) {
http.Error(
w,
fmt.Sprintf("Invalid source '%s'", r.Referer()),
http.StatusForbidden,
)
return
}
serveFile(w, r, name)
}
}
@ -72,7 +50,7 @@ func WithReferrers(serveFile FileServerFunc, referrers []string) FileServerFunc
func WithLogging(serveFile FileServerFunc) FileServerFunc {
return func(w http.ResponseWriter, r *http.Request, name string) {
log.Printf(
"REQ from %s: %s %s %s%s -> %s\n",
"REQ from '%s': %s %s %s%s -> %s\n",
r.Referer(),
r.Method,
r.Proto,
@ -132,3 +110,29 @@ func TLSListening(tlsCert, tlsKey string) ListenerFunc {
return listenAndServeTLS(binding, tlsCert, tlsKey, nil)
}
}
// validReferrer returns true if the passed referrer can be resolved by the
// passed list of referrers.
func validReferrer(s []string, e string) bool {
// Whitelisted referer list is empty. All requests are allowed.
if 0 == len(s) {
return true
}
for _, a := range s {
// Handle blank HTTP Referer header, if configured
if a == "" {
if e == "" {
return true
}
// Continue loop (all strings start with "")
continue
}
// Compare header with allowed prefixes
if strings.HasPrefix(e, a) {
return true
}
}
return false
}

View File

@ -84,6 +84,73 @@ func teardown() (err error) {
return os.RemoveAll("tmp")
}
func TestWithReferrers(t *testing.T) {
forbidden := http.StatusForbidden
ok1 := "http://valid.com"
ok2 := "https://valid.com"
ok3 := "http://localhost"
bad := "http://other.pl"
var noRefer []string
emptyRefer := []string{}
onlyNoRefer := []string{""}
refer := []string{ok1, ok2, ok3}
noWithRefer := []string{"", ok1, ok2, ok3}
testCases := []struct {
name string
refers []string
refer string
code int
}{
{"Nil refer list", noRefer, bad, ok},
{"Empty refer list", emptyRefer, bad, ok},
{"Unassigned allowed & unassigned", onlyNoRefer, "", ok},
{"Unassigned allowed & assigned", onlyNoRefer, bad, forbidden},
{"Whitelist with unassigned", refer, "", forbidden},
{"Whitelist with bad", refer, bad, forbidden},
{"Whitelist with ok1", refer, ok1, ok},
{"Whitelist with ok2", refer, ok2, ok},
{"Whitelist with ok3", refer, ok3, ok},
{"Whitelist and none with unassigned", noWithRefer, "", ok},
{"Whitelist with bad", noWithRefer, bad, forbidden},
{"Whitelist with ok1", noWithRefer, ok1, ok},
{"Whitelist with ok2", noWithRefer, ok2, ok},
{"Whitelist with ok3", noWithRefer, ok3, ok},
}
success := func(w http.ResponseWriter, r *http.Request, name string) {
defer r.Body.Close()
w.WriteHeader(ok)
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
handler := WithReferrers(success, tc.refers)
fullpath := "http://localhost/" + tmpIndexName
req := httptest.NewRequest("GET", fullpath, nil)
req.Header.Add("Referer", tc.refer)
w := httptest.NewRecorder()
handler(w, req, "")
resp := w.Result()
_, err := ioutil.ReadAll(resp.Body)
if nil != err {
t.Errorf("While reading body got %v", err)
}
if tc.code != resp.StatusCode {
t.Errorf(
"With referer '%s' in '%v' expected status code %d but got %d",
tc.refer, tc.refers, tc.code, resp.StatusCode,
)
}
})
}
}
func TestBasicWithAndWithoutLogging(t *testing.T) {
testCases := []struct {
name string
@ -333,3 +400,50 @@ func TestTLSListening(t *testing.T) {
)
}
}
func TestValidReferrer(t *testing.T) {
ok1 := "http://valid.com"
ok2 := "https://valid.com"
ok3 := "http://localhost"
bad := "http://other.pl"
var noRefer []string
emptyRefer := []string{}
onlyNoRefer := []string{""}
refer := []string{ok1, ok2, ok3}
noWithRefer := []string{"", ok1, ok2, ok3}
testCases := []struct {
name string
refers []string
refer string
result bool
}{
{"Nil refer list", noRefer, bad, true},
{"Empty refer list", emptyRefer, bad, true},
{"Unassigned allowed & unassigned", onlyNoRefer, "", true},
{"Unassigned allowed & assigned", onlyNoRefer, bad, false},
{"Whitelist with unassigned", refer, "", false},
{"Whitelist with bad", refer, bad, false},
{"Whitelist with ok1", refer, ok1, true},
{"Whitelist with ok2", refer, ok2, true},
{"Whitelist with ok3", refer, ok3, true},
{"Whitelist and none with unassigned", noWithRefer, "", true},
{"Whitelist with bad", noWithRefer, bad, false},
{"Whitelist with ok1", noWithRefer, ok1, true},
{"Whitelist with ok2", noWithRefer, ok2, true},
{"Whitelist with ok3", noWithRefer, ok3, true},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := validReferrer(tc.refers, tc.refer)
if result != tc.result {
t.Errorf(
"With referrers of '%v' and a value of '%s' expected %t but got %t",
tc.refers, tc.refer, tc.result, result,
)
}
})
}
}