From d754bfb686d991bc1761bfce23733666b1837c9c Mon Sep 17 00:00:00 2001 From: Gisle Aune Date: Sun, 6 Aug 2017 17:48:01 +0200 Subject: [PATCH] Second Commit --- auth/authenticator.go | 10 + auth/authenticator_test.go | 138 +++++++++++++ auth/handler.go | 108 ++++++++++ auth/handler_test.go | 32 +++ auth/list.go | 43 ++++ auth/session.go | 98 +++++++++ auth/session_test.go | 44 +++++ auth/user.go | 19 ++ generate/id.go | 28 +++ resource.go | 72 +++++++ resource_test.go | 396 +++++++++++++++++++++++++++++++++++++ response/empty.go | 12 ++ response/html.go | 12 ++ response/json.go | 26 +++ response/text.go | 12 ++ router.go | 73 +++++++ router_test.go | 82 ++++++++ 17 files changed, 1205 insertions(+) create mode 100644 auth/authenticator.go create mode 100644 auth/authenticator_test.go create mode 100644 auth/handler.go create mode 100644 auth/handler_test.go create mode 100644 auth/list.go create mode 100644 auth/session.go create mode 100644 auth/session_test.go create mode 100644 auth/user.go create mode 100644 generate/id.go create mode 100644 resource.go create mode 100644 resource_test.go create mode 100644 response/empty.go create mode 100644 response/html.go create mode 100644 response/json.go create mode 100644 response/text.go create mode 100644 router.go create mode 100644 router_test.go diff --git a/auth/authenticator.go b/auth/authenticator.go new file mode 100644 index 0000000..8c164de --- /dev/null +++ b/auth/authenticator.go @@ -0,0 +1,10 @@ +package auth + +type Authenticator interface { + ID() string + Name() string + Exists(username string) bool + Find(userid string) *User + Login(username, password string) (*User, error) + Register(username, password string, data map[string]string) (*User, error) +} diff --git a/auth/authenticator_test.go b/auth/authenticator_test.go new file mode 100644 index 0000000..c98b9fb --- /dev/null +++ b/auth/authenticator_test.go @@ -0,0 +1,138 @@ +package auth + +import ( + "errors" + "strings" + "testing" + + "git.aiterp.net/gisle/wrouter/generate" +) + +var ErrExists = errors.New("auth: user exists") +var ErrLogin = errors.New("auth: login failed") + +type testAuther struct { + FullName string + users []*User + passwords map[string]string +} + +func (ta *testAuther) ID() string { + return strings.ToLower(ta.FullName) +} + +func (ta *testAuther) Name() string { + return ta.FullName +} + +func (ta *testAuther) Exists(username string) bool { + for _, user := range ta.users { + if user.Name == username { + return true + } + } + + return false +} + +func (ta *testAuther) Find(userid string) *User { + for _, user := range ta.users { + if user.ID == userid { + return user + } + } + + return nil +} + +func (ta *testAuther) Login(username, password string) (*User, error) { + for _, user := range ta.users { + if user.Name == username && password == ta.passwords[user.ID] { + return user, nil + } + } + + return nil, ErrLogin +} + +func (ta *testAuther) Register(username, password string, data map[string]string) (*User, error) { + if ta.Exists(username) { + return nil, ErrExists + } + + if ta.passwords == nil { + ta.passwords = make(map[string]string) + } + + id := generate.ID() + ta.passwords[id] = password + + user := NewUser(ta, id, username, "member", data) + ta.users = append(ta.users, user) + return user, nil +} + +func TestList(t *testing.T) { + ta1 := testAuther{FullName: "Auth1"} + ta2 := testAuther{FullName: "Auth2"} + Register(&ta1) + Register(&ta2) + + if ta1.ID() != "auth1" { + t.Errorf("ta1.ID() = %s", ta1.ID()) + t.Fail() + } + + if ta2.ID() != "auth2" { + t.Errorf("ta2.ID() = %s", ta2.ID()) + t.Fail() + } + + t.Run("Find", func(t *testing.T) { + fa1 := FindAuthenticator("auth1") + fa2 := FindAuthenticator("auth2") + + if &ta1 != fa1 { + t.Errorf("%s != %s", ta1.ID(), fa1.ID()) + t.Fail() + } + + if &ta2 != fa2 { + t.Errorf("%s != %s", ta2.ID(), fa2.ID()) + t.Fail() + } + }) + + t.Run("Register", func(t *testing.T) { + user, err := ta1.Register("Test", "CakesAndStuff", nil) + if err != nil || user.Name != "Test" { + t.Logf("err = %v; name = \"%s\"", err, user.Name) + t.Fail() + } + + if !ta1.Exists("Test") { + t.Log("Registered user does not exist") + t.Fail() + } + + user2, err := ta1.Register("Test", "CakesAndStuff", nil) + if err == nil || user2 != nil { + t.Logf("err = %s; name = \"%s\"", err, user2.Name) + t.Fail() + } + }) + + t.Run("Login", func(t *testing.T) { + user, err := ta1.Login("Test", "CakesAndStuff") + if err != nil || user.Name != "Test" { + t.Logf("err = %v; name = \"%s\"", err, user.Name) + t.Fail() + } + + user2, err := ta1.Login("Test", "WrongPassword") + if err == nil || user2 != nil { + t.Logf("err = %v; name = \"%s\"", err, user.Name) + t.Fail() + } + }) +} diff --git a/auth/handler.go b/auth/handler.go new file mode 100644 index 0000000..7d72704 --- /dev/null +++ b/auth/handler.go @@ -0,0 +1,108 @@ +package auth + +import ( + "net/http" + "strings" + + "git.aiterp.net/gisle/wrouter/response" +) + +type handler struct { +} + +func (h *handler) Handle(path string, w http.ResponseWriter, req *http.Request, user *User) bool { + // Get the subpath out of the path + subpath := req.URL.Path[len(path):] + if subpath[0] == '/' { + subpath = subpath[1:] + } + + method := FindAuthenticator(req.Form.Get("method")) + + switch strings.ToLower(subpath) { + case "login": + { + if req.Method != "POST" { + response.Text(w, 405, req.Method+" not allowed") + return true + } + + username := req.Form.Get("username") + password := req.Form.Get("password") + + w.Header().Set("X-Auth-Method", method.Name()) + + user, err := method.Login(username, password) + if err != nil && user != nil { + sess := OpenSession(user) + http.SetCookie(w, &http.Cookie{Name: SessionCookieName, Value: sess.ID, Expires: sess.Time.Add(SessionMaxTime)}) + + response.JSON(w, 200, sess) + } else { + response.Text(w, 401, "Login failed") + } + } + case "register": + { + if req.Method != "POST" { + response.Text(w, 405, req.Method+" not allowed") + return true + } + + data := make(map[string]string) + for key, value := range req.Form { + if key != "username" && key != "password" && key != "method" { + data[key] = value[0] + } + } + + username := req.Form.Get("username") + password := req.Form.Get("password") + + user, err := method.Register(username, password, data) + if err != nil && user != nil { + sess := OpenSession(user) + http.SetCookie(w, &http.Cookie{Name: SessionCookieName, Value: sess.ID, Expires: sess.Time.Add(SessionMaxTime)}) + + response.JSON(w, 200, sess) + } else { + response.Text(w, 401, "Register failed") + } + } + case "logout-all": + { + if req.Method != "POST" { + response.Text(w, 405, req.Method+" not allowed") + return true + } + + if user != nil { + ClearSessions(user) + response.Empty(w) + } else { + response.Text(w, 401, "Not logged in") + } + } + case "status": + { + if req.Method != "GET" { + response.Text(w, 405, req.Method+" not allowed") + return true + } + + if user != nil { + response.JSON(w, 200, user) + } else { + response.Text(w, 401, "Not logged in") + } + } + default: + { + response.Text(w, 404, "Operation not found: "+subpath) + } + } + + return true +} + +var Handler = &handler{} diff --git a/auth/handler_test.go b/auth/handler_test.go new file mode 100644 index 0000000..d67200b --- /dev/null +++ b/auth/handler_test.go @@ -0,0 +1,32 @@ +package auth + +import ( + "net/http" + "net/url" + "strings" + "testing" +) + +type handlerStruct struct{} + +func (hs *handlerStruct) ServeHTTP(w http.ResponseWriter, req *http.Request) { + req.ParseForm() // Router does this in non-tests + + if strings.HasPrefix(req.URL.Path, "/auth") { + Handler.Handle("/auth", w, req, nil) + return + } +} + +func TestHandler(t *testing.T) { + auther := testAuther{FullName: "Test"} + Register(&auther) + + form := url.Values{} + form.Set("username", "Test") + form.Set("password", "stuff'nthings") + + t.Run("Register", func(t *testing.T) { + + }) +} diff --git a/auth/list.go b/auth/list.go new file mode 100644 index 0000000..28cb3c4 --- /dev/null +++ b/auth/list.go @@ -0,0 +1,43 @@ +package auth + +import "strings" + +var methods []Authenticator + +// Register a method +func Register(method Authenticator) { + methods = append(methods, method) +} + +// FindAuthenticator finds the first Method that answers with +// the ID(). +func FindAuthenticator(id string) Authenticator { + for _, method := range methods { + if method.ID() == id { + return method + } + } + + return nil +} + +// ListAuthenticators gets a copy of the method list +func ListAuthenticators() []Authenticator { + dst := make([]Authenticator, len(methods)) + copy(dst, methods) + + return dst +} + +func FindUser(fullid string) *User { + split := strings.SplitN(fullid, ":", 2) + autherID := split[0] + userID := split[1] + + auther := FindAuthenticator(autherID) + if auther == nil { + return nil + } + + return auther.Find(userID) +} diff --git a/auth/session.go b/auth/session.go new file mode 100644 index 0000000..b388584 --- /dev/null +++ b/auth/session.go @@ -0,0 +1,98 @@ +package auth + +import ( + "log" + "sync" + "time" + + "git.aiterp.net/gisle/wrouter/generate" +) + +const SessionMaxTime = time.Hour * 72 + +var sessionMutex sync.RWMutex +var sessions = make(map[string]*Session, 512) +var lastCheck = time.Now() + +// SessionCookieName for the session cookie +var SessionCookieName = "sessid" + +// Session is a simple in-memory structure describing a suer session +type Session struct { + ID string `json:"id"` + UserID string `json:"user"` + Time time.Time `json:"time"` +} + +// OpenSession creates a new session from the supplied user's ID +func OpenSession(user *User) *Session { + session := &Session{generate.SessionID(), user.FullID(), time.Now()} + + sessionMutex.Lock() + sessions[session.ID] = session + sessionMutex.Unlock() + + // No need to do these checks when there's no activity. + if time.Since(lastCheck) > time.Hour { + lastCheck = time.Now() + go cleanup() + } + + return session +} + +// FindSession returns a session if the id maps to a still valid session +func FindSession(id string) *Session { + sessionMutex.RLock() + defer sessionMutex.RUnlock() + + session := sessions[id] + + // Check expiry and update + if session != nil { + if time.Since(session.Time) > SessionMaxTime { + return nil + } + + if time.Since(session.Time) > time.Hour { + session.Time = time.Now() + } + } + + return session +} + +// CloseSession deletes a session by the id +func CloseSession(id string) { + sessionMutex.Lock() + delete(sessions, id) + sessionMutex.Unlock() +} + +// ClearSessions removes all sessions with the given user ID +func ClearSessions(user *User) { + sessionMutex.Lock() + for _, sess := range sessions { + if sess.UserID == user.FullID() { + delete(sessions, sess.ID) + } + } + sessionMutex.Unlock() +} + +func cleanup() { + count := 0 + + sessionMutex.Lock() + for key, session := range sessions { + if time.Since(session.Time) > SessionMaxTime { + delete(sessions, key) + count++ + } + } + sessionMutex.Unlock() + + if count > 0 { + log.Println("Removed", count, "sessions.") + } +} diff --git a/auth/session_test.go b/auth/session_test.go new file mode 100644 index 0000000..2059ec8 --- /dev/null +++ b/auth/session_test.go @@ -0,0 +1,44 @@ +package auth + +import "testing" + +func TestSession(t *testing.T) { + auther := testAuther{FullName: "Test"} + + user := NewUser(&auther, "Tester", "Tester", "member", nil) + sessions := []*Session{OpenSession(user), OpenSession(user), OpenSession(user)} + ids := []string{sessions[0].ID, sessions[1].ID, sessions[2].ID} + + t.Run("Find", func(t *testing.T) { + for i, id := range ids { + found := FindSession(id) + + if found != sessions[i] { + t.Errorf("Find(\"%s\") == %+v", id, found) + t.Fail() + } + } + }) + + t.Run("Close", func(t *testing.T) { + CloseSession(ids[2]) + + if FindSession(ids[0]) == nil || FindSession(ids[1]) == nil || FindSession(ids[2]) != nil { + t.Errorf("Find(\"%s\") == %+v", ids[0], FindSession(ids[0])) + t.Errorf("Find(\"%s\") == %+v", ids[1], FindSession(ids[1])) + t.Errorf("Find(\"%s\") == %+v", ids[2], FindSession(ids[2])) + t.Fail() + } + }) + + t.Run("Clear", func(t *testing.T) { + ClearSessions(user) + + if FindSession(ids[0]) != nil || FindSession(ids[1]) != nil || FindSession(ids[2]) != nil { + t.Errorf("Find(\"%s\") == %+v", ids[0], FindSession(ids[0])) + t.Errorf("Find(\"%s\") == %+v", ids[1], FindSession(ids[1])) + t.Errorf("Find(\"%s\") == %+v", ids[2], FindSession(ids[2])) + t.Fail() + } + }) +} diff --git a/auth/user.go b/auth/user.go new file mode 100644 index 0000000..a462822 --- /dev/null +++ b/auth/user.go @@ -0,0 +1,19 @@ +package auth + +type User struct { + ID string + Name string + Level string + Data map[string]string + + method Authenticator +} + +// FullID is the userid prefixed with the method ID +func (user *User) FullID() string { + return user.method.ID() + ":" + user.ID +} + +func NewUser(method Authenticator, id, name, level string, data map[string]string) *User { + return &User{id, name, level, data, method} +} diff --git a/generate/id.go b/generate/id.go new file mode 100644 index 0000000..e31403f --- /dev/null +++ b/generate/id.go @@ -0,0 +1,28 @@ +package generate + +import ( + "crypto/rand" + "encoding/binary" + "encoding/hex" + "strings" + "time" +) + +// ID generates a 24 character hex string from 8 bytes of current time +// in ns and 4 bytes of crypto-random. In practical terms, that makes +// them orderable. +func ID() string { + bytes := make([]byte, 12) + binary.BigEndian.PutUint64(bytes, uint64(time.Now().UnixNano())) + rand.Read(bytes[8:]) + + return strings.ToLower(hex.EncodeToString(bytes)) +} + +// SessionID generates a 48 character hex string with crypto/rand +func SessionID() string { + bytes := make([]byte, 24) + rand.Read(bytes) + + return strings.ToLower(hex.EncodeToString(bytes)) +} diff --git a/resource.go b/resource.go new file mode 100644 index 0000000..a5050a2 --- /dev/null +++ b/resource.go @@ -0,0 +1,72 @@ +package wrouter + +import ( + "net/http" + "strings" + + "git.aiterp.net/gisle/wrouter/auth" +) + +type Func func(http.ResponseWriter, *http.Request, *auth.User) +type IDFunc func(http.ResponseWriter, *http.Request, string, *auth.User) + +type Resource struct { + list Func + create Func + get IDFunc + update IDFunc + delete IDFunc +} + +func NewResource(list, create Func, get, update, delete IDFunc) Resource { + return Resource{list, create, get, update, delete} +} + +func (resource *Resource) Handle(path string, w http.ResponseWriter, req *http.Request, user *auth.User) bool { + // Get the subpath out of the path + subpath := req.URL.Path[len(path):] + if subpath[0] == '/' { + subpath = subpath[1:] + } + + // Error out on bad IDs which contains /es + if x := strings.Index(subpath, "/"); x != -1 { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(400) + w.Write([]byte("Invalid ID: " + subpath)) + return true + } + + // Route it to the resource + switch req.Method { + case "GET": + { + if subpath != "" { + resource.get(w, req, subpath, user) + } else { + resource.list(w, req, user) + } + } + case "POST": + { + if subpath != "" { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(400) + w.Write([]byte("ID not allowed in POST")) + return true + } + + resource.create(w, req, user) + } + case "PATCH", "PUT": + { + resource.update(w, req, subpath, user) + } + case "DELETE": + { + resource.delete(w, req, subpath, user) + } + } + + return true +} diff --git a/resource_test.go b/resource_test.go new file mode 100644 index 0000000..74f57d4 --- /dev/null +++ b/resource_test.go @@ -0,0 +1,396 @@ +package wrouter + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "git.aiterp.net/gisle/wrouter/generate" + "git.aiterp.net/gisle/wrouter/response" + + "git.aiterp.net/gisle/wrouter/auth" +) + +var pages = []Page{ + Page{generate.ID(), "Test Page", "Blurg blurg"}, + Page{generate.ID(), "Test Page 2", "Blurg blurg 2"}, + Page{generate.ID(), "Stuff", "And things"}, +} + +type Page struct { + ID string `json:"id"` + Title string `json:"title"` + Text string `json:"text"` +} + +type Header struct { + ID string `json:"id"` + Title string `json:"title"` +} + +type PageForm struct { + Title string `json:"title"` + Text string `json:"text"` +} + +func listPage(w http.ResponseWriter, req *http.Request, user *auth.User) { + headers := make([]Header, len(pages)) + for i, page := range pages { + headers[i] = Header{page.ID, page.Title} + } + + response.JSON(w, 200, headers) +} + +func createPage(w http.ResponseWriter, req *http.Request, user *auth.User) { + title := req.Form.Get("title") + text := req.Form.Get("text") + + if title == "" { + response.Text(w, 400, "No title") + return + } + + for _, page := range pages { + if page.Title == title { + response.Text(w, 400, "Title already exists") + return + } + } + + page := Page{generate.ID(), title, text} + pages = append(pages, page) + + response.JSON(w, 200, page) +} + +func getPage(w http.ResponseWriter, req *http.Request, id string, user *auth.User) { + for _, page := range pages { + if page.ID == id { + response.JSON(w, 200, page) + return + } + } + + response.Text(w, 404, "Page not found") +} + +func updatePage(w http.ResponseWriter, req *http.Request, id string, user *auth.User) { + for _, page := range pages { + if page.ID == id { + title := req.Form.Get("title") + text := req.Form.Get("text") + + if title != "" { + page.Title = title + } + page.Text = text + + response.JSON(w, 200, page) + return + } + } + + response.Text(w, 404, "Page not found") +} + +func deletePage(w http.ResponseWriter, req *http.Request, id string, user *auth.User) { + for i, page := range pages { + if page.ID == id { + pages = append(pages[:i], pages[i+1:]...) + response.Empty(w) + return + } + } + + response.Text(w, 404, "Page not found") +} + +var resource = Resource{listPage, createPage, getPage, updatePage, deletePage} + +type handlerStruct struct{} + +func (hs *handlerStruct) ServeHTTP(w http.ResponseWriter, req *http.Request) { + req.ParseForm() // Router does this in non-tests + + if strings.HasPrefix(req.URL.Path, "/page") { + resource.Handle("/page", w, req, nil) + return + } +} + +func runForm(method, url string, data url.Values) (*http.Response, error) { + body := strings.NewReader(data.Encode()) + req, err := http.NewRequest(method, url, body) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + return http.DefaultClient.Do(req) +} + +func TestResource(t *testing.T) { + server := httptest.NewServer(&handlerStruct{}) + + t.Run("List", func(t *testing.T) { + resp, err := http.Get(server.URL + "/page/") + if err != nil { + t.Error(err) + t.Fail() + } + + if resp.StatusCode != 200 { + t.Error("Expected status 200, got", resp.StatusCode, resp.Status) + t.Fail() + } + + headers := []Header{} + err = json.NewDecoder(resp.Body).Decode(&headers) + if err != nil { + t.Error(err) + t.Fail() + } + + if len(headers) < 3 { + t.Error("Expected 3 headers, got", len(headers)) + t.Fail() + } + + for i, header := range headers { + page := pages[i] + + if header.ID != page.ID { + t.Error(header.ID, "!=", page.ID) + t.Fail() + } + + if header.Title != page.Title { + t.Error(header.Title, "!=", page.Title) + t.Fail() + } + } + }) + + t.Run("Get", func(t *testing.T) { + page := pages[1] + resp, err := http.Get(server.URL + "/page/" + page.ID) + if err != nil { + t.Error(err) + t.Fail() + } + + if resp.StatusCode != 200 { + t.Error("Expected status 200, got", resp.StatusCode, resp.Status) + t.Fail() + } + + respPage := Page{} + err = json.NewDecoder(resp.Body).Decode(&respPage) + if err != nil { + t.Error(err) + t.Fail() + } + + if respPage.ID == "" { + t.Error("No ID in response page") + t.Fail() + } + + if respPage.ID != page.ID { + t.Errorf("ID %s != %s", respPage.ID, page.ID) + t.Fail() + } + + if respPage.Title != page.Title { + t.Errorf("Title %s != %s", respPage.Title, page.Title) + t.Fail() + } + + if respPage.Text != page.Text { + t.Errorf("Text %s != %s", respPage.Text, page.Text) + t.Fail() + } + }) + + t.Run("Get_Fail", func(t *testing.T) { + resp, err := http.Get(server.URL + "/page/" + generate.ID()) + if err != nil { + t.Error(err) + t.Fail() + } + + if resp.StatusCode != 404 { + t.Error("Expected status 404, got", resp.StatusCode, resp.Status) + t.Fail() + } + + respPage := Page{} + err = json.NewDecoder(resp.Body).Decode(&respPage) + if err == nil { + t.Error("Expected encoder error, got:", respPage) + t.Fail() + } + + if respPage.ID != "" { + t.Error("Ad ID in response page", respPage.ID) + t.Fail() + } + }) + + t.Run("Create", func(t *testing.T) { + form := url.Values{} + form.Set("title", "Hello World") + form.Set("text", "Sei Gegrüßt, Erde") + + resp, err := http.PostForm(server.URL+"/page/", form) + + if err != nil { + t.Error(err) + t.Fail() + } + + if resp.StatusCode != 200 { + t.Error("Expected status 200, got", resp.StatusCode, resp.Status) + t.Fail() + } + + respPage := Page{} + err = json.NewDecoder(resp.Body).Decode(&respPage) + if err != nil { + t.Error(err) + t.Fail() + } + + if respPage.ID == "" { + t.Error("No ID in response page") + t.Fail() + } + + if respPage.Title != form.Get("title") { + t.Errorf("Title %s != %s", respPage.Title, form.Get("title")) + t.Fail() + } + + if respPage.Text != form.Get("text") { + t.Errorf("Text %s != %s", respPage.Text, form.Get("text")) + t.Fail() + } + + if len(pages) != 4 { + t.Errorf("Page was not added") + t.Fail() + } + }) + + t.Run("Update", func(t *testing.T) { + page := pages[0] + form := url.Values{} + form.Set("text", "Edits and stuff") + + resp, err := runForm("PUT", server.URL+"/page/"+page.ID, form) + + if err != nil { + t.Error("Request:", err) + t.Fail() + } + + if resp.StatusCode != 200 { + t.Error("Expected status 200, got", resp.StatusCode, resp.Status) + t.Fail() + } + + respPage := Page{} + err = json.NewDecoder(resp.Body).Decode(&respPage) + if err != nil { + t.Error("Decode:", err) + t.Fail() + } + + if respPage.ID == "" { + t.Error("No ID in response page") + t.Fail() + } + + if respPage.Title != page.Title { + t.Errorf("Title %s != %s", respPage.Title, form.Get("title")) + t.Fail() + } + + if respPage.Text != form.Get("text") { + t.Errorf("Text %s != %s", respPage.Text, form.Get("text")) + t.Fail() + } + }) + + t.Run("Update_Fail", func(t *testing.T) { + form := url.Values{} + form.Set("text", "Edits and stuff") + + resp, err := runForm("PUT", server.URL+"/page/NONEXISTENT-ID-GOES-HERE", form) + + if err != nil { + t.Error("Request:", err) + t.Fail() + } + + if resp.StatusCode != 404 { + t.Error("Expected status 404, got", resp.StatusCode, resp.Status) + t.Fail() + } + + respPage := Page{} + err = json.NewDecoder(resp.Body).Decode(&respPage) + if err == nil { + t.Error("A page was returned:", respPage) + t.Fail() + } + + if respPage.ID != "" { + t.Error("ID in response page", respPage.ID) + t.Fail() + } + }) + + t.Run("Delete", func(t *testing.T) { + page := pages[3] + resp, err := runForm("DELETE", server.URL+"/page/"+page.ID, url.Values{}) + + if err != nil { + t.Error("Request:", err) + t.Fail() + } + + if resp.StatusCode != 204 { + t.Error("Expected status 204, got", resp.Status) + t.Fail() + } + + if len(pages) != 3 { + t.Error("Page was not deleted") + t.Fail() + } + }) + + t.Run("Delete_Fail", func(t *testing.T) { + resp, err := runForm("DELETE", server.URL+"/page/NONEXISTENT-ID-GOES-HERE", url.Values{}) + + if err != nil { + t.Error("Request:", err) + t.Fail() + } + + if resp.StatusCode != 404 { + t.Error("Expected status 404, got", resp.Status) + t.Fail() + } + + if len(pages) != 3 { + t.Error("A page was deleted despite the non-existent ID") + t.Fail() + } + }) +} diff --git a/response/empty.go b/response/empty.go new file mode 100644 index 0000000..e275fd1 --- /dev/null +++ b/response/empty.go @@ -0,0 +1,12 @@ +package response + +import ( + "net/http" +) + +// Empty makes a 204 response without a body +func Empty(writer http.ResponseWriter) { + writer.Header().Set("Content-Type", "text/html; charset=utf-8") + writer.WriteHeader(204) + writer.Write([]byte{}) +} diff --git a/response/html.go b/response/html.go new file mode 100644 index 0000000..9e6e3c5 --- /dev/null +++ b/response/html.go @@ -0,0 +1,12 @@ +package response + +import ( + "net/http" +) + +// HTML makes an HTML response +func HTML(writer http.ResponseWriter, status int, data string) { + writer.Header().Set("Content-Type", "text/html; charset=utf-8") + writer.WriteHeader(status) + writer.Write([]byte(data)) +} diff --git a/response/json.go b/response/json.go new file mode 100644 index 0000000..53cabc2 --- /dev/null +++ b/response/json.go @@ -0,0 +1,26 @@ +package response + +import ( + "encoding/json" + "fmt" + "log" + "net/http" +) + +// JSON makes a JSON response +func JSON(writer http.ResponseWriter, status int, data interface{}) { + jsonData, err := json.Marshal(data) + + if err != nil { + log.Println("JSON Marshal failed: ", err.Error()) + + writer.Header().Set("Content-Type", "text/plain; charset=utf-8") + writer.WriteHeader(503) + fmt.Fprint(writer, "JSON marshalling failed:", err.Error()) + return + } + + writer.Header().Set("Content-Type", "application/json") + writer.WriteHeader(status) + writer.Write(jsonData) +} diff --git a/response/text.go b/response/text.go new file mode 100644 index 0000000..b16a65d --- /dev/null +++ b/response/text.go @@ -0,0 +1,12 @@ +package response + +import ( + "net/http" +) + +// Text makes a textual response +func Text(writer http.ResponseWriter, status int, data string) { + writer.Header().Set("Content-Type", "text/plain; charset=utf-8") + writer.WriteHeader(status) + writer.Write([]byte(data)) +} diff --git a/router.go b/router.go new file mode 100644 index 0000000..7ed8333 --- /dev/null +++ b/router.go @@ -0,0 +1,73 @@ +package wrouter + +import ( + "fmt" + "net/http" + "strings" + + "git.aiterp.net/gisle/notebook3/session" + + "git.aiterp.net/gisle/wrouter/auth" +) + +type Route interface { + Handle(path string, w http.ResponseWriter, req *http.Request, user *auth.User) bool +} + +type Router struct { + paths map[Route]string + routes []Route +} + +func (router *Router) Route(path string, route Route) { + if router.paths == nil { + router.paths = make(map[Route]string, 16) + } + + router.paths[route] = path + router.routes = append(router.routes, route) +} + +func (router *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { + defer req.Body.Close() + + // Allow REST for clients of yore + if req.Header.Get("X-Method") != "" { + req.Method = strings.ToUpper(req.Header.Get("X-Method")) + } + + // Resolve session cookies + var user *auth.User + cookie, err := req.Cookie(session.CookieName) + if cookie != nil && err == nil { + sess := auth.FindSession(cookie.Value) + if sess != nil { + user = auth.FindUser(sess.UserID) + } + } + + for index, route := range router.routes { + path := router.paths[route] + + if strings.HasPrefix(strings.ToLower(req.URL.Path), path) { + // Just so the handler can replace the path properly in case of case + // insensitive clients getting fancy on it. + path = req.URL.Path[:len(path)] + + // Attach a little something for testing + w.Header().Set("X-Route-Path", path) + w.Header().Set("X-Route-Index", fmt.Sprint(index)) + + if route.Handle(path, w, req, user) { + return + } + } + } + + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(404) +} + +func (router *Router) Listen(host string, port int) error { + return http.ListenAndServe(fmt.Sprintf("%s:%d", host, port), router) +} diff --git a/router_test.go b/router_test.go new file mode 100644 index 0000000..4239804 --- /dev/null +++ b/router_test.go @@ -0,0 +1,82 @@ +package wrouter + +import ( + "io/ioutil" + "net/http" + "net/http/httptest" + "testing" + + "git.aiterp.net/gisle/wrouter/auth" +) + +type testRoute struct { + Name string +} + +func (tr *testRoute) Handle(path string, w http.ResponseWriter, req *http.Request, user *auth.User) bool { + w.WriteHeader(200) + w.Write([]byte(tr.Name)) + + return true +} + +func TestPaths(t *testing.T) { + tr1 := testRoute{"Test Route 1"} + tr2 := testRoute{"Test Route 2"} + + router := Router{} + router.Route("/test1", &tr1) + router.Route("/test2", &tr2) + + t.Run("It finds /test1", func(t *testing.T) { + req := httptest.NewRequest("GET", "http://test.aiterp.net/test1", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + data, _ := ioutil.ReadAll(w.Body) + + if string(data) != "Test Route 1" { + t.Error("Wrong content:", string(data)) + t.Fail() + } + + if w.Code != 200 { + t.Error("Wrong code:", w.Code) + t.Fail() + } + }) + + t.Run("It finds /test2", func(t *testing.T) { + req := httptest.NewRequest("GET", "http://test.aiterp.net/test2", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + data, _ := ioutil.ReadAll(w.Body) + + if string(data) != "Test Route 2" { + t.Error("Wrong content:", string(data)) + t.Fail() + } + + if w.Header().Get("X-Route-Index") != "1" { + t.Error("Wrong index:", w.Code) + t.Fail() + } + + if w.Code != 200 { + t.Error("Wrong code:", w.Code) + t.Fail() + } + }) + + t.Run("It does not find /test3", func(t *testing.T) { + req := httptest.NewRequest("GET", "http://test.aiterp.net/test3", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != 404 { + t.Error("Wrong code:", w.Code) + t.Fail() + } + }) +}