You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
314 lines
7.0 KiB
314 lines
7.0 KiB
package postgres
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"git.aiterp.net/rpdata/api/database/postgres/psqlcore"
|
|
"git.aiterp.net/rpdata/api/internal/generate"
|
|
"git.aiterp.net/rpdata/api/models"
|
|
"strings"
|
|
)
|
|
|
|
type postRepository struct {
|
|
insertWithIDs bool
|
|
db *sql.DB
|
|
}
|
|
|
|
func (r *postRepository) Find(ctx context.Context, id string) (*models.Post, error) {
|
|
post, err := psqlcore.New(r.db).SelectPost(ctx, id)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return r.post(post), nil
|
|
}
|
|
|
|
func (r *postRepository) List(ctx context.Context, filter models.PostFilter) ([]*models.Post, error) {
|
|
q := psqlcore.New(r.db)
|
|
params := psqlcore.SelectPostsParams{LimitSize: 0}
|
|
if filter.LogID != nil {
|
|
params.FilterLogShortID = true
|
|
|
|
if !strings.HasPrefix(*filter.LogID, "L") {
|
|
log, err := q.SelectLog(ctx, *filter.LogID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
params.LogShortID = log.ShortID
|
|
} else {
|
|
params.LogShortID = *filter.LogID
|
|
}
|
|
}
|
|
if filter.Kinds != nil {
|
|
params.FilterKinds = true
|
|
params.Kinds = filter.Kinds
|
|
}
|
|
if filter.Search != nil {
|
|
params.FilterSearch = true
|
|
params.Search = TSQueryFromSearch(*filter.Search)
|
|
}
|
|
if filter.Limit > 0 {
|
|
params.LimitSize = int32(filter.Limit)
|
|
}
|
|
|
|
posts, err := q.SelectPosts(ctx, params)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return r.posts(posts), nil
|
|
}
|
|
|
|
func (r *postRepository) Insert(ctx context.Context, post models.Post) (*models.Post, error) {
|
|
if !r.insertWithIDs || len(post.ID) < 8 {
|
|
post.ID = generate.PostID()
|
|
}
|
|
|
|
position, err := psqlcore.New(r.db).InsertPost(ctx, psqlcore.InsertPostParams{
|
|
ID: post.ID,
|
|
LogShortID: post.LogID,
|
|
Time: post.Time.UTC(),
|
|
Kind: post.Kind,
|
|
Nick: post.Nick,
|
|
Text: post.Text,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
_ = psqlcore.New(r.db).GenerateLogTSVector(ctx, post.LogID)
|
|
|
|
post.Position = int(position)
|
|
|
|
return &post, nil
|
|
}
|
|
|
|
func (r *postRepository) InsertMany(ctx context.Context, posts ...*models.Post) ([]*models.Post, error) {
|
|
if len(posts) == 0 {
|
|
return []*models.Post{}, nil
|
|
}
|
|
|
|
allowedLogID := ""
|
|
for _, post := range posts {
|
|
if allowedLogID == "" {
|
|
allowedLogID = post.LogID
|
|
} else if allowedLogID != post.LogID {
|
|
return nil, errors.New("cannot insert multiple posts with different log IDs")
|
|
}
|
|
}
|
|
|
|
params := psqlcore.InsertPostsParams{
|
|
LogShortID: posts[0].LogID,
|
|
}
|
|
|
|
for _, post := range posts {
|
|
if !r.insertWithIDs || len(post.ID) < 8 {
|
|
post.ID = generate.PostID()
|
|
}
|
|
|
|
params.Ids = append(params.Ids, post.ID)
|
|
params.Kinds = append(params.Kinds, post.Kind)
|
|
params.Nicks = append(params.Nicks, post.Nick)
|
|
params.Offsets = append(params.Offsets, int32(len(params.Offsets)+1))
|
|
params.Times = append(params.Times, post.Time.UTC())
|
|
params.Texts = append(params.Texts, post.Text)
|
|
}
|
|
|
|
offset, err := psqlcore.New(r.db).InsertPosts(ctx, params)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for i, post := range posts {
|
|
post.Position = int(offset) + i
|
|
}
|
|
|
|
err = psqlcore.New(r.db).GenerateLogTSVector(ctx, posts[0].LogID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return posts, nil
|
|
}
|
|
|
|
func (r *postRepository) Update(ctx context.Context, post models.Post, update models.PostUpdate) (*models.Post, error) {
|
|
post.ApplyUpdate(update)
|
|
|
|
err := psqlcore.New(r.db).UpdatePost(ctx, psqlcore.UpdatePostParams{
|
|
Time: post.Time,
|
|
Kind: post.Kind,
|
|
Nick: post.Nick,
|
|
Text: post.Text,
|
|
ID: post.ID,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
_ = psqlcore.New(r.db).GenerateLogTSVector(ctx, post.LogID)
|
|
|
|
return &post, nil
|
|
}
|
|
|
|
func (r *postRepository) Move(ctx context.Context, post models.Post, position int) ([]*models.Post, error) {
|
|
if position == post.Position {
|
|
return r.List(ctx, models.PostFilter{LogID: &post.LogID})
|
|
}
|
|
|
|
tx, err := r.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer func() { _ = tx.Rollback() }()
|
|
|
|
_, err = tx.Exec("LOCK TABLE log_post IN SHARE UPDATE EXCLUSIVE MODE")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var lowest, highest int32
|
|
|
|
q := psqlcore.New(tx)
|
|
|
|
err = q.MovePost(ctx, psqlcore.MovePostParams{ID: post.ID, Position: -1})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if position > post.Position {
|
|
lowest = int32(post.Position)
|
|
highest = int32(position)
|
|
|
|
err := q.ShiftPostsBetween(ctx, psqlcore.ShiftPostsBetweenParams{
|
|
ShiftOffset: -1,
|
|
LogShortID: post.LogID,
|
|
FromPosition: lowest + 1,
|
|
ToPosition: highest,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
} else {
|
|
lowest = int32(position)
|
|
highest = int32(post.Position)
|
|
|
|
err := q.ShiftPostsBetween(ctx, psqlcore.ShiftPostsBetweenParams{
|
|
ShiftOffset: 1,
|
|
LogShortID: post.LogID,
|
|
FromPosition: lowest,
|
|
ToPosition: highest - 1,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
err = q.MovePost(ctx, psqlcore.MovePostParams{ID: post.ID, Position: int32(position)})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
posts, err := q.SelectPostsByPositionRange(ctx, psqlcore.SelectPostsByPositionRangeParams{
|
|
LogShortID: post.LogID,
|
|
FromPosition: lowest,
|
|
ToPosition: highest,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
positions, err := q.SelectPositionsByLogShortID(ctx, post.LogID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
prev := int32(0)
|
|
for _, pos := range positions {
|
|
if pos != prev+1 {
|
|
return nil, errors.New("post discontinuity detected")
|
|
}
|
|
|
|
prev = pos
|
|
}
|
|
|
|
err = tx.Commit()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return r.posts2(posts), nil
|
|
}
|
|
|
|
func (r *postRepository) Delete(ctx context.Context, post models.Post) error {
|
|
tx, err := r.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer func() { _ = tx.Rollback() }()
|
|
q := psqlcore.New(tx)
|
|
|
|
_, err = tx.Exec("LOCK TABLE log_post IN SHARE UPDATE EXCLUSIVE MODE")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = q.MovePost(ctx, psqlcore.MovePostParams{ID: post.ID, Position: -1})
|
|
if err != nil {
|
|
return fmt.Errorf("pre-delete move failed: %s", err)
|
|
}
|
|
|
|
err = q.ShiftPostsAfter(ctx, psqlcore.ShiftPostsAfterParams{
|
|
ShiftOffset: -1,
|
|
LogShortID: post.LogID,
|
|
FromPosition: int32(post.Position + 1),
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("shift failed: %s", err)
|
|
}
|
|
|
|
err = q.DeletePost(ctx, post.ID)
|
|
if err != nil {
|
|
return fmt.Errorf("delete failed: %s", err)
|
|
}
|
|
|
|
err = tx.Commit()
|
|
if err != nil {
|
|
return fmt.Errorf("tx commit failed: %s", err)
|
|
}
|
|
|
|
_ = psqlcore.New(r.db).GenerateLogTSVector(ctx, post.LogID)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *postRepository) post(post psqlcore.SelectPostRow) *models.Post {
|
|
return &models.Post{
|
|
ID: post.ID,
|
|
LogID: post.LogShortID,
|
|
Time: post.Time,
|
|
Kind: post.Kind,
|
|
Nick: post.Nick,
|
|
Text: post.Text,
|
|
Position: int(post.Position),
|
|
}
|
|
}
|
|
|
|
func (r *postRepository) posts(posts []psqlcore.SelectPostsRow) []*models.Post {
|
|
results := make([]*models.Post, 0, len(posts))
|
|
for _, post := range posts {
|
|
results = append(results, r.post(psqlcore.SelectPostRow(post)))
|
|
}
|
|
|
|
return results
|
|
}
|
|
|
|
func (r *postRepository) posts2(posts []psqlcore.SelectPostsByPositionRangeRow) []*models.Post {
|
|
results := make([]*models.Post, 0, len(posts))
|
|
for _, post := range posts {
|
|
results = append(results, r.post(psqlcore.SelectPostRow(post)))
|
|
}
|
|
|
|
return results
|
|
}
|