package postgres import ( "context" "database/sql" "errors" "fmt" "git.aiterp.net/rpdata/api/database/postgres/psqlcore" "git.aiterp.net/rpdata/api/internal/generate" "git.aiterp.net/rpdata/api/models" "strings" ) type postRepository struct { insertWithIDs bool db *sql.DB } func (r *postRepository) Find(ctx context.Context, id string) (*models.Post, error) { post, err := psqlcore.New(r.db).SelectPost(ctx, id) if err != nil { return nil, err } return r.post(post), nil } func (r *postRepository) List(ctx context.Context, filter models.PostFilter) ([]*models.Post, error) { q := psqlcore.New(r.db) params := psqlcore.SelectPostsParams{LimitSize: 0} if filter.LogID != nil { params.FilterLogShortID = true if !strings.HasPrefix(*filter.LogID, "L") { log, err := q.SelectLog(ctx, *filter.LogID) if err != nil { return nil, err } params.LogShortID = log.ShortID } else { params.LogShortID = *filter.LogID } } if filter.Kinds != nil { params.FilterKinds = true params.Kinds = filter.Kinds } if filter.Search != nil { params.FilterSearch = true params.Search = TSQueryFromSearch(*filter.Search) } if filter.Limit > 0 { params.LimitSize = int32(filter.Limit) } posts, err := q.SelectPosts(ctx, params) if err != nil { return nil, err } return r.posts(posts), nil } func (r *postRepository) Insert(ctx context.Context, post models.Post) (*models.Post, error) { if !r.insertWithIDs || len(post.ID) < 8 { post.ID = generate.PostID() } position, err := psqlcore.New(r.db).InsertPost(ctx, psqlcore.InsertPostParams{ ID: post.ID, LogShortID: post.LogID, Time: post.Time.UTC(), Kind: post.Kind, Nick: post.Nick, Text: post.Text, }) if err != nil { return nil, err } _ = psqlcore.New(r.db).GenerateLogTSVector(ctx, post.LogID) post.Position = int(position) return &post, nil } func (r *postRepository) InsertMany(ctx context.Context, posts ...*models.Post) ([]*models.Post, error) { if len(posts) == 0 { return []*models.Post{}, nil } allowedLogID := "" for _, post := range posts { if allowedLogID == "" { allowedLogID = post.LogID } else if allowedLogID != post.LogID { return nil, errors.New("cannot insert multiple posts with different log IDs") } } params := psqlcore.InsertPostsParams{ LogShortID: posts[0].LogID, } for _, post := range posts { if !r.insertWithIDs || len(post.ID) < 8 { post.ID = generate.PostID() } params.Ids = append(params.Ids, post.ID) params.Kinds = append(params.Kinds, post.Kind) params.Nicks = append(params.Nicks, post.Nick) params.Offsets = append(params.Offsets, int32(len(params.Offsets)+1)) params.Times = append(params.Times, post.Time.UTC()) params.Texts = append(params.Texts, post.Text) } offset, err := psqlcore.New(r.db).InsertPosts(ctx, params) if err != nil { return nil, err } for i, post := range posts { post.Position = int(offset) + i } err = psqlcore.New(r.db).GenerateLogTSVector(ctx, posts[0].LogID) if err != nil { return nil, err } return posts, nil } func (r *postRepository) Update(ctx context.Context, post models.Post, update models.PostUpdate) (*models.Post, error) { post.ApplyUpdate(update) err := psqlcore.New(r.db).UpdatePost(ctx, psqlcore.UpdatePostParams{ Time: post.Time, Kind: post.Kind, Nick: post.Nick, Text: post.Text, ID: post.ID, }) if err != nil { return nil, err } _ = psqlcore.New(r.db).GenerateLogTSVector(ctx, post.LogID) return &post, nil } func (r *postRepository) Move(ctx context.Context, post models.Post, position int) ([]*models.Post, error) { if position == post.Position { return r.List(ctx, models.PostFilter{LogID: &post.LogID}) } tx, err := r.db.BeginTx(ctx, nil) if err != nil { return nil, err } defer func() { _ = tx.Rollback() }() _, err = tx.Exec("LOCK TABLE log_post IN SHARE UPDATE EXCLUSIVE MODE") if err != nil { return nil, err } var lowest, highest int32 q := psqlcore.New(tx) err = q.MovePost(ctx, psqlcore.MovePostParams{ID: post.ID, Position: -1}) if err != nil { return nil, err } if position > post.Position { lowest = int32(post.Position) highest = int32(position) err := q.ShiftPostsBetween(ctx, psqlcore.ShiftPostsBetweenParams{ ShiftOffset: -1, LogShortID: post.LogID, FromPosition: lowest + 1, ToPosition: highest, }) if err != nil { return nil, err } } else { lowest = int32(position) highest = int32(post.Position) err := q.ShiftPostsBetween(ctx, psqlcore.ShiftPostsBetweenParams{ ShiftOffset: 1, LogShortID: post.LogID, FromPosition: lowest, ToPosition: highest - 1, }) if err != nil { return nil, err } } err = q.MovePost(ctx, psqlcore.MovePostParams{ID: post.ID, Position: int32(position)}) if err != nil { return nil, err } posts, err := q.SelectPostsByPositionRange(ctx, psqlcore.SelectPostsByPositionRangeParams{ LogShortID: post.LogID, FromPosition: lowest, ToPosition: highest, }) if err != nil { return nil, err } positions, err := q.SelectPositionsByLogShortID(ctx, post.LogID) if err != nil { return nil, err } prev := int32(0) for _, pos := range positions { if pos != prev+1 { return nil, errors.New("post discontinuity detected") } prev = pos } err = tx.Commit() if err != nil { return nil, err } return r.posts2(posts), nil } func (r *postRepository) Delete(ctx context.Context, post models.Post) error { tx, err := r.db.BeginTx(ctx, nil) if err != nil { return err } defer func() { _ = tx.Rollback() }() q := psqlcore.New(tx) _, err = tx.Exec("LOCK TABLE log_post IN SHARE UPDATE EXCLUSIVE MODE") if err != nil { return err } err = q.MovePost(ctx, psqlcore.MovePostParams{ID: post.ID, Position: -1}) if err != nil { return fmt.Errorf("pre-delete move failed: %s", err) } err = q.ShiftPostsAfter(ctx, psqlcore.ShiftPostsAfterParams{ ShiftOffset: -1, LogShortID: post.LogID, FromPosition: int32(post.Position + 1), }) if err != nil { return fmt.Errorf("shift failed: %s", err) } err = q.DeletePost(ctx, post.ID) if err != nil { return fmt.Errorf("delete failed: %s", err) } err = tx.Commit() if err != nil { return fmt.Errorf("tx commit failed: %s", err) } _ = psqlcore.New(r.db).GenerateLogTSVector(ctx, post.LogID) return nil } func (r *postRepository) post(post psqlcore.SelectPostRow) *models.Post { return &models.Post{ ID: post.ID, LogID: post.LogShortID, Time: post.Time, Kind: post.Kind, Nick: post.Nick, Text: post.Text, Position: int(post.Position), } } func (r *postRepository) posts(posts []psqlcore.SelectPostsRow) []*models.Post { results := make([]*models.Post, 0, len(posts)) for _, post := range posts { results = append(results, r.post(psqlcore.SelectPostRow(post))) } return results } func (r *postRepository) posts2(posts []psqlcore.SelectPostsByPositionRangeRow) []*models.Post { results := make([]*models.Post, 0, len(posts)) for _, post := range posts { results = append(results, r.post(psqlcore.SelectPostRow(post))) } return results }