package bolt import ( "bytes" "errors" "go.etcd.io/bbolt" ) import "github.com/vmihailenco/msgpack/v4" var bnMap = []byte("map") var bnRev = []byte("rev") type index struct { bucketNames [][]byte } // newModelIndex creates an index with the naming convention of `_idx.Model.Field`. func newModelIndex(db *bbolt.DB, model, field string) (*index, error) { return newIndex(db, "_idx", model, field) } // newIndex creates a new index and ensures the bucket chain is in order. func newIndex(db *bbolt.DB, buckets ...string) (*index, error) { if len(buckets) == 0 { panic("no buckets") } bucketNames := make([][]byte, len(buckets)) for i := range buckets { bucketNames[i] = []byte(buckets[i]) } err := db.Update(func(tx *bbolt.Tx) error { bucket, err := tx.CreateBucketIfNotExists(bucketNames[0]) if err != nil { return err } for _, bucketName := range bucketNames[1:] { bucket, err = bucket.CreateBucketIfNotExists(bucketName) if err != nil { return err } } _, err = bucket.CreateBucketIfNotExists(bnMap) if err != nil { return err } _, err = bucket.CreateBucketIfNotExists(bnRev) if err != nil { return err } return nil }) return &index{bucketNames: bucketNames}, err } func (idx *index) Reset(tx *bbolt.Tx) error { rootBucket := tx.Bucket(idx.bucketNames[0]) for _, name := range idx.bucketNames[1:] { rootBucket = rootBucket.Bucket(name) } err := rootBucket.DeleteBucket(bnRev) if err != nil { return err } err = rootBucket.DeleteBucket(bnMap) if err != nil { return err } _, err = rootBucket.CreateBucket(bnRev) if err != nil { return err } _, err = rootBucket.CreateBucket(bnMap) if err != nil { return err } return nil } func (idx *index) buckets(tx *bbolt.Tx) (mapBucket, revBucket *bbolt.Bucket) { rootBucket := tx.Bucket(idx.bucketNames[0]) for _, name := range idx.bucketNames[1:] { rootBucket = rootBucket.Bucket(name) } return rootBucket.Bucket(bnMap), rootBucket.Bucket(bnRev) } func (idx *index) WithTx(tx *bbolt.Tx) *indexTx { mapBucket, revBucket := idx.buckets(tx) return &indexTx{ tx: tx, mapBucket: mapBucket, revBucket: revBucket, } } type indexTx struct { tx *bbolt.Tx mapBucket *bbolt.Bucket revBucket *bbolt.Bucket } func (itx *indexTx) Get(value string) (ids [][]byte, err error) { entry := itx.mapBucket.Get(unsafeStringToBytes(value)) if entry == nil { return } ids = make([][]byte, 0, 8) err = msgpack.Unmarshal(entry, &ids) return } // Reverse gets all values associated with the ID func (itx *indexTx) Reverse(id []byte) (values []string, err error) { value := itx.revBucket.Get(id) if value == nil { return } values = make([]string, 0, 8) err = msgpack.Unmarshal(value, &values) return } func (itx *indexTx) Before(value string) (ids [][]byte, err error) { cursor := itx.mapBucket.Cursor() ids = make([][]byte, 0, 16) valueBytes := unsafeStringToBytes(value) cursor.Seek(valueBytes) key, entryValue := cursor.Prev() if key == nil { return } entry := itx.mapBucket.Get([]byte(entryValue)) if entry == nil { return } err = msgpack.Unmarshal(entry, &ids) return } func (itx *indexTx) Between(a, b string) (ids [][]byte, err error) { cursor := itx.mapBucket.Cursor() ids = make([][]byte, 0, 16) aBytes := unsafeStringToBytes(a) bBytes := unsafeStringToBytes(b) for key, value := cursor.Seek(aBytes); key != nil && bytes.Compare(key, bBytes) < 1; key, value = cursor.Next() { entry := itx.mapBucket.Get([]byte(value)) if entry == nil { return } err = msgpack.Unmarshal(entry, &ids) if err != nil { return } } return } // Set sets the index to the given values. This removes any values not in the given list. func (itx *indexTx) Set(id []byte, values ...string) error { oldValues := make([]string, 0, len(values)) newValues := make([]string, 0, len(values)) // Check for duplicates for i, value := range values { for _, value2 := range values[i+1:] { if value == value2 { return errors.New("Duplicate value for index: " + value) } } } if value := itx.revBucket.Get(id); value != nil { // Existing ID existingValues := make([]string, 0, 16) err := msgpack.Unmarshal(value, &existingValues) if err != nil { return err } // Record new and old values if there are any. if len(values) > 0 { // Find old values OldValueLoop: for _, existingValue := range existingValues { for _, value := range values { if value == existingValue { continue OldValueLoop } } oldValues = append(oldValues, existingValue) } // Find new values NewValueLoop: for _, value := range values { if value == "" { continue } for _, existingValue := range existingValues { if value == existingValue { continue NewValueLoop } } newValues = append(newValues, value) } } else { // There aren't any values, all values should be considered old. oldValues = existingValues newValues = newValues[:0] } } else { // New ID, can be skipped if this is clearing the itx operation if len(values) == 0 { return nil } // Otherwise, all values are newValues. for _, value := range values { if value == "" { continue } newValues = append(newValues, value) } } // Put new reverse lookup entry if len(values) > 0 { revData, err := msgpack.Marshal(values) if err != nil { return err } err = itx.revBucket.Put(id, revData) if err != nil { return err } } // Remove old values for _, oldValue := range oldValues { ov := []byte(oldValue) value := itx.mapBucket.Get(ov) if value == nil { return errors.New("oldValue expected, but not found. itx probably corrupt") } ids := make([][]byte, 0, 8) err := msgpack.Unmarshal(value, &ids) if err != nil { return err } for i, existingId := range ids { if bytes.Equal(existingId, id) { ids = append(ids[:i], ids[i+1:]...) break } } if len(ids) == 0 { err = itx.mapBucket.Delete(ov) if err != nil { return err } } else { newValue, err := msgpack.Marshal(ids) if err != nil { return err } err = itx.mapBucket.Put(ov, newValue) if err != nil { return err } } } // Add new values for _, newValue := range newValues { nv := []byte(newValue) if existing := itx.mapBucket.Get(nv); existing != nil { ids := make([][]byte, 0, 8) err := msgpack.Unmarshal(existing, &ids) if err != nil { return err } newValue, err := msgpack.Marshal(append(ids, id)) if err != nil { return err } err = itx.mapBucket.Put(nv, newValue) if err != nil { return err } } else { newValue, err := msgpack.Marshal([][]byte{id}) if err != nil { return err } err = itx.mapBucket.Put(nv, newValue) if err != nil { return err } } } // If this will delete all values, then delete this entry if len(values) == 0 { err := itx.revBucket.Delete(id) if err != nil { return err } } return nil }