Plan stuff. Log stuff.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

288 lines
5.7 KiB

package bolt
import (
"context"
"github.com/gisle/stufflog/database/repositories"
"github.com/gisle/stufflog/models"
"github.com/gisle/stufflog/slerrors"
"github.com/vmihailenco/msgpack/v4"
"go.etcd.io/bbolt"
)
var bnPeriods = []byte("Period")
type periodRepository struct {
db *bbolt.DB
userIdIdx *index
activityIdIdx *index
}
func (r *periodRepository) FindID(ctx context.Context, id string) (*models.Period, error) {
period := new(models.Period)
err := r.db.View(func(tx *bbolt.Tx) error {
value := tx.Bucket(bnPeriods).Get(unsafeStringToBytes(id))
if value == nil {
return slerrors.NotFound("Period")
}
err := msgpack.Unmarshal(value, period)
if err != nil {
return err
}
return nil
})
if err != nil {
return nil, err
}
return period, nil
}
func (r *periodRepository) List(ctx context.Context) ([]*models.Period, error) {
periods := make([]*models.Period, 0, 16)
err := r.db.View(func(tx *bbolt.Tx) error {
cursor := tx.Bucket(bnPeriods).Cursor()
for key, value := cursor.First(); key != nil; key, value = cursor.Next() {
period := new(models.Period)
err := msgpack.Unmarshal(value, period)
if err != nil {
return err
}
periods = append(periods, period)
}
return nil
})
if err != nil {
return nil, err
}
return periods, nil
}
func (r *periodRepository) ListUser(ctx context.Context, user models.User) ([]*models.Period, error) {
periods := make([]*models.Period, 0, 16)
err := r.db.View(func(tx *bbolt.Tx) error {
bucket := tx.Bucket(bnPeriods)
ids, err := r.userIdIdx.WithTx(tx).Get(user.ID)
if err != nil {
return err
}
for _, id := range ids {
value := bucket.Get(id)
if value == nil {
continue
}
period := new(models.Period)
err := msgpack.Unmarshal(value, period)
if err != nil {
return err
}
periods = append(periods, period)
}
return nil
})
if err != nil {
return nil, err
}
return periods, nil
}
func (r *periodRepository) ListActivity(ctx context.Context, activity models.Activity) ([]*models.Period, error) {
periods := make([]*models.Period, 0, 16)
err := r.db.View(func(tx *bbolt.Tx) error {
bucket := tx.Bucket(bnPeriods)
ids, err := r.activityIdIdx.WithTx(tx).Get(activity.ID)
if err != nil {
return err
}
for _, id := range ids {
value := bucket.Get(id)
if value == nil {
continue
}
period := new(models.Period)
err := msgpack.Unmarshal(value, period)
if err != nil {
return err
}
periods = append(periods, period)
}
return nil
})
if err != nil {
return nil, err
}
return periods, nil
}
func (r *periodRepository) Insert(ctx context.Context, period models.Period) error {
value, err := msgpack.Marshal(&period)
if err != nil {
return err
}
return r.db.Update(func(tx *bbolt.Tx) error {
err := tx.Bucket(bnPeriods).Put(unsafeStringToBytes(period.ID), value)
if err != nil {
return err
}
err = r.index(tx, &period)
if err != nil {
return err
}
return nil
})
}
func (r *periodRepository) Update(ctx context.Context, period models.Period, updates []*models.PeriodUpdate) (*models.Period, error) {
err := r.db.Update(func(tx *bbolt.Tx) error {
bucket := tx.Bucket(bnPeriods)
// Re-Get to guarantee consistency.
value := bucket.Get(unsafeStringToBytes(period.ID))
if value == nil {
return slerrors.NotFound("Activity")
}
err := msgpack.Unmarshal(value, &period)
if err != nil {
return err
}
// Perform updates
didChange := false
for _, update := range updates {
changed, err := period.ApplyUpdate(*update)
if err != nil {
return err
}
if changed {
didChange = true
}
}
// Put back into bucket
if didChange {
value, err := msgpack.Marshal(&period)
if err != nil {
return err
}
err = bucket.Put(unsafeStringToBytes(period.ID), value)
if err != nil {
return err
}
err = r.index(tx, &period)
if err != nil {
return err
}
} else {
return errUnchanged
}
return nil
})
if err != nil && err != errUnchanged {
return nil, err
}
return &period, nil
}
func (r *periodRepository) Remove(ctx context.Context, period models.Period) error {
return r.db.Update(func(tx *bbolt.Tx) error {
err := tx.Bucket(bnPeriods).Delete(unsafeStringToBytes(period.ID))
if err != nil {
return err
}
err = r.unIndex(tx, &period)
if err != nil {
return err
}
return nil
})
}
func (r *periodRepository) index(tx *bbolt.Tx, period *models.Period) error {
idBytes := unsafeStringToBytes(period.ID)
err := r.userIdIdx.WithTx(tx).Set(idBytes, period.UserID)
if err != nil {
return err
}
activityIDs := make([]string, 0, len(period.Goals))
added := make(map[string]bool)
for _, goal := range period.Goals {
if !added[goal.ActivityID] {
added[goal.ActivityID] = true
activityIDs = append(activityIDs, goal.ActivityID)
}
}
err = r.activityIdIdx.WithTx(tx).Set(idBytes, activityIDs...)
if err != nil {
return err
}
return nil
}
func (r *periodRepository) unIndex(tx *bbolt.Tx, period *models.Period) error {
idBytes := unsafeStringToBytes(period.ID)
err := r.userIdIdx.WithTx(tx).Set(idBytes)
if err != nil {
return err
}
err = r.activityIdIdx.WithTx(tx).Set(idBytes)
if err != nil {
return err
}
return nil
}
func newPeriodRepository(db *bbolt.DB) (repositories.PeriodRepository, error) {
err := db.Update(func(tx *bbolt.Tx) error {
_, err := tx.CreateBucketIfNotExists(bnPeriods)
return err
})
if err != nil {
return nil, err
}
userIdIdx, err := newModelIndex(db, "Period", "UserID")
if err != nil {
return nil, err
}
activityIdIdx, err := newModelIndex(db, "Period", "Goals.ActivityID")
if err != nil {
return nil, err
}
return &periodRepository{
db: db,
userIdIdx: userIdIdx,
activityIdIdx: activityIdIdx,
}, nil
}