package services import ( "context" "github.com/gin-gonic/gin" "github.com/gisle/stufflog/config" "github.com/gisle/stufflog/database" "github.com/gisle/stufflog/internal/generate" "github.com/gisle/stufflog/models" "github.com/gisle/stufflog/slerrors" "golang.org/x/crypto/bcrypt" "time" ) var dummyPassword []byte type ctxKey string const sessionCookie = "stufflog_sessid" const sessionCtxKey = ctxKey("stufflog_session") const userCtxKey = ctxKey("stufflog_user") type AuthService struct { db database.Database } // UserFromContext gets the User object stored in the context. func (s *AuthService) UserFromContext(ctx context.Context) *models.User { if gctx, ok := ctx.(*gin.Context); ok { ctx = gctx.Request.Context() } session := s.SessionFromContext(ctx) if session == nil { return nil } user, _ := ctx.Value(userCtxKey).(*models.User) return user } // SessionFromContext gets the Session object stored in the context. func (s *AuthService) SessionFromContext(ctx context.Context) *models.UserSession { if gctx, ok := ctx.(*gin.Context); ok { ctx = gctx.Request.Context() } session, _ := ctx.Value(sessionCtxKey).(*models.UserSession) if session == nil || time.Now().After(session.Expires) { return nil } return session } // Register registers a new user. func (s *AuthService) Register(ctx context.Context, username, password string) (*models.User, error) { if !config.Get().Users.AllowRegister { return nil, &slerrors.SLError{Code: 403, Text: "Registration is disabled"} } if len(password) < 6 { return nil, &slerrors.SLError{Code: 400, Text: "Password is too short"} } if len(username) == 0 { return nil, &slerrors.SLError{Code: 400, Text: "No username provided"} } if _, err := s.db.Users().FindID(ctx, username); !slerrors.IsNotFound(err) { return nil, &slerrors.SLError{Code: 409, Text: "Username is not available"} } user := models.User{ ID: username, Name: username, } err := user.SetPassword(password) if err != nil { return nil, err } err = s.db.Users().Save(ctx, user) return &user, nil } // Login logs in the user and creates the session. The session is not saved to the request. For that you need to use // GinSaveSession on the returned session. func (s *AuthService) Login(ctx context.Context, userID, password string) (*models.User, *models.UserSession, error) { user, err := s.db.Users().FindID(ctx, userID) if err != nil { _ = bcrypt.CompareHashAndPassword(dummyPassword, []byte("u r a dummy too")) return nil, nil, slerrors.LoginFailed() } err = user.CheckPassword(password) if err != nil { return nil, nil, slerrors.LoginFailed() } session := models.UserSession{ ID: generate.ID("S", 64), Expires: time.Now().Add(time.Hour * 24 * 30), UserID: user.ID, } err = s.db.UserSessions().Save(ctx, session) if err != nil { return nil, nil, err } return user, &session, nil } // Logout deletes the session associated with the context. func (s *AuthService) Logout(ctx context.Context) error { session := s.SessionFromContext(ctx) if session == nil { return slerrors.Unauthorized().WithText("You're already logged out.") } session.Expires = time.Time{} return s.db.UserSessions().Remove(ctx, *session) } // SaveSession saves a session to gin. func (s *AuthService) GinSaveSession(c *gin.Context, session models.UserSession) { c.SetCookie(sessionCookie, session.ID, 336*3600, "", "", false, true) } // SaveSession saves a session to gin. func (s *AuthService) GinClearSession(c *gin.Context) { c.SetCookie(sessionCookie, "", 0, "", "", false, true) } // Middleware is a middleware that adds a session to the context, or // gives a 403 Unauthorized error if there is none. func (s *AuthService) GinSessionMiddleware(required bool) gin.HandlerFunc { return func(c *gin.Context) { ctx := c.Request.Context() // Find cookie cookie, err := c.Cookie(sessionCookie) if err != nil || cookie == "" { if required { slerrors.GinRespond(c, slerrors.Unauthorized()) c.Abort() return } c.Next() return } // Check session existence session, err := s.db.UserSessions().FindID(ctx, cookie) if err != nil { c.SetCookie(sessionCookie, "", 0, "", "", false, true) if required { slerrors.GinRespond(c, slerrors.Unauthorized()) c.Abort() return } c.Next() return } ctx = context.WithValue(ctx, sessionCtxKey, session) // Find the user associated with the session user, err := s.db.Users().FindID(ctx, session.UserID) if err != nil { c.SetCookie(sessionCookie, "", 0, "", "", false, true) _ = s.db.UserSessions().Remove(ctx, *session) if required { slerrors.GinRespond(c, slerrors.Unauthorized()) c.Abort() return } c.Next() return } ctx = context.WithValue(ctx, userCtxKey, user) // Check if session has expired if time.Now().After(session.Expires) { c.SetCookie(sessionCookie, "", 0, "", "", false, true) _ = s.db.UserSessions().Remove(ctx, *session) if required { slerrors.GinRespond(c, slerrors.Unauthorized()) c.Abort() return } c.Next() return } // If there's less than 27.75 days, reset expiry. if time.Until(session.Expires) < time.Hour*666 { session.Expires = time.Now().Add(time.Hour * 672) s.GinSaveSession(c, *session) _ = s.db.UserSessions().Save(c.Request.Context(), *session) } // Proceed. c.Request = c.Request.WithContext(ctx) c.Next() } } func NewAuthService(db database.Database) *AuthService { return &AuthService{db: db} } func init() { dummyPassword, _ = bcrypt.GenerateFromPassword([]byte("u r a dummy"), bcrypt.DefaultCost) }