You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
374 lines
8.7 KiB
374 lines
8.7 KiB
package mysql
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"git.aiterp.net/stufflog3/stufflog3-api/internal/database/mysql/mysqlcore"
|
|
"git.aiterp.net/stufflog3/stufflog3-api/internal/models"
|
|
"git.aiterp.net/stufflog3/stufflog3-api/internal/slerrors"
|
|
"git.aiterp.net/stufflog3/stufflog3-api/internal/sqltypes"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
type itemRepository struct {
|
|
db *sql.DB
|
|
q *mysqlcore.Queries
|
|
scopeID int
|
|
}
|
|
|
|
func (r *itemRepository) Find(ctx context.Context, id int) (*models.Item, error) {
|
|
res, err := r.q.GetItem(ctx, id)
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, slerrors.NotFound("Item")
|
|
}
|
|
return nil, err
|
|
}
|
|
if res.ScopeID != r.scopeID {
|
|
return nil, slerrors.NotFound("Item")
|
|
}
|
|
|
|
item := r.resToItem(mysqlcore.ListItemsAcquiredBetweenRow(res))
|
|
|
|
stats, _ := r.q.ListItemStatProgress(ctx, id)
|
|
for _, stat := range stats {
|
|
item.Stats = append(item.Stats, models.StatProgressEntry{
|
|
StatEntry: models.StatEntry{
|
|
ID: int(stat.ID.Int32),
|
|
Name: stat.Name.String,
|
|
Weight: stat.Weight,
|
|
},
|
|
Acquired: stat.Acquired,
|
|
Required: stat.Required,
|
|
})
|
|
}
|
|
|
|
return &item, nil
|
|
}
|
|
|
|
func (r *itemRepository) ListCreated(ctx context.Context, from, to time.Time) ([]models.Item, error) {
|
|
rows, err := r.q.ListItemsCreatedBetween(ctx, mysqlcore.ListItemsCreatedBetweenParams{
|
|
CreatedTime: from,
|
|
CreatedTime_2: to,
|
|
ScopeID: r.scopeID,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
items := make([]models.Item, 0, len(rows))
|
|
for _, row := range rows {
|
|
items = append(items, r.resToItem(mysqlcore.ListItemsAcquiredBetweenRow(row)))
|
|
}
|
|
|
|
err = r.fillStats(ctx, items)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return items, nil
|
|
}
|
|
|
|
func (r *itemRepository) ListAcquired(ctx context.Context, from, to time.Time) ([]models.Item, error) {
|
|
rows, err := r.q.ListItemsAcquiredBetween(ctx, mysqlcore.ListItemsAcquiredBetweenParams{
|
|
AcquiredTime: sql.NullTime{Valid: true, Time: from},
|
|
AcquiredTime_2: sql.NullTime{Valid: true, Time: to},
|
|
ScopeID: r.scopeID,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
items := make([]models.Item, 0, len(rows))
|
|
for _, row := range rows {
|
|
items = append(items, r.resToItem(row))
|
|
}
|
|
|
|
err = r.fillStats(ctx, items)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return items, nil
|
|
}
|
|
|
|
func (r *itemRepository) ListScheduled(ctx context.Context, from, to models.Date) ([]models.Item, error) {
|
|
rows, err := r.q.ListItemsScheduledBetween(ctx, mysqlcore.ListItemsScheduledBetweenParams{
|
|
ScheduledDate: sqltypes.NullDate{Valid: true, Date: from},
|
|
ScheduledDate_2: sqltypes.NullDate{Valid: true, Date: to},
|
|
ScopeID: r.scopeID,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
items := make([]models.Item, 0, len(rows))
|
|
for _, row := range rows {
|
|
items = append(items, r.resToItem(mysqlcore.ListItemsAcquiredBetweenRow(row)))
|
|
}
|
|
|
|
err = r.fillStats(ctx, items)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return items, nil
|
|
}
|
|
|
|
func (r *itemRepository) ListLoose(ctx context.Context, from, to time.Time) ([]models.Item, error) {
|
|
rows, err := r.q.ListItemsLooseBetween(ctx, mysqlcore.ListItemsLooseBetweenParams{
|
|
CreatedTime: from,
|
|
CreatedTime_2: to,
|
|
ScopeID: r.scopeID,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
items := make([]models.Item, 0, len(rows))
|
|
for _, row := range rows {
|
|
items = append(items, r.resToItem(mysqlcore.ListItemsAcquiredBetweenRow(row)))
|
|
}
|
|
|
|
err = r.fillStats(ctx, items)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return items, nil
|
|
}
|
|
|
|
func (r *itemRepository) Create(ctx context.Context, item models.Item) (*models.Item, error) {
|
|
item.Stats = append(item.Stats[:0:0], item.Stats...)
|
|
|
|
tx, err := r.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer tx.Rollback()
|
|
q := r.q.WithTx(tx)
|
|
|
|
if item.ProjectRequirementID != nil {
|
|
pr, err := q.GetProjectRequirement(ctx, *item.ProjectRequirementID)
|
|
if err != nil || pr.ScopeID != r.scopeID {
|
|
return nil, slerrors.NotFound("Project requirement")
|
|
}
|
|
}
|
|
|
|
prID, acqTime, schDate := r.generateNullables(item)
|
|
|
|
res, err := q.InsertItem(ctx, mysqlcore.InsertItemParams{
|
|
ScopeID: r.scopeID,
|
|
ProjectRequirementID: prID,
|
|
Name: item.Name,
|
|
Description: item.Description,
|
|
CreatedTime: time.Now(),
|
|
CreatedUserID: item.OwnerID,
|
|
AcquiredTime: acqTime,
|
|
ScheduledDate: schDate,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
id, err := res.LastInsertId()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, stat := range item.Stats {
|
|
err = q.ReplaceItemStatProgress(ctx, mysqlcore.ReplaceItemStatProgressParams{
|
|
ItemID: int(id),
|
|
StatID: stat.ID,
|
|
Acquired: stat.Acquired,
|
|
Required: stat.Required,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
err = tx.Commit()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return r.Find(ctx, int(id))
|
|
}
|
|
|
|
func (r *itemRepository) generateNullables(item models.Item) (prID sql.NullInt32, acqTime sql.NullTime, schDate sqltypes.NullDate) {
|
|
if item.ProjectRequirementID != nil {
|
|
prID.Valid = true
|
|
prID.Int32 = int32(*item.ProjectRequirementID)
|
|
}
|
|
if item.AcquiredTime != nil {
|
|
acqTime.Valid = true
|
|
acqTime.Time = *item.AcquiredTime
|
|
}
|
|
if item.ScheduledDate != nil {
|
|
schDate.Valid = true
|
|
schDate.Date = *item.ScheduledDate
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (r *itemRepository) Update(ctx context.Context, item models.Item, update models.ItemUpdate) (*models.Item, error) {
|
|
tx, err := r.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
q := r.q.WithTx(tx)
|
|
|
|
if update.ProjectRequirementID != nil {
|
|
pr, err := q.GetProjectRequirement(ctx, *update.ProjectRequirementID)
|
|
if err != nil || pr.ScopeID != r.scopeID {
|
|
return nil, slerrors.NotFound("Project requirement")
|
|
}
|
|
}
|
|
|
|
item.ApplyUpdate(update)
|
|
|
|
prID, acqTime, schDate := r.generateNullables(item)
|
|
|
|
err = q.UpdateItem(ctx, mysqlcore.UpdateItemParams{
|
|
ID: item.ID,
|
|
|
|
ProjectRequirementID: prID,
|
|
Name: item.Name,
|
|
Description: item.Description,
|
|
AcquiredTime: acqTime,
|
|
ScheduledDate: schDate,
|
|
CreatedUserID: item.OwnerID,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, stat := range update.Stats {
|
|
if stat.Acquired == 0 && stat.Required == 0 {
|
|
err = q.DeleteItemStatProgress(ctx, mysqlcore.DeleteItemStatProgressParams{
|
|
ItemID: item.ID,
|
|
StatID: stat.ID,
|
|
})
|
|
} else {
|
|
err = q.ReplaceItemStatProgress(ctx, mysqlcore.ReplaceItemStatProgressParams{
|
|
ItemID: item.ID,
|
|
StatID: stat.ID,
|
|
Acquired: stat.Acquired,
|
|
Required: stat.Required,
|
|
})
|
|
}
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
err = tx.Commit()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return r.Find(ctx, item.ID)
|
|
}
|
|
|
|
func (r *itemRepository) Delete(ctx context.Context, item models.Item) error {
|
|
tx, err := r.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer tx.Rollback()
|
|
q := r.q.WithTx(tx)
|
|
|
|
err = q.DeleteItem(ctx, item.ID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = q.ClearItemStatProgress(ctx, item.ID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return tx.Commit()
|
|
}
|
|
|
|
func (r *itemRepository) resToItem(res mysqlcore.ListItemsAcquiredBetweenRow) models.Item {
|
|
item := models.Item{
|
|
ID: res.ID,
|
|
ScopeID: res.ScopeID,
|
|
OwnerID: res.CreatedUserID,
|
|
Name: res.Name,
|
|
Description: res.Description,
|
|
CreatedTime: res.CreatedTime.UTC(),
|
|
}
|
|
if res.ProjectRequirementID.Valid {
|
|
projectRequirementID := int(res.ProjectRequirementID.Int32)
|
|
projectID := int(res.ProjectID.Int32)
|
|
item.ProjectRequirementID = &projectRequirementID
|
|
item.ProjectID = &projectID
|
|
}
|
|
if res.ScheduledDate.Valid {
|
|
item.ScheduledDate = &res.ScheduledDate.Date
|
|
}
|
|
if res.AcquiredTime.Valid {
|
|
item.AcquiredTime = &res.AcquiredTime.Time
|
|
}
|
|
|
|
return item
|
|
}
|
|
|
|
func (r *itemRepository) fillStats(ctx context.Context, items []models.Item) error {
|
|
if len(items) == 0 {
|
|
return nil
|
|
}
|
|
|
|
ids := make([]interface{}, 0, len(items))
|
|
for _, item := range items {
|
|
ids = append(ids, item.ID)
|
|
}
|
|
query := `
|
|
SELECT isp.item_id, isp.required, isp.acquired, s.id, s.name, s.weight FROM item_stat_progress isp
|
|
LEFT JOIN stat s ON s.id = isp.stat_id
|
|
WHERE item_id IN (?` + strings.Repeat(",?", len(ids)-1) + `);
|
|
`
|
|
|
|
rows, err := r.db.QueryContext(ctx, query, ids...)
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
for rows.Next() {
|
|
var itemID, required, acquired, statID int
|
|
var statName string
|
|
var statWeight float64
|
|
|
|
err = rows.Scan(&itemID, &required, &acquired, &statID, &statName, &statWeight)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for i := range items {
|
|
if items[i].ID == itemID {
|
|
items[i].Stats = append(items[i].Stats, models.StatProgressEntry{
|
|
StatEntry: models.StatEntry{
|
|
ID: statID,
|
|
Name: statName,
|
|
Weight: statWeight,
|
|
},
|
|
Acquired: acquired,
|
|
Required: required,
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|