|
|
package mongodb
import ( "context" "errors" "git.aiterp.net/rpdata/api/internal/generate" "git.aiterp.net/rpdata/api/models" "git.aiterp.net/rpdata/api/repositories" "github.com/globalsign/mgo" "github.com/globalsign/mgo/bson" "log" "strings" "sync" "time" )
type postRepository struct { restoreIDs bool logs *mgo.Collection posts *mgo.Collection
orderMutex sync.Mutex }
func newPostRepository(db *mgo.Database, restoreIDs bool) (*postRepository, error) { posts := db.C("logbot3.posts")
err := posts.EnsureIndexKey("logId") if err != nil { return nil, err } err = posts.EnsureIndexKey("time") if err != nil { return nil, err } err = posts.EnsureIndexKey("kind") if err != nil { return nil, err } err = posts.EnsureIndexKey("position") if err != nil { return nil, err }
err = posts.EnsureIndex(mgo.Index{ Key: []string{"$text:text"}, }) if err != nil { return nil, err }
return &postRepository{ restoreIDs: restoreIDs, posts: posts, logs: db.C("logbot3.logs"), }, nil }
func (r *postRepository) Find(ctx context.Context, id string) (*models.Post, error) { post := new(models.Post) err := r.posts.FindId(id).One(post) if err != nil { return nil, err }
return post, nil }
func (r *postRepository) List(ctx context.Context, filter models.PostFilter) ([]*models.Post, error) { query := bson.M{} if filter.LogID != nil && *filter.LogID != "" { logId := *filter.LogID if !strings.HasPrefix(logId, "L") { // Resolve long id to short id
log := new(models.Log) err := r.logs.FindId(logId).Select(bson.M{"logId": 1, "_id": 1}).One(log) if err != nil { return nil, err }
logId = log.ShortID }
query["logId"] = logId } if len(filter.IDs) > 0 { query["_id"] = bson.M{"$in": filter.IDs} } if len(filter.Kinds) > 0 { query["kind"] = bson.M{"$in": filter.Kinds} } if filter.Search != nil { query["$text"] = bson.M{"$search": *filter.Search} }
posts := make([]*models.Post, 0, 32) var err error if filter.LogID != nil { err = r.posts.Find(query).Sort("position").Limit(filter.Limit).All(&posts) } else { err = r.posts.Find(query).Limit(filter.Limit).All(&posts) } if err != nil { if err == mgo.ErrNotFound { return []*models.Post{}, nil }
return nil, err }
return posts, nil }
func (r *postRepository) Insert(ctx context.Context, post models.Post) (*models.Post, error) { r.orderMutex.Lock() defer r.orderMutex.Unlock()
lastPost := new(models.Post) err := r.posts.Find(bson.M{"logId": post.LogID}).Sort("-position").One(lastPost) if err != nil && err != mgo.ErrNotFound { return nil, err }
if !r.restoreIDs { post.ID = generate.PostID() } else { if len(post.ID) != len(generate.PostID()) && strings.HasPrefix(post.ID, "P") { return nil, errors.New("invalid story id") } }
post.Position = lastPost.Position + 1 // Position 1 is first position, so this is safe.
err = r.posts.Insert(post) if err != nil { return nil, err }
return &post, nil }
func (r *postRepository) InsertMany(ctx context.Context, posts ...*models.Post) ([]*models.Post, error) { if len(posts) == 0 { return []*models.Post{}, nil }
logId := posts[0].LogID for _, post := range posts { if post.LogID != logId { return nil, repositories.ErrParentMismatch }
if !r.restoreIDs || post.ID == "" { post.ID = generate.PostID() } }
r.orderMutex.Lock() defer r.orderMutex.Unlock()
lastPost := new(models.Post) err := r.posts.Find(bson.M{"logId": posts[0].LogID}).Sort("-position").One(lastPost) if err != nil && err != mgo.ErrNotFound { return nil, err }
docs := make([]interface{}, len(posts)) for i := range posts { posts[i].Position = lastPost.Position + 1 + i docs[i] = posts[i] }
err = r.posts.Insert(docs...) if err != nil && err != mgo.ErrNotFound { return nil, err }
return posts, nil }
func (r *postRepository) Update(ctx context.Context, post models.Post, update models.PostUpdate) (*models.Post, error) { updateBson := bson.M{} if update.Time != nil { updateBson["time"] = *update.Time post.Time = *update.Time } if update.Kind != nil { updateBson["kind"] = *update.Kind post.Kind = *update.Kind } if update.Nick != nil { updateBson["nick"] = *update.Nick post.Nick = *update.Nick } if update.Text != nil { updateBson["text"] = *update.Text post.Text = *update.Text }
err := r.posts.UpdateId(post.ID, bson.M{"$set": updateBson}) if err != nil { return nil, err }
return &post, nil }
func (r *postRepository) Move(ctx context.Context, post models.Post, position int) ([]*models.Post, error) { // If only MongoDB transactions weren't awful, this function would have been safe.
// Since it isn't, then good luck.
// Validate lower position bound.
if position < 1 { return nil, repositories.ErrInvalidPosition }
// Determine the operations on adjacent posts
var resultFilter bson.M var pushFilter bson.M var increment int if post.Position > position { pushFilter = bson.M{"position": bson.M{ "$gte": position, "$lt": post.Position, }} increment = 1 resultFilter = bson.M{"position": bson.M{ "$lte": post.Position, "$gte": position, }} } else { pushFilter = bson.M{"position": bson.M{ "$lte": position, "$gt": post.Position, }} increment = -1 resultFilter = bson.M{"position": bson.M{ "$lte": position, "$gte": post.Position, }} } pushFilter["logId"] = post.LogID resultFilter["logId"] = post.LogID
// From here on out, sync is required
r.orderMutex.Lock() defer r.orderMutex.Unlock()
// Detect ninja shenanigans
post2 := new(models.Post) err := r.posts.FindId(post.ID).One(post2) if err != nil && post2.Position != post.Position { return nil, repositories.ErrNotFound }
// Validate upper position bound
lastPost := new(models.Post) err = r.posts.Find(bson.M{"logId": post.LogID}).Sort("-position").One(lastPost) if err != nil && err != mgo.ErrNotFound { return nil, err } if position > lastPost.Position { return nil, repositories.ErrInvalidPosition }
// Move the posts
changeInfo, err := r.posts.UpdateAll(pushFilter, bson.M{"$inc": bson.M{"position": increment}}) if err != nil && err != mgo.ErrNotFound { return nil, err } err = r.posts.UpdateId(post.ID, bson.M{"$set": bson.M{"position": position}}) if err != nil { // Try to undo it
_, err := r.posts.UpdateAll(pushFilter, bson.M{"$inc": bson.M{"position": -increment}}) if err != nil && err != mgo.ErrNotFound { return nil, err }
return nil, err }
results := make([]*models.Post, 0, changeInfo.Matched+1) err = r.posts.Find(resultFilter).All(&results) if err != nil { return nil, err }
return results, nil }
func (r *postRepository) Delete(ctx context.Context, post models.Post) error { r.orderMutex.Lock() defer r.orderMutex.Unlock()
err := r.posts.RemoveId(post.ID) if err != nil { return err }
_, _ = r.posts.UpdateAll(bson.M{"logId": post.LogID, "position": bson.M{"$gt": post.Position}}, bson.M{"$inc": bson.M{"position": -1}})
return nil }
func (r *postRepository) fixPositions(logRepo *logRepository) { disorders := make([]int, 0, 16) diffs := make([]int, 0, 16) startTime := time.Now()
timeout, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel()
logs, err := logRepo.List(timeout, models.LogFilter{}) if err != nil { log.Println("Failed to get logs for position fix:", err) }
log.Println("Starting log position fixing, this should not take longer than 10 seconds.")
for _, l := range logs { r.orderMutex.Lock()
posts, err := r.List(timeout, models.PostFilter{LogID: &l.ShortID}) if err != nil { r.orderMutex.Unlock() continue }
disorders = disorders[:0] diffs = diffs[:0] for i, post := range posts { if post.Position != (i + 1) { disorders = append(disorders, i) diffs = append(diffs, post.Position-(i+1)) } }
if len(disorders) > 0 { log.Println(len(disorders), "order errors detected in", l.ID)
ops := 0
for i, post := range posts { if (i + 1) != posts[i].Position { ops++
err := r.posts.UpdateId(post.ID, bson.M{"$set": bson.M{"position": i + 1}}) if err != nil { log.Println(l.ShortID, "fix failed after", ops, "ops:", err) break } } }
log.Println(l.ShortID, "fixed after", ops, "ops.") }
r.orderMutex.Unlock() }
log.Println("Log position fixing finished in", time.Since(startTime)) }
|