diff --git a/auth/handler.go b/auth/handler.go index f53a4be..491e94b 100644 --- a/auth/handler.go +++ b/auth/handler.go @@ -43,7 +43,7 @@ func (h *handler) Handle(path string, w http.ResponseWriter, req *http.Request, user, err := method.Login(username, password) if err == nil && user != nil { sess := OpenSession(user) - http.SetCookie(w, &http.Cookie{Name: SessionCookieName, Value: sess.ID, Expires: sess.Time.Add(SessionMaxTime)}) + http.SetCookie(w, &http.Cookie{Name: SessionCookieName, Value: sess.ID, Expires: sess.Time.Add(SessionMaxTime), Path: "/", HttpOnly: true}) response.JSON(w, 200, sess) } else { @@ -70,11 +70,11 @@ func (h *handler) Handle(path string, w http.ResponseWriter, req *http.Request, user, err := method.Register(username, password, data) if err == nil && user != nil { sess := OpenSession(user) - http.SetCookie(w, &http.Cookie{Name: SessionCookieName, Value: sess.ID, Expires: sess.Time.Add(SessionMaxTime)}) + http.SetCookie(w, &http.Cookie{Name: SessionCookieName, Value: sess.ID, Expires: sess.Time.Add(SessionMaxTime), Path: "/", HttpOnly: true}) response.JSON(w, 200, sess) } else { - response.Text(w, 401, "Register failed") + response.Text(w, 401, err.Error()) } } case "logout-all": diff --git a/auth/handler_test.go b/auth/handler_test.go index 37c24dd..1d44e7c 100644 --- a/auth/handler_test.go +++ b/auth/handler_test.go @@ -3,6 +3,7 @@ package auth import ( "encoding/json" "net/http" + "net/http/cookiejar" "net/http/httptest" "net/url" "strings" @@ -22,8 +23,16 @@ func (hs *handlerStruct) ServeHTTP(w http.ResponseWriter, req *http.Request) { } func TestHandler(t *testing.T) { - server := httptest.NewServer(&handlerStruct{}) + cookieJar, err := cookiejar.New(nil) + if err != nil { + t.Error("Cookie Jar:", err) + t.Fail() + return + } + server := httptest.NewServer(&handlerStruct{}) + url2, _ := url.Parse(server.URL) + client := &http.Client{Jar: cookieJar} auther := testAuther{FullName: "Test"} Register(&auther) @@ -41,7 +50,7 @@ func TestHandler(t *testing.T) { form3.Set("password", "stuff'nthings") t.Run("Register", func(t *testing.T) { - resp, err := http.PostForm(server.URL+"/auth/register", form) + resp, err := client.PostForm(server.URL+"/auth/register", form) if err != nil { t.Error("Request:", err) t.Fail() @@ -67,7 +76,7 @@ func TestHandler(t *testing.T) { }) t.Run("Login", func(t *testing.T) { - resp, err := http.PostForm(server.URL+"/auth/login", form) + resp, err := client.PostForm(server.URL+"/auth/login", form) if err != nil { t.Error("Request:", err) t.Fail() @@ -78,6 +87,11 @@ func TestHandler(t *testing.T) { t.Fail() } + if len(resp.Cookies()) == 0 || len(client.Jar.Cookies(url2)) == 0 { + t.Error("No cookies set") + t.Fail() + } + respSession := Session{} json.NewDecoder(resp.Body).Decode(&respSession) @@ -87,8 +101,30 @@ func TestHandler(t *testing.T) { } }) + // TODO: Move to router test + /* t.Run("Status", func(t *testing.T) { + resp, err := client.Get(server.URL + "/auth/status?method=test") + if err != nil { + t.Error("Request:", err) + t.Fail() + } + + if resp.StatusCode != 200 { + t.Error("Expected 200, got", resp.Status) + t.Fail() + } + + respSession := Session{} + json.NewDecoder(resp.Body).Decode(&respSession) + + if respSession.UserID == "" { + t.Errorf("No user ID in session") + t.Fail() + } + }) */ + t.Run("Login_Fail", func(t *testing.T) { - resp, err := http.PostForm(server.URL+"/auth/login", form3) + resp, err := client.PostForm(server.URL+"/auth/login", form3) if err != nil { t.Error("Request:", err) t.Fail() diff --git a/auth/list.go b/auth/list.go index 28cb3c4..258e977 100644 --- a/auth/list.go +++ b/auth/list.go @@ -1,6 +1,9 @@ package auth -import "strings" +import ( + "errors" + "strings" +) var methods []Authenticator @@ -29,15 +32,15 @@ func ListAuthenticators() []Authenticator { return dst } -func FindUser(fullid string) *User { +func FindUser(fullid string) (*User, error) { split := strings.SplitN(fullid, ":", 2) autherID := split[0] userID := split[1] auther := FindAuthenticator(autherID) if auther == nil { - return nil + return nil, errors.New("auth: auther not found") } - return auther.Find(userID) + return auther.Find(userID), nil } diff --git a/auth/session.go b/auth/session.go index b388584..ec3f68b 100644 --- a/auth/session.go +++ b/auth/session.go @@ -44,9 +44,8 @@ func OpenSession(user *User) *Session { // FindSession returns a session if the id maps to a still valid session func FindSession(id string) *Session { sessionMutex.RLock() - defer sessionMutex.RUnlock() - session := sessions[id] + sessionMutex.RUnlock() // Check expiry and update if session != nil { diff --git a/auth/session_test.go b/auth/session_test.go index f2089d0..2b7964d 100644 --- a/auth/session_test.go +++ b/auth/session_test.go @@ -3,9 +3,16 @@ package auth import "testing" func TestSession(t *testing.T) { - auther := testAuther{FullName: "Test"} + auther := testAuther{FullName: "Test64"} + Register(&auther) + + user, err := auther.Register("Tester1401", "1234", nil) + if err != nil { + t.Error("Failed to register test user:", err) + t.Fail() + return + } - user := NewUser(&auther, "Tester", "member", nil) sessions := []*Session{OpenSession(user), OpenSession(user), OpenSession(user)} ids := []string{sessions[0].ID, sessions[1].ID, sessions[2].ID} @@ -20,6 +27,24 @@ func TestSession(t *testing.T) { } }) + t.Run("Auth", func(t *testing.T) { + for _, id := range ids { + found := FindSession(id) + + foundUser, err := FindUser(found.UserID) + + if foundUser == nil { + t.Error("User", found.UserID, "not found", err) + t.Fail() + } + + if foundUser != nil && foundUser != user { + t.Error("User", found.UserID, "is not correct", foundUser) + t.Fail() + } + } + }) + t.Run("Close", func(t *testing.T) { CloseSession(ids[2]) diff --git a/router.go b/router.go index 8632884..2c045e4 100644 --- a/router.go +++ b/router.go @@ -40,10 +40,11 @@ func (router *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { var user *auth.User var sess *auth.Session cookie, err := req.Cookie(auth.SessionCookieName) + if cookie != nil && err == nil { sess = auth.FindSession(cookie.Value) if sess != nil { - user = auth.FindUser(sess.UserID) + user, _ = auth.FindUser(sess.UserID) } }