Browse Source

Second Commit

master
Gisle Aune 7 years ago
parent
commit
d754bfb686
  1. 10
      auth/authenticator.go
  2. 138
      auth/authenticator_test.go
  3. 108
      auth/handler.go
  4. 32
      auth/handler_test.go
  5. 43
      auth/list.go
  6. 98
      auth/session.go
  7. 44
      auth/session_test.go
  8. 19
      auth/user.go
  9. 28
      generate/id.go
  10. 72
      resource.go
  11. 396
      resource_test.go
  12. 12
      response/empty.go
  13. 12
      response/html.go
  14. 26
      response/json.go
  15. 12
      response/text.go
  16. 73
      router.go
  17. 82
      router_test.go

10
auth/authenticator.go

@ -0,0 +1,10 @@
package auth
type Authenticator interface {
ID() string
Name() string
Exists(username string) bool
Find(userid string) *User
Login(username, password string) (*User, error)
Register(username, password string, data map[string]string) (*User, error)
}

138
auth/authenticator_test.go

@ -0,0 +1,138 @@
package auth
import (
"errors"
"strings"
"testing"
"git.aiterp.net/gisle/wrouter/generate"
)
var ErrExists = errors.New("auth: user exists")
var ErrLogin = errors.New("auth: login failed")
type testAuther struct {
FullName string
users []*User
passwords map[string]string
}
func (ta *testAuther) ID() string {
return strings.ToLower(ta.FullName)
}
func (ta *testAuther) Name() string {
return ta.FullName
}
func (ta *testAuther) Exists(username string) bool {
for _, user := range ta.users {
if user.Name == username {
return true
}
}
return false
}
func (ta *testAuther) Find(userid string) *User {
for _, user := range ta.users {
if user.ID == userid {
return user
}
}
return nil
}
func (ta *testAuther) Login(username, password string) (*User, error) {
for _, user := range ta.users {
if user.Name == username && password == ta.passwords[user.ID] {
return user, nil
}
}
return nil, ErrLogin
}
func (ta *testAuther) Register(username, password string, data map[string]string) (*User, error) {
if ta.Exists(username) {
return nil, ErrExists
}
if ta.passwords == nil {
ta.passwords = make(map[string]string)
}
id := generate.ID()
ta.passwords[id] = password
user := NewUser(ta, id, username, "member", data)
ta.users = append(ta.users, user)
return user, nil
}
func TestList(t *testing.T) {
ta1 := testAuther{FullName: "Auth1"}
ta2 := testAuther{FullName: "Auth2"}
Register(&ta1)
Register(&ta2)
if ta1.ID() != "auth1" {
t.Errorf("ta1.ID() = %s", ta1.ID())
t.Fail()
}
if ta2.ID() != "auth2" {
t.Errorf("ta2.ID() = %s", ta2.ID())
t.Fail()
}
t.Run("Find", func(t *testing.T) {
fa1 := FindAuthenticator("auth1")
fa2 := FindAuthenticator("auth2")
if &ta1 != fa1 {
t.Errorf("%s != %s", ta1.ID(), fa1.ID())
t.Fail()
}
if &ta2 != fa2 {
t.Errorf("%s != %s", ta2.ID(), fa2.ID())
t.Fail()
}
})
t.Run("Register", func(t *testing.T) {
user, err := ta1.Register("Test", "CakesAndStuff", nil)
if err != nil || user.Name != "Test" {
t.Logf("err = %v; name = \"%s\"", err, user.Name)
t.Fail()
}
if !ta1.Exists("Test") {
t.Log("Registered user does not exist")
t.Fail()
}
user2, err := ta1.Register("Test", "CakesAndStuff", nil)
if err == nil || user2 != nil {
t.Logf("err = %s; name = \"%s\"", err, user2.Name)
t.Fail()
}
})
t.Run("Login", func(t *testing.T) {
user, err := ta1.Login("Test", "CakesAndStuff")
if err != nil || user.Name != "Test" {
t.Logf("err = %v; name = \"%s\"", err, user.Name)
t.Fail()
}
user2, err := ta1.Login("Test", "WrongPassword")
if err == nil || user2 != nil {
t.Logf("err = %v; name = \"%s\"", err, user.Name)
t.Fail()
}
})
}

