package mysql import ( "context" "database/sql" "encoding/json" "git.aiterp.net/stufflog3/stufflog3-api/internal/database/mysql/mysqlcore" "git.aiterp.net/stufflog3/stufflog3-api/internal/models" "git.aiterp.net/stufflog3/stufflog3-api/internal/slerrors" "git.aiterp.net/stufflog3/stufflog3-api/internal/sqltypes" ) type statsRepository struct { db *sql.DB q *mysqlcore.Queries scopeID int } func (r *statsRepository) Find(ctx context.Context, id int) (*models.Stat, error) { row, err := r.q.GetStat(ctx, mysqlcore.GetStatParams{ScopeID: r.scopeID, ID: id}) if err != nil { if err == sql.ErrNoRows { return nil, slerrors.NotFound("Stat") } return nil, err } return r.rowToStat(row), nil } func (r *statsRepository) List(ctx context.Context) ([]models.Stat, error) { rows, err := r.q.ListStats(ctx, r.scopeID) if err != nil && err != sql.ErrNoRows { return nil, err } stats := make([]models.Stat, 0, len(rows)) for _, row := range rows { stats = append(stats, *r.rowToStat(mysqlcore.GetStatRow(row))) } return stats, nil } func (r *statsRepository) Create(ctx context.Context, stat models.Stat) (*models.Stat, error) { allowedAmounts := sqltypes.NullRawMessage{} if stat.AllowedAmounts != nil && len(stat.AllowedAmounts) > 0 { allowedAmounts.Valid = true allowedAmounts.RawMessage, _ = json.Marshal(stat.AllowedAmounts) } res, err := r.q.InsertStat(ctx, mysqlcore.InsertStatParams{ ScopeID: r.scopeID, Name: stat.Name, Description: stat.Description, Weight: stat.Weight, AllowedAmounts: allowedAmounts, }) if err != nil { return nil, err } id, err := res.LastInsertId() if err != nil { return nil, err } return r.Find(ctx, int(id)) } func (r *statsRepository) Update(ctx context.Context, stat models.Stat, update models.StatUpdate) (*models.Stat, error) { stat.Update(update) allowedAmounts := sqltypes.NullRawMessage{} if stat.AllowedAmounts != nil && len(stat.AllowedAmounts) > 0 { allowedAmounts.Valid = true allowedAmounts.RawMessage, _ = json.Marshal(stat.AllowedAmounts) } err := r.q.UpdateStat(ctx, mysqlcore.UpdateStatParams{ Name: stat.Name, Description: stat.Description, Weight: stat.Weight, AllowedAmounts: allowedAmounts, ID: stat.ID, ScopeID: r.scopeID, }) if err != nil { return nil, err } return &stat, nil } func (r *statsRepository) Delete(ctx context.Context, stat models.Stat) error { tx, err := r.db.BeginTx(ctx, nil) if err != nil { return err } defer tx.Rollback() q := r.q.WithTx(tx) err = q.DeleteStat(ctx, mysqlcore.DeleteStatParams{ScopeID: r.scopeID, ID: stat.ID}) if err == sql.ErrNoRows { return slerrors.NotFound("Stat") } if err != nil { return err } err = q.CLearItemStatProgressByStat(ctx, stat.ID) if err != nil { return err } err = q.DeleteAllProjectRequirementStatsByStat(ctx, stat.ID) if err != nil { return err } // TODO: delete from Sprints return tx.Commit() } func (r *statsRepository) rowToStat(row mysqlcore.GetStatRow) *models.Stat { stat := models.Stat{ StatEntry: models.StatEntry{ ID: row.ID, Name: row.Name, Weight: row.Weight, }, Description: row.Description, AllowedAmounts: nil, } if row.AllowedAmounts.Valid { stat.AllowedAmounts = make(map[string]int) _ = json.Unmarshal(row.AllowedAmounts.RawMessage, &stat) if len(stat.AllowedAmounts) == 0 { stat.AllowedAmounts = nil } } return &stat }