package mysql import ( "context" "database/sql" "git.aiterp.net/stufflog3/stufflog3/entities" "git.aiterp.net/stufflog3/stufflog3/internal/genutils" "git.aiterp.net/stufflog3/stufflog3/models" "git.aiterp.net/stufflog3/stufflog3/ports/mysql/mysqlcore" "git.aiterp.net/stufflog3/stufflog3/ports/mysql/sqltypes" "github.com/Masterminds/squirrel" "sort" "strings" ) type itemRepository struct { db *sql.DB q *mysqlcore.Queries } func (r *itemRepository) Find(ctx context.Context, scopeID, itemID int) (*entities.Item, error) { row, err := r.q.GetItem(ctx, mysqlcore.GetItemParams{ScopeID: scopeID, ID: itemID}) if err != nil { return nil, err } tags, err := r.q.ListTagsByObject(ctx, mysqlcore.ListTagsByObjectParams{ ObjectKind: tagObjectKindItem, ObjectID: row.ID, }) return &entities.Item{ ID: row.ID, ScopeID: row.ScopeID, OwnerID: row.OwnerID, ProjectID: intPtr(row.ProjectID), RequirementID: intPtr(row.ProjectRequirementID), Name: row.Name, Description: row.Description, CreatedTime: row.CreatedTime, AcquiredTime: timePtr(row.AcquiredTime), ScheduledDate: row.ScheduledDate.AsPtr(), Tags: tags, }, nil } func (r *itemRepository) Fetch(ctx context.Context, filter models.ItemFilter) ([]entities.Item, error) { // Blank arrays are not the same as nulls if filter.IDs != nil && len(filter.IDs) == 0 { return []entities.Item{}, nil } if filter.ScopeIDs != nil && len(filter.ScopeIDs) == 0 { return []entities.Item{}, nil } if filter.ProjectIDs != nil && len(filter.ProjectIDs) == 0 { return []entities.Item{}, nil } if filter.RequirementIDs != nil && len(filter.RequirementIDs) == 0 { return []entities.Item{}, nil } if filter.StatIDs != nil && len(filter.StatIDs) == 0 { return []entities.Item{}, nil } if filter.Tags != nil && len(filter.Tags) == 0 { return []entities.Item{}, nil } sq := squirrel.Select( "i.id, i.scope_id, i.project_requirement_id, pr.project_id, i.owner_id, i.name," + " i.description, i.created_time, i.acquired_time, i.scheduled_date", ).From("item i").LeftJoin("project_requirement pr ON pr.id = i.project_requirement_id") dateOr := squirrel.Or{} if filter.CreatedTime != nil { dateOr = append(dateOr, squirrel.And{ squirrel.GtOrEq{"i.created_time": filter.CreatedTime.Min}, squirrel.Lt{"i.created_time": filter.CreatedTime.Max}, }) } if filter.UnAcquired { sq = sq.Where("i.acquired_time IS NULL") } else if filter.AcquiredTime != nil { dateOr = append(dateOr, squirrel.And{ squirrel.GtOrEq{"i.acquired_time": filter.AcquiredTime.Min}, squirrel.Lt{"i.acquired_time": filter.AcquiredTime.Max}, }) } if filter.UnScheduled { sq = sq.Where("i.scheduled_date IS NULL") } else if filter.ScheduledDate != nil { dateOr = append(dateOr, squirrel.And{ squirrel.GtOrEq{"i.scheduled_date": filter.ScheduledDate.Min.AsTime()}, squirrel.Lt{"i.scheduled_date": filter.ScheduledDate.Max.AsTime()}, }) } if len(dateOr) > 0 { sq = sq.Where(dateOr) } if len(filter.IDs) > 0 { sq = sq.Where(squirrel.Eq{"i.id": filter.IDs}) } if len(filter.ScopeIDs) > 0 { sq = sq.Where(squirrel.Eq{"i.scope_id": filter.ScopeIDs}) } if len(filter.RequirementIDs) > 0 { sq = sq.Where(squirrel.Eq{"i.project_requirement_id": filter.RequirementIDs}) } if len(filter.ProjectIDs) > 0 { sq = sq.Where(squirrel.Eq{"pr.project_id": filter.ProjectIDs}) } if len(filter.StatIDs) > 0 { sq = sq.LeftJoin("item_stat_progress isp ON isp.item_id = i.id") sq = sq.Where(squirrel.Eq{"isp.stat_id": filter.StatIDs}) } if filter.OwnerID != nil { sq = sq.Where(squirrel.Eq{"i.owner_id": filter.OwnerID}) } if filter.Loose { sq = sq.Where("i.project_requirement_id IS NULL") } query, params, err := sq.ToSql() if err != nil { return nil, err } rows, err := r.db.QueryContext(ctx, query, params...) if err != nil { if err == sql.ErrNoRows { return []entities.Item{}, nil } return nil, err } seen := make(map[int]bool, 32) res := make([]entities.Item, 0, 32) ids := genutils.Set[int]{} projectIDs := genutils.Set[int]{} requirementIDs := genutils.Set[int]{} for rows.Next() { item := entities.Item{} var projectRequirementId, projectId sql.NullInt32 var acquiredTime sql.NullTime var scheduledDate sqltypes.NullDate err = rows.Scan( &item.ID, &item.ScopeID, &projectRequirementId, &projectId, &item.OwnerID, &item.Name, &item.Description, &item.CreatedTime, &acquiredTime, &scheduledDate, ) if err != nil { if err == sql.ErrNoRows { break } return nil, err } if seen[item.ID] { continue } seen[item.ID] = true item.RequirementID = intPtr(projectRequirementId) item.ProjectID = intPtr(projectId) item.AcquiredTime = timePtr(acquiredTime) item.ScheduledDate = scheduledDate.AsPtr() item.Tags = []string{} ids.Add(item.ID) if item.ProjectID != nil { projectIDs.Add(*item.ProjectID) requirementIDs.Add(*item.RequirementID) } res = append(res, item) } err = fetchTags(ctx, r.db, tagObjectKindItem, ids.Values(), func(id int, tag string) { for i := range res { if id == res[i].ID { res[i].Tags = append(res[i].Tags, tag) } } }) if err != nil { return nil, err } if len(filter.Tags) > 0 { projectTagMap := make(map[int][]string, 64) err = fetchTags(ctx, r.db, tagObjectKindProject, projectIDs.Values(), func(id int, tag string) { projectTagMap[id] = append(projectTagMap[id], tag) }) if err != nil { return nil, err } requirementTagMap := make(map[int][]string, 64) err = fetchTags(ctx, r.db, tagObjectKindRequirement, requirementIDs.Values(), func(id int, tag string) { requirementTagMap[id] = append(requirementTagMap[id], tag) }) if err != nil { return nil, err } res = genutils.RetainInPlace(res, func(item entities.Item) bool { good := false for _, tag := range filter.Tags { if strings.HasPrefix(tag, "!") { tag = tag[1:] if !item.HasTag(tag) { if item.RequirementID != nil { if !genutils.Contains(requirementTagMap[*item.RequirementID], tag) && !genutils.Contains(projectTagMap[*item.ProjectID], tag) { return false } } else { return false } } good = true } else { if item.HasTag(tag) { good = true } else if item.RequirementID != nil && (genutils.Contains(requirementTagMap[*item.RequirementID], tag) || genutils.Contains(projectTagMap[*item.ProjectID], tag)) { good = true } } } return good }) } sort.Slice(res, func(i, j int) bool { // Acquired time, descending ati, atj := res[i].AcquiredTime, res[j].AcquiredTime if ati == nil && atj != nil { return true } if ati != nil && atj == nil { return false } if ati != nil && atj != nil { return ati.After(*atj) } // Scheduled date, ascending cti, ctj := res[i].CreatedTime, res[j].CreatedTime sdi, sdj := res[i].ScheduledDate, res[j].ScheduledDate if sdi != nil && sdj == nil { return true } if sdi == nil && sdj != nil { return false } if sdi != nil && sdj != nil { if *sdi == *sdj { // This should change the behavior on the front page only. #hax if filter.UnAcquired { return cti.Before(ctj) } else { return cti.After(ctj) } } return sdi.Before(*sdj) } // Created time, descending return cti.After(ctj) }) return res, nil } func (r *itemRepository) Insert(ctx context.Context, item entities.Item) (*entities.Item, error) { tx, err := r.db.BeginTx(ctx, nil) if err != nil { return nil, err } defer tx.Rollback() q := mysqlcore.New(tx) res, err := q.InsertItem(ctx, mysqlcore.InsertItemParams{ ScopeID: item.ScopeID, ProjectRequirementID: sqlIntPtr(item.RequirementID), Name: item.Name, Description: item.Description, CreatedTime: item.CreatedTime, OwnerID: item.OwnerID, AcquiredTime: sqlTimePtr(item.AcquiredTime), ScheduledDate: sqlDatePtr(item.ScheduledDate), }) if err != nil { return nil, err } id, err := res.LastInsertId() if err != nil { return nil, err } item.ID = int(id) for _, tag := range item.Tags { err := q.InsertTag(ctx, mysqlcore.InsertTagParams{ ObjectKind: tagObjectKindItem, ObjectID: item.ID, TagName: tag, }) if err != nil { return nil, err } } err = tx.Commit() if err != nil { return nil, err } return &item, nil } func (r *itemRepository) Update(ctx context.Context, item entities.Item, update models.ItemUpdate) error { item.ApplyUpdate(update) tx, err := r.db.BeginTx(ctx, nil) if err != nil { return err } defer tx.Rollback() q := mysqlcore.New(tx) err = q.UpdateItem(ctx, mysqlcore.UpdateItemParams{ ProjectRequirementID: sqlIntPtr(item.RequirementID), Name: item.Name, Description: item.Description, AcquiredTime: sqlTimePtr(item.AcquiredTime), ScheduledDate: sqlDatePtr(item.ScheduledDate), OwnerID: item.OwnerID, ID: item.ID, }) if err != nil { return err } for _, tag := range update.RemoveTags { err := q.DeleteTag(ctx, mysqlcore.DeleteTagParams{ ObjectKind: tagObjectKindItem, ObjectID: item.ID, TagName: tag, }) if err != nil { return err } } for _, tag := range update.AddTags { err = q.InsertTag(ctx, mysqlcore.InsertTagParams{ ObjectKind: tagObjectKindItem, ObjectID: item.ID, TagName: tag, }) if err != nil { return err } } return tx.Commit() } func (r *itemRepository) Delete(ctx context.Context, item entities.Item) error { err := r.q.DeleteTagByObject(ctx, mysqlcore.DeleteTagByObjectParams{ ObjectKind: tagObjectKindItem, ObjectID: item.ID, }) if err != nil { return err } return r.q.DeleteItem(ctx, item.ID) } func (r *itemRepository) ListStat(ctx context.Context, items ...entities.Item) ([]entities.ItemStat, error) { if len(items) == 0 { return []entities.ItemStat{}, nil } ids := make([]interface{}, 0, 64) for _, item := range items { ids = append(ids, item.ID) } query := ` SELECT item_id, stat_id, acquired, required FROM item_stat_progress WHERE item_id IN (?` + strings.Repeat(",?", len(ids)-1) + `); ` rows, err := r.db.QueryContext(ctx, query, ids...) if err != nil { if err == sql.ErrNoRows { return []entities.ItemStat{}, nil } return nil, err } res := make([]entities.ItemStat, 0, 8) for rows.Next() { progress := entities.ItemStat{} err = rows.Scan(&progress.ItemID, &progress.StatID, &progress.Acquired, &progress.Required) if err != nil { return nil, err } res = append(res, progress) } return res, nil } func (r *itemRepository) UpdateStat(ctx context.Context, item entities.ItemStat) error { if item.Required <= 0 { return r.q.DeleteItemStatProgress(ctx, mysqlcore.DeleteItemStatProgressParams{ItemID: item.ItemID, StatID: item.StatID}) } return r.q.ReplaceItemStatProgress(ctx, mysqlcore.ReplaceItemStatProgressParams{ ItemID: item.ItemID, StatID: item.StatID, Acquired: item.Acquired, Required: item.Required, }) }