package mysql import ( "context" "database/sql" "git.aiterp.net/stufflog3/stufflog3/entities" "git.aiterp.net/stufflog3/stufflog3/internal/genutils" "git.aiterp.net/stufflog3/stufflog3/models" "git.aiterp.net/stufflog3/stufflog3/ports/mysql/mysqlcore" "github.com/Masterminds/squirrel" "time" ) type sprintRepository struct { db *sql.DB q *mysqlcore.Queries } func (r *sprintRepository) Find(ctx context.Context, scopeID, sprintID int) (*entities.Sprint, error) { row, err := r.q.GetSprint(ctx, mysqlcore.GetSprintParams{ScopeID: scopeID, ID: sprintID}) if err != nil { if err == sql.ErrNoRows { return nil, models.NotFoundError("Sprint") } return nil, err } return &entities.Sprint{ ID: row.ID, ScopeID: row.ScopeID, Name: row.Name, Description: row.Description, Kind: models.SprintKind(row.Kind), FromTime: row.FromTime, ToTime: row.ToTime, IsTimed: row.IsTimed, IsCoarse: row.IsCoarse, IsUnweighted: row.IsUnweighted, AggregateName: row.AggregateName, AggregateRequired: row.AggregateRequired, }, nil } func (r *sprintRepository) ListAt(ctx context.Context, scopeID int, at time.Time) ([]entities.Sprint, error) { rows, err := r.q.ListSprintsAt(ctx, mysqlcore.ListSprintsAtParams{ScopeID: scopeID, Time: at}) if err != nil { if err == sql.ErrNoRows { return nil, models.NotFoundError("Sprint") } return nil, err } sprints := make([]entities.Sprint, 0, len(rows)) for _, row := range rows { sprints = append(sprints, entities.Sprint{ ID: row.ID, ScopeID: row.ScopeID, Name: row.Name, Description: row.Description, Kind: models.SprintKind(row.Kind), FromTime: row.FromTime, ToTime: row.ToTime, IsTimed: row.IsTimed, IsCoarse: row.IsCoarse, IsUnweighted: row.IsUnweighted, AggregateName: row.AggregateName, AggregateRequired: row.AggregateRequired, }) } return sprints, nil } func (r *sprintRepository) ListBetween(ctx context.Context, scopeID int, from, to time.Time) ([]entities.Sprint, error) { rows, err := r.q.ListSprintsBetween(ctx, mysqlcore.ListSprintsBetweenParams{ ScopeID: scopeID, FromTime: from, ToTime: from, FromTime_2: to, ToTime_2: to, FromTime_3: from, ToTime_3: to, }) if err != nil { if err == sql.ErrNoRows { return nil, models.NotFoundError("Sprint") } return nil, err } sprints := make([]entities.Sprint, 0, len(rows)) for _, row := range rows { sprints = append(sprints, entities.Sprint{ ID: row.ID, ScopeID: row.ScopeID, Name: row.Name, Description: row.Description, Kind: models.SprintKind(row.Kind), FromTime: row.FromTime, ToTime: row.ToTime, IsTimed: row.IsTimed, IsCoarse: row.IsCoarse, IsUnweighted: row.IsUnweighted, AggregateName: row.AggregateName, AggregateRequired: row.AggregateRequired, }) } return sprints, nil } func (r *sprintRepository) Insert(ctx context.Context, sprint entities.Sprint) (*entities.Sprint, error) { res, err := r.q.InsertSprint(ctx, mysqlcore.InsertSprintParams{ ScopeID: sprint.ScopeID, Name: sprint.Name, Description: sprint.Description, Kind: int(sprint.Kind), FromTime: sprint.FromTime, ToTime: sprint.ToTime, IsTimed: sprint.IsTimed, IsCoarse: sprint.IsCoarse, AggregateName: sprint.AggregateName, AggregateRequired: sprint.AggregateRequired, IsUnweighted: sprint.IsUnweighted, }) if err != nil { return nil, err } id, err := res.LastInsertId() if err != nil { return nil, err } sprint.ID = int(id) return &sprint, nil } func (r *sprintRepository) Update(ctx context.Context, sprint entities.Sprint, update models.SprintUpdate) error { sprint.ApplyUpdate(update) return r.q.UpdateSprint(ctx, mysqlcore.UpdateSprintParams{ Name: sprint.Name, Description: sprint.Description, FromTime: sprint.FromTime, ToTime: sprint.ToTime, IsTimed: sprint.IsTimed, IsCoarse: sprint.IsCoarse, IsUnweighted: sprint.IsUnweighted, AggregateName: sprint.AggregateName, AggregateRequired: sprint.AggregateRequired, ID: sprint.ID, }) } func (r *sprintRepository) Delete(ctx context.Context, sprint entities.Sprint) error { tx, err := r.db.BeginTx(ctx, nil) if err != nil { return err } defer tx.Rollback() q := r.q.WithTx(tx) err = q.DeleteAllSprintParts(ctx, sprint.ID) if err != nil { return err } err = q.DeleteSprint(ctx, sprint.ID) if err != nil { return err } return tx.Commit() } func (r *sprintRepository) ListParts(ctx context.Context, sprints ...entities.Sprint) ([]entities.SprintPart, error) { if len(sprints) == 0 { return []entities.SprintPart{}, nil } else if len(sprints) == 1 { rows, err := r.q.ListSprintParts(ctx, sprints[0].ID) if err != nil { return nil, err } return genutils.Map(rows, func(row mysqlcore.SprintPart) entities.SprintPart { return entities.SprintPart{ SprintID: row.SprintID, PartID: row.ObjectID, Required: row.Required, } }), nil } ids := make([]int, 0, len(sprints)) for _, sprint := range sprints { ids = append(ids, sprint.ID) } query, args, err := squirrel.Select("sprint_id, object_id, required"). From("sprint_part"). Where(squirrel.Eq{"sprint_id": ids}). ToSql() if err != nil { return nil, err } rows, err := r.db.QueryContext(ctx, query, args...) if err != nil { return nil, err } res := make([]entities.SprintPart, 0, 16) for rows.Next() { part := entities.SprintPart{} err = rows.Scan(&part.SprintID, &part.PartID, &part.Required) if err != nil { return nil, err } res = append(res, part) } return res, nil } func (r *sprintRepository) UpdatePart(ctx context.Context, part entities.SprintPart) error { return r.q.ReplaceSprintPart(ctx, mysqlcore.ReplaceSprintPartParams{ SprintID: part.SprintID, ObjectID: part.PartID, Required: part.Required, }) } func (r *sprintRepository) DeletePart(ctx context.Context, part entities.SprintPart) error { return r.q.DeleteSprintPart(ctx, mysqlcore.DeleteSprintPartParams{ SprintID: part.SprintID, ObjectID: part.PartID, }) }