108
auth/handler.go

@ -0,0 +1,108 @@
package auth
import (
"net/http"
"strings"
"git.aiterp.net/gisle/wrouter/response"
)
type handler struct {
}
func (h *handler) Handle(path string, w http.ResponseWriter, req *http.Request, user *User) bool {
// Get the subpath out of the path
subpath := req.URL.Path[len(path):]
if subpath[0] == '/' {
subpath = subpath[1:]
}
method := FindAuthenticator(req.Form.Get("method"))
switch strings.ToLower(subpath) {
case "login":
{
if req.Method != "POST" {
response.Text(w, 405, req.Method+" not allowed")
return true
}
username := req.Form.Get("username")
password := req.Form.Get("password")
w.Header().Set("X-Auth-Method", method.Name())
user, err := method.Login(username, password)
if err != nil && user != nil {
sess := OpenSession(user)
http.SetCookie(w, &http.Cookie{Name: SessionCookieName, Value: sess.ID, Expires: sess.Time.Add(SessionMaxTime)})
response.JSON(w, 200, sess)
} else {
response.Text(w, 401, "Login failed")
}
}
case "register":
{
if req.Method != "POST" {
response.Text(w, 405, req.Method+" not allowed")
return true
}
data := make(map[string]string)
for key, value := range req.Form {
if key != "username" && key != "password" && key != "method" {
data[key] = value[0]
}
}
username := req.Form.Get("username")
password := req.Form.Get("password")
user, err := method.Register(username, password, data)
if err != nil && user != nil {
sess := OpenSession(user)
http.SetCookie(w, &http.Cookie{Name: SessionCookieName, Value: sess.ID, Expires: sess.Time.Add(SessionMaxTime)})
response.JSON(w, 200, sess)
} else {
response.Text(w, 401, "Register failed")
}
}
case "logout-all":
{
if req.Method != "POST" {
response.Text(w, 405, req.Method+" not allowed")
return true
}
if user != nil {
ClearSessions(user)
response.Empty(w)
} else {
response.Text(w, 401, "Not logged in")
}
}
case "status":
{
if req.Method != "GET" {
response.Text(w, 405, req.Method+" not allowed")
return true
}
if user != nil {
response.JSON(w, 200, user)
} else {
response.Text(w, 401, "Not logged in")
}
}
default:
{
response.Text(w, 404, "Operation not found: "+subpath)
}
}
return true
}
var Handler = &handler{}

32
auth/handler_test.go

@ -0,0 +1,32 @@
package auth
import (
"net/http"
"net/url"
"strings"
"testing"
)
type handlerStruct struct{}
func (hs *handlerStruct) ServeHTTP(w http.ResponseWriter, req *http.Request) {
req.ParseForm() // Router does this in non-tests
if strings.HasPrefix(req.URL.Path, "/auth") {
Handler.Handle("/auth", w, req, nil)
return
}
}
func TestHandler(t *testing.T) {
auther := testAuther{FullName: "Test"}
Register(&auther)
form := url.Values{}
form.Set("username", "Test")
form.Set("password", "stuff'nthings")
t.Run("Register", func(t *testing.T) {
})
}

43
auth/list.go

@ -0,0 +1,43 @@
package auth
import "strings"
var methods []Authenticator
// Register a method
func Register(method Authenticator) {
methods = append(methods, method)
}
// FindAuthenticator finds the first Method that answers with
// the ID().
func FindAuthenticator(id string) Authenticator {
for _, method := range methods {
if method.ID() == id {
return method
}
}
return nil
}
// ListAuthenticators gets a copy of the method list
func ListAuthenticators() []Authenticator {
dst := make([]Authenticator, len(methods))
copy(dst, methods)
return dst
}
func FindUser(fullid string) *User {
split := strings.SplitN(fullid, ":", 2)
autherID := split[0]
userID := split[1]
auther := FindAuthenticator(autherID)
if auther == nil {
return nil
}
return auther.Find(userID)
}

