Browse Source

Merge remote-tracking branch 'origin/master' into webui

webui
Stian Aune 5 years ago
parent
commit
195738d6ed
  1. 18
      controllers/user-controller.go
  2. 58
      middlewares/session.go
  3. 34
      models/session.go

18
controllers/user-controller.go

@ -3,6 +3,7 @@ package controllers
import (
"encoding/json"
"net/http"
"time"
"git.aiterp.net/lucifer/lucifer/internal/respond"
"git.aiterp.net/lucifer/lucifer/models"
@ -16,7 +17,10 @@ type UserController struct {
// getUsers (`GET /`): List users
func (c *UserController) getUsers(w http.ResponseWriter, r *http.Request) {
// TODO: Check session
if session := models.SessionFromContext(r.Context()); session == nil {
respond.Error(w, 403, "permission_denied", "You must log in")
return
}
users, err := c.users.List(r.Context())
if err != nil {
@ -51,7 +55,13 @@ func (c *UserController) login(w http.ResponseWriter, r *http.Request) {
return
}
// TODO: Open session
session := models.Session{
Expires: time.Now().Add(7 * 24 * time.Hour),
UserID: user.ID,
}
session.GenerateID()
http.SetCookie(w, session.Cookie())
respond.JSON(w, 200, user)
}
@ -60,8 +70,8 @@ func (c *UserController) login(w http.ResponseWriter, r *http.Request) {
func (c *UserController) Mount(router *mux.Router, prefix string) {
sub := router.PathPrefix(prefix).Subrouter()
sub.Handle("/", http.HandlerFunc(c.getUsers)).Methods("GET")
sub.Handle("/login", http.HandlerFunc(c.login)).Methods("POST")
sub.HandleFunc("/", c.getUsers).Methods("GET")
sub.HandleFunc("/login", c.login).Methods("POST")
}
// NewUserController creates a new UserController.

58
middlewares/session.go

@ -0,0 +1,58 @@
package middlewares
import (
"net/http"
"time"
"git.aiterp.net/lucifer/lucifer/models"
"github.com/gorilla/mux"
)
// Session is a middleware that adds a Session to the request context if there
// is one.
func Session(repo models.SessionRepository) mux.MiddlewareFunc {
clearCookie := &http.Cookie{
Name: "lucifer_session",
Value: "",
Path: "/",
Expires: time.Unix(0, 0),
HttpOnly: true,
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Find cookie
cookie, err := r.Cookie("lucifer_session")
if err == nil && cookie != nil {
next.ServeHTTP(w, r)
return
}
// Check cookie expiration
if cookie.Expires.IsZero() || time.Now().After(cookie.Expires) {
http.SetCookie(w, clearCookie)
next.ServeHTTP(w, r)
return
}
// Check session existence
session, err := repo.FindSessionByID(cookie.Value)
if err != nil {
http.SetCookie(w, clearCookie)
next.ServeHTTP(w, r)
return
}
// Check if session has expired
if session.Expired() {
http.SetCookie(w, clearCookie)
next.ServeHTTP(w, r)
return
}
// Proceed.
next.ServeHTTP(w, r.WithContext(session.InContext(r.Context())))
})
}
}

34
models/session.go

@ -1,12 +1,18 @@
package models
import (
"context"
"crypto/rand"
"encoding/binary"
"net/http"
"strconv"
"time"
)
type sessionCtxKeyType string
const sessionCtxKey = sessionCtxKeyType("luciter_session")
// The Session model represents a user being logged in.
type Session struct {
ID string
@ -44,11 +50,37 @@ func (session *Session) GenerateID() error {
return nil
}
// Cookie gets a cookie that could be used to restore this session.
func (session *Session) Cookie() *http.Cookie {
return &http.Cookie{
Name: "lucifer_session",
Value: session.ID,
Path: "/",
Expires: session.Expires,
HttpOnly: true,
}
}
// Expired returns true if the session has expired.
func (session *Session) Expired() bool {
return time.Now().After(session.Expires)
}
// InContext returns a child context with the value.
func (session *Session) InContext(parent context.Context) context.Context {
return context.WithValue(parent, sessionCtxKey, session)
}
// SessionFromContext gets the session from context, or `nil` if no session is available.
func SessionFromContext(ctx context.Context) *Session {
return ctx.Value(sessionCtxKey).(*Session)
}
// SessionRepository is an interface for all database operations
// the user model makes.
type SessionRepository interface {
// FindSessionByID finds a non-expired session by ID.
FindSessionByID(id int) (Session, error)
FindSessionByID(id string) (Session, error)
InsertSession(session Session) error
RemoveSession(session Session) error
ClearUserSessions(user User) error

Loading…
Cancel
Save