package mysql import ( "context" "database/sql" "git.aiterp.net/stufflog3/stufflog3-api/internal/database/mysql/mysqlcore" "git.aiterp.net/stufflog3/stufflog3-api/internal/models" "git.aiterp.net/stufflog3/stufflog3-api/internal/slerrors" "git.aiterp.net/stufflog3/stufflog3-api/internal/sqltypes" "strings" "time" ) type itemRepository struct { db *sql.DB q *mysqlcore.Queries scopeID int } func (r *itemRepository) Find(ctx context.Context, id int) (*models.Item, error) { res, err := r.q.GetItem(ctx, id) if err != nil { if err == sql.ErrNoRows { return nil, slerrors.NotFound("Item") } return nil, err } if res.ScopeID != r.scopeID { return nil, slerrors.NotFound("Item") } item := r.resToItem(mysqlcore.ListItemsAcquiredBetweenRow(res)) stats, _ := r.q.ListItemStatProgress(ctx, id) for _, stat := range stats { item.Stats = append(item.Stats, models.StatProgressEntry{ StatEntry: models.StatEntry{ ID: int(stat.ID.Int32), Name: stat.Name.String, Weight: stat.Weight, }, Acquired: stat.Acquired, Required: stat.Required, }) } return &item, nil } func (r *itemRepository) ListCreated(ctx context.Context, from, to time.Time) ([]models.Item, error) { rows, err := r.q.ListItemsCreatedBetween(ctx, mysqlcore.ListItemsCreatedBetweenParams{ CreatedTime: from, CreatedTime_2: to, ScopeID: r.scopeID, }) if err != nil { return nil, err } items := make([]models.Item, 0, len(rows)) for _, row := range rows { items = append(items, r.resToItem(mysqlcore.ListItemsAcquiredBetweenRow(row))) } err = r.fillStats(ctx, items) if err != nil { return nil, err } return items, nil } func (r *itemRepository) ListAcquired(ctx context.Context, from, to time.Time) ([]models.Item, error) { rows, err := r.q.ListItemsAcquiredBetween(ctx, mysqlcore.ListItemsAcquiredBetweenParams{ AcquiredTime: sql.NullTime{Valid: true, Time: from}, AcquiredTime_2: sql.NullTime{Valid: true, Time: to}, ScopeID: r.scopeID, }) if err != nil { return nil, err } items := make([]models.Item, 0, len(rows)) for _, row := range rows { items = append(items, r.resToItem(row)) } err = r.fillStats(ctx, items) if err != nil { return nil, err } return items, nil } func (r *itemRepository) ListScheduled(ctx context.Context, from, to models.Date) ([]models.Item, error) { rows, err := r.q.ListItemsScheduledBetween(ctx, mysqlcore.ListItemsScheduledBetweenParams{ ScheduledDate: sqltypes.NullDate{Valid: true, Date: from}, ScheduledDate_2: sqltypes.NullDate{Valid: true, Date: to}, ScopeID: r.scopeID, }) if err != nil { return nil, err } items := make([]models.Item, 0, len(rows)) for _, row := range rows { items = append(items, r.resToItem(mysqlcore.ListItemsAcquiredBetweenRow(row))) } err = r.fillStats(ctx, items) if err != nil { return nil, err } return items, nil } func (r *itemRepository) ListLoose(ctx context.Context, from, to time.Time) ([]models.Item, error) { rows, err := r.q.ListItemsLooseBetween(ctx, mysqlcore.ListItemsLooseBetweenParams{ CreatedTime: from, CreatedTime_2: to, ScopeID: r.scopeID, }) if err != nil { return nil, err } items := make([]models.Item, 0, len(rows)) for _, row := range rows { items = append(items, r.resToItem(mysqlcore.ListItemsAcquiredBetweenRow(row))) } err = r.fillStats(ctx, items) if err != nil { return nil, err } return items, nil } func (r *itemRepository) Create(ctx context.Context, item models.Item) (*models.Item, error) { item.Stats = append(item.Stats[:0:0], item.Stats...) tx, err := r.db.BeginTx(ctx, nil) if err != nil { return nil, err } defer tx.Rollback() q := r.q.WithTx(tx) if item.ProjectRequirementID != nil { pr, err := q.GetProjectRequirement(ctx, *item.ProjectRequirementID) if err != nil || pr.ScopeID != r.scopeID { return nil, slerrors.NotFound("Project requirement") } } prID, acqTime, schDate := r.generateNullables(item) res, err := q.InsertItem(ctx, mysqlcore.InsertItemParams{ ScopeID: r.scopeID, ProjectRequirementID: prID, Name: item.Name, Description: item.Description, CreatedTime: time.Now(), CreatedUserID: item.OwnerID, AcquiredTime: acqTime, ScheduledDate: schDate, }) if err != nil { return nil, err } id, err := res.LastInsertId() if err != nil { return nil, err } for _, stat := range item.Stats { err = q.ReplaceItemStatProgress(ctx, mysqlcore.ReplaceItemStatProgressParams{ ItemID: int(id), StatID: stat.ID, Acquired: stat.Acquired, Required: stat.Required, }) if err != nil { return nil, err } } err = tx.Commit() if err != nil { return nil, err } return r.Find(ctx, int(id)) } func (r *itemRepository) generateNullables(item models.Item) (prID sql.NullInt32, acqTime sql.NullTime, schDate sqltypes.NullDate) { if item.ProjectRequirementID != nil { prID.Valid = true prID.Int32 = int32(*item.ProjectRequirementID) } if item.AcquiredTime != nil { acqTime.Valid = true acqTime.Time = *item.AcquiredTime } if item.ScheduledDate != nil { schDate.Valid = true schDate.Date = *item.ScheduledDate } return } func (r *itemRepository) Update(ctx context.Context, item models.Item, update models.ItemUpdate) (*models.Item, error) { tx, err := r.db.BeginTx(ctx, nil) if err != nil { return nil, err } defer tx.Rollback() q := r.q.WithTx(tx) if update.ProjectRequirementID != nil { pr, err := q.GetProjectRequirement(ctx, *update.ProjectRequirementID) if err != nil || pr.ScopeID != r.scopeID { return nil, slerrors.NotFound("Project requirement") } } item.ApplyUpdate(update) prID, acqTime, schDate := r.generateNullables(item) err = q.UpdateItem(ctx, mysqlcore.UpdateItemParams{ ID: item.ID, ProjectRequirementID: prID, Name: item.Name, Description: item.Description, AcquiredTime: acqTime, ScheduledDate: schDate, CreatedUserID: item.OwnerID, }) if err != nil { return nil, err } for _, stat := range update.Stats { if stat.Acquired == 0 && stat.Required == 0 { err = q.DeleteItemStatProgress(ctx, mysqlcore.DeleteItemStatProgressParams{ ItemID: item.ID, StatID: stat.ID, }) } else { err = q.ReplaceItemStatProgress(ctx, mysqlcore.ReplaceItemStatProgressParams{ ItemID: item.ID, StatID: stat.ID, Acquired: stat.Acquired, Required: stat.Required, }) } if err != nil { return nil, err } } err = tx.Commit() if err != nil { return nil, err } return r.Find(ctx, item.ID) } func (r *itemRepository) Delete(ctx context.Context, item models.Item) error { tx, err := r.db.BeginTx(ctx, nil) if err != nil { return err } defer tx.Rollback() q := r.q.WithTx(tx) err = q.DeleteItem(ctx, item.ID) if err != nil { return err } err = q.ClearItemStatProgress(ctx, item.ID) if err != nil { return err } return tx.Commit() } func (r *itemRepository) resToItem(res mysqlcore.ListItemsAcquiredBetweenRow) models.Item { item := models.Item{ ID: res.ID, ScopeID: res.ScopeID, OwnerID: res.CreatedUserID, Name: res.Name, Description: res.Description, CreatedTime: res.CreatedTime.UTC(), } if res.ProjectRequirementID.Valid { projectRequirementID := int(res.ProjectRequirementID.Int32) projectID := int(res.ProjectID.Int32) item.ProjectRequirementID = &projectRequirementID item.ProjectID = &projectID } if res.ScheduledDate.Valid { item.ScheduledDate = &res.ScheduledDate.Date } if res.AcquiredTime.Valid { item.AcquiredTime = &res.AcquiredTime.Time } return item } func (r *itemRepository) fillStats(ctx context.Context, items []models.Item) error { if len(items) == 0 { return nil } ids := make([]interface{}, 0, len(items)) for _, item := range items { ids = append(ids, item.ID) } query := ` SELECT isp.item_id, isp.required, isp.acquired, s.id, s.name, s.weight FROM item_stat_progress isp LEFT JOIN stat s ON s.id = isp.stat_id WHERE item_id IN (?` + strings.Repeat(",?", len(ids)-1) + `); ` rows, err := r.db.QueryContext(ctx, query, ids...) if err != nil { if err == sql.ErrNoRows { return nil } return err } for rows.Next() { var itemID, required, acquired, statID int var statName string var statWeight float64 err = rows.Scan(&itemID, &required, &acquired, &statID, &statName, &statWeight) if err != nil { return err } for i := range items { if items[i].ID == itemID { items[i].Stats = append(items[i].Stats, models.StatProgressEntry{ StatEntry: models.StatEntry{ ID: statID, Name: statName, Weight: statWeight, }, Acquired: acquired, Required: required, }) } } } return nil }