|
|
package postgres
import ( "context" "database/sql" "errors" "git.aiterp.net/rpdata/api/database/postgres/psqlcore" "git.aiterp.net/rpdata/api/internal/generate" "git.aiterp.net/rpdata/api/models" "strings" )
type postRepository struct { insertWithIDs bool db *sql.DB }
func (r *postRepository) Find(ctx context.Context, id string) (*models.Post, error) { post, err := psqlcore.New(r.db).SelectPost(ctx, id) if err != nil { return nil, err }
return r.post(post), nil }
func (r *postRepository) List(ctx context.Context, filter models.PostFilter) ([]*models.Post, error) { q := psqlcore.New(r.db) params := psqlcore.SelectPostsParams{LimitSize: 0} if filter.LogID != nil { params.FilterLogShortID = true
if !strings.HasPrefix(*filter.LogID, "L") { log, err := q.SelectLog(ctx, *filter.LogID) if err != nil { return nil, err }
params.LogShortID = log.ShortID } else { params.LogShortID = *filter.LogID } } if filter.Kinds != nil { params.FilterKinds = true params.Kinds = filter.Kinds } if filter.Search != nil { params.FilterSearch = true params.Search = TSQueryFromSearch(*filter.Search) } if filter.Limit > 0 { params.LimitSize = int32(filter.Limit) }
posts, err := q.SelectPosts(ctx, params) if err != nil { return nil, err }
return r.posts(posts), nil }
func (r *postRepository) Insert(ctx context.Context, post models.Post) (*models.Post, error) { if !r.insertWithIDs || len(post.ID) < 8 { post.ID = generate.PostID() }
position, err := psqlcore.New(r.db).InsertPost(ctx, psqlcore.InsertPostParams{ ID: post.ID, LogShortID: post.LogID, Time: post.Time.UTC(), Kind: post.Kind, Nick: post.Nick, Text: post.Text, }) if err != nil { return nil, err }
_ = psqlcore.New(r.db).GenerateLogTSVector(ctx, post.LogID)
post.Position = int(position)
return &post, nil }
func (r *postRepository) InsertMany(ctx context.Context, posts ...*models.Post) ([]*models.Post, error) { if len(posts) == 0 { return []*models.Post{}, nil }
allowedLogID := "" for _, post := range posts { if allowedLogID == "" { allowedLogID = post.LogID } else if allowedLogID != post.LogID { return nil, errors.New("cannot insert multiple posts with different log IDs") } }
params := psqlcore.InsertPostsParams{ LogShortID: posts[0].LogID, }
for _, post := range posts { if !r.insertWithIDs || len(post.ID) < 8 { post.ID = generate.PostID() }
params.Ids = append(params.Ids, post.ID) params.Kinds = append(params.Kinds, post.Kind) params.Nicks = append(params.Nicks, post.Nick) params.Offsets = append(params.Offsets, int32(len(params.Offsets)+1)) params.Times = append(params.Times, post.Time.UTC()) params.Texts = append(params.Texts, post.Text) }
offset, err := psqlcore.New(r.db).InsertPosts(ctx, params) if err != nil { return nil, err }
for i, post := range posts { post.Position = int(offset) + i }
err = psqlcore.New(r.db).GenerateLogTSVector(ctx, posts[0].LogID) if err != nil { return nil, err }
return posts, nil }
func (r *postRepository) Update(ctx context.Context, post models.Post, update models.PostUpdate) (*models.Post, error) { post.ApplyUpdate(update)
err := psqlcore.New(r.db).UpdatePost(ctx, psqlcore.UpdatePostParams{ Time: post.Time, Kind: post.Kind, Nick: post.Nick, Text: post.Text, ID: post.ID, }) if err != nil { return nil, err }
_ = psqlcore.New(r.db).GenerateLogTSVector(ctx, post.LogID)
return &post, nil }
func (r *postRepository) Move(ctx context.Context, post models.Post, position int) ([]*models.Post, error) { if position == post.Position { return r.List(ctx, models.PostFilter{LogID: &post.LogID}) }
tx, err := r.db.BeginTx(ctx, nil) if err != nil { return nil, err } defer func() { _ = tx.Rollback() }()
_, err = tx.Exec("LOCK TABLE log_post IN SHARE UPDATE EXCLUSIVE MODE") if err != nil { return nil, err }
var lowest, highest int32
q := psqlcore.New(tx)
err = q.MovePost(ctx, psqlcore.MovePostParams{ID: post.ID, Position: -1}) if err != nil { return nil, err }
if position > post.Position { lowest = int32(post.Position) highest = int32(position)
err := q.ShiftPostsBetween(ctx, psqlcore.ShiftPostsBetweenParams{ ShiftOffset: -1, LogShortID: post.LogID, FromPosition: lowest + 1, ToPosition: highest, }) if err != nil { return nil, err } } else { lowest = int32(position) highest = int32(post.Position)
err := q.ShiftPostsBetween(ctx, psqlcore.ShiftPostsBetweenParams{ ShiftOffset: 1, LogShortID: post.LogID, FromPosition: lowest, ToPosition: highest - 1, }) if err != nil { return nil, err } }
err = q.MovePost(ctx, psqlcore.MovePostParams{ID: post.ID, Position: int32(position)}) if err != nil { return nil, err }
posts, err := q.SelectPostsByPositionRange(ctx, psqlcore.SelectPostsByPositionRangeParams{ LogShortID: post.LogID, FromPosition: lowest, ToPosition: highest, }) if err != nil { return nil, err }
positions, err := q.SelectPositionsByLogShortID(ctx, post.LogID) if err != nil { return nil, err } prev := int32(0) for _, pos := range positions { if pos != prev+1 { return nil, errors.New("post discontinuity detected") }
prev = pos }
err = tx.Commit() if err != nil { return nil, err }
return r.posts2(posts), nil }
func (r *postRepository) Delete(ctx context.Context, post models.Post) error { tx, err := r.db.BeginTx(ctx, nil) if err != nil { return err } defer func() { _ = tx.Rollback() }() q := psqlcore.New(tx)
_, err = tx.Exec("LOCK TABLE log_post IN SHARE UPDATE EXCLUSIVE MODE") if err != nil { return err }
err = q.DeletePost(ctx, post.ID) if err != nil { return err }
err = q.ShiftPostsAfter(ctx, psqlcore.ShiftPostsAfterParams{ ShiftOffset: -1, LogShortID: post.LogID, FromPosition: int32(post.Position + 1), }) if err != nil { return err }
_ = q.GenerateLogTSVector(ctx, post.LogID)
return tx.Commit() }
func (r *postRepository) post(post psqlcore.SelectPostRow) *models.Post { return &models.Post{ ID: post.ID, LogID: post.LogShortID, Time: post.Time, Kind: post.Kind, Nick: post.Nick, Text: post.Text, Position: int(post.Position), } }
func (r *postRepository) posts(posts []psqlcore.SelectPostsRow) []*models.Post { results := make([]*models.Post, 0, len(posts)) for _, post := range posts { results = append(results, r.post(psqlcore.SelectPostRow(post))) }
return results }
func (r *postRepository) posts2(posts []psqlcore.SelectPostsByPositionRangeRow) []*models.Post { results := make([]*models.Post, 0, len(posts)) for _, post := range posts { results = append(results, r.post(psqlcore.SelectPostRow(post))) }
return results }
|