98
auth/session.go

@ -0,0 +1,98 @@
package auth
import (
"log"
"sync"
"time"
"git.aiterp.net/gisle/wrouter/generate"
)
const SessionMaxTime = time.Hour * 72
var sessionMutex sync.RWMutex
var sessions = make(map[string]*Session, 512)
var lastCheck = time.Now()
// SessionCookieName for the session cookie
var SessionCookieName = "sessid"
// Session is a simple in-memory structure describing a suer session
type Session struct {
ID string `json:"id"`
UserID string `json:"user"`
Time time.Time `json:"time"`
}
// OpenSession creates a new session from the supplied user's ID
func OpenSession(user *User) *Session {
session := &Session{generate.SessionID(), user.FullID(), time.Now()}
sessionMutex.Lock()
sessions[session.ID] = session
sessionMutex.Unlock()
// No need to do these checks when there's no activity.
if time.Since(lastCheck) > time.Hour {
lastCheck = time.Now()
go cleanup()
}
return session
}
// FindSession returns a session if the id maps to a still valid session
func FindSession(id string) *Session {
sessionMutex.RLock()
defer sessionMutex.RUnlock()
session := sessions[id]
// Check expiry and update
if session != nil {
if time.Since(session.Time) > SessionMaxTime {
return nil
}
if time.Since(session.Time) > time.Hour {
session.Time = time.Now()
}
}
return session
}
// CloseSession deletes a session by the id
func CloseSession(id string) {
sessionMutex.Lock()
delete(sessions, id)
sessionMutex.Unlock()
}
// ClearSessions removes all sessions with the given user ID
func ClearSessions(user *User) {
sessionMutex.Lock()
for _, sess := range sessions {
if sess.UserID == user.FullID() {
delete(sessions, sess.ID)
}
}
sessionMutex.Unlock()
}
func cleanup() {
count := 0
sessionMutex.Lock()
for key, session := range sessions {
if time.Since(session.Time) > SessionMaxTime {
delete(sessions, key)
count++
}
}
sessionMutex.Unlock()
if count > 0 {
log.Println("Removed", count, "sessions.")
}
}

44
auth/session_test.go

@ -0,0 +1,44 @@
package auth
import "testing"
func TestSession(t *testing.T) {
auther := testAuther{FullName: "Test"}
user := NewUser(&auther, "Tester", "Tester", "member", nil)
sessions := []*Session{OpenSession(user), OpenSession(user), OpenSession(user)}
ids := []string{sessions[0].ID, sessions[1].ID, sessions[2].ID}
t.Run("Find", func(t *testing.T) {
for i, id := range ids {
found := FindSession(id)
if found != sessions[i] {
t.Errorf("Find(\"%s\") == %+v", id, found)
t.Fail()
}
}
})
t.Run("Close", func(t *testing.T) {
CloseSession(ids[2])
if FindSession(ids[0]) == nil || FindSession(ids[1]) == nil || FindSession(ids[2]) != nil {
t.Errorf("Find(\"%s\") == %+v", ids[0], FindSession(ids[0]))
t.Errorf("Find(\"%s\") == %+v", ids[1], FindSession(ids[1]))
t.Errorf("Find(\"%s\") == %+v", ids[2], FindSession(ids[2]))
t.Fail()
}
})
t.Run("Clear", func(t *testing.T) {
ClearSessions(user)
if FindSession(ids[0]) != nil || FindSession(ids[1]) != nil || FindSession(ids[2]) != nil {
t.Errorf("Find(\"%s\") == %+v", ids[0], FindSession(ids[0]))
t.Errorf("Find(\"%s\") == %+v", ids[1], FindSession(ids[1]))
t.Errorf("Find(\"%s\") == %+v", ids[2], FindSession(ids[2]))
t.Fail()
}
})
}

19
auth/user.go

