GraphQL API and utilities for the rpdata project
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

345 lines
7.9 KiB

  1. package mongodb
  2. import (
  3. "context"
  4. "git.aiterp.net/rpdata/api/internal/generate"
  5. "git.aiterp.net/rpdata/api/models"
  6. "git.aiterp.net/rpdata/api/repositories"
  7. "github.com/globalsign/mgo"
  8. "github.com/globalsign/mgo/bson"
  9. "log"
  10. "strings"
  11. "sync"
  12. "time"
  13. )
  14. type postRepository struct {
  15. logs *mgo.Collection
  16. posts *mgo.Collection
  17. orderMutex sync.Mutex
  18. }
  19. func newPostRepository(db *mgo.Database) (*postRepository, error) {
  20. posts := db.C("logbot3.posts")
  21. err := posts.EnsureIndexKey("logId")
  22. if err != nil {
  23. return nil, err
  24. }
  25. err = posts.EnsureIndexKey("time")
  26. if err != nil {
  27. return nil, err
  28. }
  29. err = posts.EnsureIndexKey("kind")
  30. if err != nil {
  31. return nil, err
  32. }
  33. err = posts.EnsureIndexKey("position")
  34. if err != nil {
  35. return nil, err
  36. }
  37. err = posts.EnsureIndex(mgo.Index{
  38. Key: []string{"$text:text"},
  39. })
  40. if err != nil {
  41. return nil, err
  42. }
  43. return &postRepository{
  44. posts: posts,
  45. logs: db.C("logbot3.logs"),
  46. }, nil
  47. }
  48. func (r *postRepository) Find(ctx context.Context, id string) (*models.Post, error) {
  49. post := new(models.Post)
  50. err := r.posts.FindId(id).One(post)
  51. if err != nil {
  52. return nil, err
  53. }
  54. return post, nil
  55. }
  56. func (r *postRepository) List(ctx context.Context, filter models.PostFilter) ([]*models.Post, error) {
  57. query := bson.M{}
  58. if filter.LogID != nil && *filter.LogID != "" {
  59. logId := *filter.LogID
  60. if !strings.HasPrefix(logId, "L") {
  61. // Resolve long id to short id
  62. log := new(models.Log)
  63. err := r.logs.FindId(logId).Select(bson.M{"logId": 1, "_id": 1}).One(log)
  64. if err != nil {
  65. return nil, err
  66. }
  67. logId = log.ShortID
  68. }
  69. query["logId"] = logId
  70. }
  71. if len(filter.IDs) > 0 {
  72. query["_id"] = bson.M{"$in": filter.IDs}
  73. }
  74. if len(filter.Kinds) > 0 {
  75. query["kind"] = bson.M{"$in": filter.Kinds}
  76. }
  77. if filter.Search != nil {
  78. query["$text"] = bson.M{"$search": *filter.Search}
  79. }
  80. posts := make([]*models.Post, 0, 32)
  81. err := r.posts.Find(query).Sort("-logId", "position").Limit(filter.Limit).All(&posts)
  82. if err != nil {
  83. if err == mgo.ErrNotFound {
  84. return []*models.Post{}, nil
  85. }
  86. return nil, err
  87. }
  88. return posts, nil
  89. }
  90. func (r *postRepository) Insert(ctx context.Context, post models.Post) (*models.Post, error) {
  91. r.orderMutex.Lock()
  92. defer r.orderMutex.Unlock()
  93. lastPost := new(models.Post)
  94. err := r.posts.Find(bson.M{"logId": post.LogID}).Sort("-position").One(lastPost)
  95. if err != nil && err != mgo.ErrNotFound {
  96. return nil, err
  97. }
  98. post.ID = generate.PostID()
  99. post.Position = lastPost.Position + 1 // Position 1 is first position, so this is safe.
  100. err = r.posts.Insert(post)
  101. if err != nil {
  102. return nil, err
  103. }
  104. return &post, nil
  105. }
  106. func (r *postRepository) InsertMany(ctx context.Context, posts ...*models.Post) ([]*models.Post, error) {
  107. if len(posts) == 0 {
  108. return []*models.Post{}, nil
  109. }
  110. logId := posts[0].LogID
  111. for _, post := range posts {
  112. if post.LogID != logId {
  113. return nil, repositories.ErrParentMismatch
  114. }
  115. post.ID = generate.PostID()
  116. }
  117. r.orderMutex.Lock()
  118. defer r.orderMutex.Unlock()
  119. lastPost := new(models.Post)
  120. err := r.posts.Find(bson.M{"logId": posts[0].LogID}).Sort("-position").One(lastPost)
  121. if err != nil && err != mgo.ErrNotFound {
  122. return nil, err
  123. }
  124. docs := make([]interface{}, len(posts))
  125. for i := range posts {
  126. posts[i].Position = lastPost.Position + 1 + i
  127. docs[i] = posts[i]
  128. }
  129. err = r.posts.Insert(docs...)
  130. if err != nil && err != mgo.ErrNotFound {
  131. return nil, err
  132. }
  133. return posts, nil
  134. }
  135. func (r *postRepository) Update(ctx context.Context, post models.Post, update models.PostUpdate) (*models.Post, error) {
  136. updateBson := bson.M{}
  137. if update.Time != nil {
  138. updateBson["time"] = *update.Time
  139. post.Time = *update.Time
  140. }
  141. if update.Kind != nil {
  142. updateBson["kind"] = *update.Kind
  143. post.Kind = *update.Kind
  144. }
  145. if update.Nick != nil {
  146. updateBson["nick"] = *update.Nick
  147. post.Nick = *update.Nick
  148. }
  149. if update.Text != nil {
  150. updateBson["text"] = *update.Text
  151. post.Text = *update.Text
  152. }
  153. err := r.posts.UpdateId(post.ID, bson.M{"$set": updateBson})
  154. if err != nil {
  155. return nil, err
  156. }
  157. return &post, nil
  158. }
  159. func (r *postRepository) Move(ctx context.Context, post models.Post, position int) ([]*models.Post, error) {
  160. // If only MongoDB transactions weren't awful, this function would have been safe.
  161. // Since it isn't, then good luck.
  162. // Validate lower position bound.
  163. if position < 1 {
  164. return nil, repositories.ErrInvalidPosition
  165. }
  166. // Determine the operations on adjacent posts
  167. var resultFilter bson.M
  168. var pushFilter bson.M
  169. var increment int
  170. if post.Position > position {
  171. pushFilter = bson.M{"position": bson.M{
  172. "$gte": position,
  173. "$lt": post.Position,
  174. }}
  175. increment = 1
  176. resultFilter = bson.M{"position": bson.M{
  177. "$lte": post.Position,
  178. "$gte": position,
  179. }}
  180. } else {
  181. pushFilter = bson.M{"position": bson.M{
  182. "$lte": position,
  183. "$gt": post.Position,
  184. }}
  185. increment = -1
  186. resultFilter = bson.M{"position": bson.M{
  187. "$lte": position,
  188. "$gte": post.Position,
  189. }}
  190. }
  191. pushFilter["logId"] = post.LogID
  192. resultFilter["logId"] = post.LogID
  193. // From here on out, sync is required
  194. r.orderMutex.Lock()
  195. defer r.orderMutex.Unlock()
  196. // Detect ninja shenanigans
  197. post2 := new(models.Post)
  198. err := r.posts.FindId(post.ID).One(post2)
  199. if err != nil && post2.Position != post.Position {
  200. return nil, repositories.ErrNotFound
  201. }
  202. // Validate upper position bound
  203. lastPost := new(models.Post)
  204. err = r.posts.Find(bson.M{"logId": post.LogID}).Sort("-position").One(lastPost)
  205. if err != nil && err != mgo.ErrNotFound {
  206. return nil, err
  207. }
  208. if position > lastPost.Position {
  209. return nil, repositories.ErrInvalidPosition
  210. }
  211. // Move the posts
  212. changeInfo, err := r.posts.UpdateAll(pushFilter, bson.M{"$inc": bson.M{"position": increment}})
  213. if err != nil && err != mgo.ErrNotFound {
  214. return nil, err
  215. }
  216. err = r.posts.UpdateId(post.ID, bson.M{"$set": bson.M{"position": position}})
  217. if err != nil {
  218. // Try to undo it
  219. _, err := r.posts.UpdateAll(pushFilter, bson.M{"$inc": bson.M{"position": -increment}})
  220. if err != nil && err != mgo.ErrNotFound {
  221. return nil, err
  222. }
  223. return nil, err
  224. }
  225. results := make([]*models.Post, 0, changeInfo.Matched+1)
  226. err = r.posts.Find(resultFilter).All(&results)
  227. if err != nil {
  228. return nil, err
  229. }
  230. return results, nil
  231. }
  232. func (r *postRepository) Delete(ctx context.Context, post models.Post) error {
  233. r.orderMutex.Lock()
  234. defer r.orderMutex.Unlock()
  235. err := r.posts.RemoveId(post.ID)
  236. if err != nil {
  237. return err
  238. }
  239. _, _ = r.posts.UpdateAll(bson.M{"logId": post.LogID, "position": bson.M{"$gt": post.Position}}, bson.M{"$inc": bson.M{"position": -1}})
  240. return nil
  241. }
  242. func (r *postRepository) fixPositions(logRepo *logRepository) {
  243. disorders := make([]int, 0, 16)
  244. diffs := make([]int, 0, 16)
  245. startTime := time.Now()
  246. timeout, cancel := context.WithTimeout(context.Background(), time.Minute)
  247. defer cancel()
  248. logs, err := logRepo.List(timeout, models.LogFilter{})
  249. if err != nil {
  250. log.Println("Failed to get logs for position fix:", err)
  251. }
  252. log.Println("Starting log position fixing, this should not take longer than 10 seconds.")
  253. for _, l := range logs {
  254. r.orderMutex.Lock()
  255. posts, err := r.List(timeout, models.PostFilter{LogID: &l.ShortID})
  256. if err != nil {
  257. r.orderMutex.Unlock()
  258. continue
  259. }
  260. disorders = disorders[:0]
  261. diffs = diffs[:0]
  262. for i, post := range posts {
  263. if post.Position != (i + 1) {
  264. disorders = append(disorders, i)
  265. diffs = append(diffs, post.Position-(i+1))
  266. }
  267. }
  268. if len(disorders) > 0 {
  269. log.Println(len(disorders), "order errors detected in", l.ID)
  270. ops := 0
  271. for i, post := range posts {
  272. if (i + 1) != posts[i].Position {
  273. ops++
  274. err := r.posts.UpdateId(post.ID, bson.M{"$set": bson.M{"position": i + 1}})
  275. if err != nil {
  276. log.Println(l.ShortID, "fix failed after", ops, "ops:", err)
  277. break
  278. }
  279. }
  280. }
  281. log.Println(l.ShortID, "fixed after", ops, "ops.")
  282. }
  283. r.orderMutex.Unlock()
  284. }
  285. log.Println("Log position fixing finished in", time.Since(startTime))
  286. }