GraphQL API and utilities for the rpdata project
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

308 lines
6.8 KiB

package postgres
import (
"context"
"database/sql"
"errors"
"git.aiterp.net/rpdata/api/database/postgres/psqlcore"
"git.aiterp.net/rpdata/api/internal/generate"
"git.aiterp.net/rpdata/api/models"
"strings"
)
type postRepository struct {
insertWithIDs bool
db *sql.DB
}
func (r *postRepository) Find(ctx context.Context, id string) (*models.Post, error) {
post, err := psqlcore.New(r.db).SelectPost(ctx, id)
if err != nil {
return nil, err
}
return r.post(post), nil
}
func (r *postRepository) List(ctx context.Context, filter models.PostFilter) ([]*models.Post, error) {
q := psqlcore.New(r.db)
params := psqlcore.SelectPostsParams{LimitSize: 0}
if filter.LogID != nil {
params.FilterLogShortID = true
if !strings.HasPrefix(*filter.LogID, "L") {
log, err := q.SelectLog(ctx, *filter.LogID)
if err != nil {
return nil, err
}
params.LogShortID = log.ShortID
} else {
params.LogShortID = *filter.LogID
}
}
if filter.Kinds != nil {
params.FilterKinds = true
params.Kinds = filter.Kinds
}
if filter.Search != nil {
params.FilterSearch = true
params.Search = TSQueryFromSearch(*filter.Search)
}
if filter.Limit > 0 {
params.LimitSize = int32(filter.Limit)
}
posts, err := q.SelectPosts(ctx, params)
if err != nil {
return nil, err
}
return r.posts(posts), nil
}
func (r *postRepository) Insert(ctx context.Context, post models.Post) (*models.Post, error) {
if !r.insertWithIDs || len(post.ID) < 8 {
post.ID = generate.PostID()
}
position, err := psqlcore.New(r.db).InsertPost(ctx, psqlcore.InsertPostParams{
ID: post.ID,
LogShortID: post.LogID,
Time: post.Time.UTC(),
Kind: post.Kind,
Nick: post.Nick,
Text: post.Text,
})
if err != nil {
return nil, err
}
_ = psqlcore.New(r.db).GenerateLogTSVector(ctx, post.LogID)
post.Position = int(position)
return &post, nil
}
func (r *postRepository) InsertMany(ctx context.Context, posts ...*models.Post) ([]*models.Post, error) {
if len(posts) == 0 {
return []*models.Post{}, nil
}
allowedLogID := ""
for _, post := range posts {
if allowedLogID == "" {
allowedLogID = post.LogID
} else if allowedLogID != post.LogID {
return nil, errors.New("cannot insert multiple posts with different log IDs")
}
}
params := psqlcore.InsertPostsParams{
LogShortID: posts[0].LogID,
}
for _, post := range posts {
if !r.insertWithIDs || len(post.ID) < 8 {
post.ID = generate.PostID()
}
params.Ids = append(params.Ids, post.ID)
params.Kinds = append(params.Kinds, post.Kind)
params.Nicks = append(params.Nicks, post.Nick)
params.Offsets = append(params.Offsets, int32(len(params.Offsets)+1))
params.Times = append(params.Times, post.Time.UTC())
params.Texts = append(params.Texts, post.Text)
}
offset, err := psqlcore.New(r.db).InsertPosts(ctx, params)
if err != nil {
return nil, err
}
for i, post := range posts {
post.Position = int(offset) + i
}
err = psqlcore.New(r.db).GenerateLogTSVector(ctx, posts[0].LogID)
if err != nil {
return nil, err
}
return posts, nil
}
func (r *postRepository) Update(ctx context.Context, post models.Post, update models.PostUpdate) (*models.Post, error) {
post.ApplyUpdate(update)
err := psqlcore.New(r.db).UpdatePost(ctx, psqlcore.UpdatePostParams{
Time: post.Time,
Kind: post.Kind,
Nick: post.Nick,
Text: post.Text,
ID: post.ID,
})
if err != nil {
return nil, err
}
_ = psqlcore.New(r.db).GenerateLogTSVector(ctx, post.LogID)
return &post, nil
}
func (r *postRepository) Move(ctx context.Context, post models.Post, position int) ([]*models.Post, error) {
if position == post.Position {
return r.List(ctx, models.PostFilter{LogID: &post.LogID})
}
tx, err := r.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer func() { _ = tx.Rollback() }()
_, err = tx.Exec("LOCK TABLE log_post IN SHARE UPDATE EXCLUSIVE MODE")
if err != nil {
return nil, err
}
var lowest, highest int32
q := psqlcore.New(tx)
err = q.MovePost(ctx, psqlcore.MovePostParams{ID: post.ID, Position: -1})
if err != nil {
return nil, err
}
if position > post.Position {
lowest = int32(post.Position)
highest = int32(position)
err := q.ShiftPostsBetween(ctx, psqlcore.ShiftPostsBetweenParams{
ShiftOffset: -1,
LogShortID: post.LogID,
FromPosition: lowest + 1,
ToPosition: highest,
})
if err != nil {
return nil, err
}
} else {
lowest = int32(position)
highest = int32(post.Position)
err := q.ShiftPostsBetween(ctx, psqlcore.ShiftPostsBetweenParams{
ShiftOffset: 1,
LogShortID: post.LogID,
FromPosition: lowest,
ToPosition: highest - 1,
})
if err != nil {
return nil, err
}
}
err = q.MovePost(ctx, psqlcore.MovePostParams{ID: post.ID, Position: int32(position)})
if err != nil {
return nil, err
}
posts, err := q.SelectPostsByPositionRange(ctx, psqlcore.SelectPostsByPositionRangeParams{
LogShortID: post.LogID,
FromPosition: lowest,
ToPosition: highest,
})
if err != nil {
return nil, err
}
positions, err := q.SelectPositionsByLogShortID(ctx, post.LogID)
if err != nil {
return nil, err
}
prev := int32(0)
for _, pos := range positions {
if pos != prev+1 {
return nil, errors.New("post discontinuity detected")
}
prev = pos
}
err = tx.Commit()
if err != nil {
return nil, err
}
return r.posts2(posts), nil
}
func (r *postRepository) Delete(ctx context.Context, post models.Post) error {
tx, err := r.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer func() { _ = tx.Rollback() }()
q := psqlcore.New(tx)
_, err = tx.Exec("LOCK TABLE log_post IN SHARE UPDATE EXCLUSIVE MODE")
if err != nil {
return err
}
err = q.DeletePost(ctx, post.ID)
if err != nil {
return err
}
err = q.ShiftPostsAfter(ctx, psqlcore.ShiftPostsAfterParams{
ShiftOffset: -1,
LogShortID: post.LogID,
FromPosition: int32(post.Position + 1),
})
if err != nil {
return err
}
err = tx.Commit()
if err != nil {
return err
}
_ = psqlcore.New(r.db).GenerateLogTSVector(ctx, post.LogID)
return nil
}
func (r *postRepository) post(post psqlcore.SelectPostRow) *models.Post {
return &models.Post{
ID: post.ID,
LogID: post.LogShortID,
Time: post.Time,
Kind: post.Kind,
Nick: post.Nick,
Text: post.Text,
Position: int(post.Position),
}
}
func (r *postRepository) posts(posts []psqlcore.SelectPostsRow) []*models.Post {
results := make([]*models.Post, 0, len(posts))
for _, post := range posts {
results = append(results, r.post(psqlcore.SelectPostRow(post)))
}
return results
}
func (r *postRepository) posts2(posts []psqlcore.SelectPostsByPositionRangeRow) []*models.Post {
results := make([]*models.Post, 0, len(posts))
for _, post := range posts {
results = append(results, r.post(psqlcore.SelectPostRow(post)))
}
return results
}