diff --git a/controllers/user-controller.go b/controllers/user-controller.go index 045a119..18edd15 100644 --- a/controllers/user-controller.go +++ b/controllers/user-controller.go @@ -3,6 +3,7 @@ package controllers import ( "encoding/json" "net/http" + "time" "git.aiterp.net/lucifer/lucifer/internal/respond" "git.aiterp.net/lucifer/lucifer/models" @@ -16,7 +17,10 @@ type UserController struct { // getUsers (`GET /`): List users func (c *UserController) getUsers(w http.ResponseWriter, r *http.Request) { - // TODO: Check session + if session := models.SessionFromContext(r.Context()); session == nil { + respond.Error(w, 403, "permission_denied", "You must log in") + return + } users, err := c.users.List(r.Context()) if err != nil { @@ -51,7 +55,13 @@ func (c *UserController) login(w http.ResponseWriter, r *http.Request) { return } - // TODO: Open session + session := models.Session{ + Expires: time.Now().Add(7 * 24 * time.Hour), + UserID: user.ID, + } + session.GenerateID() + + http.SetCookie(w, session.Cookie()) respond.JSON(w, 200, user) } @@ -60,8 +70,8 @@ func (c *UserController) login(w http.ResponseWriter, r *http.Request) { func (c *UserController) Mount(router *mux.Router, prefix string) { sub := router.PathPrefix(prefix).Subrouter() - sub.Handle("/", http.HandlerFunc(c.getUsers)).Methods("GET") - sub.Handle("/login", http.HandlerFunc(c.login)).Methods("POST") + sub.HandleFunc("/", c.getUsers).Methods("GET") + sub.HandleFunc("/login", c.login).Methods("POST") } // NewUserController creates a new UserController. diff --git a/middlewares/session.go b/middlewares/session.go new file mode 100644 index 0000000..ff727a6 --- /dev/null +++ b/middlewares/session.go @@ -0,0 +1,58 @@ +package middlewares + +import ( + "net/http" + "time" + + "git.aiterp.net/lucifer/lucifer/models" + "github.com/gorilla/mux" +) + +// Session is a middleware that adds a Session to the request context if there +// is one. +func Session(repo models.SessionRepository) mux.MiddlewareFunc { + clearCookie := &http.Cookie{ + Name: "lucifer_session", + Value: "", + Path: "/", + Expires: time.Unix(0, 0), + + HttpOnly: true, + } + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Find cookie + cookie, err := r.Cookie("lucifer_session") + if err == nil && cookie != nil { + next.ServeHTTP(w, r) + return + } + + // Check cookie expiration + if cookie.Expires.IsZero() || time.Now().After(cookie.Expires) { + http.SetCookie(w, clearCookie) + next.ServeHTTP(w, r) + return + } + + // Check session existence + session, err := repo.FindSessionByID(cookie.Value) + if err != nil { + http.SetCookie(w, clearCookie) + next.ServeHTTP(w, r) + return + } + + // Check if session has expired + if session.Expired() { + http.SetCookie(w, clearCookie) + next.ServeHTTP(w, r) + return + } + + // Proceed. + next.ServeHTTP(w, r.WithContext(session.InContext(r.Context()))) + }) + } +} diff --git a/models/session.go b/models/session.go index 8fb780a..2cc7afa 100644 --- a/models/session.go +++ b/models/session.go @@ -1,12 +1,18 @@ package models import ( + "context" "crypto/rand" "encoding/binary" + "net/http" "strconv" "time" ) +type sessionCtxKeyType string + +const sessionCtxKey = sessionCtxKeyType("luciter_session") + // The Session model represents a user being logged in. type Session struct { ID string @@ -44,11 +50,37 @@ func (session *Session) GenerateID() error { return nil } +// Cookie gets a cookie that could be used to restore this session. +func (session *Session) Cookie() *http.Cookie { + return &http.Cookie{ + Name: "lucifer_session", + Value: session.ID, + Path: "/", + Expires: session.Expires, + HttpOnly: true, + } +} + +// Expired returns true if the session has expired. +func (session *Session) Expired() bool { + return time.Now().After(session.Expires) +} + +// InContext returns a child context with the value. +func (session *Session) InContext(parent context.Context) context.Context { + return context.WithValue(parent, sessionCtxKey, session) +} + +// SessionFromContext gets the session from context, or `nil` if no session is available. +func SessionFromContext(ctx context.Context) *Session { + return ctx.Value(sessionCtxKey).(*Session) +} + // SessionRepository is an interface for all database operations // the user model makes. type SessionRepository interface { // FindSessionByID finds a non-expired session by ID. - FindSessionByID(id int) (Session, error) + FindSessionByID(id string) (Session, error) InsertSession(session Session) error RemoveSession(session Session) error ClearUserSessions(user User) error