You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
362 lines
8.3 KiB
362 lines
8.3 KiB
package mongodb
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"git.aiterp.net/rpdata/api/internal/generate"
|
|
"git.aiterp.net/rpdata/api/models"
|
|
"git.aiterp.net/rpdata/api/repositories"
|
|
"github.com/globalsign/mgo"
|
|
"github.com/globalsign/mgo/bson"
|
|
"log"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
type postRepository struct {
|
|
restoreIDs bool
|
|
logs *mgo.Collection
|
|
posts *mgo.Collection
|
|
|
|
orderMutex sync.Mutex
|
|
}
|
|
|
|
func newPostRepository(db *mgo.Database, restoreIDs bool) (*postRepository, error) {
|
|
posts := db.C("logbot3.posts")
|
|
|
|
err := posts.EnsureIndexKey("logId")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
err = posts.EnsureIndexKey("time")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
err = posts.EnsureIndexKey("kind")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
err = posts.EnsureIndexKey("position")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
err = posts.EnsureIndex(mgo.Index{
|
|
Key: []string{"$text:text"},
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &postRepository{
|
|
restoreIDs: restoreIDs,
|
|
posts: posts,
|
|
logs: db.C("logbot3.logs"),
|
|
}, nil
|
|
}
|
|
|
|
func (r *postRepository) Find(ctx context.Context, id string) (*models.Post, error) {
|
|
post := new(models.Post)
|
|
err := r.posts.FindId(id).One(post)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return post, nil
|
|
}
|
|
|
|
func (r *postRepository) List(ctx context.Context, filter models.PostFilter) ([]*models.Post, error) {
|
|
query := bson.M{}
|
|
if filter.LogID != nil && *filter.LogID != "" {
|
|
logId := *filter.LogID
|
|
if !strings.HasPrefix(logId, "L") {
|
|
// Resolve long id to short id
|
|
log := new(models.Log)
|
|
err := r.logs.FindId(logId).Select(bson.M{"logId": 1, "_id": 1}).One(log)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
logId = log.ShortID
|
|
}
|
|
|
|
query["logId"] = logId
|
|
}
|
|
if len(filter.IDs) > 0 {
|
|
query["_id"] = bson.M{"$in": filter.IDs}
|
|
}
|
|
if len(filter.Kinds) > 0 {
|
|
query["kind"] = bson.M{"$in": filter.Kinds}
|
|
}
|
|
if filter.Search != nil {
|
|
query["$text"] = bson.M{"$search": *filter.Search}
|
|
}
|
|
|
|
posts := make([]*models.Post, 0, 32)
|
|
var err error
|
|
if filter.LogID != nil {
|
|
err = r.posts.Find(query).Sort("position").Limit(filter.Limit).All(&posts)
|
|
} else {
|
|
err = r.posts.Find(query).Limit(filter.Limit).All(&posts)
|
|
}
|
|
if err != nil {
|
|
if err == mgo.ErrNotFound {
|
|
return []*models.Post{}, nil
|
|
}
|
|
|
|
return nil, err
|
|
}
|
|
|
|
return posts, nil
|
|
}
|
|
|
|
func (r *postRepository) Insert(ctx context.Context, post models.Post) (*models.Post, error) {
|
|
r.orderMutex.Lock()
|
|
defer r.orderMutex.Unlock()
|
|
|
|
lastPost := new(models.Post)
|
|
err := r.posts.Find(bson.M{"logId": post.LogID}).Sort("-position").One(lastPost)
|
|
if err != nil && err != mgo.ErrNotFound {
|
|
return nil, err
|
|
}
|
|
|
|
if !r.restoreIDs {
|
|
post.ID = generate.PostID()
|
|
} else {
|
|
if len(post.ID) != len(generate.PostID()) && strings.HasPrefix(post.ID, "P") {
|
|
return nil, errors.New("invalid story id")
|
|
}
|
|
}
|
|
|
|
post.Position = lastPost.Position + 1 // Position 1 is first position, so this is safe.
|
|
|
|
err = r.posts.Insert(post)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &post, nil
|
|
}
|
|
|
|
func (r *postRepository) InsertMany(ctx context.Context, posts ...*models.Post) ([]*models.Post, error) {
|
|
if len(posts) == 0 {
|
|
return []*models.Post{}, nil
|
|
}
|
|
|
|
logId := posts[0].LogID
|
|
for _, post := range posts {
|
|
if post.LogID != logId {
|
|
return nil, repositories.ErrParentMismatch
|
|
}
|
|
|
|
if !r.restoreIDs || post.ID == "" {
|
|
post.ID = generate.PostID()
|
|
}
|
|
}
|
|
|
|
r.orderMutex.Lock()
|
|
defer r.orderMutex.Unlock()
|
|
|
|
lastPost := new(models.Post)
|
|
err := r.posts.Find(bson.M{"logId": posts[0].LogID}).Sort("-position").One(lastPost)
|
|
if err != nil && err != mgo.ErrNotFound {
|
|
return nil, err
|
|
}
|
|
|
|
docs := make([]interface{}, len(posts))
|
|
for i := range posts {
|
|
posts[i].Position = lastPost.Position + 1 + i
|
|
docs[i] = posts[i]
|
|
}
|
|
|
|
err = r.posts.Insert(docs...)
|
|
if err != nil && err != mgo.ErrNotFound {
|
|
return nil, err
|
|
}
|
|
|
|
return posts, nil
|
|
}
|
|
|
|
func (r *postRepository) Update(ctx context.Context, post models.Post, update models.PostUpdate) (*models.Post, error) {
|
|
updateBson := bson.M{}
|
|
if update.Time != nil {
|
|
updateBson["time"] = *update.Time
|
|
post.Time = *update.Time
|
|
}
|
|
if update.Kind != nil {
|
|
updateBson["kind"] = *update.Kind
|
|
post.Kind = *update.Kind
|
|
}
|
|
if update.Nick != nil {
|
|
updateBson["nick"] = *update.Nick
|
|
post.Nick = *update.Nick
|
|
}
|
|
if update.Text != nil {
|
|
updateBson["text"] = *update.Text
|
|
post.Text = *update.Text
|
|
}
|
|
|
|
err := r.posts.UpdateId(post.ID, bson.M{"$set": updateBson})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &post, nil
|
|
}
|
|
|
|
func (r *postRepository) Move(ctx context.Context, post models.Post, position int) ([]*models.Post, error) {
|
|
// If only MongoDB transactions weren't awful, this function would have been safe.
|
|
// Since it isn't, then good luck.
|
|
|
|
// Validate lower position bound.
|
|
if position < 1 {
|
|
return nil, repositories.ErrInvalidPosition
|
|
}
|
|
|
|
// Determine the operations on adjacent posts
|
|
var resultFilter bson.M
|
|
var pushFilter bson.M
|
|
var increment int
|
|
if post.Position > position {
|
|
pushFilter = bson.M{"position": bson.M{
|
|
"$gte": position,
|
|
"$lt": post.Position,
|
|
}}
|
|
increment = 1
|
|
resultFilter = bson.M{"position": bson.M{
|
|
"$lte": post.Position,
|
|
"$gte": position,
|
|
}}
|
|
} else {
|
|
pushFilter = bson.M{"position": bson.M{
|
|
"$lte": position,
|
|
"$gt": post.Position,
|
|
}}
|
|
increment = -1
|
|
resultFilter = bson.M{"position": bson.M{
|
|
"$lte": position,
|
|
"$gte": post.Position,
|
|
}}
|
|
}
|
|
pushFilter["logId"] = post.LogID
|
|
resultFilter["logId"] = post.LogID
|
|
|
|
// From here on out, sync is required
|
|
r.orderMutex.Lock()
|
|
defer r.orderMutex.Unlock()
|
|
|
|
// Detect ninja shenanigans
|
|
post2 := new(models.Post)
|
|
err := r.posts.FindId(post.ID).One(post2)
|
|
if err != nil && post2.Position != post.Position {
|
|
return nil, repositories.ErrNotFound
|
|
}
|
|
|
|
// Validate upper position bound
|
|
lastPost := new(models.Post)
|
|
err = r.posts.Find(bson.M{"logId": post.LogID}).Sort("-position").One(lastPost)
|
|
if err != nil && err != mgo.ErrNotFound {
|
|
return nil, err
|
|
}
|
|
if position > lastPost.Position {
|
|
return nil, repositories.ErrInvalidPosition
|
|
}
|
|
|
|
// Move the posts
|
|
changeInfo, err := r.posts.UpdateAll(pushFilter, bson.M{"$inc": bson.M{"position": increment}})
|
|
if err != nil && err != mgo.ErrNotFound {
|
|
return nil, err
|
|
}
|
|
err = r.posts.UpdateId(post.ID, bson.M{"$set": bson.M{"position": position}})
|
|
if err != nil {
|
|
// Try to undo it
|
|
_, err := r.posts.UpdateAll(pushFilter, bson.M{"$inc": bson.M{"position": -increment}})
|
|
if err != nil && err != mgo.ErrNotFound {
|
|
return nil, err
|
|
}
|
|
|
|
return nil, err
|
|
}
|
|
|
|
results := make([]*models.Post, 0, changeInfo.Matched+1)
|
|
err = r.posts.Find(resultFilter).All(&results)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return results, nil
|
|
}
|
|
|
|
func (r *postRepository) Delete(ctx context.Context, post models.Post) error {
|
|
r.orderMutex.Lock()
|
|
defer r.orderMutex.Unlock()
|
|
|
|
err := r.posts.RemoveId(post.ID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, _ = r.posts.UpdateAll(bson.M{"logId": post.LogID, "position": bson.M{"$gt": post.Position}}, bson.M{"$inc": bson.M{"position": -1}})
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *postRepository) fixPositions(logRepo *logRepository) {
|
|
disorders := make([]int, 0, 16)
|
|
diffs := make([]int, 0, 16)
|
|
startTime := time.Now()
|
|
|
|
timeout, cancel := context.WithTimeout(context.Background(), time.Minute)
|
|
defer cancel()
|
|
|
|
logs, err := logRepo.List(timeout, models.LogFilter{})
|
|
if err != nil {
|
|
log.Println("Failed to get logs for position fix:", err)
|
|
}
|
|
|
|
log.Println("Starting log position fixing, this should not take longer than 10 seconds.")
|
|
|
|
for _, l := range logs {
|
|
r.orderMutex.Lock()
|
|
|
|
posts, err := r.List(timeout, models.PostFilter{LogID: &l.ShortID})
|
|
if err != nil {
|
|
r.orderMutex.Unlock()
|
|
continue
|
|
}
|
|
|
|
disorders = disorders[:0]
|
|
diffs = diffs[:0]
|
|
for i, post := range posts {
|
|
if post.Position != (i + 1) {
|
|
disorders = append(disorders, i)
|
|
diffs = append(diffs, post.Position-(i+1))
|
|
}
|
|
}
|
|
|
|
if len(disorders) > 0 {
|
|
log.Println(len(disorders), "order errors detected in", l.ID)
|
|
|
|
ops := 0
|
|
|
|
for i, post := range posts {
|
|
if (i + 1) != posts[i].Position {
|
|
ops++
|
|
|
|
err := r.posts.UpdateId(post.ID, bson.M{"$set": bson.M{"position": i + 1}})
|
|
if err != nil {
|
|
log.Println(l.ShortID, "fix failed after", ops, "ops:", err)
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
log.Println(l.ShortID, "fixed after", ops, "ops.")
|
|
}
|
|
|
|
r.orderMutex.Unlock()
|
|
}
|
|
|
|
log.Println("Log position fixing finished in", time.Since(startTime))
|
|
}
|