GraphQL API and utilities for the rpdata project
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

362 lines
8.3 KiB

package mongodb
import (
"context"
"errors"
"git.aiterp.net/rpdata/api/internal/generate"
"git.aiterp.net/rpdata/api/models"
"git.aiterp.net/rpdata/api/repositories"
"github.com/globalsign/mgo"
"github.com/globalsign/mgo/bson"
"log"
"strings"
"sync"
"time"
)
type postRepository struct {
restoreIDs bool
logs *mgo.Collection
posts *mgo.Collection
orderMutex sync.Mutex
}
func newPostRepository(db *mgo.Database, restoreIDs bool) (*postRepository, error) {
posts := db.C("logbot3.posts")
err := posts.EnsureIndexKey("logId")
if err != nil {
return nil, err
}
err = posts.EnsureIndexKey("time")
if err != nil {
return nil, err
}
err = posts.EnsureIndexKey("kind")
if err != nil {
return nil, err
}
err = posts.EnsureIndexKey("position")
if err != nil {
return nil, err
}
err = posts.EnsureIndex(mgo.Index{
Key: []string{"$text:text"},
})
if err != nil {
return nil, err
}
return &postRepository{
restoreIDs: restoreIDs,
posts: posts,
logs: db.C("logbot3.logs"),
}, nil
}
func (r *postRepository) Find(ctx context.Context, id string) (*models.Post, error) {
post := new(models.Post)
err := r.posts.FindId(id).One(post)
if err != nil {
return nil, err
}
return post, nil
}
func (r *postRepository) List(ctx context.Context, filter models.PostFilter) ([]*models.Post, error) {
query := bson.M{}
if filter.LogID != nil && *filter.LogID != "" {
logId := *filter.LogID
if !strings.HasPrefix(logId, "L") {
// Resolve long id to short id
log := new(models.Log)
err := r.logs.FindId(logId).Select(bson.M{"logId": 1, "_id": 1}).One(log)
if err != nil {
return nil, err
}
logId = log.ShortID
}
query["logId"] = logId
}
if len(filter.IDs) > 0 {
query["_id"] = bson.M{"$in": filter.IDs}
}
if len(filter.Kinds) > 0 {
query["kind"] = bson.M{"$in": filter.Kinds}
}
if filter.Search != nil {
query["$text"] = bson.M{"$search": *filter.Search}
}
posts := make([]*models.Post, 0, 32)
var err error
if filter.LogID != nil {
err = r.posts.Find(query).Sort("position").Limit(filter.Limit).All(&posts)
} else {
err = r.posts.Find(query).Limit(filter.Limit).All(&posts)
}
if err != nil {
if err == mgo.ErrNotFound {
return []*models.Post{}, nil
}
return nil, err
}
return posts, nil
}
func (r *postRepository) Insert(ctx context.Context, post models.Post) (*models.Post, error) {
r.orderMutex.Lock()
defer r.orderMutex.Unlock()
lastPost := new(models.Post)
err := r.posts.Find(bson.M{"logId": post.LogID}).Sort("-position").One(lastPost)
if err != nil && err != mgo.ErrNotFound {
return nil, err
}
if !r.restoreIDs {
post.ID = generate.PostID()
} else {
if len(post.ID) != len(generate.PostID()) && strings.HasPrefix(post.ID, "P") {
return nil, errors.New("invalid story id")
}
}
post.Position = lastPost.Position + 1 // Position 1 is first position, so this is safe.
err = r.posts.Insert(post)
if err != nil {
return nil, err
}
return &post, nil
}
func (r *postRepository) InsertMany(ctx context.Context, posts ...*models.Post) ([]*models.Post, error) {
if len(posts) == 0 {
return []*models.Post{}, nil
}
logId := posts[0].LogID
for _, post := range posts {
if post.LogID != logId {
return nil, repositories.ErrParentMismatch
}
if !r.restoreIDs || post.ID == "" {
post.ID = generate.PostID()
}
}
r.orderMutex.Lock()
defer r.orderMutex.Unlock()
lastPost := new(models.Post)
err := r.posts.Find(bson.M{"logId": posts[0].LogID}).Sort("-position").One(lastPost)
if err != nil && err != mgo.ErrNotFound {
return nil, err
}
docs := make([]interface{}, len(posts))
for i := range posts {
posts[i].Position = lastPost.Position + 1 + i
docs[i] = posts[i]
}
err = r.posts.Insert(docs...)
if err != nil && err != mgo.ErrNotFound {
return nil, err
}
return posts, nil
}
func (r *postRepository) Update(ctx context.Context, post models.Post, update models.PostUpdate) (*models.Post, error) {
updateBson := bson.M{}
if update.Time != nil {
updateBson["time"] = *update.Time
post.Time = *update.Time
}
if update.Kind != nil {
updateBson["kind"] = *update.Kind
post.Kind = *update.Kind
}
if update.Nick != nil {
updateBson["nick"] = *update.Nick
post.Nick = *update.Nick
}
if update.Text != nil {
updateBson["text"] = *update.Text
post.Text = *update.Text
}
err := r.posts.UpdateId(post.ID, bson.M{"$set": updateBson})
if err != nil {
return nil, err
}
return &post, nil
}
func (r *postRepository) Move(ctx context.Context, post models.Post, position int) ([]*models.Post, error) {
// If only MongoDB transactions weren't awful, this function would have been safe.
// Since it isn't, then good luck.
// Validate lower position bound.
if position < 1 {
return nil, repositories.ErrInvalidPosition
}
// Determine the operations on adjacent posts
var resultFilter bson.M
var pushFilter bson.M
var increment int
if post.Position > position {
pushFilter = bson.M{"position": bson.M{
"$gte": position,
"$lt": post.Position,
}}
increment = 1
resultFilter = bson.M{"position": bson.M{
"$lte": post.Position,
"$gte": position,
}}
} else {
pushFilter = bson.M{"position": bson.M{
"$lte": position,
"$gt": post.Position,
}}
increment = -1
resultFilter = bson.M{"position": bson.M{
"$lte": position,
"$gte": post.Position,
}}
}
pushFilter["logId"] = post.LogID
resultFilter["logId"] = post.LogID
// From here on out, sync is required
r.orderMutex.Lock()
defer r.orderMutex.Unlock()
// Detect ninja shenanigans
post2 := new(models.Post)
err := r.posts.FindId(post.ID).One(post2)
if err != nil && post2.Position != post.Position {
return nil, repositories.ErrNotFound
}
// Validate upper position bound
lastPost := new(models.Post)
err = r.posts.Find(bson.M{"logId": post.LogID}).Sort("-position").One(lastPost)
if err != nil && err != mgo.ErrNotFound {
return nil, err
}
if position > lastPost.Position {
return nil, repositories.ErrInvalidPosition
}
// Move the posts
changeInfo, err := r.posts.UpdateAll(pushFilter, bson.M{"$inc": bson.M{"position": increment}})
if err != nil && err != mgo.ErrNotFound {
return nil, err
}
err = r.posts.UpdateId(post.ID, bson.M{"$set": bson.M{"position": position}})
if err != nil {
// Try to undo it
_, err := r.posts.UpdateAll(pushFilter, bson.M{"$inc": bson.M{"position": -increment}})
if err != nil && err != mgo.ErrNotFound {
return nil, err
}
return nil, err
}
results := make([]*models.Post, 0, changeInfo.Matched+1)
err = r.posts.Find(resultFilter).All(&results)
if err != nil {
return nil, err
}
return results, nil
}
func (r *postRepository) Delete(ctx context.Context, post models.Post) error {
r.orderMutex.Lock()
defer r.orderMutex.Unlock()
err := r.posts.RemoveId(post.ID)
if err != nil {
return err
}
_, _ = r.posts.UpdateAll(bson.M{"logId": post.LogID, "position": bson.M{"$gt": post.Position}}, bson.M{"$inc": bson.M{"position": -1}})
return nil
}
func (r *postRepository) fixPositions(logRepo *logRepository) {
disorders := make([]int, 0, 16)
diffs := make([]int, 0, 16)
startTime := time.Now()
timeout, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
logs, err := logRepo.List(timeout, models.LogFilter{})
if err != nil {
log.Println("Failed to get logs for position fix:", err)
}
log.Println("Starting log position fixing, this should not take longer than 10 seconds.")
for _, l := range logs {
r.orderMutex.Lock()
posts, err := r.List(timeout, models.PostFilter{LogID: &l.ShortID})
if err != nil {
r.orderMutex.Unlock()
continue
}
disorders = disorders[:0]
diffs = diffs[:0]
for i, post := range posts {
if post.Position != (i + 1) {
disorders = append(disorders, i)
diffs = append(diffs, post.Position-(i+1))
}
}
if len(disorders) > 0 {
log.Println(len(disorders), "order errors detected in", l.ID)
ops := 0
for i, post := range posts {
if (i + 1) != posts[i].Position {
ops++
err := r.posts.UpdateId(post.ID, bson.M{"$set": bson.M{"position": i + 1}})
if err != nil {
log.Println(l.ShortID, "fix failed after", ops, "ops:", err)
break
}
}
}
log.Println(l.ShortID, "fixed after", ops, "ops.")
}
r.orderMutex.Unlock()
}
log.Println("Log position fixing finished in", time.Since(startTime))
}