Plan stuff. Log stuff.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

221 lines
5.5 KiB

5 years ago
  1. package services
  2. import (
  3. "context"
  4. "github.com/gin-gonic/gin"
  5. "github.com/gisle/stufflog/config"
  6. "github.com/gisle/stufflog/database"
  7. "github.com/gisle/stufflog/internal/generate"
  8. "github.com/gisle/stufflog/models"
  9. "github.com/gisle/stufflog/slerrors"
  10. "golang.org/x/crypto/bcrypt"
  11. "time"
  12. )
  13. var dummyPassword []byte
  14. type ctxKey string
  15. const sessionCookie = "stufflog_sessid"
  16. const sessionCtxKey = ctxKey("stufflog_session")
  17. const userCtxKey = ctxKey("stufflog_user")
  18. type AuthService struct {
  19. db database.Database
  20. }
  21. // UserFromContext gets the User object stored in the context.
  22. func (s *AuthService) UserFromContext(ctx context.Context) *models.User {
  23. if gctx, ok := ctx.(*gin.Context); ok {
  24. ctx = gctx.Request.Context()
  25. }
  26. session := s.SessionFromContext(ctx)
  27. if session == nil {
  28. return nil
  29. }
  30. user, _ := ctx.Value(userCtxKey).(*models.User)
  31. return user
  32. }
  33. // SessionFromContext gets the Session object stored in the context.
  34. func (s *AuthService) SessionFromContext(ctx context.Context) *models.UserSession {
  35. if gctx, ok := ctx.(*gin.Context); ok {
  36. ctx = gctx.Request.Context()
  37. }
  38. session, _ := ctx.Value(sessionCtxKey).(*models.UserSession)
  39. if session == nil || time.Now().After(session.Expires) {
  40. return nil
  41. }
  42. return session
  43. }
  44. // Register registers a new user.
  45. func (s *AuthService) Register(ctx context.Context, username, password string) (*models.User, error) {
  46. if !config.Get().Users.AllowRegister {
  47. return nil, &slerrors.SLError{Code: 403, Text: "Registration is disabled"}
  48. }
  49. if len(password) < 6 {
  50. return nil, &slerrors.SLError{Code: 400, Text: "Password is too short"}
  51. }
  52. if len(username) == 0 {
  53. return nil, &slerrors.SLError{Code: 400, Text: "No username provided"}
  54. }
  55. if _, err := s.db.Users().FindID(ctx, username); !slerrors.IsNotFound(err) {
  56. return nil, &slerrors.SLError{Code: 409, Text: "Username is not available"}
  57. }
  58. user := models.User{
  59. ID: username,
  60. Name: username,
  61. }
  62. err := user.SetPassword(password)
  63. if err != nil {
  64. return nil, err
  65. }
  66. err = s.db.Users().Save(ctx, user)
  67. return &user, nil
  68. }
  69. // Login logs in the user and creates the session. The session is not saved to the request. For that you need to use
  70. // GinSaveSession on the returned session.
  71. func (s *AuthService) Login(ctx context.Context, userID, password string) (*models.User, *models.UserSession, error) {
  72. user, err := s.db.Users().FindID(ctx, userID)
  73. if err != nil {
  74. _ = bcrypt.CompareHashAndPassword(dummyPassword, []byte("u r a dummy too"))
  75. return nil, nil, slerrors.LoginFailed()
  76. }
  77. err = user.CheckPassword(password)
  78. if err != nil {
  79. return nil, nil, slerrors.LoginFailed()
  80. }
  81. session := models.UserSession{
  82. ID: generate.ID("S", 64),
  83. Expires: time.Now().Add(time.Hour * 24 * 30),
  84. UserID: user.ID,
  85. }
  86. err = s.db.UserSessions().Save(ctx, session)
  87. if err != nil {
  88. return nil, nil, err
  89. }
  90. return user, &session, nil
  91. }
  92. // Logout deletes the session associated with the context.
  93. func (s *AuthService) Logout(ctx context.Context) error {
  94. session := s.SessionFromContext(ctx)
  95. if session == nil {
  96. return slerrors.Unauthorized().WithText("You're already logged out.")
  97. }
  98. session.Expires = time.Time{}
  99. return s.db.UserSessions().Remove(ctx, *session)
  100. }
  101. // SaveSession saves a session to gin.
  102. func (s *AuthService) GinSaveSession(c *gin.Context, session models.UserSession) {
  103. c.SetCookie(sessionCookie, session.ID, 336*3600, "", "", false, true)
  104. }
  105. // SaveSession saves a session to gin.
  106. func (s *AuthService) GinClearSession(c *gin.Context) {
  107. c.SetCookie(sessionCookie, "", 0, "", "", false, true)
  108. }
  109. // Middleware is a middleware that adds a session to the context, or
  110. // gives a 403 Unauthorized error if there is none.
  111. func (s *AuthService) GinSessionMiddleware(required bool) gin.HandlerFunc {
  112. return func(c *gin.Context) {
  113. ctx := c.Request.Context()
  114. // Find cookie
  115. cookie, err := c.Cookie(sessionCookie)
  116. if err != nil || cookie == "" {
  117. if required {
  118. slerrors.GinRespond(c, slerrors.Unauthorized())
  119. c.Abort()
  120. return
  121. }
  122. c.Next()
  123. return
  124. }
  125. // Check session existence
  126. session, err := s.db.UserSessions().FindID(ctx, cookie)
  127. if err != nil {
  128. c.SetCookie(sessionCookie, "", 0, "", "", false, true)
  129. if required {
  130. slerrors.GinRespond(c, slerrors.Unauthorized())
  131. c.Abort()
  132. return
  133. }
  134. c.Next()
  135. return
  136. }
  137. ctx = context.WithValue(ctx, sessionCtxKey, session)
  138. // Find the user associated with the session
  139. user, err := s.db.Users().FindID(ctx, session.UserID)
  140. if err != nil {
  141. c.SetCookie(sessionCookie, "", 0, "", "", false, true)
  142. _ = s.db.UserSessions().Remove(ctx, *session)
  143. if required {
  144. slerrors.GinRespond(c, slerrors.Unauthorized())
  145. c.Abort()
  146. return
  147. }
  148. c.Next()
  149. return
  150. }
  151. ctx = context.WithValue(ctx, userCtxKey, user)
  152. // Check if session has expired
  153. if time.Now().After(session.Expires) {
  154. c.SetCookie(sessionCookie, "", 0, "", "", false, true)
  155. _ = s.db.UserSessions().Remove(ctx, *session)
  156. if required {
  157. slerrors.GinRespond(c, slerrors.Unauthorized())
  158. c.Abort()
  159. return
  160. }
  161. c.Next()
  162. return
  163. }
  164. // If there's less than 27.75 days, reset expiry.
  165. if time.Until(session.Expires) < time.Hour*666 {
  166. session.Expires = time.Now().Add(time.Hour * 672)
  167. s.GinSaveSession(c, *session)
  168. _ = s.db.UserSessions().Save(c.Request.Context(), *session)
  169. }
  170. // Proceed.
  171. c.Request = c.Request.WithContext(ctx)
  172. c.Next()
  173. }
  174. }
  175. func NewAuthService(db database.Database) *AuthService {
  176. return &AuthService{db: db}
  177. }
  178. func init() {
  179. dummyPassword, _ = bcrypt.GenerateFromPassword([]byte("u r a dummy"), bcrypt.DefaultCost)
  180. }