@ -0,0 +1,19 @@
package auth
type User struct {
ID string
Name string
Level string
Data map[string]string
method Authenticator
}
// FullID is the userid prefixed with the method ID
func (user *User) FullID() string {
return user.method.ID() + ":" + user.ID
}
func NewUser(method Authenticator, id, name, level string, data map[string]string) *User {
return &User{id, name, level, data, method}
}

28
generate/id.go

@ -0,0 +1,28 @@
package generate
import (
"crypto/rand"
"encoding/binary"
"encoding/hex"
"strings"
"time"
)
// ID generates a 24 character hex string from 8 bytes of current time
// in ns and 4 bytes of crypto-random. In practical terms, that makes
// them orderable.
func ID() string {
bytes := make([]byte, 12)
binary.BigEndian.PutUint64(bytes, uint64(time.Now().UnixNano()))
rand.Read(bytes[8:])
return strings.ToLower(hex.EncodeToString(bytes))
}
// SessionID generates a 48 character hex string with crypto/rand
func SessionID() string {
bytes := make([]byte, 24)
rand.Read(bytes)
return strings.ToLower(hex.EncodeToString(bytes))
}

72
resource.go

@ -0,0 +1,72 @@
package wrouter
import (
"net/http"
"strings"
"git.aiterp.net/gisle/wrouter/auth"
)
type Func func(http.ResponseWriter, *http.Request, *auth.User)
type IDFunc func(http.ResponseWriter, *http.Request, string, *auth.User)
type Resource struct {
list Func
create Func
get IDFunc
update IDFunc
delete IDFunc
}
func NewResource(list, create Func, get, update, delete IDFunc) Resource {
return Resource{list, create, get, update, delete}
}
func (resource *Resource) Handle(path string, w http.ResponseWriter, req *http.Request, user *auth.User) bool {
// Get the subpath out of the path
subpath := req.URL.Path[len(path):]
if subpath[0] == '/' {
subpath = subpath[1:]
}
// Error out on bad IDs which contains /es
if x := strings.Index(subpath, "/"); x != -1 {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(400)
w.Write([]byte("Invalid ID: " + subpath))
return true
}
// Route it to the resource
switch req.Method {
case "GET":
{
if subpath != "" {
resource.get(w, req, subpath, user)
} else {
resource.list(w, req, user)
}
}
case "POST":
{
if subpath != "" {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(400)
w.Write([]byte("ID not allowed in POST"))
return true
}
resource.create(w, req, user)
}
case "PATCH", "PUT":
{
resource.update(w, req, subpath, user)
}
case "DELETE":
{
resource.delete(w, req, subpath, user)
}
}
return true
}

396
resource_test.go

