diff --git a/config/config.go b/config/config.go index 0bcf18e..4e970bc 100644 --- a/config/config.go +++ b/config/config.go @@ -77,8 +77,9 @@ func Load(filename string) (err error) { if err = yaml.Unmarshal(contents, &Get); nil != err { return } + overrideWithEnvVars() - return + return validate() } // overrideWithEnvVars the default values and the configuration file values. @@ -93,6 +94,38 @@ func overrideWithEnvVars() { Get.URLPrefix = envAsStr(urlPrefixKey, Get.URLPrefix) } +// validate the configuration. +func validate() error { + // If HTTPS is to be used, verify both TLS_* environment variables are set. + if 0 < len(Get.TLSCert) || 0 < len(Get.TLSKey) { + if 0 == len(Get.TLSCert) || 0 == len(Get.TLSKey) { + msg := "if value for either 'TLS_CERT' or 'TLS_KEY' is set then " + + "then value for the other must also be set (values are " + + "currently '%s' and '%s', respectively)" + return fmt.Errorf(msg, Get.TLSCert, Get.TLSKey) + } + if _, err := os.Stat(Get.TLSCert); nil != err { + msg := "value of TLS_CERT is set with filename '%s' that returns %v" + return fmt.Errorf(msg, err) + } + if _, err := os.Stat(Get.TLSKey); nil != err { + msg := "value of TLS_KEY is set with filename '%s' that returns %v" + return fmt.Errorf(msg, err) + } + } + + // If the URL path prefix is to be used, verify it is properly formatted. + if 0 < len(Get.URLPrefix) && + (!strings.HasPrefix(Get.URLPrefix, "/") || strings.HasSuffix(Get.URLPrefix, "/")) { + msg := "if value for 'URL_PREFIX' is set then the value must start " + + "with '/' and not end with '/' (current value of '%s' vs valid " + + "example of '/my/prefix'" + return fmt.Errorf(msg, Get.URLPrefix) + } + + return nil +} + // envAsStr returns the value of the environment variable as a string if set. func envAsStr(key, fallback string) string { if value := os.Getenv(key); "" != value { diff --git a/config/config_test.go b/config/config_test.go index af0ac65..1aa9e24 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -135,6 +135,51 @@ func TestOverrideWithEnvvars(t *testing.T) { equalStrings(t, phase, urlPrefixKey, testURLPrefix, Get.URLPrefix) } +func TestValidate(t *testing.T) { + validPath := "config.go" + invalidPath := "should/never/exist.txt" + empty := "" + prefix := "/my/prefix" + + testCases := []struct { + name string + cert string + key string + prefix string + isError bool + }{ + {"Valid paths w/prefix", validPath, validPath, prefix, false}, + {"Valid paths wo/prefix", validPath, validPath, empty, false}, + {"Empty paths w/prefix", empty, empty, prefix, false}, + {"Empty paths wo/prefix", empty, empty, empty, false}, + {"Mixed paths w/prefix", empty, validPath, prefix, true}, + {"Alt mixed paths w/prefix", validPath, empty, prefix, true}, + {"Mixed paths wo/prefix", empty, validPath, empty, true}, + {"Alt mixed paths wo/prefix", validPath, empty, empty, true}, + {"Invalid cert w/prefix", invalidPath, validPath, prefix, true}, + {"Invalid key w/prefix", validPath, invalidPath, prefix, true}, + {"Invalid cert & key w/prefix", invalidPath, invalidPath, prefix, true}, + {"Prefix missing leading /", empty, empty, "my/prefix", true}, + {"Prefix with trailing /", empty, empty, "/my/prefix/", true}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + Get.TLSCert = tc.cert + Get.TLSKey = tc.key + Get.URLPrefix = tc.prefix + err := validate() + hasError := nil != err + if hasError && !tc.isError { + t.Errorf("Expected no error but got %v", err) + } + if !hasError && tc.isError { + t.Error("Expected an error but got no error") + } + }) + } +} + func TestEnvAsStr(t *testing.T) { sv := "STRING_VALUE" fv := "FLOAT_VALUE" diff --git a/handle/handle.go b/handle/handle.go index b4d41e3..ed04c44 100644 --- a/handle/handle.go +++ b/handle/handle.go @@ -5,6 +5,20 @@ import ( "strings" ) +var ( + // These assignments are for unit testing. + listenAndServe = http.ListenAndServe + listenAndServeTLS = http.ListenAndServeTLS + setHandler = http.HandleFunc +) + +var ( + server http.Server +) + +// ListenerFunc accepts the {hostname:port} binding string required by HTTP +// listeners and the handler (router) function and returns any errors that +// occur. type ListenerFunc func(string, http.HandlerFunc) error // Basic file handler servers files from the passed folder. @@ -43,15 +57,15 @@ func IgnoreIndex(serve http.HandlerFunc) http.HandlerFunc { // Listening function for serving the handler function. func Listening() ListenerFunc { return func(binding string, handler http.HandlerFunc) error { - http.HandleFunc("/", handler) - return http.ListenAndServe(binding, nil) + setHandler("/", handler) + return listenAndServe(binding, nil) } } // TLSListening function for serving the handler function with encryption. func TLSListening(tlsCert, tlsKey string) ListenerFunc { return func(binding string, handler http.HandlerFunc) error { - http.HandleFunc("/", handler) - return http.ListenAndServeTLS(binding, tlsCert, tlsKey, nil) + setHandler("/", handler) + return listenAndServeTLS(binding, tlsCert, tlsKey, nil) } } diff --git a/handle/handle_test.go b/handle/handle_test.go new file mode 100644 index 0000000..20df1f3 --- /dev/null +++ b/handle/handle_test.go @@ -0,0 +1,351 @@ +package handle + +import ( + "errors" + "io/ioutil" + "log" + "net/http" + "net/http/httptest" + "os" + "path" + "testing" +) + +var ( + baseDir = "tmp/" + subDir = "sub/" + subDeepDir = "sub/deep/" + tmpIndexName = "index.html" + tmpFileName = "file.txt" + tmpBadName = "bad.txt" + tmpSubIndexName = "sub/index.html" + tmpSubFileName = "sub/file.txt" + tmpSubBadName = "sub/bad.txt" + tmpSubDeepIndexName = "sub/deep/index.html" + tmpSubDeepFileName = "sub/deep/file.txt" + tmpSubDeepBadName = "sub/deep/bad.txt" + + tmpIndex = "Space: the final frontier" + tmpFile = "These are the voyages of the starship Enterprise." + tmpSubIndex = "Its continuing mission:" + tmpSubFile = "To explore strange new worlds" + tmpSubDeepIndex = "To seek out new life and new civilizations" + tmpSubDeepFile = "To boldly go where no one has gone before" + + nothing = "" + ok = http.StatusOK + missing = http.StatusNotFound + redirect = http.StatusMovedPermanently + notFound = "404 page not found\n" + + files = map[string]string{ + baseDir + tmpIndexName: tmpIndex, + baseDir + tmpFileName: tmpFile, + baseDir + tmpSubIndexName: tmpSubIndex, + baseDir + tmpSubFileName: tmpSubFile, + baseDir + tmpSubDeepIndexName: tmpSubDeepIndex, + baseDir + tmpSubDeepFileName: tmpSubDeepFile, + } +) + +func TestMain(m *testing.M) { + code := func(m *testing.M) int { + if err := setup(); nil != err { + log.Fatalf("While setting up test got: %v\n", err) + } + defer teardown() + return m.Run() + }(m) + os.Exit(code) +} + +func setup() (err error) { + for filename, contents := range files { + if err = os.MkdirAll(path.Dir(filename), 0700); nil != err { + return + } + if err = ioutil.WriteFile( + filename, + []byte(contents), + 0600, + ); nil != err { + return + } + } + return +} + +func teardown() (err error) { + return os.RemoveAll("tmp") +} + +func TestBasic(t *testing.T) { + testCases := []struct { + name string + path string + code int + contents string + }{ + {"Good base dir", "", ok, tmpIndex}, + {"Good base index", tmpIndexName, redirect, nothing}, + {"Good base file", tmpFileName, ok, tmpFile}, + {"Bad base file", tmpBadName, missing, notFound}, + {"Good subdir dir", subDir, ok, tmpSubIndex}, + {"Good subdir index", tmpSubIndexName, redirect, nothing}, + {"Good subdir file", tmpSubFileName, ok, tmpSubFile}, + } + + handler := Basic(baseDir) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + fullpath := "http://localhost/" + tc.path + req := httptest.NewRequest("GET", fullpath, nil) + w := httptest.NewRecorder() + + handler(w, req) + + resp := w.Result() + body, err := ioutil.ReadAll(resp.Body) + if nil != err { + t.Errorf("While reading body got %v", err) + } + contents := string(body) + if tc.code != resp.StatusCode { + t.Errorf( + "While retrieving %s expected status code of %d but got %d", + fullpath, tc.code, resp.StatusCode, + ) + } + if tc.contents != contents { + t.Errorf( + "While retrieving %s expected contents '%s' but got '%s'", + fullpath, tc.contents, contents, + ) + } + }) + } +} + +func TestPrefix(t *testing.T) { + prefix := "/my/prefix/path/" + + testCases := []struct { + name string + path string + code int + contents string + }{ + {"Good base dir", prefix, ok, tmpIndex}, + {"Good base index", prefix + tmpIndexName, redirect, nothing}, + {"Good base file", prefix + tmpFileName, ok, tmpFile}, + {"Bad base file", prefix + tmpBadName, missing, notFound}, + {"Good subdir dir", prefix + subDir, ok, tmpSubIndex}, + {"Good subdir index", prefix + tmpSubIndexName, redirect, nothing}, + {"Good subdir file", prefix + tmpSubFileName, ok, tmpSubFile}, + {"Unknown prefix", tmpFileName, missing, notFound}, + } + + handler := Prefix(baseDir, prefix) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + fullpath := "http://localhost" + tc.path + req := httptest.NewRequest("GET", fullpath, nil) + w := httptest.NewRecorder() + + handler(w, req) + + resp := w.Result() + body, err := ioutil.ReadAll(resp.Body) + if nil != err { + t.Errorf("While reading body got %v", err) + } + contents := string(body) + if tc.code != resp.StatusCode { + t.Errorf( + "While retrieving %s expected status code of %d but got %d", + fullpath, tc.code, resp.StatusCode, + ) + } + if tc.contents != contents { + t.Errorf( + "While retrieving %s expected contents '%s' but got '%s'", + fullpath, tc.contents, contents, + ) + } + }) + } +} + +func TestIgnoreIndex(t *testing.T) { + testCases := []struct { + name string + path string + code int + contents string + }{ + {"Good base dir", "", missing, notFound}, + {"Good base index", tmpIndexName, redirect, nothing}, + {"Good base file", tmpFileName, ok, tmpFile}, + {"Bad base file", tmpBadName, missing, notFound}, + {"Good subdir dir", subDir, missing, notFound}, + {"Good subdir index", tmpSubIndexName, redirect, nothing}, + {"Good subdir file", tmpSubFileName, ok, tmpSubFile}, + } + + handler := IgnoreIndex(Basic(baseDir)) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + fullpath := "http://localhost/" + tc.path + req := httptest.NewRequest("GET", fullpath, nil) + w := httptest.NewRecorder() + + handler(w, req) + + resp := w.Result() + body, err := ioutil.ReadAll(resp.Body) + if nil != err { + t.Errorf("While reading body got %v", err) + } + contents := string(body) + if tc.code != resp.StatusCode { + t.Errorf( + "While retrieving %s expected status code of %d but got %d", + fullpath, tc.code, resp.StatusCode, + ) + } + if tc.contents != contents { + t.Errorf( + "While retrieving %s expected contents '%s' but got '%s'", + fullpath, tc.contents, contents, + ) + } + }) + } +} + +// func TestIgnoreIndex(t *testing.T) { +// handler := IgnoreIndex(Basic("tmp")) +// testCases := []struct { +// name string +// path string +// code int +// contents string +// }{} + +// // Build test cases for directories. +// var dirs []string +// for filename, contents := range files { +// dir := path.Dir(filename) +// found := false +// for _, other := range dirs { +// if other == dir { +// found = true +// break +// } +// } +// if !found { +// dirs = append(dirs, dir) +// } +// } + +// } + +func TestListening(t *testing.T) { + // Choose values for testing. + called := false + testBinding := "host:port" + testError := errors.New("random problem") + + // Create an empty placeholder router function. + handler := func(http.ResponseWriter, *http.Request) {} + + // Override setHandler so that multiple calls to 'http.HandleFunc' doesn't + // panic. + setHandler = func(string, func(http.ResponseWriter, *http.Request)) {} + + // Override listenAndServe with a function with more introspection and + // control than 'http.ListenAndServe'. + listenAndServe = func( + binding string, handler http.Handler, + ) error { + if testBinding != binding { + t.Errorf( + "While serving expected binding of %s but got %s", + testBinding, binding, + ) + } + called = !called + if called { + return nil + } + return testError + } + + // Perform test. + listener := Listening() + if err := listener(testBinding, handler); nil != err { + t.Errorf("While serving first expected nil error but got %v", err) + } + if err := listener(testBinding, handler); nil == err { + t.Errorf( + "While serving second got nil while expecting %v", testError, + ) + } +} + +func TestTLSListening(t *testing.T) { + // Choose values for testing. + called := false + testBinding := "host:port" + testTLSCert := "test/file.pem" + testTLSKey := "test/file.key" + testError := errors.New("random problem") + + // Create an empty placeholder router function. + handler := func(http.ResponseWriter, *http.Request) {} + + // Override setHandler so that multiple calls to 'http.HandleFunc' doesn't + // panic. + setHandler = func(string, func(http.ResponseWriter, *http.Request)) {} + + // Override listenAndServeTLS with a function with more introspection and + // control than 'http.ListenAndServeTLS'. + listenAndServeTLS = func( + binding, tlsCert, tlsKey string, handler http.Handler, + ) error { + if testBinding != binding { + t.Errorf( + "While serving TLS expected binding of %s but got %s", + testBinding, binding, + ) + } + if testTLSCert != tlsCert { + t.Errorf( + "While serving TLS expected TLS cert of %s but got %s", + testTLSCert, tlsCert, + ) + } + if testTLSKey != tlsKey { + t.Errorf( + "While serving TLS expected TLS key of %s but got %s", + testTLSKey, tlsKey, + ) + } + called = !called + if called { + return nil + } + return testError + } + + // Perform test. + listener := TLSListening(testTLSCert, testTLSKey) + if err := listener(testBinding, handler); nil != err { + t.Errorf("While serving first TLS expected nil error but got %v", err) + } + if err := listener(testBinding, handler); nil == err { + t.Errorf( + "While serving second TLS got nil while expecting %v", testError, + ) + } +}