Plan stuff. Log stuff.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

288 lines
5.7 KiB

4 years ago
  1. package bolt
  2. import (
  3. "context"
  4. "github.com/gisle/stufflog/database/repositories"
  5. "github.com/gisle/stufflog/models"
  6. "github.com/gisle/stufflog/slerrors"
  7. "github.com/vmihailenco/msgpack/v4"
  8. "go.etcd.io/bbolt"
  9. )
  10. var bnPeriods = []byte("Period")
  11. type periodRepository struct {
  12. db *bbolt.DB
  13. userIdIdx *index
  14. activityIdIdx *index
  15. }
  16. func (r *periodRepository) FindID(ctx context.Context, id string) (*models.Period, error) {
  17. period := new(models.Period)
  18. err := r.db.View(func(tx *bbolt.Tx) error {
  19. value := tx.Bucket(bnPeriods).Get(unsafeStringToBytes(id))
  20. if value == nil {
  21. return slerrors.NotFound("Period")
  22. }
  23. err := msgpack.Unmarshal(value, period)
  24. if err != nil {
  25. return err
  26. }
  27. return nil
  28. })
  29. if err != nil {
  30. return nil, err
  31. }
  32. return period, nil
  33. }
  34. func (r *periodRepository) List(ctx context.Context) ([]*models.Period, error) {
  35. periods := make([]*models.Period, 0, 16)
  36. err := r.db.View(func(tx *bbolt.Tx) error {
  37. cursor := tx.Bucket(bnPeriods).Cursor()
  38. for key, value := cursor.First(); key != nil; key, value = cursor.Next() {
  39. period := new(models.Period)
  40. err := msgpack.Unmarshal(value, period)
  41. if err != nil {
  42. return err
  43. }
  44. periods = append(periods, period)
  45. }
  46. return nil
  47. })
  48. if err != nil {
  49. return nil, err
  50. }
  51. return periods, nil
  52. }
  53. func (r *periodRepository) ListUser(ctx context.Context, user models.User) ([]*models.Period, error) {
  54. periods := make([]*models.Period, 0, 16)
  55. err := r.db.View(func(tx *bbolt.Tx) error {
  56. bucket := tx.Bucket(bnPeriods)
  57. ids, err := r.userIdIdx.WithTx(tx).Get(user.ID)
  58. if err != nil {
  59. return err
  60. }
  61. for _, id := range ids {
  62. value := bucket.Get(id)
  63. if value == nil {
  64. continue
  65. }
  66. period := new(models.Period)
  67. err := msgpack.Unmarshal(value, period)
  68. if err != nil {
  69. return err
  70. }
  71. periods = append(periods, period)
  72. }
  73. return nil
  74. })
  75. if err != nil {
  76. return nil, err
  77. }
  78. return periods, nil
  79. }
  80. func (r *periodRepository) ListActivity(ctx context.Context, activity models.Activity) ([]*models.Period, error) {
  81. periods := make([]*models.Period, 0, 16)
  82. err := r.db.View(func(tx *bbolt.Tx) error {
  83. bucket := tx.Bucket(bnPeriods)
  84. ids, err := r.activityIdIdx.WithTx(tx).Get(activity.ID)
  85. if err != nil {
  86. return err
  87. }
  88. for _, id := range ids {
  89. value := bucket.Get(id)
  90. if value == nil {
  91. continue
  92. }
  93. period := new(models.Period)
  94. err := msgpack.Unmarshal(value, period)
  95. if err != nil {
  96. return err
  97. }
  98. periods = append(periods, period)
  99. }
  100. return nil
  101. })
  102. if err != nil {
  103. return nil, err
  104. }
  105. return periods, nil
  106. }
  107. func (r *periodRepository) Insert(ctx context.Context, period models.Period) error {
  108. value, err := msgpack.Marshal(&period)
  109. if err != nil {
  110. return err
  111. }
  112. return r.db.Update(func(tx *bbolt.Tx) error {
  113. err := tx.Bucket(bnPeriods).Put(unsafeStringToBytes(period.ID), value)
  114. if err != nil {
  115. return err
  116. }
  117. err = r.index(tx, &period)
  118. if err != nil {
  119. return err
  120. }
  121. return nil
  122. })
  123. }
  124. func (r *periodRepository) Update(ctx context.Context, period models.Period, updates []*models.PeriodUpdate) (*models.Period, error) {
  125. err := r.db.Update(func(tx *bbolt.Tx) error {
  126. bucket := tx.Bucket(bnPeriods)
  127. // Re-Get to guarantee consistency.
  128. value := bucket.Get(unsafeStringToBytes(period.ID))
  129. if value == nil {
  130. return slerrors.NotFound("Activity")
  131. }
  132. err := msgpack.Unmarshal(value, &period)
  133. if err != nil {
  134. return err
  135. }
  136. // Perform updates
  137. didChange := false
  138. for _, update := range updates {
  139. changed, err := period.ApplyUpdate(*update)
  140. if err != nil {
  141. return err
  142. }
  143. if changed {
  144. didChange = true
  145. }
  146. }
  147. // Put back into bucket
  148. if didChange {
  149. value, err := msgpack.Marshal(&period)
  150. if err != nil {
  151. return err
  152. }
  153. err = bucket.Put(unsafeStringToBytes(period.ID), value)
  154. if err != nil {
  155. return err
  156. }
  157. err = r.index(tx, &period)
  158. if err != nil {
  159. return err
  160. }
  161. } else {
  162. return errUnchanged
  163. }
  164. return nil
  165. })
  166. if err != nil && err != errUnchanged {
  167. return nil, err
  168. }
  169. return &period, nil
  170. }
  171. func (r *periodRepository) Remove(ctx context.Context, period models.Period) error {
  172. return r.db.Update(func(tx *bbolt.Tx) error {
  173. err := tx.Bucket(bnPeriods).Delete(unsafeStringToBytes(period.ID))
  174. if err != nil {
  175. return err
  176. }
  177. err = r.unIndex(tx, &period)
  178. if err != nil {
  179. return err
  180. }
  181. return nil
  182. })
  183. }
  184. func (r *periodRepository) index(tx *bbolt.Tx, period *models.Period) error {
  185. idBytes := unsafeStringToBytes(period.ID)
  186. err := r.userIdIdx.WithTx(tx).Set(idBytes, period.UserID)
  187. if err != nil {
  188. return err
  189. }
  190. activityIDs := make([]string, 0, len(period.Goals))
  191. added := make(map[string]bool)
  192. for _, goal := range period.Goals {
  193. if !added[goal.ActivityID] {
  194. added[goal.ActivityID] = true
  195. activityIDs = append(activityIDs, goal.ActivityID)
  196. }
  197. }
  198. err = r.activityIdIdx.WithTx(tx).Set(idBytes, activityIDs...)
  199. if err != nil {
  200. return err
  201. }
  202. return nil
  203. }
  204. func (r *periodRepository) unIndex(tx *bbolt.Tx, period *models.Period) error {
  205. idBytes := unsafeStringToBytes(period.ID)
  206. err := r.userIdIdx.WithTx(tx).Set(idBytes)
  207. if err != nil {
  208. return err
  209. }
  210. err = r.activityIdIdx.WithTx(tx).Set(idBytes)
  211. if err != nil {
  212. return err
  213. }
  214. return nil
  215. }
  216. func newPeriodRepository(db *bbolt.DB) (repositories.PeriodRepository, error) {
  217. err := db.Update(func(tx *bbolt.Tx) error {
  218. _, err := tx.CreateBucketIfNotExists(bnPeriods)
  219. return err
  220. })
  221. if err != nil {
  222. return nil, err
  223. }
  224. userIdIdx, err := newModelIndex(db, "Period", "UserID")
  225. if err != nil {
  226. return nil, err
  227. }
  228. activityIdIdx, err := newModelIndex(db, "Period", "Goals.ActivityID")
  229. if err != nil {
  230. return nil, err
  231. }
  232. return &periodRepository{
  233. db: db,
  234. userIdIdx: userIdIdx,
  235. activityIdIdx: activityIdIdx,
  236. }, nil
  237. }