@ -0,0 +1,396 @@
package wrouter
import (
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"git.aiterp.net/gisle/wrouter/generate"
"git.aiterp.net/gisle/wrouter/response"
"git.aiterp.net/gisle/wrouter/auth"
)
var pages = []Page{
Page{generate.ID(), "Test Page", "Blurg blurg"},
Page{generate.ID(), "Test Page 2", "Blurg blurg 2"},
Page{generate.ID(), "Stuff", "And things"},
}
type Page struct {
ID string `json:"id"`
Title string `json:"title"`
Text string `json:"text"`
}
type Header struct {
ID string `json:"id"`
Title string `json:"title"`
}
type PageForm struct {
Title string `json:"title"`
Text string `json:"text"`
}
func listPage(w http.ResponseWriter, req *http.Request, user *auth.User) {
headers := make([]Header, len(pages))
for i, page := range pages {
headers[i] = Header{page.ID, page.Title}
}
response.JSON(w, 200, headers)
}
func createPage(w http.ResponseWriter, req *http.Request, user *auth.User) {
title := req.Form.Get("title")
text := req.Form.Get("text")
if title == "" {
response.Text(w, 400, "No title")
return
}
for _, page := range pages {
if page.Title == title {
response.Text(w, 400, "Title already exists")
return
}
}
page := Page{generate.ID(), title, text}
pages = append(pages, page)
response.JSON(w, 200, page)
}
func getPage(w http.ResponseWriter, req *http.Request, id string, user *auth.User) {
for _, page := range pages {
if page.ID == id {
response.JSON(w, 200, page)
return
}
}
response.Text(w, 404, "Page not found")
}
func updatePage(w http.ResponseWriter, req *http.Request, id string, user *auth.User) {
for _, page := range pages {
if page.ID == id {
title := req.Form.Get("title")
text := req.Form.Get("text")
if title != "" {
page.Title = title
}
page.Text = text
response.JSON(w, 200, page)
return
}
}
response.Text(w, 404, "Page not found")
}
func deletePage(w http.ResponseWriter, req *http.Request, id string, user *auth.User) {
for i, page := range pages {
if page.ID == id {
pages = append(pages[:i], pages[i+1:]...)
response.Empty(w)
return
}
}
response.Text(w, 404, "Page not found")
}
var resource = Resource{listPage, createPage, getPage, updatePage, deletePage}
type handlerStruct struct{}
func (hs *handlerStruct) ServeHTTP(w http.ResponseWriter, req *http.Request) {
req.ParseForm() // Router does this in non-tests
if strings.HasPrefix(req.URL.Path, "/page") {
resource.Handle("/page", w, req, nil)
return
}
}
func runForm(method, url string, data url.Values) (*http.Response, error) {
body := strings.NewReader(data.Encode())
req, err := http.NewRequest(method, url, body)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
return http.DefaultClient.Do(req)
}
func TestResource(t *testing.T) {
server := httptest.NewServer(&handlerStruct{})
t.Run("List", func(t *testing.T) {
resp, err := http.Get(server.URL + "/page/")
if err != nil {
t.Error(err)
t.Fail()
}
if resp.StatusCode != 200 {
t.Error("Expected status 200, got", resp.StatusCode, resp.Status)
t.Fail()
}
headers := []Header{}
err = json.NewDecoder(resp.Body).Decode(&headers)
if err != nil {
t.Error(err)
t.Fail()
}
if len(headers) < 3 {
t.Error("Expected 3 headers, got", len(headers))
t.Fail()
}
for i, header := range headers {
page := pages[i]
if header.ID != page.ID {
t.Error(header.ID, "!=", page.ID)
t.Fail()
}
if header.Title != page.Title {
t.Error(header.Title, "!=", page.Title)
t.Fail()
}
}
})
t.Run("Get", func(t *testing.T) {
page := pages[1]
resp, err := http.Get(server.URL + "/page/" + page.ID)
if err != nil {
t.Error(err)
t.Fail()
}
if resp.StatusCode != 200 {
t.Error("Expected status 200, got", resp.StatusCode, resp.Status)
t.Fail()
}
respPage := Page{}
err = json.NewDecoder(resp.Body).Decode(&respPage)
if err != nil {
t.Error(err)
t.Fail()
}
if respPage.ID == "" {
t.Error("No ID in response page")
t.Fail()
}
if respPage.ID != page.ID {
t.Errorf("ID %s != %s", respPage.ID, page.ID)
t.Fail()
}
if respPage.Title != page.Title {
t.Errorf("Title %s != %s", respPage.Title, page.Title)
t.Fail()
}
if respPage.Text != page.Text {
t.Errorf("Text %s != %s", respPage.Text, page.Text)
t.Fail()
}
})
t.Run("Get_Fail", func(t *testing.T) {
resp, err := http.Get(server.URL + "/page/" + generate.ID())
if err != nil {
t.Error(err)
t.Fail()
}
if resp.StatusCode != 404 {
t.Error("Expected status 404, got", resp.StatusCode, resp.Status)
t.Fail()
}
respPage := Page{}
err = json.NewDecoder(resp.Body).Decode(&respPage)
if err == nil {
t.Error("Expected encoder error, got:", respPage)
t.Fail()
}
if respPage.ID != "" {
t.Error("Ad ID in response page", respPage.ID)
t.Fail()
}
})
t.Run("Create", func(t *testing.T) {
form := url.Values{}
form.Set("title", "Hello World")
form.Set("text", "Sei Gegrüßt, Erde")
resp, err := http.PostForm(server.URL+"/page/", form)
if err != nil {
t.Error(err)
t.Fail()
}
if resp.StatusCode != 200 {
t.Error("Expected status 200, got", resp.StatusCode, resp.Status)
t.Fail()
}
respPage := Page{}
err = json.NewDecoder(resp.Body).Decode(&respPage)
if err != nil {
t.Error(err)
t.Fail()
}
if respPage.ID == "" {
t.Error("No ID in response page")
t.Fail()
}
if respPage.Title != form.Get("title") {
t.Errorf("Title %s != %s", respPage.Title, form.Get("title"))
t.Fail()
}
if respPage.Text != form.Get("text") {
t.Errorf("Text %s != %s", respPage.Text, form.Get("text"))
t.Fail()
}
if len(pages) != 4 {
t.Errorf("Page was not added")
t.Fail()
}
})
t.Run("Update", func(t *testing.T) {
page := pages[0]
form := url.Values{}
form.Set("text", "Edits and stuff")
resp, err := runForm("PUT", server.URL+"/page/"+page.ID, form)
if err != nil {
t.Error("Request:", err)
t.Fail()
}
if resp.StatusCode != 200 {
t.Error("Expected status 200, got", resp.StatusCode, resp.Status)
t.Fail()
}
respPage := Page{}
err = json.NewDecoder(resp.Body).Decode(&respPage)
if err != nil {
t.Error("Decode:", err)
t.Fail()
}
if respPage.ID == "" {
t.Error("No ID in response page")
t.Fail()
}
if respPage.Title != page.Title {
t.Errorf("Title %s != %s", respPage.Title, form.Get("title"))
t.Fail()
}
if respPage.Text != form.Get("text") {
t.Errorf("Text %s != %s", respPage.Text, form.Get("text"))
t.Fail()
}
})
t.Run("Update_Fail", func(t *testing.T) {
form := url.Values{}
form.Set("text", "Edits and stuff")
resp, err := runForm("PUT", server.URL+"/page/NONEXISTENT-ID-GOES-HERE", form)
if err != nil {
t.Error("Request:", err)
t.Fail()
}
if resp.StatusCode != 404 {
t.Error("Expected status 404, got", resp.StatusCode, resp.Status)
t.Fail()
}
respPage := Page{}
err = json.NewDecoder(resp.Body).Decode(&respPage)
if err == nil {
t.Error("A page was returned:", respPage)
t.Fail()
}
if respPage.ID != "" {
t.Error("ID in response page", respPage.ID)
t.Fail()
}
})
t.Run("Delete", func(t *testing.T) {
page := pages[3]
resp, err := runForm("DELETE", server.URL+"/page/"+page.ID, url.Values{})
if err != nil {
t.Error("Request:", err)
t.Fail()
}
if resp.StatusCode != 204 {
t.Error("Expected status 204, got", resp.Status)
t.Fail()
}
if len(pages) != 3 {
t.Error("Page was not deleted")
t.Fail()
}
})
t.Run("Delete_Fail", func(t *testing.T) {
resp, err := runForm("DELETE", server.URL+"/page/NONEXISTENT-ID-GOES-HERE", url.Values{})
if err != nil {
t.Error("Request:", err)
t.Fail()
}
if resp.StatusCode != 404 {
t.Error("Expected status 404, got", resp.Status)
t.Fail()
}
if len(pages) != 3 {
t.Error("A page was deleted despite the non-existent ID")
t.Fail()
}
})
}

