You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

418 lines
11 KiB

  1. package mysql
  2. import (
  3. "context"
  4. "database/sql"
  5. "encoding/json"
  6. "git.aiterp.net/lucifer/new-server/models"
  7. sq "github.com/Masterminds/squirrel"
  8. "github.com/jmoiron/sqlx"
  9. "log"
  10. "strings"
  11. )
  12. type deviceRecord struct {
  13. ID int `db:"id"`
  14. BridgeID int `db:"bridge_id"`
  15. InternalID string `db:"internal_id"`
  16. Icon string `db:"icon"`
  17. Name string `db:"name"`
  18. Capabilities string `db:"capabilities"`
  19. ButtonNames string `db:"button_names"`
  20. SceneAssignmentJSON json.RawMessage `db:"scene_assignments"`
  21. }
  22. type deviceStateRecord struct {
  23. DeviceID int `db:"device_id"`
  24. Hue float64 `db:"hue"`
  25. Saturation float64 `db:"saturation"`
  26. Kelvin int `db:"kelvin"`
  27. Power bool `db:"power"`
  28. Intensity float64 `db:"intensity"`
  29. }
  30. type devicePropertyRecord struct {
  31. DeviceID int `db:"device_id"`
  32. Key string `db:"prop_key"`
  33. Value string `db:"prop_value"`
  34. IsUser bool `db:"is_user"`
  35. }
  36. type deviceTagRecord struct {
  37. DeviceID int `db:"device_id"`
  38. TagName string `db:"tag_name"`
  39. }
  40. type DeviceRepo struct {
  41. DBX *sqlx.DB
  42. }
  43. func (r *DeviceRepo) Find(ctx context.Context, id int) (*models.Device, error) {
  44. var device deviceRecord
  45. err := r.DBX.GetContext(ctx, &device, "SELECT * FROM device WHERE id = ?", id)
  46. if err != nil {
  47. return nil, dbErr(err)
  48. }
  49. return r.populateOne(ctx, device)
  50. }
  51. func (r *DeviceRepo) FetchByReference(ctx context.Context, kind models.ReferenceKind, value string) ([]models.Device, error) {
  52. q := sq.Select("device.*").From("device")
  53. var requiredTags []string
  54. var unwantedTags []string
  55. switch kind {
  56. case models.RKDeviceID:
  57. q = q.Where(sq.Eq{"id": strings.Split(value, ",")})
  58. case models.RKBridgeID:
  59. q = q.Where(sq.Eq{"bridge_id": strings.Split(value, ",")})
  60. case models.RKName:
  61. value = strings.ReplaceAll(value, "*", "%")
  62. q = q.Where(sq.Like{"name": value})
  63. case models.RKTag:
  64. allTags := strings.Split(strings.ReplaceAll(strings.ReplaceAll(value, "-", ",-"), "+", ",+"), ",")
  65. optionalTags := make([]string, 0, len(allTags))
  66. for _, tag := range allTags {
  67. if tag == "" || tag == "+" || tag == "-" {
  68. continue
  69. }
  70. if strings.HasPrefix(tag, "+") {
  71. requiredTags = append(requiredTags, tag[1:])
  72. } else if strings.HasPrefix(tag, "-") {
  73. unwantedTags = append(unwantedTags, tag[1:])
  74. } else {
  75. optionalTags = append(optionalTags, tag)
  76. }
  77. }
  78. q = q.Join("device_tag dt ON device.id=dt.device_id").Where(sq.Eq{"dt.tag_name": optionalTags})
  79. case models.RKAll:
  80. default:
  81. log.Println("Unknown reference kind used for device fetch:", kind)
  82. return []models.Device{}, nil
  83. }
  84. query, args, err := q.OrderBy("name", "id").ToSql()
  85. if err != nil {
  86. if err == sql.ErrNoRows {
  87. return []models.Device{}, nil
  88. }
  89. return nil, dbErr(err)
  90. }
  91. records := make([]deviceRecord, 0, 8)
  92. err = r.DBX.SelectContext(ctx, &records, query, args...)
  93. if err != nil {
  94. return nil, dbErr(err)
  95. }
  96. if len(requiredTags) > 0 || len(unwantedTags) > 0 {
  97. return r.populateFiltered(ctx, records, requiredTags, unwantedTags)
  98. }
  99. return r.populate(ctx, records)
  100. }
  101. func (r *DeviceRepo) SaveMany(ctx context.Context, mode models.SaveMode, devices []models.Device) error {
  102. tx, err := r.DBX.Beginx()
  103. if err != nil {
  104. return dbErr(err)
  105. }
  106. defer tx.Rollback()
  107. for i, device := range devices {
  108. scenesJSON, err := json.Marshal(device.SceneAssignments)
  109. if err != nil {
  110. return err
  111. }
  112. record := deviceRecord{
  113. ID: device.ID,
  114. BridgeID: device.BridgeID,
  115. InternalID: device.InternalID,
  116. SceneAssignmentJSON: scenesJSON,
  117. Icon: device.Icon,
  118. Name: device.Name,
  119. Capabilities: strings.Join(models.DeviceCapabilitiesToStrings(device.Capabilities), ","),
  120. ButtonNames: strings.Join(device.ButtonNames, ","),
  121. }
  122. if device.ID > 0 {
  123. _, err := tx.NamedExecContext(ctx, `
  124. UPDATE device SET
  125. internal_id = :internal_id,
  126. icon = :icon,
  127. name = :name,
  128. capabilities = :capabilities,
  129. button_names = :button_names,
  130. scene_assignments = :scene_assignments
  131. WHERE id=:id
  132. `, record)
  133. if err != nil {
  134. return dbErr(err)
  135. }
  136. // Let's just be lazy for now, optimize later if need be.
  137. if mode == 0 || mode&models.SMTags != 0 {
  138. _, err = tx.ExecContext(ctx, "DELETE FROM device_tag WHERE device_id=?", record.ID)
  139. if err != nil {
  140. return dbErr(err)
  141. }
  142. }
  143. if mode == 0 || mode&models.SMProperties != 0 {
  144. _, err = tx.ExecContext(ctx, "DELETE FROM device_property WHERE device_id=?", record.ID)
  145. if err != nil {
  146. return dbErr(err)
  147. }
  148. }
  149. } else {
  150. res, err := tx.NamedExecContext(ctx, `
  151. INSERT INTO device (bridge_id, internal_id, icon, name, capabilities, button_names)
  152. VALUES (:bridge_id, :internal_id, :icon, :name, :capabilities, :button_names)
  153. `, record)
  154. if err != nil {
  155. return dbErr(err)
  156. }
  157. lastID, err := res.LastInsertId()
  158. if err != nil {
  159. return dbErr(err)
  160. }
  161. record.ID = int(lastID)
  162. devices[i].ID = int(lastID)
  163. }
  164. if mode == 0 || mode&models.SMTags != 0 {
  165. for _, tag := range device.Tags {
  166. _, err := tx.ExecContext(ctx, "INSERT INTO device_tag (device_id, tag_name) VALUES (?, ?)", record.ID, tag)
  167. if err != nil {
  168. return dbErr(err)
  169. }
  170. }
  171. }
  172. if mode == 0 || mode&models.SMProperties != 0 {
  173. for key, value := range device.UserProperties {
  174. _, err := tx.ExecContext(ctx, "INSERT INTO device_property (device_id, prop_key, prop_value, is_user) VALUES (?, ?, ?, 1)",
  175. record.ID, key, value,
  176. )
  177. if err != nil {
  178. return dbErr(err)
  179. }
  180. }
  181. for key, value := range device.DriverProperties {
  182. j, err := json.Marshal(value)
  183. if err != nil {
  184. // Eh, it'll get filled by the driver anyway
  185. continue
  186. }
  187. _, err = tx.ExecContext(ctx, "INSERT INTO device_property (device_id, prop_key, prop_value, is_user) VALUES (?, ?, ?, 0)",
  188. record.ID, key, string(j),
  189. )
  190. if err != nil {
  191. // Return err here anyway, it might put the tx in a bad state to ignore it.
  192. return dbErr(err)
  193. }
  194. }
  195. }
  196. if mode == 0 || mode&models.SMState != 0 {
  197. _, err = tx.NamedExecContext(ctx, `
  198. REPLACE INTO device_state(device_id, hue, saturation, kelvin, power, intensity)
  199. VALUES (:device_id, :hue, :saturation, :kelvin, :power, :intensity)
  200. `, deviceStateRecord{
  201. DeviceID: record.ID,
  202. Hue: device.State.Color.Hue,
  203. Saturation: device.State.Color.Saturation,
  204. Kelvin: device.State.Color.Kelvin,
  205. Power: device.State.Power,
  206. Intensity: device.State.Intensity,
  207. })
  208. if err != nil {
  209. return dbErr(err)
  210. }
  211. }
  212. }
  213. return tx.Commit()
  214. }
  215. func (r *DeviceRepo) Save(ctx context.Context, device *models.Device, mode models.SaveMode) error {
  216. devices := []models.Device{*device}
  217. err := r.SaveMany(ctx, mode, devices)
  218. if err != nil {
  219. return err
  220. }
  221. *device = devices[0]
  222. return nil
  223. }
  224. func (r *DeviceRepo) Delete(ctx context.Context, device *models.Device) error {
  225. _, err := r.DBX.ExecContext(ctx, "DELETE FROM device WHERE Id=?", device.ID)
  226. if err != nil {
  227. return dbErr(err)
  228. }
  229. return nil
  230. }
  231. func (r *DeviceRepo) populateOne(ctx context.Context, record deviceRecord) (*models.Device, error) {
  232. records, err := r.populate(ctx, []deviceRecord{record})
  233. if err != nil {
  234. return nil, err
  235. }
  236. return &records[0], nil
  237. }
  238. func (r *DeviceRepo) populateFiltered(ctx context.Context, records []deviceRecord, requiredTags []string, unwantedTags []string) ([]models.Device, error) {
  239. devices, err := r.populate(ctx, records)
  240. if err != nil {
  241. return nil, err
  242. }
  243. filteredDevices := make([]models.Device, 0, len(devices))
  244. for _, device := range devices {
  245. if device.HasTag(unwantedTags...) {
  246. continue
  247. }
  248. if len(requiredTags) == 0 || device.HasTag(requiredTags...) {
  249. filteredDevices = append(filteredDevices, device)
  250. }
  251. }
  252. return filteredDevices, nil
  253. }
  254. func (r *DeviceRepo) populate(ctx context.Context, records []deviceRecord) ([]models.Device, error) {
  255. if len(records) == 0 {
  256. return []models.Device{}, nil
  257. }
  258. ids := make([]int, 0, len(records))
  259. for _, record := range records {
  260. ids = append(ids, record.ID)
  261. }
  262. tagsQuery, tagsArgs, err := sq.Select("*").From("device_tag").Where(sq.Eq{"device_id": ids}).ToSql()
  263. if err != nil {
  264. return nil, dbErr(err)
  265. }
  266. propsQuery, propsArgs, err := sq.Select("*").From("device_property").Where(sq.Eq{"device_id": ids}).ToSql()
  267. if err != nil {
  268. return nil, dbErr(err)
  269. }
  270. stateQuery, stateArgs, err := sq.Select("*").From("device_state").Where(sq.Eq{"device_id": ids}).ToSql()
  271. if err != nil {
  272. return nil, dbErr(err)
  273. }
  274. states := make([]deviceStateRecord, 0, len(records))
  275. props := make([]devicePropertyRecord, 0, len(records)*8)
  276. tags := make([]deviceTagRecord, 0, len(records)*4)
  277. err = r.DBX.SelectContext(ctx, &states, stateQuery, stateArgs...)
  278. if err != nil {
  279. return nil, dbErr(err)
  280. }
  281. err = r.DBX.SelectContext(ctx, &props, propsQuery, propsArgs...)
  282. if err != nil {
  283. return nil, dbErr(err)
  284. }
  285. err = r.DBX.SelectContext(ctx, &tags, tagsQuery, tagsArgs...)
  286. if err != nil {
  287. return nil, dbErr(err)
  288. }
  289. hasAdded := make(map[int]bool, len(records))
  290. devices := make([]models.Device, 0, len(records))
  291. for _, record := range records {
  292. if hasAdded[record.ID] {
  293. continue
  294. }
  295. device := models.Device{
  296. ID: record.ID,
  297. BridgeID: record.BridgeID,
  298. InternalID: record.InternalID,
  299. Icon: record.Icon,
  300. Name: record.Name,
  301. SceneAssignments: make([]models.DeviceSceneAssignment, 0, 4),
  302. ButtonNames: strings.Split(record.ButtonNames, ","),
  303. DriverProperties: make(map[string]interface{}, 8),
  304. UserProperties: make(map[string]string, 8),
  305. Tags: make([]string, 0, 8),
  306. }
  307. _ = json.Unmarshal(record.SceneAssignmentJSON, &device.SceneAssignments)
  308. if device.ButtonNames[0] == "" {
  309. device.ButtonNames = device.ButtonNames[:0]
  310. }
  311. caps := make([]models.DeviceCapability, 0, 16)
  312. for _, capStr := range strings.Split(record.Capabilities, ",") {
  313. caps = append(caps, models.DeviceCapability(capStr))
  314. }
  315. device.Capabilities = caps
  316. for _, state := range states {
  317. if state.DeviceID == record.ID {
  318. device.State = models.DeviceState{
  319. Power: state.Power,
  320. Color: models.ColorValue{
  321. Hue: state.Hue,
  322. Saturation: state.Saturation,
  323. Kelvin: state.Kelvin,
  324. },
  325. Intensity: state.Intensity,
  326. }
  327. }
  328. }
  329. driverProps := make(map[string]json.RawMessage, 8)
  330. for _, prop := range props {
  331. if prop.DeviceID == record.ID {
  332. if prop.IsUser {
  333. device.UserProperties[prop.Key] = prop.Value
  334. } else {
  335. driverProps[prop.Key] = json.RawMessage(prop.Value)
  336. }
  337. }
  338. }
  339. if len(driverProps) > 0 {
  340. j, err := json.Marshal(driverProps)
  341. if err != nil {
  342. return nil, dbErr(err)
  343. }
  344. err = json.Unmarshal(j, &device.DriverProperties)
  345. if err != nil {
  346. return nil, dbErr(err)
  347. }
  348. }
  349. for _, tag := range tags {
  350. if tag.DeviceID == record.ID {
  351. device.Tags = append(device.Tags, tag.TagName)
  352. }
  353. }
  354. hasAdded[record.ID] = true
  355. devices = append(devices, device)
  356. }
  357. return devices, nil
  358. }