package mysql import ( "context" "database/sql" "git.aiterp.net/stufflog3/stufflog3-api/internal/database/mysql/mysqlcore" "git.aiterp.net/stufflog3/stufflog3-api/internal/models" "git.aiterp.net/stufflog3/stufflog3-api/internal/slerrors" "golang.org/x/sync/errgroup" ) type projectRepository struct { db *sql.DB q *mysqlcore.Queries items *itemRepository scopeID int } func (r *projectRepository) Find(ctx context.Context, id int) (*models.Project, error) { row, err := r.q.GetProject(ctx, id) if err == sql.ErrNoRows || row.ScopeID != r.scopeID { return nil, slerrors.NotFound("Project") } else if err != nil { return nil, err } project := models.Project{ ProjectEntry: models.ProjectEntry{ ID: row.ID, OwnerID: row.AuthorID, CreatedTime: row.CreatedTime, Name: row.Name, Status: models.Status(row.Status), }, Description: row.Description, Requirements: []models.ProjectRequirement{}, } reqs, err := r.q.ListProjectRequirementsByProjectID(ctx, id) if err != nil && err != sql.ErrNoRows { return nil, err } itemRows, err := r.q.ListItemsByProject(ctx, id) if err != nil && err != sql.ErrNoRows { return nil, err } items := make([]models.Item, 0, len(itemRows)) for _, itemRow := range itemRows { item := r.items.resToItem(mysqlcore.ListItemsAcquiredBetweenRow(itemRow)) items = append(items, item) } err = r.items.fillStats(ctx, items) if err != nil { return nil, err } eg, ctx := errgroup.WithContext(ctx) for i := range reqs { project.Requirements = append(project.Requirements, models.ProjectRequirement{ ID: reqs[i].ID, Name: reqs[i].Name, Description: reqs[i].Description, Status: models.Status(reqs[i].Status), Stats: []models.StatProgressEntry{}, Items: nil, }) requirement := &project.Requirements[len(project.Requirements)-1] for _, item := range items { if *item.ProjectRequirementID == requirement.ID { requirement.Items = append(requirement.Items, item) } } eg.Go(func() error { stats, err := r.q.ListProjectRequirementStats(ctx, requirement.ID) if err != nil && err != sql.ErrNoRows { return err } for _, statRow := range stats { stat := models.StatProgressEntry{ StatEntry: models.StatEntry{ ID: statRow.ID, Name: statRow.Name, Weight: statRow.Weight, }, Acquired: 0, Required: int(statRow.Required.Int32), } for _, item := range requirement.Items { for _, stat2 := range item.Stats { if stat2.ID == stat.ID { stat.Acquired += stat2.Acquired } } } requirement.Stats = append(requirement.Stats, stat) } return nil }) } err = eg.Wait() if err != nil { return nil, err } return &project, nil } func (r *projectRepository) List(ctx context.Context) ([]models.ProjectEntry, error) { rows, err := r.q.ListProjectEntries(ctx, r.scopeID) if err != nil && err != sql.ErrNoRows { return nil, err } projects := make([]models.ProjectEntry, 0, len(rows)) for _, row := range rows { projects = append(projects, models.ProjectEntry{ ID: row.ID, OwnerID: row.AuthorID, CreatedTime: row.CreatedTime, Name: row.Name, Status: models.Status(row.Status), }) } return projects, nil } func (r *projectRepository) Create(ctx context.Context, project models.Project) (*models.Project, error) { res, err := r.q.InsertProject(ctx, mysqlcore.InsertProjectParams{ ScopeID: r.scopeID, AuthorID: project.OwnerID, Name: project.Name, Status: int(project.Status), Description: project.Description, CreatedTime: project.CreatedTime, }) if err != nil { return nil, err } id, err := res.LastInsertId() if err != nil { return nil, err } return r.Find(ctx, int(id)) } func (r *projectRepository) Update(ctx context.Context, project models.Project, update models.ProjectUpdate) (*models.Project, error) { project.Update(update) err := r.q.UpdateProject(ctx, mysqlcore.UpdateProjectParams{ Name: project.Name, Status: int(project.Status), Description: project.Description, ID: project.ID, }) if err != nil { return nil, err } return r.Find(ctx, project.ID) } func (r *projectRepository) Delete(ctx context.Context, project models.ProjectEntry, deleteItems bool) error { tx, err := r.db.BeginTx(ctx, nil) if err != nil { return err } defer tx.Rollback() q := r.q.WithTx(tx) err = q.DeleteProject(ctx, project.ID) if err != nil { return err } if deleteItems { reqs, err := q.ListProjectRequirementsByProjectID(ctx, project.ID) if err != nil { return err } for _, req := range reqs { err = q.DeleteItemForRequirement(ctx, sql.NullInt32{Valid: true, Int32: int32(req.ID)}) if err != nil { return err } err = q.DeleteAllProjectRequirementStats(ctx, req.ID) if err != nil { return err } } } else { err = q.ClearItemProjectRequirementByProjectID(ctx, project.ID) if err != nil { return err } } return tx.Commit() } func (r *projectRepository) AddRequirement(ctx context.Context, project models.ProjectEntry, requirement models.ProjectRequirement) (*models.Project, error) { tx, err := r.db.BeginTx(ctx, nil) if err != nil { return nil, err } defer tx.Rollback() q := r.q.WithTx(tx) res, err := q.InsertProjectRequirement(ctx, mysqlcore.InsertProjectRequirementParams{ ScopeID: r.scopeID, ProjectID: project.ID, Name: requirement.Name, Status: int(requirement.Status), Description: requirement.Description, }) if err != nil { return nil, err } id, err := res.LastInsertId() if err != nil { return nil, err } for _, stat := range requirement.Stats { err = q.ReplaceProjectRequirementStat(ctx, mysqlcore.ReplaceProjectRequirementStatParams{ ProjectRequirementID: int(id), StatID: stat.ID, Required: stat.Required, }) if err != nil { return nil, err } } err = tx.Commit() if err != nil { return nil, err } return r.Find(ctx, project.ID) } func (r *projectRepository) UpdateRequirement(ctx context.Context, project models.ProjectEntry, requirement models.ProjectRequirement, update models.ProjectRequirementUpdate) (*models.Project, error) { tx, err := r.db.BeginTx(ctx, nil) if err != nil { return nil, err } defer tx.Rollback() q := r.q.WithTx(tx) requirement.Update(update) err = q.UpdateProjectRequirement(ctx, mysqlcore.UpdateProjectRequirementParams{ Name: requirement.Name, Status: int(requirement.Status), Description: requirement.Description, ID: requirement.ID, }) if err != nil { return nil, err } for _, stat := range requirement.Stats { if stat.Required < 0 { err = q.DeleteProjectRequirementStat(ctx, mysqlcore.DeleteProjectRequirementStatParams{ ProjectRequirementID: requirement.ID, StatID: stat.ID, }) } else { err = q.ReplaceProjectRequirementStat(ctx, mysqlcore.ReplaceProjectRequirementStatParams{ ProjectRequirementID: requirement.ID, StatID: stat.ID, Required: stat.Required, }) } if err != nil { return nil, err } } err = tx.Commit() if err != nil { return nil, err } return r.Find(ctx, project.ID) } func (r *projectRepository) DeleteRequirement(ctx context.Context, project models.ProjectEntry, requirement models.ProjectRequirement, deleteItems bool) (*models.Project, error) { tx, err := r.db.BeginTx(ctx, nil) if err != nil { return nil, err } defer tx.Rollback() q := r.q.WithTx(tx) err = q.DeleteProjectRequirement(ctx, requirement.ID) if err != nil { if err == sql.ErrNoRows { return nil, slerrors.NotFound("Project requirement") } return nil, err } if deleteItems { err = q.DeleteItemForRequirement(ctx, sql.NullInt32{Valid: true, Int32: int32(requirement.ID)}) if err != nil { return nil, err } } else { err = q.ClearItemProjectRequirement(ctx, sql.NullInt32{Valid: true, Int32: int32(requirement.ID)}) if err != nil { return nil, err } err = q.DeleteAllProjectRequirementStats(ctx, requirement.ID) if err != nil { return nil, err } } err = tx.Commit() if err != nil { return nil, err } return r.Find(ctx, project.ID) }