package mysql import ( "context" "database/sql" "git.aiterp.net/stufflog3/stufflog3/entities" "git.aiterp.net/stufflog3/stufflog3/internal/genutils" "git.aiterp.net/stufflog3/stufflog3/models" "git.aiterp.net/stufflog3/stufflog3/ports/mysql/mysqlcore" "github.com/Masterminds/squirrel" ) type projectRepository struct { db *sql.DB q *mysqlcore.Queries } func (r *projectRepository) Find(ctx context.Context, scopeID, projectID int) (*entities.Project, error) { row, err := r.q.GetProject(ctx, mysqlcore.GetProjectParams{ID: projectID, ScopeID: scopeID}) if err != nil { if err == sql.ErrNoRows { return nil, models.NotFoundError("Project") } return nil, err } tags, err := r.q.ListTagsByObject(ctx, mysqlcore.ListTagsByObjectParams{ ObjectKind: tagObjectKindProject, ObjectID: row.ID, }) return &entities.Project{ ID: row.ID, ScopeID: row.ScopeID, OwnerID: row.OwnerID, CreatedTime: row.CreatedTime, Name: row.Name, Description: row.Description, Status: models.Status(row.Status), Tags: tags, }, nil } func (r *projectRepository) FetchProjects(ctx context.Context, scopeID int, ids ...int) ([]entities.Project, error) { if len(ids) == 0 { return []entities.Project{}, nil } else if len(ids) == 1 && scopeID != -1 { project, err := r.Find(ctx, scopeID, ids[0]) if err != nil { return nil, err } return []entities.Project{*project}, nil } sq := squirrel.Select("id,scope_id,owner_id,name,status,description,created_time"). From("project"). Where(squirrel.Eq{"id": ids}) if scopeID != -1 { sq = sq.Where(squirrel.Eq{"scope_id": scopeID}) } query, args, err := sq.ToSql() if err != nil { return nil, err } rows, err := r.db.QueryContext(ctx, query, args...) if err != nil { return nil, err } projects := make([]entities.Project, 0, len(ids)) for rows.Next() { project := entities.Project{} if err := rows.Scan( &project.ID, &project.ScopeID, &project.OwnerID, &project.Name, &project.Status, &project.Description, &project.CreatedTime, ); err != nil { return nil, err } project.Tags = []string{} projects = append(projects, project) } if err := rows.Close(); err != nil { return nil, err } if err := rows.Err(); err != nil { return nil, err } // Fill tags err = fetchTags(ctx, r.db, tagObjectKindProject, ids, func(id int, tag string) { for i := range projects { if projects[i].ID == id { projects[i].Tags = append(projects[i].Tags, tag) } } }) if err != nil { return nil, err } return projects, nil } func (r *projectRepository) List(ctx context.Context, scopeID int) ([]entities.Project, error) { rows, err := r.q.ListProjects(ctx, scopeID) if err != nil { if err == sql.ErrNoRows { return []entities.Project{}, nil } return nil, err } res := make([]entities.Project, 0, len(rows)) ids := make([]int, 0, len(rows)) for _, row := range rows { res = append(res, entities.Project{ ID: row.ID, ScopeID: row.ScopeID, OwnerID: row.OwnerID, CreatedTime: row.CreatedTime, Name: row.Name, Description: row.Description, Status: models.Status(row.Status), Tags: []string{}, }) ids = append(ids, row.ID) } // Fill tags err = fetchTags(ctx, r.db, tagObjectKindProject, ids, func(id int, tag string) { for i := range res { if id == res[i].ID { res[i].Tags = append(res[i].Tags, tag) } } }) if err != nil { return nil, err } return res, nil } func (r *projectRepository) ListByTags(ctx context.Context, scopeID int, tags []string) ([]entities.Project, error) { query, args, err := squirrel.Select("object_id, tag_name"). From("tag"). Where(squirrel.Eq{"tag_name": tags, "object_kind": tagObjectKindProject}). ToSql() if err != nil { return nil, err } rows, err := r.db.QueryContext(ctx, query, args...) if err != nil { return nil, err } ids := make([]int, 0, 16) matches := make(map[int]int, 64) for rows.Next() { var objectID int var tagName string err := rows.Scan(&objectID, &tagName) if err != nil { return nil, err } if matches[objectID] == 0 { ids = append(ids, objectID) } matches[objectID] += 1 } err = rows.Close() if err != nil { return nil, err } ids = genutils.RetainInPlace(ids, func(id int) bool { return matches[id] == len(tags) }) return r.FetchProjects(ctx, scopeID, ids...) } func (r *projectRepository) Insert(ctx context.Context, project entities.Project) (*entities.Project, error) { tx, err := r.db.BeginTx(ctx, nil) if err != nil { return nil, err } defer tx.Rollback() q := r.q.WithTx(tx) res, err := q.InsertProject(ctx, mysqlcore.InsertProjectParams{ ScopeID: project.ScopeID, OwnerID: project.OwnerID, Name: project.Name, Status: int(project.Status), Description: project.Description, }) if err != nil { return nil, err } id, err := res.LastInsertId() if err != nil { return nil, err } project.ID = int(id) for _, tag := range project.Tags { err := q.InsertTag(ctx, mysqlcore.InsertTagParams{ ObjectKind: tagObjectKindProject, ObjectID: project.ID, TagName: tag, }) if err != nil { return nil, err } } err = tx.Commit() if err != nil { return nil, err } return &project, nil } func (r *projectRepository) Update(ctx context.Context, project entities.Project, update models.ProjectUpdate) error { tx, err := r.db.BeginTx(ctx, nil) if err != nil { return err } defer tx.Rollback() q := r.q.WithTx(tx) project.Update(update) err = q.UpdateProject(ctx, mysqlcore.UpdateProjectParams{ OwnerID: project.OwnerID, Name: project.Name, Status: int(project.Status), Description: project.Description, ID: project.ID, ScopeID: project.ScopeID, }) if err != nil { return err } for _, tag := range update.RemoveTags { err := q.DeleteTag(ctx, mysqlcore.DeleteTagParams{ ObjectKind: tagObjectKindProject, ObjectID: project.ID, TagName: tag, }) if err != nil { return err } } for _, tag := range update.AddTags { err = q.InsertTag(ctx, mysqlcore.InsertTagParams{ ObjectKind: tagObjectKindProject, ObjectID: project.ID, TagName: tag, }) if err != nil { return err } } return tx.Commit() } func (r *projectRepository) Delete(ctx context.Context, project entities.Project) error { tx, err := r.db.BeginTx(ctx, nil) if err != nil { return err } defer tx.Rollback() q := r.q.WithTx(tx) reqs, err := q.ListProjectRequirements(ctx, project.ID) if err != nil { return err } for _, req := range reqs { err = q.ClearItemProjectRequirement(ctx, sql.NullInt32{Valid: true, Int32: int32(req.ID)}) if err != nil { return err } err = q.DeleteAllProjectRequirementStats(ctx, req.ID) if err != nil { return err } err = q.DeleteTagByObject(ctx, mysqlcore.DeleteTagByObjectParams{ ObjectKind: tagObjectKindRequirement, ObjectID: req.ID, }) if err != nil { return err } } err = q.DeleteTagByObject(ctx, mysqlcore.DeleteTagByObjectParams{ ObjectKind: tagObjectKindProject, ObjectID: project.ID, }) if err != nil { return err } err = q.DeleteAllProjectRequirements(ctx, project.ID) if err != nil { return err } err = q.DeleteProject(ctx, mysqlcore.DeleteProjectParams{ID: project.ID, ScopeID: project.ScopeID}) if err != nil { return err } return tx.Commit() } func (r *projectRepository) FetchRequirements(ctx context.Context, scopeID int, requirementIDs ...int) ([]entities.Requirement, []entities.RequirementStat, error) { if len(requirementIDs) == 0 { return []entities.Requirement{}, []entities.RequirementStat{}, nil } sq := squirrel.Select("id, scope_id, project_id, name, status, description, is_coarse, aggregate_required"). From("project_requirement"). Where(squirrel.Eq{"id": requirementIDs}) if scopeID != -1 { sq = sq.Where(squirrel.Eq{"scope_id": scopeID}) } query, args, err := sq.ToSql() if err != nil { return nil, nil, err } rows, err := r.db.QueryContext(ctx, query, args...) if err != nil { return nil, nil, err } ids := make([]int, 0, 16) requirements := make([]entities.Requirement, 0, len(requirementIDs)) for rows.Next() { requirement := entities.Requirement{} if err := rows.Scan( &requirement.ID, &requirement.ScopeID, &requirement.ProjectID, &requirement.Name, &requirement.Status, &requirement.Description, &requirement.IsCoarse, &requirement.AggregateRequired, ); err != nil { return nil, nil, err } requirement.Tags = []string{} requirements = append(requirements, requirement) ids = append(ids, requirement.ID) } // Fill tags err = fetchTags(ctx, r.db, tagObjectKindRequirement, ids, func(id int, tag string) { for i := range requirements { if id == requirements[i].ID { requirements[i].Tags = append(requirements[i].Tags, tag) break } } }) if err != nil { return nil, nil, err } query, args, err = squirrel.Select("project_requirement_id, stat_id, required"). From("project_requirement_stat"). Where(squirrel.Eq{"project_requirement_id": requirementIDs}). ToSql() if err != nil { return nil, nil, err } rows, err = r.db.QueryContext(ctx, query, args...) if err != nil { return nil, nil, err } stats := make([]entities.RequirementStat, 0, len(requirementIDs)) for rows.Next() { stat := entities.RequirementStat{} if err := rows.Scan( &stat.RequirementID, &stat.StatID, &stat.Required, ); err != nil { return nil, nil, err } stats = append(stats, stat) } return requirements, stats, nil } func (r *projectRepository) ListRequirements(ctx context.Context, projectID int) ([]entities.Requirement, []entities.RequirementStat, error) { reqRows, err := r.q.ListProjectRequirements(ctx, projectID) if err != nil && err != sql.ErrNoRows { return nil, nil, err } statsRows, err := r.q.ListProjectRequirementsStats(ctx, projectID) if err != nil && err != sql.ErrNoRows { return nil, nil, err } requirements := make([]entities.Requirement, 0, len(reqRows)) ids := make([]int, 0, len(reqRows)) for _, row := range reqRows { requirements = append(requirements, entities.Requirement{ ID: row.ID, ScopeID: row.ScopeID, ProjectID: row.ProjectID, Name: row.Name, Description: row.Description, IsCoarse: row.IsCoarse, AggregateRequired: row.AggregateRequired, Status: models.Status(row.Status), Tags: []string{}, }) ids = append(ids, row.ID) } // Fill tags err = fetchTags(ctx, r.db, tagObjectKindRequirement, ids, func(id int, tag string) { for i := range requirements { if id == requirements[i].ID { requirements[i].Tags = append(requirements[i].Tags, tag) break } } }) if err != nil { return nil, nil, err } stats := make([]entities.RequirementStat, 0, len(statsRows)) for _, row := range statsRows { stats = append(stats, entities.RequirementStat{ RequirementID: row.ProjectRequirementID, StatID: row.StatID, Required: row.Required, }) } return requirements, stats, nil } func (r *projectRepository) CreateRequirement(ctx context.Context, requirement entities.Requirement) (*entities.Requirement, error) { tx, err := r.db.BeginTx(ctx, nil) if err != nil { return nil, err } defer tx.Rollback() q := r.q.WithTx(tx) res, err := q.InsertProjectRequirement(ctx, mysqlcore.InsertProjectRequirementParams{ ScopeID: requirement.ScopeID, ProjectID: requirement.ProjectID, Name: requirement.Name, Status: int(requirement.Status), Description: requirement.Description, IsCoarse: requirement.IsCoarse, AggregateRequired: requirement.AggregateRequired, }) if err != nil { return nil, err } id, err := res.LastInsertId() if err != nil { return nil, err } for _, tag := range requirement.Tags { err := q.InsertTag(ctx, mysqlcore.InsertTagParams{ ObjectKind: tagObjectKindRequirement, ObjectID: int(id), TagName: tag, }) if err != nil { return nil, err } } err = tx.Commit() if err != nil { return nil, err } requirement.ID = int(id) return &requirement, nil } func (r *projectRepository) UpdateRequirement(ctx context.Context, requirement entities.Requirement, update models.RequirementUpdate) error { tx, err := r.db.BeginTx(ctx, nil) if err != nil { return err } defer tx.Rollback() q := r.q.WithTx(tx) requirement.Update(update) _ = q.UpdateProjectRequirement(ctx, mysqlcore.UpdateProjectRequirementParams{ Name: requirement.Name, Status: int(requirement.Status), Description: requirement.Description, IsCoarse: requirement.IsCoarse, ID: requirement.ID, ScopeID: requirement.ScopeID, AggregateRequired: requirement.AggregateRequired, ProjectID: requirement.ProjectID, }) for _, tag := range update.RemoveTags { err := q.DeleteTag(ctx, mysqlcore.DeleteTagParams{ ObjectKind: tagObjectKindRequirement, ObjectID: requirement.ID, TagName: tag, }) if err != nil { return err } } for _, tag := range update.AddTags { err = q.InsertTag(ctx, mysqlcore.InsertTagParams{ ObjectKind: tagObjectKindRequirement, ObjectID: requirement.ID, TagName: tag, }) if err != nil { return err } } return tx.Commit() } func (r *projectRepository) DeleteRequirement(ctx context.Context, requirement entities.Requirement) error { tx, err := r.db.BeginTx(ctx, nil) if err != nil { return err } defer tx.Rollback() q := r.q.WithTx(tx) err = q.ClearItemProjectRequirement(ctx, sql.NullInt32{Valid: true, Int32: int32(requirement.ID)}) if err != nil { return err } err = q.DeleteAllProjectRequirementStats(ctx, requirement.ID) if err != nil { return err } err = q.DeleteProjectRequirement(ctx, mysqlcore.DeleteProjectRequirementParams{ ScopeID: requirement.ScopeID, ID: requirement.ID, }) if err != nil { return err } err = q.DeleteTagByObject(ctx, mysqlcore.DeleteTagByObjectParams{ ObjectKind: tagObjectKindRequirement, ObjectID: requirement.ID, }) if err != nil { return err } return tx.Commit() } func (r *projectRepository) UpsertRequirementStat(ctx context.Context, stat entities.RequirementStat) error { return r.q.ReplaceProjectRequirementStat(ctx, mysqlcore.ReplaceProjectRequirementStatParams{ ProjectRequirementID: stat.RequirementID, StatID: stat.StatID, Required: stat.Required, }) } func (r *projectRepository) DeleteRequirementStat(ctx context.Context, stat entities.RequirementStat) error { return r.q.DeleteProjectRequirementStat(ctx, mysqlcore.DeleteProjectRequirementStatParams{ ProjectRequirementID: stat.RequirementID, StatID: stat.StatID, }) }