package bolt import ( "context" "github.com/gisle/stufflog/database/repositories" "github.com/gisle/stufflog/models" "github.com/gisle/stufflog/slerrors" "github.com/vmihailenco/msgpack/v4" "go.etcd.io/bbolt" ) var bnPeriods = []byte("Period") type periodRepository struct { db *bbolt.DB userIdIdx *index activityIdIdx *index } func (r *periodRepository) FindID(ctx context.Context, id string) (*models.Period, error) { period := new(models.Period) err := r.db.View(func(tx *bbolt.Tx) error { value := tx.Bucket(bnPeriods).Get(unsafeStringToBytes(id)) if value == nil { return slerrors.NotFound("Period") } err := msgpack.Unmarshal(value, period) if err != nil { return err } return nil }) if err != nil { return nil, err } return period, nil } func (r *periodRepository) List(ctx context.Context) ([]*models.Period, error) { periods := make([]*models.Period, 0, 16) err := r.db.View(func(tx *bbolt.Tx) error { cursor := tx.Bucket(bnPeriods).Cursor() for key, value := cursor.First(); key != nil; key, value = cursor.Next() { period := new(models.Period) err := msgpack.Unmarshal(value, period) if err != nil { return err } periods = append(periods, period) } return nil }) if err != nil { return nil, err } return periods, nil } func (r *periodRepository) ListUser(ctx context.Context, user models.User) ([]*models.Period, error) { periods := make([]*models.Period, 0, 16) err := r.db.View(func(tx *bbolt.Tx) error { bucket := tx.Bucket(bnPeriods) ids, err := r.userIdIdx.WithTx(tx).Get(user.ID) if err != nil { return err } for _, id := range ids { value := bucket.Get(id) if value == nil { continue } period := new(models.Period) err := msgpack.Unmarshal(value, period) if err != nil { return err } periods = append(periods, period) } return nil }) if err != nil { return nil, err } return periods, nil } func (r *periodRepository) ListActivity(ctx context.Context, activity models.Activity) ([]*models.Period, error) { periods := make([]*models.Period, 0, 16) err := r.db.View(func(tx *bbolt.Tx) error { bucket := tx.Bucket(bnPeriods) ids, err := r.activityIdIdx.WithTx(tx).Get(activity.ID) if err != nil { return err } for _, id := range ids { value := bucket.Get(id) if value == nil { continue } period := new(models.Period) err := msgpack.Unmarshal(value, period) if err != nil { return err } periods = append(periods, period) } return nil }) if err != nil { return nil, err } return periods, nil } func (r *periodRepository) Insert(ctx context.Context, period models.Period) error { value, err := msgpack.Marshal(&period) if err != nil { return err } return r.db.Update(func(tx *bbolt.Tx) error { err := tx.Bucket(bnPeriods).Put(unsafeStringToBytes(period.ID), value) if err != nil { return err } err = r.index(tx, &period) if err != nil { return err } return nil }) } func (r *periodRepository) Update(ctx context.Context, period models.Period, updates []*models.PeriodUpdate) (*models.Period, error) { err := r.db.Update(func(tx *bbolt.Tx) error { bucket := tx.Bucket(bnPeriods) // Re-Get to guarantee consistency. value := bucket.Get(unsafeStringToBytes(period.ID)) if value == nil { return slerrors.NotFound("Activity") } err := msgpack.Unmarshal(value, &period) if err != nil { return err } // Perform updates didChange := false for _, update := range updates { changed, err := period.ApplyUpdate(*update) if err != nil { return err } if changed { didChange = true } } // Put back into bucket if didChange { value, err := msgpack.Marshal(&period) if err != nil { return err } err = bucket.Put(unsafeStringToBytes(period.ID), value) if err != nil { return err } err = r.index(tx, &period) if err != nil { return err } } else { return errUnchanged } return nil }) if err != nil && err != errUnchanged { return nil, err } return &period, nil } func (r *periodRepository) Remove(ctx context.Context, period models.Period) error { return r.db.Update(func(tx *bbolt.Tx) error { err := tx.Bucket(bnPeriods).Delete(unsafeStringToBytes(period.ID)) if err != nil { return err } err = r.unIndex(tx, &period) if err != nil { return err } return nil }) } func (r *periodRepository) index(tx *bbolt.Tx, period *models.Period) error { idBytes := unsafeStringToBytes(period.ID) err := r.userIdIdx.WithTx(tx).Set(idBytes, period.UserID) if err != nil { return err } activityIDs := make([]string, 0, len(period.Goals)) added := make(map[string]bool) for _, goal := range period.Goals { if !added[goal.ActivityID] { added[goal.ActivityID] = true activityIDs = append(activityIDs, goal.ActivityID) } } err = r.activityIdIdx.WithTx(tx).Set(idBytes, activityIDs...) if err != nil { return err } return nil } func (r *periodRepository) unIndex(tx *bbolt.Tx, period *models.Period) error { idBytes := unsafeStringToBytes(period.ID) err := r.userIdIdx.WithTx(tx).Set(idBytes) if err != nil { return err } err = r.activityIdIdx.WithTx(tx).Set(idBytes) if err != nil { return err } return nil } func newPeriodRepository(db *bbolt.DB) (repositories.PeriodRepository, error) { err := db.Update(func(tx *bbolt.Tx) error { _, err := tx.CreateBucketIfNotExists(bnPeriods) return err }) if err != nil { return nil, err } userIdIdx, err := newModelIndex(db, "Period", "UserID") if err != nil { return nil, err } activityIdIdx, err := newModelIndex(db, "Period", "Goals.ActivityID") if err != nil { return nil, err } return &periodRepository{ db: db, userIdIdx: userIdIdx, activityIdIdx: activityIdIdx, }, nil }