package mysqldriver import ( "context" "database/sql" "git.aiterp.net/stufflog/server/internal/generate" "git.aiterp.net/stufflog/server/internal/xlerrors" "git.aiterp.net/stufflog/server/models" sq "github.com/Masterminds/squirrel" "github.com/jmoiron/sqlx" ) type itemRepository struct { db *sqlx.DB } func (r *itemRepository) Find(ctx context.Context, id string) (*models.Item, error) { item := models.Item{} err := r.db.GetContext(ctx, &item, "SELECT * FROM item WHERE item_id=?", id) if err != nil { if err == sql.ErrNoRows { return nil, xlerrors.NotFound("Project") } return nil, err } err = r.db.SelectContext(ctx, &item.Tags, "SELECT tag FROM item_tag WHERE item_id=? ORDER BY tag", id) if err != nil { return nil, err } return &item, nil } func (r *itemRepository) List(ctx context.Context, filter models.ItemFilter) ([]*models.Item, error) { q := sq.Select("item.*").From("item").OrderBy("name") if len(filter.ItemIDs) > 0 { q = q.Where(sq.Eq{"item_id": filter.ItemIDs}) } if len(filter.Tags) > 0 { q = q.LeftJoin("item_tag ON item_tag.item_id = item.item_id"). Where(sq.Eq{"item_tag.tag": filter.Tags}). GroupBy("item.item_id") } query, args, err := q.ToSql() if err != nil { return nil, err } results := make([]*models.Item, 0, 16) err = r.db.SelectContext(ctx, &results, query, args...) if err != nil { if err == sql.ErrNoRows { return []*models.Item{}, nil } return nil, err } err = r.fillTags(ctx, results) if err != nil { return nil, err } return results, nil } func (r *itemRepository) Insert(ctx context.Context, item models.Item) (*models.Item, error) { item.ID = generate.ItemID() tx, err := r.db.BeginTxx(ctx, nil) if err != nil { return nil, err } _, err = tx.NamedExecContext(ctx, ` INSERT INTO item (item_id, name, description, image_url) VALUES (:item_id, :name, :description, :image_url) `, item) if err != nil { _ = tx.Rollback() return nil, err } if len(item.Tags) > 0 { q := sq.Insert("item_tag").Columns("item_id", "tag") for _, tag := range item.Tags { q = q.Values(item.ID, tag) } tagQuery, args, err := q.ToSql() if err != nil { _ = tx.Rollback() return nil, err } _, err = r.db.ExecContext(ctx, tagQuery, args...) if err != nil { _ = tx.Rollback() return nil, err } } err = tx.Commit() if err != nil { return nil, err } return &item, nil } func (r *itemRepository) Save(ctx context.Context, item models.Item) error { tx, err := r.db.BeginTxx(ctx, nil) if err != nil { return err } _, err = tx.NamedExecContext(ctx, ` UPDATE item SET name=:name, description=:description, image_url=:image_url WHERE item_id=:item_id `, item) if err != nil { _ = tx.Rollback() return err } _, err = r.db.ExecContext(ctx, "DELETE FROM item_tag WHERE item_id=?", item.ID) if err != nil && err != sql.ErrNoRows { _ = tx.Rollback() return err } if len(item.Tags) > 0 { q := sq.Insert("item_tag").Columns("item_id", "tag") for _, tag := range item.Tags { q = q.Values(item.ID, tag) } tagQuery, args, err := q.ToSql() if err != nil { _ = tx.Rollback() return err } _, err = r.db.ExecContext(ctx, tagQuery, args...) if err != nil { _ = tx.Rollback() return err } } err = tx.Commit() if err != nil { return err } return nil } func (r *itemRepository) Delete(ctx context.Context, item models.Item) error { _, err := r.db.ExecContext(ctx, "DELETE FROM item WHERE item_id=?", item.ID) if err != nil { return err } _, err = r.db.ExecContext(ctx, "DELETE FROM item_tag WHERE item_id=?", item.ID) if err != nil { return err } return err } func (r *itemRepository) GetTags(ctx context.Context) ([]string, error) { tags := make([]string, 0, 16) err := r.db.SelectContext(ctx, &tags, "SELECT DISTINCT(tag) FROM item_tag") if err != nil { return nil, err } return tags, nil } func (r *itemRepository) fillTags(ctx context.Context, items []*models.Item) error { ids := make([]string, len(items)) idMap := make(map[string]int, len(items)) for i, item := range items { ids[i] = item.ID idMap[item.ID] = i } query, args, err := sq.Select("*").From("item_tag").Where(sq.Eq{"item_id": ids}).ToSql() if err != nil { return err } results := make([]struct { ItemID string `db:"item_id"` Tag string `db:"tag"` }, 0, len(items)*4) err = r.db.SelectContext(ctx, &results, query, args...) if err != nil { return err } for _, result := range results { item := items[idMap[result.ItemID]] item.Tags = append(item.Tags, result.Tag) } return nil }