diff --git a/resource.go b/resource.go index a5050a2..d9f0ad6 100644 --- a/resource.go +++ b/resource.go @@ -18,8 +18,8 @@ type Resource struct { delete IDFunc } -func NewResource(list, create Func, get, update, delete IDFunc) Resource { - return Resource{list, create, get, update, delete} +func NewResource(list, create Func, get, update, delete IDFunc) *Resource { + return &Resource{list, create, get, update, delete} } func (resource *Resource) Handle(path string, w http.ResponseWriter, req *http.Request, user *auth.User) bool { diff --git a/router.go b/router.go index 3ac0624..1719c5d 100644 --- a/router.go +++ b/router.go @@ -73,6 +73,14 @@ func (router *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { w.WriteHeader(404) } +func (router *Router) Resource(mount string, list, create Func, get, update, delete IDFunc) { + router.Route(mount, NewResource(list, create, get, update, delete)) +} + +func (router *Router) Static(mount string, filePath string) { + router.Route(mount, NewStatic(filePath)) +} + func (router *Router) Listen(host string, port int) error { return http.ListenAndServe(fmt.Sprintf("%s:%d", host, port), router) } diff --git a/static.go b/static.go new file mode 100644 index 0000000..520ee67 --- /dev/null +++ b/static.go @@ -0,0 +1,65 @@ +package wrouter + +import ( + "io" + "mime" + "net/http" + "os" + "path" + "strings" + + "git.aiterp.net/gisle/wrouter/response" + + "git.aiterp.net/gisle/wrouter/auth" +) + +type Static struct { + path string +} + +func NewStatic(path string) *Static { + return &Static{path} +} + +func (static *Static) Handle(urlPath string, w http.ResponseWriter, req *http.Request, user *auth.User) bool { + // Get the subpath out of the path + subpath := req.URL.Path[len(urlPath):] + if subpath[0] == '/' { + subpath = subpath[1:] + } + + // Disallow breaking out of the folder + if strings.Contains(subpath, "..") { + response.Text(w, 403, "No .. in paths allowed") + return true + } + + // Try loading the file + filepath := path.Join(static.path, subpath) + info, err := os.Stat(filepath) + if err != nil || info.IsDir() { + return false + } + file, err := os.Open(filepath) + if err != nil || file == nil { + return false + } + + // Find and convert extension + ep := strings.LastIndex(filepath, ".") + ext := "" + if ep != -1 { + ext = filepath[ep:] + } + mimeType := mime.TypeByExtension(ext) + if mimeType == "" { + mimeType = "text/plain" + } + w.Header().Set("Content-Type", mimeType) + + // Submit + w.WriteHeader(200) + io.Copy(w, file) + + return true +} diff --git a/static_test.go b/static_test.go new file mode 100644 index 0000000..6916564 --- /dev/null +++ b/static_test.go @@ -0,0 +1,66 @@ +package wrouter + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestStatic(t *testing.T) { + router := &Router{} + router.Static("/data", "./") + server := httptest.NewServer(router) + + t.Run("Download", func(t *testing.T) { + resp, err := http.Get(server.URL + "/data/README.md") + + if err != nil { + t.Error("Request:", err) + t.Fail() + } + + if resp.StatusCode != 200 { + t.Error("Expected 200, got", resp.Status) + t.Fail() + } + + if resp.ContentLength == 0 { + t.Error("No content returned from server") + t.Fail() + } + + if !strings.Contains(resp.Header.Get("Content-Type"), "text/plain") { + t.Errorf("Content-Type %s != %s", resp.Header.Get("Content-Type"), "text/plain") + t.Fail() + } + }) + + t.Run("Download_Fail", func(t *testing.T) { + resp, err := http.Get(server.URL + "/data/f42klfk2kf2kfk.md") + + if err != nil { + t.Error("Request:", err) + t.Fail() + } + + if resp.StatusCode != 404 { + t.Error("Expected 404, got", resp.Status) + t.Fail() + } + }) + + t.Run("Download_Fail2", func(t *testing.T) { + resp, err := http.Get(server.URL + "/data/../../../../../../../etc/passwd") + + if err != nil { + t.Error("Request:", err) + t.Fail() + } + + if resp.StatusCode != 403 { + t.Error("Expected 403, got", resp.Status) + t.Fail() + } + }) +}