diff --git a/auth/handler.go b/auth/handler.go index 7d72704..f53a4be 100644 --- a/auth/handler.go +++ b/auth/handler.go @@ -18,6 +18,14 @@ func (h *handler) Handle(path string, w http.ResponseWriter, req *http.Request, } method := FindAuthenticator(req.Form.Get("method")) + if method == nil { + if user == nil { + response.Text(w, 400, "Invalid method: "+req.Form.Get("method")) + return true + } + + method = user.method + } switch strings.ToLower(subpath) { case "login": @@ -33,7 +41,7 @@ func (h *handler) Handle(path string, w http.ResponseWriter, req *http.Request, w.Header().Set("X-Auth-Method", method.Name()) user, err := method.Login(username, password) - if err != nil && user != nil { + if err == nil && user != nil { sess := OpenSession(user) http.SetCookie(w, &http.Cookie{Name: SessionCookieName, Value: sess.ID, Expires: sess.Time.Add(SessionMaxTime)}) @@ -60,7 +68,7 @@ func (h *handler) Handle(path string, w http.ResponseWriter, req *http.Request, password := req.Form.Get("password") user, err := method.Register(username, password, data) - if err != nil && user != nil { + if err == nil && user != nil { sess := OpenSession(user) http.SetCookie(w, &http.Cookie{Name: SessionCookieName, Value: sess.ID, Expires: sess.Time.Add(SessionMaxTime)}) diff --git a/auth/handler_test.go b/auth/handler_test.go index d67200b..37c24dd 100644 --- a/auth/handler_test.go +++ b/auth/handler_test.go @@ -1,10 +1,13 @@ package auth import ( + "encoding/json" "net/http" + "net/http/httptest" "net/url" "strings" "testing" + "time" ) type handlerStruct struct{} @@ -19,14 +22,89 @@ func (hs *handlerStruct) ServeHTTP(w http.ResponseWriter, req *http.Request) { } func TestHandler(t *testing.T) { + server := httptest.NewServer(&handlerStruct{}) + auther := testAuther{FullName: "Test"} Register(&auther) form := url.Values{} + form.Set("method", "test") form.Set("username", "Test") form.Set("password", "stuff'nthings") + form2 := url.Values{} + form2.Set("method", "test") + + form3 := url.Values{} + form3.Set("method", "test") + form3.Set("username", "Test2") + form3.Set("password", "stuff'nthings") + t.Run("Register", func(t *testing.T) { + resp, err := http.PostForm(server.URL+"/auth/register", form) + if err != nil { + t.Error("Request:", err) + t.Fail() + } + + if resp.StatusCode != 200 { + t.Error("Expected 200, got", resp.Status) + t.Fail() + } + + respSession := Session{} + json.NewDecoder(resp.Body).Decode(&respSession) + + if respSession.UserID == "" { + t.Errorf("No user ID in session") + t.Fail() + } + + if time.Since(respSession.Time) > time.Second { + t.Error("Session time is too low", time.Since(respSession.Time)) + t.Fail() + } + }) + + t.Run("Login", func(t *testing.T) { + resp, err := http.PostForm(server.URL+"/auth/login", form) + if err != nil { + t.Error("Request:", err) + t.Fail() + } + + if resp.StatusCode != 200 { + t.Error("Expected 200, got", resp.Status) + t.Fail() + } + + respSession := Session{} + json.NewDecoder(resp.Body).Decode(&respSession) + + if respSession.UserID == "" { + t.Errorf("No user ID in session") + t.Fail() + } + }) + + t.Run("Login_Fail", func(t *testing.T) { + resp, err := http.PostForm(server.URL+"/auth/login", form3) + if err != nil { + t.Error("Request:", err) + t.Fail() + } + + if resp.StatusCode != 401 { + t.Error("Expected 401, got", resp.Status) + t.Fail() + } + + respSession := Session{} + json.NewDecoder(resp.Body).Decode(&respSession) + if respSession.UserID != "" { + t.Errorf("A user ID in supposedly empty session") + t.Fail() + } }) } diff --git a/auth/user.go b/auth/user.go index a462822..933914a 100644 --- a/auth/user.go +++ b/auth/user.go @@ -6,7 +6,8 @@ type User struct { Level string Data map[string]string - method Authenticator + method Authenticator + loggedOut bool } // FullID is the userid prefixed with the method ID @@ -14,6 +15,17 @@ func (user *User) FullID() string { return user.method.ID() + ":" + user.ID } +// Logout flags the user for logout +func (user *User) Logout() { + user.loggedOut = true +} + +// LoggedOut returns whether the Logout() function has been called +func (user *User) LoggedOut() bool { + return user.loggedOut +} + +// NewUser creates a new User object func NewUser(method Authenticator, id, name, level string, data map[string]string) *User { - return &User{id, name, level, data, method} + return &User{id, name, level, data, method, false} } diff --git a/router.go b/router.go index 7ed8333..3ac0624 100644 --- a/router.go +++ b/router.go @@ -38,9 +38,10 @@ func (router *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { // Resolve session cookies var user *auth.User + var sess *auth.Session cookie, err := req.Cookie(session.CookieName) if cookie != nil && err == nil { - sess := auth.FindSession(cookie.Value) + sess = auth.FindSession(cookie.Value) if sess != nil { user = auth.FindUser(sess.UserID) } @@ -59,6 +60,10 @@ func (router *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { w.Header().Set("X-Route-Index", fmt.Sprint(index)) if route.Handle(path, w, req, user) { + if user != nil && user.LoggedOut() { + auth.CloseSession(sess.ID) + } + return } }