diff --git a/internal/mysql/devicerepo.go b/internal/mysql/devicerepo.go index 25d8df9..9892e64 100644 --- a/internal/mysql/devicerepo.go +++ b/internal/mysql/devicerepo.go @@ -58,13 +58,33 @@ func (r *DeviceRepo) Find(ctx context.Context, id int) (*models.Device, error) { func (r *DeviceRepo) FetchByReference(ctx context.Context, kind models.ReferenceKind, value string) ([]models.Device, error) { q := sq.Select("device.*").From("device") + var requiredTags []string + var unwantedTags []string + switch kind { case models.RKDeviceID: q = q.Where(sq.Eq{"id": strings.Split(value, ",")}) case models.RKBridgeID: q = q.Where(sq.Eq{"bridge_id": strings.Split(value, ",")}) case models.RKTag: - q = q.Join("device_tag dt ON device.id=dt.device_id").Where(sq.Eq{"dt.tag_name": strings.Split(value, ",")}) + allTags := strings.Split(strings.ReplaceAll(strings.ReplaceAll(value, "-", ",-"), "+", ",+"), ",") + optionalTags := make([]string, 0, len(allTags)) + + for _, tag := range allTags { + if tag == "" || tag == "+" || tag == "-" { + continue + } + + if strings.HasPrefix(tag, "+") { + requiredTags = append(requiredTags, tag[1:]) + } else if strings.HasPrefix(tag, "-") { + unwantedTags = append(unwantedTags, tag[1:]) + } else { + optionalTags = append(optionalTags, tag) + } + } + + q = q.Join("device_tag dt ON device.id=dt.device_id").Where(sq.Eq{"dt.tag_name": optionalTags}) case models.RKAll: default: log.Println("Unknown reference kind used for device fetch:", kind) @@ -82,6 +102,10 @@ func (r *DeviceRepo) FetchByReference(ctx context.Context, kind models.Reference return nil, dbErr(err) } + if len(requiredTags) > 0 || len(unwantedTags) > 0 { + return r.populateFiltered(ctx, records, requiredTags, unwantedTags) + } + return r.populate(ctx, records) } @@ -211,6 +235,27 @@ func (r *DeviceRepo) populateOne(ctx context.Context, record deviceRecord) (*mod return &records[0], nil } +func (r *DeviceRepo) populateFiltered(ctx context.Context, records []deviceRecord, requiredTags []string, unwantedTags []string) ([]models.Device, error) { + devices, err := r.populate(ctx, records) + if err != nil { + return nil, err + } + + filteredDevices := make([]models.Device, 0, len(devices)) + + for _, device := range devices { + if device.HasTag(unwantedTags...) { + continue + } + + if len(requiredTags) == 0 || device.HasTag(requiredTags...) { + filteredDevices = append(filteredDevices, device) + } + } + + return filteredDevices, nil +} + func (r *DeviceRepo) populate(ctx context.Context, records []deviceRecord) ([]models.Device, error) { if len(records) == 0 { return []models.Device{}, nil diff --git a/models/device.go b/models/device.go index c2d2af7..fe02e23 100644 --- a/models/device.go +++ b/models/device.go @@ -99,6 +99,18 @@ func (d *Device) Validate() error { return nil } +func (d *Device) HasTag(tags ...string) bool { + for _, c := range d.Tags { + for _, c2 := range tags { + if c == c2 { + return true + } + } + } + + return false +} + func (d *Device) HasCapability(capabilities ...DeviceCapability) bool { for _, c := range d.Capabilities { for _, c2 := range capabilities {