You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
221 lines
5.5 KiB
221 lines
5.5 KiB
package services
|
|
|
|
import (
|
|
"context"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/gisle/stufflog/config"
|
|
"github.com/gisle/stufflog/database"
|
|
"github.com/gisle/stufflog/internal/generate"
|
|
"github.com/gisle/stufflog/models"
|
|
"github.com/gisle/stufflog/slerrors"
|
|
"golang.org/x/crypto/bcrypt"
|
|
"time"
|
|
)
|
|
|
|
var dummyPassword []byte
|
|
|
|
type ctxKey string
|
|
|
|
const sessionCookie = "stufflog_sessid"
|
|
const sessionCtxKey = ctxKey("stufflog_session")
|
|
const userCtxKey = ctxKey("stufflog_user")
|
|
|
|
type AuthService struct {
|
|
db database.Database
|
|
}
|
|
|
|
// UserFromContext gets the User object stored in the context.
|
|
func (s *AuthService) UserFromContext(ctx context.Context) *models.User {
|
|
if gctx, ok := ctx.(*gin.Context); ok {
|
|
ctx = gctx.Request.Context()
|
|
}
|
|
|
|
session := s.SessionFromContext(ctx)
|
|
if session == nil {
|
|
return nil
|
|
}
|
|
|
|
user, _ := ctx.Value(userCtxKey).(*models.User)
|
|
return user
|
|
}
|
|
|
|
// SessionFromContext gets the Session object stored in the context.
|
|
func (s *AuthService) SessionFromContext(ctx context.Context) *models.UserSession {
|
|
if gctx, ok := ctx.(*gin.Context); ok {
|
|
ctx = gctx.Request.Context()
|
|
}
|
|
|
|
session, _ := ctx.Value(sessionCtxKey).(*models.UserSession)
|
|
if session == nil || time.Now().After(session.Expires) {
|
|
return nil
|
|
}
|
|
|
|
return session
|
|
}
|
|
|
|
// Register registers a new user.
|
|
func (s *AuthService) Register(ctx context.Context, username, password string) (*models.User, error) {
|
|
if !config.Get().Users.AllowRegister {
|
|
return nil, &slerrors.SLError{Code: 403, Text: "Registration is disabled"}
|
|
}
|
|
|
|
if len(password) < 6 {
|
|
return nil, &slerrors.SLError{Code: 400, Text: "Password is too short"}
|
|
}
|
|
if len(username) == 0 {
|
|
return nil, &slerrors.SLError{Code: 400, Text: "No username provided"}
|
|
}
|
|
|
|
if _, err := s.db.Users().FindID(ctx, username); !slerrors.IsNotFound(err) {
|
|
return nil, &slerrors.SLError{Code: 409, Text: "Username is not available"}
|
|
}
|
|
|
|
user := models.User{
|
|
ID: username,
|
|
Name: username,
|
|
}
|
|
err := user.SetPassword(password)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
err = s.db.Users().Save(ctx, user)
|
|
|
|
return &user, nil
|
|
}
|
|
|
|
// Login logs in the user and creates the session. The session is not saved to the request. For that you need to use
|
|
// GinSaveSession on the returned session.
|
|
func (s *AuthService) Login(ctx context.Context, userID, password string) (*models.User, *models.UserSession, error) {
|
|
user, err := s.db.Users().FindID(ctx, userID)
|
|
if err != nil {
|
|
_ = bcrypt.CompareHashAndPassword(dummyPassword, []byte("u r a dummy too"))
|
|
return nil, nil, slerrors.LoginFailed()
|
|
}
|
|
|
|
err = user.CheckPassword(password)
|
|
if err != nil {
|
|
return nil, nil, slerrors.LoginFailed()
|
|
}
|
|
|
|
session := models.UserSession{
|
|
ID: generate.ID("S", 64),
|
|
Expires: time.Now().Add(time.Hour * 24 * 30),
|
|
UserID: user.ID,
|
|
}
|
|
err = s.db.UserSessions().Save(ctx, session)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
return user, &session, nil
|
|
}
|
|
|
|
// Logout deletes the session associated with the context.
|
|
func (s *AuthService) Logout(ctx context.Context) error {
|
|
session := s.SessionFromContext(ctx)
|
|
if session == nil {
|
|
return slerrors.Unauthorized().WithText("You're already logged out.")
|
|
}
|
|
|
|
session.Expires = time.Time{}
|
|
|
|
return s.db.UserSessions().Remove(ctx, *session)
|
|
}
|
|
|
|
// SaveSession saves a session to gin.
|
|
func (s *AuthService) GinSaveSession(c *gin.Context, session models.UserSession) {
|
|
c.SetCookie(sessionCookie, session.ID, 336*3600, "", "", false, true)
|
|
}
|
|
|
|
// SaveSession saves a session to gin.
|
|
func (s *AuthService) GinClearSession(c *gin.Context) {
|
|
c.SetCookie(sessionCookie, "", 0, "", "", false, true)
|
|
}
|
|
|
|
// Middleware is a middleware that adds a session to the context, or
|
|
// gives a 403 Unauthorized error if there is none.
|
|
func (s *AuthService) GinSessionMiddleware(required bool) gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
ctx := c.Request.Context()
|
|
|
|
// Find cookie
|
|
cookie, err := c.Cookie(sessionCookie)
|
|
if err != nil || cookie == "" {
|
|
if required {
|
|
slerrors.GinRespond(c, slerrors.Unauthorized())
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
c.Next()
|
|
return
|
|
}
|
|
|
|
// Check session existence
|
|
session, err := s.db.UserSessions().FindID(ctx, cookie)
|
|
if err != nil {
|
|
c.SetCookie(sessionCookie, "", 0, "", "", false, true)
|
|
|
|
if required {
|
|
slerrors.GinRespond(c, slerrors.Unauthorized())
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
c.Next()
|
|
return
|
|
}
|
|
ctx = context.WithValue(ctx, sessionCtxKey, session)
|
|
|
|
// Find the user associated with the session
|
|
user, err := s.db.Users().FindID(ctx, session.UserID)
|
|
if err != nil {
|
|
c.SetCookie(sessionCookie, "", 0, "", "", false, true)
|
|
_ = s.db.UserSessions().Remove(ctx, *session)
|
|
|
|
if required {
|
|
slerrors.GinRespond(c, slerrors.Unauthorized())
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
c.Next()
|
|
return
|
|
}
|
|
ctx = context.WithValue(ctx, userCtxKey, user)
|
|
|
|
// Check if session has expired
|
|
if time.Now().After(session.Expires) {
|
|
c.SetCookie(sessionCookie, "", 0, "", "", false, true)
|
|
_ = s.db.UserSessions().Remove(ctx, *session)
|
|
|
|
if required {
|
|
slerrors.GinRespond(c, slerrors.Unauthorized())
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
c.Next()
|
|
return
|
|
}
|
|
|
|
// If there's less than 27.75 days, reset expiry.
|
|
if time.Until(session.Expires) < time.Hour*666 {
|
|
session.Expires = time.Now().Add(time.Hour * 672)
|
|
s.GinSaveSession(c, *session)
|
|
_ = s.db.UserSessions().Save(c.Request.Context(), *session)
|
|
}
|
|
|
|
// Proceed.
|
|
c.Request = c.Request.WithContext(ctx)
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
func NewAuthService(db database.Database) *AuthService {
|
|
return &AuthService{db: db}
|
|
}
|
|
|
|
func init() {
|
|
dummyPassword, _ = bcrypt.GenerateFromPassword([]byte("u r a dummy"), bcrypt.DefaultCost)
|
|
}
|