|
|
package mysql
import ( "context" "database/sql" "git.aiterp.net/stufflog3/stufflog3-api/internal/database/mysql/mysqlcore" "git.aiterp.net/stufflog3/stufflog3-api/internal/models" "git.aiterp.net/stufflog3/stufflog3-api/internal/slerrors" "git.aiterp.net/stufflog3/stufflog3-api/internal/sqltypes" "strings" "time" )
type itemRepository struct { db *sql.DB scopeID int }
func (r *itemRepository) Find(ctx context.Context, id int) (*models.Item, error) { res, err := mysqlcore.New(r.db).GetItem(ctx, id) if err != nil { if err == sql.ErrNoRows { return nil, slerrors.NotFound("Item") } return nil, err } if res.ScopeID != r.scopeID { return nil, slerrors.NotFound("Item") }
item := r.resToItem(mysqlcore.ListItemsAcquiredBetweenRow(res))
stats, _ := mysqlcore.New(r.db).ListItemStatProgress(ctx, id) for _, stat := range stats { item.Stats = append(item.Stats, models.StatProgressEntry{ StatEntry: models.StatEntry{ ID: int(stat.ID.Int32), Name: stat.Name.String, Weight: stat.Weight, }, Acquired: stat.Acquired, Required: stat.Required, }) }
return &item, nil }
func (r *itemRepository) ListCreated(ctx context.Context, from, to time.Time) ([]models.Item, error) { rows, err := mysqlcore.New(r.db).ListItemsCreatedBetween(ctx, mysqlcore.ListItemsCreatedBetweenParams{ CreatedTime: from, CreatedTime_2: to, ScopeID: r.scopeID, }) if err != nil { return nil, err }
items := make([]models.Item, 0, len(rows)) for _, row := range rows { items = append(items, r.resToItem(mysqlcore.ListItemsAcquiredBetweenRow(row))) }
err = r.fillStats(ctx, items) if err != nil { return nil, err }
return items, nil }
func (r *itemRepository) ListAcquired(ctx context.Context, from, to time.Time) ([]models.Item, error) { rows, err := mysqlcore.New(r.db).ListItemsAcquiredBetween(ctx, mysqlcore.ListItemsAcquiredBetweenParams{ AcquiredTime: sql.NullTime{Valid: true, Time: from}, AcquiredTime_2: sql.NullTime{Valid: true, Time: to}, ScopeID: r.scopeID, }) if err != nil { return nil, err }
items := make([]models.Item, 0, len(rows)) for _, row := range rows { items = append(items, r.resToItem(row)) }
err = r.fillStats(ctx, items) if err != nil { return nil, err }
return items, nil }
func (r *itemRepository) ListScheduled(ctx context.Context, from, to models.Date) ([]models.Item, error) { rows, err := mysqlcore.New(r.db).ListItemsScheduledBetween(ctx, mysqlcore.ListItemsScheduledBetweenParams{ ScheduledDate: sqltypes.NullDate{Valid: true, Date: from}, ScheduledDate_2: sqltypes.NullDate{Valid: true, Date: to}, ScopeID: r.scopeID, }) if err != nil { return nil, err }
items := make([]models.Item, 0, len(rows)) for _, row := range rows { items = append(items, r.resToItem(mysqlcore.ListItemsAcquiredBetweenRow(row))) }
err = r.fillStats(ctx, items) if err != nil { return nil, err }
return items, nil }
func (r *itemRepository) ListLoose(ctx context.Context, from, to time.Time) ([]models.Item, error) { rows, err := mysqlcore.New(r.db).ListItemsLooseBetween(ctx, mysqlcore.ListItemsLooseBetweenParams{ CreatedTime: from, CreatedTime_2: to, ScopeID: r.scopeID, }) if err != nil { return nil, err }
items := make([]models.Item, 0, len(rows)) for _, row := range rows { items = append(items, r.resToItem(mysqlcore.ListItemsAcquiredBetweenRow(row))) }
err = r.fillStats(ctx, items) if err != nil { return nil, err }
return items, nil }
func (r *itemRepository) Create(ctx context.Context, item models.Item) (*models.Item, error) { item.Stats = append(item.Stats[:0:0], item.Stats...)
tx, err := r.db.BeginTx(ctx, nil) if err != nil { return nil, err } defer tx.Rollback() q := mysqlcore.New(tx)
prID, acqTime, schDate := r.generateNullables(item)
res, err := q.InsertItem(ctx, mysqlcore.InsertItemParams{ ScopeID: r.scopeID, ProjectRequirementID: prID, Name: item.Name, Description: item.Description, CreatedTime: time.Now(), CreatedUserID: item.OwnerID, AcquiredTime: acqTime, ScheduledDate: schDate, }) if err != nil { return nil, err }
id, err := res.LastInsertId() if err != nil { return nil, err }
for _, stat := range item.Stats { err = q.ReplaceItemStatProgress(ctx, mysqlcore.ReplaceItemStatProgressParams{ ItemID: int(id), StatID: stat.ID, Acquired: stat.Acquired, Required: stat.Required, }) if err != nil { return nil, err } }
err = tx.Commit() if err != nil { return nil, err }
return r.Find(ctx, int(id)) }
func (r *itemRepository) generateNullables(item models.Item) (prID sql.NullInt32, acqTime sql.NullTime, schDate sqltypes.NullDate) { if item.ProjectRequirementID != nil { prID.Valid = true prID.Int32 = int32(*item.ProjectRequirementID) } if item.AcquiredTime != nil { acqTime.Valid = true acqTime.Time = *item.AcquiredTime } if item.ScheduledDate != nil { schDate.Valid = true schDate.Date = *item.ScheduledDate }
return }
func (r *itemRepository) Update(ctx context.Context, item models.Item, update models.ItemUpdate) (*models.Item, error) { tx, err := r.db.BeginTx(ctx, nil) if err != nil { return nil, err } defer tx.Rollback() q := mysqlcore.New(r.db)
item.ApplyUpdate(update)
prID, acqTime, schDate := r.generateNullables(item)
err = q.UpdateItem(ctx, mysqlcore.UpdateItemParams{ ID: item.ID,
ProjectRequirementID: prID, Name: item.Name, Description: item.Description, AcquiredTime: acqTime, ScheduledDate: schDate, CreatedUserID: item.OwnerID, }) if err != nil { return nil, err }
for _, stat := range update.Stats { if stat.Acquired == 0 && stat.Required == 0 { err = q.DeleteItemStatProgress(ctx, mysqlcore.DeleteItemStatProgressParams{ ItemID: item.ID, StatID: stat.ID, }) } else { err = q.ReplaceItemStatProgress(ctx, mysqlcore.ReplaceItemStatProgressParams{ ItemID: item.ID, StatID: stat.ID, Acquired: stat.Acquired, Required: stat.Required, }) }
if err != nil { return nil, err } }
err = tx.Commit() if err != nil { return nil, err }
return r.Find(ctx, item.ID) }
func (r *itemRepository) Delete(ctx context.Context, item models.Item) error { tx, err := r.db.BeginTx(ctx, nil) if err != nil { return err } defer tx.Rollback() q := mysqlcore.New(r.db)
err = q.DeleteItem(ctx, item.ID) if err != nil { return err } err = q.ClearItemStatProgress(ctx, item.ID) if err != nil { return err }
return tx.Commit() }
func (r *itemRepository) resToItem(res mysqlcore.ListItemsAcquiredBetweenRow) models.Item { item := models.Item{ ID: res.ID, ScopeID: res.ScopeID, OwnerID: res.CreatedUserID, Name: res.Name, Description: res.Description, CreatedTime: res.CreatedTime.UTC(), } if res.ProjectRequirementID.Valid { projectRequirementID := int(res.ProjectRequirementID.Int32) projectID := int(res.ProjectID.Int32) item.ProjectRequirementID = &projectRequirementID item.ProjectID = &projectID } if res.ScheduledDate.Valid { item.ScheduledDate = &res.ScheduledDate.Date } if res.AcquiredTime.Valid { item.AcquiredTime = &res.AcquiredTime.Time }
return item }
func (r *itemRepository) fillStats(ctx context.Context, items []models.Item) error { if len(items) == 0 { return nil }
ids := make([]interface{}, 0, len(items)) for _, item := range items { ids = append(ids, item.ID) } query := ` SELECT isp.item_id, isp.required, isp.acquired, s.id, s.name, s.weight FROM item_stat_progress isp LEFT JOIN stat s ON s.id = isp.stat_id WHERE item_id IN (?` + strings.Repeat(",?", len(ids)-1) + `); `
rows, err := r.db.QueryContext(ctx, query, ids...) if err != nil { return err }
for rows.Next() { var itemID, required, acquired, statID int var statName string var statWeight float64
err = rows.Scan(&itemID, &required, &acquired, &statID, &statName, &statWeight) if err != nil { return err }
for i := range items { if items[i].ID == itemID { items[i].Stats = append(items[i].Stats, models.StatProgressEntry{ StatEntry: models.StatEntry{ ID: statID, Name: statName, Weight: statWeight, }, Acquired: acquired, Required: required, }) } } }
return nil }
|