Browse Source

Fixed handler_test

master
Gisle Aune 7 years ago
parent
commit
25f0f1ece0
  1. 6
      auth/handler.go
  2. 44
      auth/handler_test.go
  3. 11
      auth/list.go
  4. 3
      auth/session.go
  5. 29
      auth/session_test.go
  6. 3
      router.go

6
auth/handler.go

@ -43,7 +43,7 @@ func (h *handler) Handle(path string, w http.ResponseWriter, req *http.Request,
user, err := method.Login(username, password) user, err := method.Login(username, password)
if err == nil && user != nil { if err == nil && user != nil {
sess := OpenSession(user) sess := OpenSession(user)
http.SetCookie(w, &http.Cookie{Name: SessionCookieName, Value: sess.ID, Expires: sess.Time.Add(SessionMaxTime)})
http.SetCookie(w, &http.Cookie{Name: SessionCookieName, Value: sess.ID, Expires: sess.Time.Add(SessionMaxTime), Path: "/", HttpOnly: true})
response.JSON(w, 200, sess) response.JSON(w, 200, sess)
} else { } else {
@ -70,11 +70,11 @@ func (h *handler) Handle(path string, w http.ResponseWriter, req *http.Request,
user, err := method.Register(username, password, data) user, err := method.Register(username, password, data)
if err == nil && user != nil { if err == nil && user != nil {
sess := OpenSession(user) sess := OpenSession(user)
http.SetCookie(w, &http.Cookie{Name: SessionCookieName, Value: sess.ID, Expires: sess.Time.Add(SessionMaxTime)})
http.SetCookie(w, &http.Cookie{Name: SessionCookieName, Value: sess.ID, Expires: sess.Time.Add(SessionMaxTime), Path: "/", HttpOnly: true})
response.JSON(w, 200, sess) response.JSON(w, 200, sess)
} else { } else {
response.Text(w, 401, "Register failed")
response.Text(w, 401, err.Error())
} }
} }
case "logout-all": case "logout-all":

44
auth/handler_test.go

@ -3,6 +3,7 @@ package auth
import ( import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"net/http/cookiejar"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"strings" "strings"
@ -22,8 +23,16 @@ func (hs *handlerStruct) ServeHTTP(w http.ResponseWriter, req *http.Request) {
} }
func TestHandler(t *testing.T) { func TestHandler(t *testing.T) {
server := httptest.NewServer(&handlerStruct{})
cookieJar, err := cookiejar.New(nil)
if err != nil {
t.Error("Cookie Jar:", err)
t.Fail()
return
}
server := httptest.NewServer(&handlerStruct{})
url2, _ := url.Parse(server.URL)
client := &http.Client{Jar: cookieJar}
auther := testAuther{FullName: "Test"} auther := testAuther{FullName: "Test"}
Register(&auther) Register(&auther)
@ -41,7 +50,7 @@ func TestHandler(t *testing.T) {
form3.Set("password", "stuff'nthings") form3.Set("password", "stuff'nthings")
t.Run("Register", func(t *testing.T) { t.Run("Register", func(t *testing.T) {
resp, err := http.PostForm(server.URL+"/auth/register", form)
resp, err := client.PostForm(server.URL+"/auth/register", form)
if err != nil { if err != nil {
t.Error("Request:", err) t.Error("Request:", err)
t.Fail() t.Fail()
@ -67,7 +76,7 @@ func TestHandler(t *testing.T) {
}) })
t.Run("Login", func(t *testing.T) { t.Run("Login", func(t *testing.T) {
resp, err := http.PostForm(server.URL+"/auth/login", form)
resp, err := client.PostForm(server.URL+"/auth/login", form)
if err != nil { if err != nil {
t.Error("Request:", err) t.Error("Request:", err)
t.Fail() t.Fail()
@ -78,6 +87,11 @@ func TestHandler(t *testing.T) {
t.Fail() t.Fail()
} }
if len(resp.Cookies()) == 0 || len(client.Jar.Cookies(url2)) == 0 {
t.Error("No cookies set")
t.Fail()
}
respSession := Session{} respSession := Session{}
json.NewDecoder(resp.Body).Decode(&respSession) json.NewDecoder(resp.Body).Decode(&respSession)
@ -87,8 +101,30 @@ func TestHandler(t *testing.T) {
} }
}) })
// TODO: Move to router test
/* t.Run("Status", func(t *testing.T) {
resp, err := client.Get(server.URL + "/auth/status?method=test")
if err != nil {
t.Error("Request:", err)
t.Fail()
}
if resp.StatusCode != 200 {
t.Error("Expected 200, got", resp.Status)
t.Fail()
}
respSession := Session{}
json.NewDecoder(resp.Body).Decode(&respSession)
if respSession.UserID == "" {
t.Errorf("No user ID in session")
t.Fail()
}
}) */
t.Run("Login_Fail", func(t *testing.T) { t.Run("Login_Fail", func(t *testing.T) {
resp, err := http.PostForm(server.URL+"/auth/login", form3)
resp, err := client.PostForm(server.URL+"/auth/login", form3)
if err != nil { if err != nil {
t.Error("Request:", err) t.Error("Request:", err)
t.Fail() t.Fail()

11
auth/list.go

@ -1,6 +1,9 @@
package auth package auth
import "strings"
import (
"errors"
"strings"
)
var methods []Authenticator var methods []Authenticator
@ -29,15 +32,15 @@ func ListAuthenticators() []Authenticator {
return dst return dst
} }
func FindUser(fullid string) *User {
func FindUser(fullid string) (*User, error) {
split := strings.SplitN(fullid, ":", 2) split := strings.SplitN(fullid, ":", 2)
autherID := split[0] autherID := split[0]
userID := split[1] userID := split[1]
auther := FindAuthenticator(autherID) auther := FindAuthenticator(autherID)
if auther == nil { if auther == nil {
return nil
return nil, errors.New("auth: auther not found")
} }
return auther.Find(userID)
return auther.Find(userID), nil
} }

3
auth/session.go

@ -44,9 +44,8 @@ func OpenSession(user *User) *Session {
// FindSession returns a session if the id maps to a still valid session // FindSession returns a session if the id maps to a still valid session
func FindSession(id string) *Session { func FindSession(id string) *Session {
sessionMutex.RLock() sessionMutex.RLock()
defer sessionMutex.RUnlock()
session := sessions[id] session := sessions[id]
sessionMutex.RUnlock()
// Check expiry and update // Check expiry and update
if session != nil { if session != nil {

29
auth/session_test.go

@ -3,9 +3,16 @@ package auth
import "testing" import "testing"
func TestSession(t *testing.T) { func TestSession(t *testing.T) {
auther := testAuther{FullName: "Test"}
auther := testAuther{FullName: "Test64"}
Register(&auther)
user, err := auther.Register("Tester1401", "1234", nil)
if err != nil {
t.Error("Failed to register test user:", err)
t.Fail()
return
}
user := NewUser(&auther, "Tester", "member", nil)
sessions := []*Session{OpenSession(user), OpenSession(user), OpenSession(user)} sessions := []*Session{OpenSession(user), OpenSession(user), OpenSession(user)}
ids := []string{sessions[0].ID, sessions[1].ID, sessions[2].ID} ids := []string{sessions[0].ID, sessions[1].ID, sessions[2].ID}
@ -20,6 +27,24 @@ func TestSession(t *testing.T) {
} }
}) })
t.Run("Auth", func(t *testing.T) {
for _, id := range ids {
found := FindSession(id)
foundUser, err := FindUser(found.UserID)
if foundUser == nil {
t.Error("User", found.UserID, "not found", err)
t.Fail()
}
if foundUser != nil && foundUser != user {
t.Error("User", found.UserID, "is not correct", foundUser)
t.Fail()
}
}
})
t.Run("Close", func(t *testing.T) { t.Run("Close", func(t *testing.T) {
CloseSession(ids[2]) CloseSession(ids[2])

3
router.go

@ -40,10 +40,11 @@ func (router *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
var user *auth.User var user *auth.User
var sess *auth.Session var sess *auth.Session
cookie, err := req.Cookie(auth.SessionCookieName) cookie, err := req.Cookie(auth.SessionCookieName)
if cookie != nil && err == nil { if cookie != nil && err == nil {
sess = auth.FindSession(cookie.Value) sess = auth.FindSession(cookie.Value)
if sess != nil { if sess != nil {
user = auth.FindUser(sess.UserID)
user, _ = auth.FindUser(sess.UserID)
} }
} }

Loading…
Cancel
Save