12
response/empty.go

@ -0,0 +1,12 @@
package response
import (
"net/http"
)
// Empty makes a 204 response without a body
func Empty(writer http.ResponseWriter) {
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(204)
writer.Write([]byte{})
}

12
response/html.go

@ -0,0 +1,12 @@
package response
import (
"net/http"
)
// HTML makes an HTML response
func HTML(writer http.ResponseWriter, status int, data string) {
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(status)
writer.Write([]byte(data))
}

26
response/json.go

@ -0,0 +1,26 @@
package response
import (
"encoding/json"
"fmt"
"log"
"net/http"
)
// JSON makes a JSON response
func JSON(writer http.ResponseWriter, status int, data interface{}) {
jsonData, err := json.Marshal(data)
if err != nil {
log.Println("JSON Marshal failed: ", err.Error())
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(503)
fmt.Fprint(writer, "JSON marshalling failed:", err.Error())
return
}
writer.Header().Set("Content-Type", "application/json")
writer.WriteHeader(status)
writer.Write(jsonData)
}

12
response/text.go

@ -0,0 +1,12 @@
package response
import (
"net/http"
)
// Text makes a textual response
func Text(writer http.ResponseWriter, status int, data string) {
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(status)
writer.Write([]byte(data))
}

73
router.go

@ -0,0 +1,73 @@
package wrouter
import (
"fmt"
"net/http"
"strings"
"git.aiterp.net/gisle/notebook3/session"
"git.aiterp.net/gisle/wrouter/auth"
)
type Route interface {
Handle(path string, w http.ResponseWriter, req *http.Request, user *auth.User) bool
}
type Router struct {
paths map[Route]string
routes []Route
}
func (router *Router) Route(path string, route Route) {
if router.paths == nil {
router.paths = make(map[Route]string, 16)
}
router.paths[route] = path
router.routes = append(router.routes, route)
}
func (router *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
defer req.Body.Close()
// Allow REST for clients of yore
if req.Header.Get("X-Method") != "" {
req.Method = strings.ToUpper(req.Header.Get("X-Method"))
}
// Resolve session cookies
var user *auth.User
cookie, err := req.Cookie(session.CookieName)
if cookie != nil && err == nil {
sess := auth.FindSession(cookie.Value)
if sess != nil {
user = auth.FindUser(sess.UserID)
}
}
for index, route := range router.routes {
path := router.paths[route]
if strings.HasPrefix(strings.ToLower(req.URL.Path), path) {
// Just so the handler can replace the path properly in case of case
// insensitive clients getting fancy on it.
path = req.URL.Path[:len(path)]
// Attach a little something for testing
w.Header().Set("X-Route-Path", path)
w.Header().Set("X-Route-Index", fmt.Sprint(index))
if route.Handle(path, w, req, user) {
return
}
}
}
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(404)
}
func (router *Router) Listen(host string, port int) error {
return http.ListenAndServe(fmt.Sprintf("%s:%d", host, port), router)
}

