|
|
package services
import ( "context" "github.com/gin-gonic/gin" "github.com/gisle/stufflog/config" "github.com/gisle/stufflog/database" "github.com/gisle/stufflog/internal/generate" "github.com/gisle/stufflog/models" "github.com/gisle/stufflog/slerrors" "golang.org/x/crypto/bcrypt" "time" )
var dummyPassword []byte
type ctxKey string
const sessionCookie = "stufflog_sessid" const sessionCtxKey = ctxKey("stufflog_session") const userCtxKey = ctxKey("stufflog_user")
type AuthService struct { db database.Database }
// UserFromContext gets the User object stored in the context.
func (s *AuthService) UserFromContext(ctx context.Context) *models.User { if gctx, ok := ctx.(*gin.Context); ok { ctx = gctx.Request.Context() }
session := s.SessionFromContext(ctx) if session == nil { return nil }
user, _ := ctx.Value(userCtxKey).(*models.User) return user }
// SessionFromContext gets the Session object stored in the context.
func (s *AuthService) SessionFromContext(ctx context.Context) *models.UserSession { if gctx, ok := ctx.(*gin.Context); ok { ctx = gctx.Request.Context() }
session, _ := ctx.Value(sessionCtxKey).(*models.UserSession) if session == nil || time.Now().After(session.Expires) { return nil }
return session }
// Register registers a new user.
func (s *AuthService) Register(ctx context.Context, username, password string) (*models.User, error) { if !config.Get().Users.AllowRegister { return nil, &slerrors.SLError{Code: 403, Text: "Registration is disabled"} }
if len(password) < 6 { return nil, &slerrors.SLError{Code: 400, Text: "Password is too short"} } if len(username) == 0 { return nil, &slerrors.SLError{Code: 400, Text: "No username provided"} }
if _, err := s.db.Users().FindID(ctx, username); !slerrors.IsNotFound(err) { return nil, &slerrors.SLError{Code: 409, Text: "Username is not available"} }
user := models.User{ ID: username, Name: username, } err := user.SetPassword(password) if err != nil { return nil, err } err = s.db.Users().Save(ctx, user)
return &user, nil }
// Login logs in the user and creates the session. The session is not saved to the request. For that you need to use
// GinSaveSession on the returned session.
func (s *AuthService) Login(ctx context.Context, userID, password string) (*models.User, *models.UserSession, error) { user, err := s.db.Users().FindID(ctx, userID) if err != nil { _ = bcrypt.CompareHashAndPassword(dummyPassword, []byte("u r a dummy too")) return nil, nil, slerrors.LoginFailed() }
err = user.CheckPassword(password) if err != nil { return nil, nil, slerrors.LoginFailed() }
session := models.UserSession{ ID: generate.ID("S", 64), Expires: time.Now().Add(time.Hour * 24 * 30), UserID: user.ID, } err = s.db.UserSessions().Save(ctx, session) if err != nil { return nil, nil, err }
return user, &session, nil }
// Logout deletes the session associated with the context.
func (s *AuthService) Logout(ctx context.Context) error { session := s.SessionFromContext(ctx) if session == nil { return slerrors.Unauthorized().WithText("You're already logged out.") }
session.Expires = time.Time{}
return s.db.UserSessions().Remove(ctx, *session) }
// SaveSession saves a session to gin.
func (s *AuthService) GinSaveSession(c *gin.Context, session models.UserSession) { c.SetCookie(sessionCookie, session.ID, 336*3600, "", "", false, true) }
// SaveSession saves a session to gin.
func (s *AuthService) GinClearSession(c *gin.Context) { c.SetCookie(sessionCookie, "", 0, "", "", false, true) }
// Middleware is a middleware that adds a session to the context, or
// gives a 403 Unauthorized error if there is none.
func (s *AuthService) GinSessionMiddleware(required bool) gin.HandlerFunc { return func(c *gin.Context) { ctx := c.Request.Context()
// Find cookie
cookie, err := c.Cookie(sessionCookie) if err != nil || cookie == "" { if required { slerrors.GinRespond(c, slerrors.Unauthorized()) c.Abort() return }
c.Next() return }
// Check session existence
session, err := s.db.UserSessions().FindID(ctx, cookie) if err != nil { c.SetCookie(sessionCookie, "", 0, "", "", false, true)
if required { slerrors.GinRespond(c, slerrors.Unauthorized()) c.Abort() return }
c.Next() return } ctx = context.WithValue(ctx, sessionCtxKey, session)
// Find the user associated with the session
user, err := s.db.Users().FindID(ctx, session.UserID) if err != nil { c.SetCookie(sessionCookie, "", 0, "", "", false, true) _ = s.db.UserSessions().Remove(ctx, *session)
if required { slerrors.GinRespond(c, slerrors.Unauthorized()) c.Abort() return }
c.Next() return } ctx = context.WithValue(ctx, userCtxKey, user)
// Check if session has expired
if time.Now().After(session.Expires) { c.SetCookie(sessionCookie, "", 0, "", "", false, true) _ = s.db.UserSessions().Remove(ctx, *session)
if required { slerrors.GinRespond(c, slerrors.Unauthorized()) c.Abort() return }
c.Next() return }
// If there's less than 27.75 days, reset expiry.
if time.Until(session.Expires) < time.Hour*666 { session.Expires = time.Now().Add(time.Hour * 672) s.GinSaveSession(c, *session) _ = s.db.UserSessions().Save(c.Request.Context(), *session) }
// Proceed.
c.Request = c.Request.WithContext(ctx) c.Next() } }
func NewAuthService(db database.Database) *AuthService { return &AuthService{db: db} }
func init() { dummyPassword, _ = bcrypt.GenerateFromPassword([]byte("u r a dummy"), bcrypt.DefaultCost) }
|