GraphQL API and utilities for the rpdata project
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

350 lines
8.0 KiB

  1. package mongodb
  2. import (
  3. "context"
  4. "git.aiterp.net/rpdata/api/internal/generate"
  5. "git.aiterp.net/rpdata/api/models"
  6. "git.aiterp.net/rpdata/api/repositories"
  7. "github.com/globalsign/mgo"
  8. "github.com/globalsign/mgo/bson"
  9. "log"
  10. "strings"
  11. "sync"
  12. "time"
  13. )
  14. type postRepository struct {
  15. logs *mgo.Collection
  16. posts *mgo.Collection
  17. orderMutex sync.Mutex
  18. }
  19. func newPostRepository(db *mgo.Database) (*postRepository, error) {
  20. posts := db.C("logbot3.posts")
  21. err := posts.EnsureIndexKey("logId")
  22. if err != nil {
  23. return nil, err
  24. }
  25. err = posts.EnsureIndexKey("time")
  26. if err != nil {
  27. return nil, err
  28. }
  29. err = posts.EnsureIndexKey("kind")
  30. if err != nil {
  31. return nil, err
  32. }
  33. err = posts.EnsureIndexKey("position")
  34. if err != nil {
  35. return nil, err
  36. }
  37. err = posts.EnsureIndex(mgo.Index{
  38. Key: []string{"$text:text"},
  39. })
  40. if err != nil {
  41. return nil, err
  42. }
  43. return &postRepository{
  44. posts: posts,
  45. logs: db.C("logbot3.logs"),
  46. }, nil
  47. }
  48. func (r *postRepository) Find(ctx context.Context, id string) (*models.Post, error) {
  49. post := new(models.Post)
  50. err := r.posts.FindId(id).One(post)
  51. if err != nil {
  52. return nil, err
  53. }
  54. return post, nil
  55. }
  56. func (r *postRepository) List(ctx context.Context, filter models.PostFilter) ([]*models.Post, error) {
  57. query := bson.M{}
  58. if filter.LogID != nil && *filter.LogID != "" {
  59. logId := *filter.LogID
  60. if !strings.HasPrefix(logId, "L") {
  61. // Resolve long id to short id
  62. log := new(models.Log)
  63. err := r.logs.FindId(logId).Select(bson.M{"logId": 1, "_id": 1}).One(log)
  64. if err != nil {
  65. return nil, err
  66. }
  67. logId = log.ShortID
  68. }
  69. query["logId"] = logId
  70. }
  71. if len(filter.IDs) > 0 {
  72. query["_id"] = bson.M{"$in": filter.IDs}
  73. }
  74. if len(filter.Kinds) > 0 {
  75. query["kind"] = bson.M{"$in": filter.Kinds}
  76. }
  77. if filter.Search != nil {
  78. query["$text"] = bson.M{"$search": *filter.Search}
  79. }
  80. posts := make([]*models.Post, 0, 32)
  81. var err error
  82. if filter.LogID != nil {
  83. err = r.posts.Find(query).Sort("position").Limit(filter.Limit).All(&posts)
  84. } else {
  85. err = r.posts.Find(query).Limit(filter.Limit).All(&posts)
  86. }
  87. if err != nil {
  88. if err == mgo.ErrNotFound {
  89. return []*models.Post{}, nil
  90. }
  91. return nil, err
  92. }
  93. return posts, nil
  94. }
  95. func (r *postRepository) Insert(ctx context.Context, post models.Post) (*models.Post, error) {
  96. r.orderMutex.Lock()
  97. defer r.orderMutex.Unlock()
  98. lastPost := new(models.Post)
  99. err := r.posts.Find(bson.M{"logId": post.LogID}).Sort("-position").One(lastPost)
  100. if err != nil && err != mgo.ErrNotFound {
  101. return nil, err
  102. }
  103. post.ID = generate.PostID()
  104. post.Position = lastPost.Position + 1 // Position 1 is first position, so this is safe.
  105. err = r.posts.Insert(post)
  106. if err != nil {
  107. return nil, err
  108. }
  109. return &post, nil
  110. }
  111. func (r *postRepository) InsertMany(ctx context.Context, posts ...*models.Post) ([]*models.Post, error) {
  112. if len(posts) == 0 {
  113. return []*models.Post{}, nil
  114. }
  115. logId := posts[0].LogID
  116. for _, post := range posts {
  117. if post.LogID != logId {
  118. return nil, repositories.ErrParentMismatch
  119. }
  120. post.ID = generate.PostID()
  121. }
  122. r.orderMutex.Lock()
  123. defer r.orderMutex.Unlock()
  124. lastPost := new(models.Post)
  125. err := r.posts.Find(bson.M{"logId": posts[0].LogID}).Sort("-position").One(lastPost)
  126. if err != nil && err != mgo.ErrNotFound {
  127. return nil, err
  128. }
  129. docs := make([]interface{}, len(posts))
  130. for i := range posts {
  131. posts[i].Position = lastPost.Position + 1 + i
  132. docs[i] = posts[i]
  133. }
  134. err = r.posts.Insert(docs...)
  135. if err != nil && err != mgo.ErrNotFound {
  136. return nil, err
  137. }
  138. return posts, nil
  139. }
  140. func (r *postRepository) Update(ctx context.Context, post models.Post, update models.PostUpdate) (*models.Post, error) {
  141. updateBson := bson.M{}
  142. if update.Time != nil {
  143. updateBson["time"] = *update.Time
  144. post.Time = *update.Time
  145. }
  146. if update.Kind != nil {
  147. updateBson["kind"] = *update.Kind
  148. post.Kind = *update.Kind
  149. }
  150. if update.Nick != nil {
  151. updateBson["nick"] = *update.Nick
  152. post.Nick = *update.Nick
  153. }
  154. if update.Text != nil {
  155. updateBson["text"] = *update.Text
  156. post.Text = *update.Text
  157. }
  158. err := r.posts.UpdateId(post.ID, bson.M{"$set": updateBson})
  159. if err != nil {
  160. return nil, err
  161. }
  162. return &post, nil
  163. }
  164. func (r *postRepository) Move(ctx context.Context, post models.Post, position int) ([]*models.Post, error) {
  165. // If only MongoDB transactions weren't awful, this function would have been safe.
  166. // Since it isn't, then good luck.
  167. // Validate lower position bound.
  168. if position < 1 {
  169. return nil, repositories.ErrInvalidPosition
  170. }
  171. // Determine the operations on adjacent posts
  172. var resultFilter bson.M
  173. var pushFilter bson.M
  174. var increment int
  175. if post.Position > position {
  176. pushFilter = bson.M{"position": bson.M{
  177. "$gte": position,
  178. "$lt": post.Position,
  179. }}
  180. increment = 1
  181. resultFilter = bson.M{"position": bson.M{
  182. "$lte": post.Position,
  183. "$gte": position,
  184. }}
  185. } else {
  186. pushFilter = bson.M{"position": bson.M{
  187. "$lte": position,
  188. "$gt": post.Position,
  189. }}
  190. increment = -1
  191. resultFilter = bson.M{"position": bson.M{
  192. "$lte": position,
  193. "$gte": post.Position,
  194. }}
  195. }
  196. pushFilter["logId"] = post.LogID
  197. resultFilter["logId"] = post.LogID
  198. // From here on out, sync is required
  199. r.orderMutex.Lock()
  200. defer r.orderMutex.Unlock()
  201. // Detect ninja shenanigans
  202. post2 := new(models.Post)
  203. err := r.posts.FindId(post.ID).One(post2)
  204. if err != nil && post2.Position != post.Position {
  205. return nil, repositories.ErrNotFound
  206. }
  207. // Validate upper position bound
  208. lastPost := new(models.Post)
  209. err = r.posts.Find(bson.M{"logId": post.LogID}).Sort("-position").One(lastPost)
  210. if err != nil && err != mgo.ErrNotFound {
  211. return nil, err
  212. }
  213. if position > lastPost.Position {
  214. return nil, repositories.ErrInvalidPosition
  215. }
  216. // Move the posts
  217. changeInfo, err := r.posts.UpdateAll(pushFilter, bson.M{"$inc": bson.M{"position": increment}})
  218. if err != nil && err != mgo.ErrNotFound {
  219. return nil, err
  220. }
  221. err = r.posts.UpdateId(post.ID, bson.M{"$set": bson.M{"position": position}})
  222. if err != nil {
  223. // Try to undo it
  224. _, err := r.posts.UpdateAll(pushFilter, bson.M{"$inc": bson.M{"position": -increment}})
  225. if err != nil && err != mgo.ErrNotFound {
  226. return nil, err
  227. }
  228. return nil, err
  229. }
  230. results := make([]*models.Post, 0, changeInfo.Matched+1)
  231. err = r.posts.Find(resultFilter).All(&results)
  232. if err != nil {
  233. return nil, err
  234. }
  235. return results, nil
  236. }
  237. func (r *postRepository) Delete(ctx context.Context, post models.Post) error {
  238. r.orderMutex.Lock()
  239. defer r.orderMutex.Unlock()
  240. err := r.posts.RemoveId(post.ID)
  241. if err != nil {
  242. return err
  243. }
  244. _, _ = r.posts.UpdateAll(bson.M{"logId": post.LogID, "position": bson.M{"$gt": post.Position}}, bson.M{"$inc": bson.M{"position": -1}})
  245. return nil
  246. }
  247. func (r *postRepository) fixPositions(logRepo *logRepository) {
  248. disorders := make([]int, 0, 16)
  249. diffs := make([]int, 0, 16)
  250. startTime := time.Now()
  251. timeout, cancel := context.WithTimeout(context.Background(), time.Minute)
  252. defer cancel()
  253. logs, err := logRepo.List(timeout, models.LogFilter{})
  254. if err != nil {
  255. log.Println("Failed to get logs for position fix:", err)
  256. }
  257. log.Println("Starting log position fixing, this should not take longer than 10 seconds.")
  258. for _, l := range logs {
  259. r.orderMutex.Lock()
  260. posts, err := r.List(timeout, models.PostFilter{LogID: &l.ShortID})
  261. if err != nil {
  262. r.orderMutex.Unlock()
  263. continue
  264. }
  265. disorders = disorders[:0]
  266. diffs = diffs[:0]
  267. for i, post := range posts {
  268. if post.Position != (i + 1) {
  269. disorders = append(disorders, i)
  270. diffs = append(diffs, post.Position-(i+1))
  271. }
  272. }
  273. if len(disorders) > 0 {
  274. log.Println(len(disorders), "order errors detected in", l.ID)
  275. ops := 0
  276. for i, post := range posts {
  277. if (i + 1) != posts[i].Position {
  278. ops++
  279. err := r.posts.UpdateId(post.ID, bson.M{"$set": bson.M{"position": i + 1}})
  280. if err != nil {
  281. log.Println(l.ShortID, "fix failed after", ops, "ops:", err)
  282. break
  283. }
  284. }
  285. }
  286. log.Println(l.ShortID, "fixed after", ops, "ops.")
  287. }
  288. r.orderMutex.Unlock()
  289. }
  290. log.Println("Log position fixing finished in", time.Since(startTime))
  291. }