GraphQL API and utilities for the rpdata project
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

313 lines
6.9 KiB

  1. package postgres
  2. import (
  3. "context"
  4. "database/sql"
  5. "errors"
  6. "git.aiterp.net/rpdata/api/database/postgres/psqlcore"
  7. "git.aiterp.net/rpdata/api/internal/generate"
  8. "git.aiterp.net/rpdata/api/models"
  9. "strings"
  10. )
  11. type postRepository struct {
  12. insertWithIDs bool
  13. db *sql.DB
  14. }
  15. func (r *postRepository) Find(ctx context.Context, id string) (*models.Post, error) {
  16. post, err := psqlcore.New(r.db).SelectPost(ctx, id)
  17. if err != nil {
  18. return nil, err
  19. }
  20. return r.post(post), nil
  21. }
  22. func (r *postRepository) List(ctx context.Context, filter models.PostFilter) ([]*models.Post, error) {
  23. q := psqlcore.New(r.db)
  24. params := psqlcore.SelectPostsParams{LimitSize: 0}
  25. if filter.LogID != nil {
  26. params.FilterLogShortID = true
  27. if !strings.HasPrefix(*filter.LogID, "L") {
  28. log, err := q.SelectLog(ctx, *filter.LogID)
  29. if err != nil {
  30. return nil, err
  31. }
  32. params.LogShortID = log.ShortID
  33. } else {
  34. params.LogShortID = *filter.LogID
  35. }
  36. }
  37. if filter.Kinds != nil {
  38. params.FilterKinds = true
  39. params.Kinds = filter.Kinds
  40. }
  41. if filter.Search != nil {
  42. params.FilterSearch = true
  43. params.Search = TSQueryFromSearch(*filter.Search)
  44. }
  45. if filter.Limit > 0 {
  46. params.LimitSize = int32(filter.Limit)
  47. }
  48. posts, err := q.SelectPosts(ctx, params)
  49. if err != nil {
  50. return nil, err
  51. }
  52. return r.posts(posts), nil
  53. }
  54. func (r *postRepository) Insert(ctx context.Context, post models.Post) (*models.Post, error) {
  55. if !r.insertWithIDs || len(post.ID) < 8 {
  56. post.ID = generate.PostID()
  57. }
  58. position, err := psqlcore.New(r.db).InsertPost(ctx, psqlcore.InsertPostParams{
  59. ID: post.ID,
  60. LogShortID: post.LogID,
  61. Time: post.Time.UTC(),
  62. Kind: post.Kind,
  63. Nick: post.Nick,
  64. Text: post.Text,
  65. })
  66. if err != nil {
  67. return nil, err
  68. }
  69. _ = psqlcore.New(r.db).GenerateLogTSVector(ctx, post.LogID)
  70. post.Position = int(position)
  71. return &post, nil
  72. }
  73. func (r *postRepository) InsertMany(ctx context.Context, posts ...*models.Post) ([]*models.Post, error) {
  74. if len(posts) == 0 {
  75. return []*models.Post{}, nil
  76. }
  77. allowedLogID := ""
  78. for _, post := range posts {
  79. if allowedLogID == "" {
  80. allowedLogID = post.LogID
  81. } else if allowedLogID != post.LogID {
  82. return nil, errors.New("cannot insert multiple posts with different log IDs")
  83. }
  84. }
  85. params := psqlcore.InsertPostsParams{
  86. LogShortID: posts[0].LogID,
  87. }
  88. for _, post := range posts {
  89. if !r.insertWithIDs || len(post.ID) < 8 {
  90. post.ID = generate.PostID()
  91. }
  92. params.Ids = append(params.Ids, post.ID)
  93. params.Kinds = append(params.Kinds, post.Kind)
  94. params.Nicks = append(params.Nicks, post.Nick)
  95. params.Offsets = append(params.Offsets, int32(len(params.Offsets)+1))
  96. params.Times = append(params.Times, post.Time.UTC())
  97. params.Texts = append(params.Texts, post.Text)
  98. }
  99. offset, err := psqlcore.New(r.db).InsertPosts(ctx, params)
  100. if err != nil {
  101. return nil, err
  102. }
  103. for i, post := range posts {
  104. post.Position = int(offset) + i
  105. }
  106. err = psqlcore.New(r.db).GenerateLogTSVector(ctx, posts[0].LogID)
  107. if err != nil {
  108. return nil, err
  109. }
  110. return posts, nil
  111. }
  112. func (r *postRepository) Update(ctx context.Context, post models.Post, update models.PostUpdate) (*models.Post, error) {
  113. post.ApplyUpdate(update)
  114. err := psqlcore.New(r.db).UpdatePost(ctx, psqlcore.UpdatePostParams{
  115. Time: post.Time,
  116. Kind: post.Kind,
  117. Nick: post.Nick,
  118. Text: post.Text,
  119. ID: post.ID,
  120. })
  121. if err != nil {
  122. return nil, err
  123. }
  124. _ = psqlcore.New(r.db).GenerateLogTSVector(ctx, post.LogID)
  125. return &post, nil
  126. }
  127. func (r *postRepository) Move(ctx context.Context, post models.Post, position int) ([]*models.Post, error) {
  128. if position == post.Position {
  129. return r.List(ctx, models.PostFilter{LogID: &post.LogID})
  130. }
  131. tx, err := r.db.BeginTx(ctx, nil)
  132. if err != nil {
  133. return nil, err
  134. }
  135. defer func() { _ = tx.Rollback() }()
  136. _, err = tx.Exec("LOCK TABLE log_post IN SHARE UPDATE EXCLUSIVE MODE")
  137. if err != nil {
  138. return nil, err
  139. }
  140. var lowest, highest int32
  141. q := psqlcore.New(tx)
  142. err = q.MovePost(ctx, psqlcore.MovePostParams{ID: post.ID, Position: -1})
  143. if err != nil {
  144. return nil, err
  145. }
  146. if position > post.Position {
  147. lowest = int32(post.Position)
  148. highest = int32(position)
  149. err := q.ShiftPostsBetween(ctx, psqlcore.ShiftPostsBetweenParams{
  150. ShiftOffset: -1,
  151. LogShortID: post.LogID,
  152. FromPosition: lowest + 1,
  153. ToPosition: highest,
  154. })
  155. if err != nil {
  156. return nil, err
  157. }
  158. } else {
  159. lowest = int32(position)
  160. highest = int32(post.Position)
  161. err := q.ShiftPostsBetween(ctx, psqlcore.ShiftPostsBetweenParams{
  162. ShiftOffset: 1,
  163. LogShortID: post.LogID,
  164. FromPosition: lowest,
  165. ToPosition: highest - 1,
  166. })
  167. if err != nil {
  168. return nil, err
  169. }
  170. }
  171. err = q.MovePost(ctx, psqlcore.MovePostParams{ID: post.ID, Position: int32(position)})
  172. if err != nil {
  173. return nil, err
  174. }
  175. posts, err := q.SelectPostsByPositionRange(ctx, psqlcore.SelectPostsByPositionRangeParams{
  176. LogShortID: post.LogID,
  177. FromPosition: lowest,
  178. ToPosition: highest,
  179. })
  180. if err != nil {
  181. return nil, err
  182. }
  183. positions, err := q.SelectPositionsByLogShortID(ctx, post.LogID)
  184. if err != nil {
  185. return nil, err
  186. }
  187. prev := int32(0)
  188. for _, pos := range positions {
  189. if pos != prev+1 {
  190. return nil, errors.New("post discontinuity detected")
  191. }
  192. prev = pos
  193. }
  194. err = tx.Commit()
  195. if err != nil {
  196. return nil, err
  197. }
  198. return r.posts2(posts), nil
  199. }
  200. func (r *postRepository) Delete(ctx context.Context, post models.Post) error {
  201. tx, err := r.db.BeginTx(ctx, nil)
  202. if err != nil {
  203. return err
  204. }
  205. defer func() { _ = tx.Rollback() }()
  206. q := psqlcore.New(tx)
  207. _, err = tx.Exec("LOCK TABLE log_post IN SHARE UPDATE EXCLUSIVE MODE")
  208. if err != nil {
  209. return err
  210. }
  211. err = q.MovePost(ctx, psqlcore.MovePostParams{ID: post.ID, Position: -1})
  212. if err != nil {
  213. return err
  214. }
  215. err = q.DeletePost(ctx, post.ID)
  216. if err != nil {
  217. return err
  218. }
  219. err = q.ShiftPostsAfter(ctx, psqlcore.ShiftPostsAfterParams{
  220. ShiftOffset: -1,
  221. LogShortID: post.LogID,
  222. FromPosition: int32(post.Position + 1),
  223. })
  224. if err != nil {
  225. return err
  226. }
  227. err = tx.Commit()
  228. if err != nil {
  229. return err
  230. }
  231. _ = psqlcore.New(r.db).GenerateLogTSVector(ctx, post.LogID)
  232. return nil
  233. }
  234. func (r *postRepository) post(post psqlcore.SelectPostRow) *models.Post {
  235. return &models.Post{
  236. ID: post.ID,
  237. LogID: post.LogShortID,
  238. Time: post.Time,
  239. Kind: post.Kind,
  240. Nick: post.Nick,
  241. Text: post.Text,
  242. Position: int(post.Position),
  243. }
  244. }
  245. func (r *postRepository) posts(posts []psqlcore.SelectPostsRow) []*models.Post {
  246. results := make([]*models.Post, 0, len(posts))
  247. for _, post := range posts {
  248. results = append(results, r.post(psqlcore.SelectPostRow(post)))
  249. }
  250. return results
  251. }
  252. func (r *postRepository) posts2(posts []psqlcore.SelectPostsByPositionRangeRow) []*models.Post {
  253. results := make([]*models.Post, 0, len(posts))
  254. for _, post := range posts {
  255. results = append(results, r.post(psqlcore.SelectPostRow(post)))
  256. }
  257. return results
  258. }