GraphQL API and utilities for the rpdata project
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

314 lines
7.0 KiB

3 years ago
3 years ago
  1. package postgres
  2. import (
  3. "context"
  4. "database/sql"
  5. "errors"
  6. "fmt"
  7. "git.aiterp.net/rpdata/api/database/postgres/psqlcore"
  8. "git.aiterp.net/rpdata/api/internal/generate"
  9. "git.aiterp.net/rpdata/api/models"
  10. "strings"
  11. )
  12. type postRepository struct {
  13. insertWithIDs bool
  14. db *sql.DB
  15. }
  16. func (r *postRepository) Find(ctx context.Context, id string) (*models.Post, error) {
  17. post, err := psqlcore.New(r.db).SelectPost(ctx, id)
  18. if err != nil {
  19. return nil, err
  20. }
  21. return r.post(post), nil
  22. }
  23. func (r *postRepository) List(ctx context.Context, filter models.PostFilter) ([]*models.Post, error) {
  24. q := psqlcore.New(r.db)
  25. params := psqlcore.SelectPostsParams{LimitSize: 0}
  26. if filter.LogID != nil {
  27. params.FilterLogShortID = true
  28. if !strings.HasPrefix(*filter.LogID, "L") {
  29. log, err := q.SelectLog(ctx, *filter.LogID)
  30. if err != nil {
  31. return nil, err
  32. }
  33. params.LogShortID = log.ShortID
  34. } else {
  35. params.LogShortID = *filter.LogID
  36. }
  37. }
  38. if filter.Kinds != nil {
  39. params.FilterKinds = true
  40. params.Kinds = filter.Kinds
  41. }
  42. if filter.Search != nil {
  43. params.FilterSearch = true
  44. params.Search = TSQueryFromSearch(*filter.Search)
  45. }
  46. if filter.Limit > 0 {
  47. params.LimitSize = int32(filter.Limit)
  48. }
  49. posts, err := q.SelectPosts(ctx, params)
  50. if err != nil {
  51. return nil, err
  52. }
  53. return r.posts(posts), nil
  54. }
  55. func (r *postRepository) Insert(ctx context.Context, post models.Post) (*models.Post, error) {
  56. if !r.insertWithIDs || len(post.ID) < 8 {
  57. post.ID = generate.PostID()
  58. }
  59. position, err := psqlcore.New(r.db).InsertPost(ctx, psqlcore.InsertPostParams{
  60. ID: post.ID,
  61. LogShortID: post.LogID,
  62. Time: post.Time.UTC(),
  63. Kind: post.Kind,
  64. Nick: post.Nick,
  65. Text: post.Text,
  66. })
  67. if err != nil {
  68. return nil, err
  69. }
  70. _ = psqlcore.New(r.db).GenerateLogTSVector(ctx, post.LogID)
  71. post.Position = int(position)
  72. return &post, nil
  73. }
  74. func (r *postRepository) InsertMany(ctx context.Context, posts ...*models.Post) ([]*models.Post, error) {
  75. if len(posts) == 0 {
  76. return []*models.Post{}, nil
  77. }
  78. allowedLogID := ""
  79. for _, post := range posts {
  80. if allowedLogID == "" {
  81. allowedLogID = post.LogID
  82. } else if allowedLogID != post.LogID {
  83. return nil, errors.New("cannot insert multiple posts with different log IDs")
  84. }
  85. }
  86. params := psqlcore.InsertPostsParams{
  87. LogShortID: posts[0].LogID,
  88. }
  89. for _, post := range posts {
  90. if !r.insertWithIDs || len(post.ID) < 8 {
  91. post.ID = generate.PostID()
  92. }
  93. params.Ids = append(params.Ids, post.ID)
  94. params.Kinds = append(params.Kinds, post.Kind)
  95. params.Nicks = append(params.Nicks, post.Nick)
  96. params.Offsets = append(params.Offsets, int32(len(params.Offsets)+1))
  97. params.Times = append(params.Times, post.Time.UTC())
  98. params.Texts = append(params.Texts, post.Text)
  99. }
  100. offset, err := psqlcore.New(r.db).InsertPosts(ctx, params)
  101. if err != nil {
  102. return nil, err
  103. }
  104. for i, post := range posts {
  105. post.Position = int(offset) + i
  106. }
  107. err = psqlcore.New(r.db).GenerateLogTSVector(ctx, posts[0].LogID)
  108. if err != nil {
  109. return nil, err
  110. }
  111. return posts, nil
  112. }
  113. func (r *postRepository) Update(ctx context.Context, post models.Post, update models.PostUpdate) (*models.Post, error) {
  114. post.ApplyUpdate(update)
  115. err := psqlcore.New(r.db).UpdatePost(ctx, psqlcore.UpdatePostParams{
  116. Time: post.Time,
  117. Kind: post.Kind,
  118. Nick: post.Nick,
  119. Text: post.Text,
  120. ID: post.ID,
  121. })
  122. if err != nil {
  123. return nil, err
  124. }
  125. _ = psqlcore.New(r.db).GenerateLogTSVector(ctx, post.LogID)
  126. return &post, nil
  127. }
  128. func (r *postRepository) Move(ctx context.Context, post models.Post, position int) ([]*models.Post, error) {
  129. if position == post.Position {
  130. return r.List(ctx, models.PostFilter{LogID: &post.LogID})
  131. }
  132. tx, err := r.db.BeginTx(ctx, nil)
  133. if err != nil {
  134. return nil, err
  135. }
  136. defer func() { _ = tx.Rollback() }()
  137. _, err = tx.Exec("LOCK TABLE log_post IN SHARE UPDATE EXCLUSIVE MODE")
  138. if err != nil {
  139. return nil, err
  140. }
  141. var lowest, highest int32
  142. q := psqlcore.New(tx)
  143. err = q.MovePost(ctx, psqlcore.MovePostParams{ID: post.ID, Position: -1})
  144. if err != nil {
  145. return nil, err
  146. }
  147. if position > post.Position {
  148. lowest = int32(post.Position)
  149. highest = int32(position)
  150. err := q.ShiftPostsBetween(ctx, psqlcore.ShiftPostsBetweenParams{
  151. ShiftOffset: -1,
  152. LogShortID: post.LogID,
  153. FromPosition: lowest + 1,
  154. ToPosition: highest,
  155. })
  156. if err != nil {
  157. return nil, err
  158. }
  159. } else {
  160. lowest = int32(position)
  161. highest = int32(post.Position)
  162. err := q.ShiftPostsBetween(ctx, psqlcore.ShiftPostsBetweenParams{
  163. ShiftOffset: 1,
  164. LogShortID: post.LogID,
  165. FromPosition: lowest,
  166. ToPosition: highest - 1,
  167. })
  168. if err != nil {
  169. return nil, err
  170. }
  171. }
  172. err = q.MovePost(ctx, psqlcore.MovePostParams{ID: post.ID, Position: int32(position)})
  173. if err != nil {
  174. return nil, err
  175. }
  176. posts, err := q.SelectPostsByPositionRange(ctx, psqlcore.SelectPostsByPositionRangeParams{
  177. LogShortID: post.LogID,
  178. FromPosition: lowest,
  179. ToPosition: highest,
  180. })
  181. if err != nil {
  182. return nil, err
  183. }
  184. positions, err := q.SelectPositionsByLogShortID(ctx, post.LogID)
  185. if err != nil {
  186. return nil, err
  187. }
  188. prev := int32(0)
  189. for _, pos := range positions {
  190. if pos != prev+1 {
  191. return nil, errors.New("post discontinuity detected")
  192. }
  193. prev = pos
  194. }
  195. err = tx.Commit()
  196. if err != nil {
  197. return nil, err
  198. }
  199. return r.posts2(posts), nil
  200. }
  201. func (r *postRepository) Delete(ctx context.Context, post models.Post) error {
  202. tx, err := r.db.BeginTx(ctx, nil)
  203. if err != nil {
  204. return err
  205. }
  206. defer func() { _ = tx.Rollback() }()
  207. q := psqlcore.New(tx)
  208. _, err = tx.Exec("LOCK TABLE log_post IN SHARE UPDATE EXCLUSIVE MODE")
  209. if err != nil {
  210. return err
  211. }
  212. err = q.MovePost(ctx, psqlcore.MovePostParams{ID: post.ID, Position: -1})
  213. if err != nil {
  214. return fmt.Errorf("pre-delete move failed: %s", err)
  215. }
  216. err = q.ShiftPostsAfter(ctx, psqlcore.ShiftPostsAfterParams{
  217. ShiftOffset: -1,
  218. LogShortID: post.LogID,
  219. FromPosition: int32(post.Position + 1),
  220. })
  221. if err != nil {
  222. return fmt.Errorf("shift failed: %s", err)
  223. }
  224. err = q.DeletePost(ctx, post.ID)
  225. if err != nil {
  226. return fmt.Errorf("delete failed: %s", err)
  227. }
  228. err = tx.Commit()
  229. if err != nil {
  230. return fmt.Errorf("tx commit failed: %s", err)
  231. }
  232. _ = psqlcore.New(r.db).GenerateLogTSVector(ctx, post.LogID)
  233. return nil
  234. }
  235. func (r *postRepository) post(post psqlcore.SelectPostRow) *models.Post {
  236. return &models.Post{
  237. ID: post.ID,
  238. LogID: post.LogShortID,
  239. Time: post.Time,
  240. Kind: post.Kind,
  241. Nick: post.Nick,
  242. Text: post.Text,
  243. Position: int(post.Position),
  244. }
  245. }
  246. func (r *postRepository) posts(posts []psqlcore.SelectPostsRow) []*models.Post {
  247. results := make([]*models.Post, 0, len(posts))
  248. for _, post := range posts {
  249. results = append(results, r.post(psqlcore.SelectPostRow(post)))
  250. }
  251. return results
  252. }
  253. func (r *postRepository) posts2(posts []psqlcore.SelectPostsByPositionRangeRow) []*models.Post {
  254. results := make([]*models.Post, 0, len(posts))
  255. for _, post := range posts {
  256. results = append(results, r.post(psqlcore.SelectPostRow(post)))
  257. }
  258. return results
  259. }