GraphQL API and utilities for the rpdata project
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

362 lines
8.3 KiB

  1. package mongodb
  2. import (
  3. "context"
  4. "errors"
  5. "git.aiterp.net/rpdata/api/internal/generate"
  6. "git.aiterp.net/rpdata/api/models"
  7. "git.aiterp.net/rpdata/api/repositories"
  8. "github.com/globalsign/mgo"
  9. "github.com/globalsign/mgo/bson"
  10. "log"
  11. "strings"
  12. "sync"
  13. "time"
  14. )
  15. type postRepository struct {
  16. restoreIDs bool
  17. logs *mgo.Collection
  18. posts *mgo.Collection
  19. orderMutex sync.Mutex
  20. }
  21. func newPostRepository(db *mgo.Database, restoreIDs bool) (*postRepository, error) {
  22. posts := db.C("logbot3.posts")
  23. err := posts.EnsureIndexKey("logId")
  24. if err != nil {
  25. return nil, err
  26. }
  27. err = posts.EnsureIndexKey("time")
  28. if err != nil {
  29. return nil, err
  30. }
  31. err = posts.EnsureIndexKey("kind")
  32. if err != nil {
  33. return nil, err
  34. }
  35. err = posts.EnsureIndexKey("position")
  36. if err != nil {
  37. return nil, err
  38. }
  39. err = posts.EnsureIndex(mgo.Index{
  40. Key: []string{"$text:text"},
  41. })
  42. if err != nil {
  43. return nil, err
  44. }
  45. return &postRepository{
  46. restoreIDs: restoreIDs,
  47. posts: posts,
  48. logs: db.C("logbot3.logs"),
  49. }, nil
  50. }
  51. func (r *postRepository) Find(ctx context.Context, id string) (*models.Post, error) {
  52. post := new(models.Post)
  53. err := r.posts.FindId(id).One(post)
  54. if err != nil {
  55. return nil, err
  56. }
  57. return post, nil
  58. }
  59. func (r *postRepository) List(ctx context.Context, filter models.PostFilter) ([]*models.Post, error) {
  60. query := bson.M{}
  61. if filter.LogID != nil && *filter.LogID != "" {
  62. logId := *filter.LogID
  63. if !strings.HasPrefix(logId, "L") {
  64. // Resolve long id to short id
  65. log := new(models.Log)
  66. err := r.logs.FindId(logId).Select(bson.M{"logId": 1, "_id": 1}).One(log)
  67. if err != nil {
  68. return nil, err
  69. }
  70. logId = log.ShortID
  71. }
  72. query["logId"] = logId
  73. }
  74. if len(filter.IDs) > 0 {
  75. query["_id"] = bson.M{"$in": filter.IDs}
  76. }
  77. if len(filter.Kinds) > 0 {
  78. query["kind"] = bson.M{"$in": filter.Kinds}
  79. }
  80. if filter.Search != nil {
  81. query["$text"] = bson.M{"$search": *filter.Search}
  82. }
  83. posts := make([]*models.Post, 0, 32)
  84. var err error
  85. if filter.LogID != nil {
  86. err = r.posts.Find(query).Sort("position").Limit(filter.Limit).All(&posts)
  87. } else {
  88. err = r.posts.Find(query).Limit(filter.Limit).All(&posts)
  89. }
  90. if err != nil {
  91. if err == mgo.ErrNotFound {
  92. return []*models.Post{}, nil
  93. }
  94. return nil, err
  95. }
  96. return posts, nil
  97. }
  98. func (r *postRepository) Insert(ctx context.Context, post models.Post) (*models.Post, error) {
  99. r.orderMutex.Lock()
  100. defer r.orderMutex.Unlock()
  101. lastPost := new(models.Post)
  102. err := r.posts.Find(bson.M{"logId": post.LogID}).Sort("-position").One(lastPost)
  103. if err != nil && err != mgo.ErrNotFound {
  104. return nil, err
  105. }
  106. if !r.restoreIDs {
  107. post.ID = generate.PostID()
  108. } else {
  109. if len(post.ID) != len(generate.PostID()) && strings.HasPrefix(post.ID, "P") {
  110. return nil, errors.New("invalid story id")
  111. }
  112. }
  113. post.Position = lastPost.Position + 1 // Position 1 is first position, so this is safe.
  114. err = r.posts.Insert(post)
  115. if err != nil {
  116. return nil, err
  117. }
  118. return &post, nil
  119. }
  120. func (r *postRepository) InsertMany(ctx context.Context, posts ...*models.Post) ([]*models.Post, error) {
  121. if len(posts) == 0 {
  122. return []*models.Post{}, nil
  123. }
  124. logId := posts[0].LogID
  125. for _, post := range posts {
  126. if post.LogID != logId {
  127. return nil, repositories.ErrParentMismatch
  128. }
  129. if !r.restoreIDs || post.ID == "" {
  130. post.ID = generate.PostID()
  131. }
  132. }
  133. r.orderMutex.Lock()
  134. defer r.orderMutex.Unlock()
  135. lastPost := new(models.Post)
  136. err := r.posts.Find(bson.M{"logId": posts[0].LogID}).Sort("-position").One(lastPost)
  137. if err != nil && err != mgo.ErrNotFound {
  138. return nil, err
  139. }
  140. docs := make([]interface{}, len(posts))
  141. for i := range posts {
  142. posts[i].Position = lastPost.Position + 1 + i
  143. docs[i] = posts[i]
  144. }
  145. err = r.posts.Insert(docs...)
  146. if err != nil && err != mgo.ErrNotFound {
  147. return nil, err
  148. }
  149. return posts, nil
  150. }
  151. func (r *postRepository) Update(ctx context.Context, post models.Post, update models.PostUpdate) (*models.Post, error) {
  152. updateBson := bson.M{}
  153. if update.Time != nil {
  154. updateBson["time"] = *update.Time
  155. post.Time = *update.Time
  156. }
  157. if update.Kind != nil {
  158. updateBson["kind"] = *update.Kind
  159. post.Kind = *update.Kind
  160. }
  161. if update.Nick != nil {
  162. updateBson["nick"] = *update.Nick
  163. post.Nick = *update.Nick
  164. }
  165. if update.Text != nil {
  166. updateBson["text"] = *update.Text
  167. post.Text = *update.Text
  168. }
  169. err := r.posts.UpdateId(post.ID, bson.M{"$set": updateBson})
  170. if err != nil {
  171. return nil, err
  172. }
  173. return &post, nil
  174. }
  175. func (r *postRepository) Move(ctx context.Context, post models.Post, position int) ([]*models.Post, error) {
  176. // If only MongoDB transactions weren't awful, this function would have been safe.
  177. // Since it isn't, then good luck.
  178. // Validate lower position bound.
  179. if position < 1 {
  180. return nil, repositories.ErrInvalidPosition
  181. }
  182. // Determine the operations on adjacent posts
  183. var resultFilter bson.M
  184. var pushFilter bson.M
  185. var increment int
  186. if post.Position > position {
  187. pushFilter = bson.M{"position": bson.M{
  188. "$gte": position,
  189. "$lt": post.Position,
  190. }}
  191. increment = 1
  192. resultFilter = bson.M{"position": bson.M{
  193. "$lte": post.Position,
  194. "$gte": position,
  195. }}
  196. } else {
  197. pushFilter = bson.M{"position": bson.M{
  198. "$lte": position,
  199. "$gt": post.Position,
  200. }}
  201. increment = -1
  202. resultFilter = bson.M{"position": bson.M{
  203. "$lte": position,
  204. "$gte": post.Position,
  205. }}
  206. }
  207. pushFilter["logId"] = post.LogID
  208. resultFilter["logId"] = post.LogID
  209. // From here on out, sync is required
  210. r.orderMutex.Lock()
  211. defer r.orderMutex.Unlock()
  212. // Detect ninja shenanigans
  213. post2 := new(models.Post)
  214. err := r.posts.FindId(post.ID).One(post2)
  215. if err != nil && post2.Position != post.Position {
  216. return nil, repositories.ErrNotFound
  217. }
  218. // Validate upper position bound
  219. lastPost := new(models.Post)
  220. err = r.posts.Find(bson.M{"logId": post.LogID}).Sort("-position").One(lastPost)
  221. if err != nil && err != mgo.ErrNotFound {
  222. return nil, err
  223. }
  224. if position > lastPost.Position {
  225. return nil, repositories.ErrInvalidPosition
  226. }
  227. // Move the posts
  228. changeInfo, err := r.posts.UpdateAll(pushFilter, bson.M{"$inc": bson.M{"position": increment}})
  229. if err != nil && err != mgo.ErrNotFound {
  230. return nil, err
  231. }
  232. err = r.posts.UpdateId(post.ID, bson.M{"$set": bson.M{"position": position}})
  233. if err != nil {
  234. // Try to undo it
  235. _, err := r.posts.UpdateAll(pushFilter, bson.M{"$inc": bson.M{"position": -increment}})
  236. if err != nil && err != mgo.ErrNotFound {
  237. return nil, err
  238. }
  239. return nil, err
  240. }
  241. results := make([]*models.Post, 0, changeInfo.Matched+1)
  242. err = r.posts.Find(resultFilter).All(&results)
  243. if err != nil {
  244. return nil, err
  245. }
  246. return results, nil
  247. }
  248. func (r *postRepository) Delete(ctx context.Context, post models.Post) error {
  249. r.orderMutex.Lock()
  250. defer r.orderMutex.Unlock()
  251. err := r.posts.RemoveId(post.ID)
  252. if err != nil {
  253. return err
  254. }
  255. _, _ = r.posts.UpdateAll(bson.M{"logId": post.LogID, "position": bson.M{"$gt": post.Position}}, bson.M{"$inc": bson.M{"position": -1}})
  256. return nil
  257. }
  258. func (r *postRepository) fixPositions(logRepo *logRepository) {
  259. disorders := make([]int, 0, 16)
  260. diffs := make([]int, 0, 16)
  261. startTime := time.Now()
  262. timeout, cancel := context.WithTimeout(context.Background(), time.Minute)
  263. defer cancel()
  264. logs, err := logRepo.List(timeout, models.LogFilter{})
  265. if err != nil {
  266. log.Println("Failed to get logs for position fix:", err)
  267. }
  268. log.Println("Starting log position fixing, this should not take longer than 10 seconds.")
  269. for _, l := range logs {
  270. r.orderMutex.Lock()
  271. posts, err := r.List(timeout, models.PostFilter{LogID: &l.ShortID})
  272. if err != nil {
  273. r.orderMutex.Unlock()
  274. continue
  275. }
  276. disorders = disorders[:0]
  277. diffs = diffs[:0]
  278. for i, post := range posts {
  279. if post.Position != (i + 1) {
  280. disorders = append(disorders, i)
  281. diffs = append(diffs, post.Position-(i+1))
  282. }
  283. }
  284. if len(disorders) > 0 {
  285. log.Println(len(disorders), "order errors detected in", l.ID)
  286. ops := 0
  287. for i, post := range posts {
  288. if (i + 1) != posts[i].Position {
  289. ops++
  290. err := r.posts.UpdateId(post.ID, bson.M{"$set": bson.M{"position": i + 1}})
  291. if err != nil {
  292. log.Println(l.ShortID, "fix failed after", ops, "ops:", err)
  293. break
  294. }
  295. }
  296. }
  297. log.Println(l.ShortID, "fixed after", ops, "ops.")
  298. }
  299. r.orderMutex.Unlock()
  300. }
  301. log.Println("Log position fixing finished in", time.Since(startTime))
  302. }