82
router_test.go

@ -0,0 +1,82 @@
package wrouter
import (
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
"git.aiterp.net/gisle/wrouter/auth"
)
type testRoute struct {
Name string
}
func (tr *testRoute) Handle(path string, w http.ResponseWriter, req *http.Request, user *auth.User) bool {
w.WriteHeader(200)
w.Write([]byte(tr.Name))
return true
}
func TestPaths(t *testing.T) {
tr1 := testRoute{"Test Route 1"}
tr2 := testRoute{"Test Route 2"}
router := Router{}
router.Route("/test1", &tr1)
router.Route("/test2", &tr2)
t.Run("It finds /test1", func(t *testing.T) {
req := httptest.NewRequest("GET", "http://test.aiterp.net/test1", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
data, _ := ioutil.ReadAll(w.Body)
if string(data) != "Test Route 1" {
t.Error("Wrong content:", string(data))
t.Fail()
}
if w.Code != 200 {
t.Error("Wrong code:", w.Code)
t.Fail()
}
})
t.Run("It finds /test2", func(t *testing.T) {
req := httptest.NewRequest("GET", "http://test.aiterp.net/test2", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
data, _ := ioutil.ReadAll(w.Body)
if string(data) != "Test Route 2" {
t.Error("Wrong content:", string(data))
t.Fail()
}
if w.Header().Get("X-Route-Index") != "1" {
t.Error("Wrong index:", w.Code)
t.Fail()
}
if w.Code != 200 {
t.Error("Wrong code:", w.Code)
t.Fail()
}
})
t.Run("It does not find /test3", func(t *testing.T) {
req := httptest.NewRequest("GET", "http://test.aiterp.net/test3", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != 404 {
t.Error("Wrong code:", w.Code)
t.Fail()
}
})
}
Loading…
Cancel
Save