package mongodb import ( "context" "errors" "git.aiterp.net/rpdata/api/internal/generate" "git.aiterp.net/rpdata/api/models" "git.aiterp.net/rpdata/api/repositories" "github.com/globalsign/mgo" "github.com/globalsign/mgo/bson" "log" "strings" "sync" "time" ) type postRepository struct { restoreIDs bool logs *mgo.Collection posts *mgo.Collection orderMutex sync.Mutex } func newPostRepository(db *mgo.Database, restoreIDs bool) (*postRepository, error) { posts := db.C("logbot3.posts") err := posts.EnsureIndexKey("logId") if err != nil { return nil, err } err = posts.EnsureIndexKey("time") if err != nil { return nil, err } err = posts.EnsureIndexKey("kind") if err != nil { return nil, err } err = posts.EnsureIndexKey("position") if err != nil { return nil, err } err = posts.EnsureIndex(mgo.Index{ Key: []string{"$text:text"}, }) if err != nil { return nil, err } return &postRepository{ restoreIDs: restoreIDs, posts: posts, logs: db.C("logbot3.logs"), }, nil } func (r *postRepository) Find(ctx context.Context, id string) (*models.Post, error) { post := new(models.Post) err := r.posts.FindId(id).One(post) if err != nil { return nil, err } return post, nil } func (r *postRepository) List(ctx context.Context, filter models.PostFilter) ([]*models.Post, error) { query := bson.M{} if filter.LogID != nil && *filter.LogID != "" { logId := *filter.LogID if !strings.HasPrefix(logId, "L") { // Resolve long id to short id log := new(models.Log) err := r.logs.FindId(logId).Select(bson.M{"logId": 1, "_id": 1}).One(log) if err != nil { return nil, err } logId = log.ShortID } query["logId"] = logId } if len(filter.IDs) > 0 { query["_id"] = bson.M{"$in": filter.IDs} } if len(filter.Kinds) > 0 { query["kind"] = bson.M{"$in": filter.Kinds} } if filter.Search != nil { query["$text"] = bson.M{"$search": *filter.Search} } posts := make([]*models.Post, 0, 32) var err error if filter.LogID != nil { err = r.posts.Find(query).Sort("position").Limit(filter.Limit).All(&posts) } else { err = r.posts.Find(query).Limit(filter.Limit).All(&posts) } if err != nil { if err == mgo.ErrNotFound { return []*models.Post{}, nil } return nil, err } return posts, nil } func (r *postRepository) Insert(ctx context.Context, post models.Post) (*models.Post, error) { r.orderMutex.Lock() defer r.orderMutex.Unlock() lastPost := new(models.Post) err := r.posts.Find(bson.M{"logId": post.LogID}).Sort("-position").One(lastPost) if err != nil && err != mgo.ErrNotFound { return nil, err } if !r.restoreIDs { post.ID = generate.PostID() } else { if len(post.ID) != len(generate.PostID()) && strings.HasPrefix(post.ID, "P") { return nil, errors.New("invalid story id") } } post.Position = lastPost.Position + 1 // Position 1 is first position, so this is safe. err = r.posts.Insert(post) if err != nil { return nil, err } return &post, nil } func (r *postRepository) InsertMany(ctx context.Context, posts ...*models.Post) ([]*models.Post, error) { if len(posts) == 0 { return []*models.Post{}, nil } logId := posts[0].LogID for _, post := range posts { if post.LogID != logId { return nil, repositories.ErrParentMismatch } if !r.restoreIDs || post.ID == "" { post.ID = generate.PostID() } } r.orderMutex.Lock() defer r.orderMutex.Unlock() lastPost := new(models.Post) err := r.posts.Find(bson.M{"logId": posts[0].LogID}).Sort("-position").One(lastPost) if err != nil && err != mgo.ErrNotFound { return nil, err } docs := make([]interface{}, len(posts)) for i := range posts { posts[i].Position = lastPost.Position + 1 + i docs[i] = posts[i] } err = r.posts.Insert(docs...) if err != nil && err != mgo.ErrNotFound { return nil, err } return posts, nil } func (r *postRepository) Update(ctx context.Context, post models.Post, update models.PostUpdate) (*models.Post, error) { updateBson := bson.M{} if update.Time != nil { updateBson["time"] = *update.Time post.Time = *update.Time } if update.Kind != nil { updateBson["kind"] = *update.Kind post.Kind = *update.Kind } if update.Nick != nil { updateBson["nick"] = *update.Nick post.Nick = *update.Nick } if update.Text != nil { updateBson["text"] = *update.Text post.Text = *update.Text } err := r.posts.UpdateId(post.ID, bson.M{"$set": updateBson}) if err != nil { return nil, err } return &post, nil } func (r *postRepository) Move(ctx context.Context, post models.Post, position int) ([]*models.Post, error) { // If only MongoDB transactions weren't awful, this function would have been safe. // Since it isn't, then good luck. // Validate lower position bound. if position < 1 { return nil, repositories.ErrInvalidPosition } // Determine the operations on adjacent posts var resultFilter bson.M var pushFilter bson.M var increment int if post.Position > position { pushFilter = bson.M{"position": bson.M{ "$gte": position, "$lt": post.Position, }} increment = 1 resultFilter = bson.M{"position": bson.M{ "$lte": post.Position, "$gte": position, }} } else { pushFilter = bson.M{"position": bson.M{ "$lte": position, "$gt": post.Position, }} increment = -1 resultFilter = bson.M{"position": bson.M{ "$lte": position, "$gte": post.Position, }} } pushFilter["logId"] = post.LogID resultFilter["logId"] = post.LogID // From here on out, sync is required r.orderMutex.Lock() defer r.orderMutex.Unlock() // Detect ninja shenanigans post2 := new(models.Post) err := r.posts.FindId(post.ID).One(post2) if err != nil && post2.Position != post.Position { return nil, repositories.ErrNotFound } // Validate upper position bound lastPost := new(models.Post) err = r.posts.Find(bson.M{"logId": post.LogID}).Sort("-position").One(lastPost) if err != nil && err != mgo.ErrNotFound { return nil, err } if position > lastPost.Position { return nil, repositories.ErrInvalidPosition } // Move the posts changeInfo, err := r.posts.UpdateAll(pushFilter, bson.M{"$inc": bson.M{"position": increment}}) if err != nil && err != mgo.ErrNotFound { return nil, err } err = r.posts.UpdateId(post.ID, bson.M{"$set": bson.M{"position": position}}) if err != nil { // Try to undo it _, err := r.posts.UpdateAll(pushFilter, bson.M{"$inc": bson.M{"position": -increment}}) if err != nil && err != mgo.ErrNotFound { return nil, err } return nil, err } results := make([]*models.Post, 0, changeInfo.Matched+1) err = r.posts.Find(resultFilter).All(&results) if err != nil { return nil, err } return results, nil } func (r *postRepository) Delete(ctx context.Context, post models.Post) error { r.orderMutex.Lock() defer r.orderMutex.Unlock() err := r.posts.RemoveId(post.ID) if err != nil { return err } _, _ = r.posts.UpdateAll(bson.M{"logId": post.LogID, "position": bson.M{"$gt": post.Position}}, bson.M{"$inc": bson.M{"position": -1}}) return nil } func (r *postRepository) fixPositions(logRepo *logRepository) { disorders := make([]int, 0, 16) diffs := make([]int, 0, 16) startTime := time.Now() timeout, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() logs, err := logRepo.List(timeout, models.LogFilter{}) if err != nil { log.Println("Failed to get logs for position fix:", err) } log.Println("Starting log position fixing, this should not take longer than 10 seconds.") for _, l := range logs { r.orderMutex.Lock() posts, err := r.List(timeout, models.PostFilter{LogID: &l.ShortID}) if err != nil { r.orderMutex.Unlock() continue } disorders = disorders[:0] diffs = diffs[:0] for i, post := range posts { if post.Position != (i + 1) { disorders = append(disorders, i) diffs = append(diffs, post.Position-(i+1)) } } if len(disorders) > 0 { log.Println(len(disorders), "order errors detected in", l.ID) ops := 0 for i, post := range posts { if (i + 1) != posts[i].Position { ops++ err := r.posts.UpdateId(post.ID, bson.M{"$set": bson.M{"position": i + 1}}) if err != nil { log.Println(l.ShortID, "fix failed after", ops, "ops:", err) break } } } log.Println(l.ShortID, "fixed after", ops, "ops.") } r.orderMutex.Unlock() } log.Println("Log position fixing finished in", time.Since(startTime)) }