You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

421 lines
11 KiB

3 years ago
3 years ago
3 years ago
3 years ago
  1. package mysql
  2. import (
  3. "context"
  4. "database/sql"
  5. "encoding/json"
  6. "git.aiterp.net/lucifer/new-server/models"
  7. sq "github.com/Masterminds/squirrel"
  8. "github.com/jmoiron/sqlx"
  9. "log"
  10. "strings"
  11. )
  12. type deviceRecord struct {
  13. ID int `db:"id"`
  14. BridgeID int `db:"bridge_id"`
  15. InternalID string `db:"internal_id"`
  16. Icon string `db:"icon"`
  17. Name string `db:"name"`
  18. Capabilities string `db:"capabilities"`
  19. ButtonNames string `db:"button_names"`
  20. SceneAssignmentJSON json.RawMessage `db:"scene_assignments"`
  21. }
  22. type deviceStateRecord struct {
  23. DeviceID int `db:"device_id"`
  24. Hue float64 `db:"hue"`
  25. Saturation float64 `db:"saturation"`
  26. Kelvin int `db:"kelvin"`
  27. Power bool `db:"power"`
  28. Intensity float64 `db:"intensity"`
  29. Temperature int `db:"temperature"`
  30. }
  31. type devicePropertyRecord struct {
  32. DeviceID int `db:"device_id"`
  33. Key string `db:"prop_key"`
  34. Value string `db:"prop_value"`
  35. IsUser bool `db:"is_user"`
  36. }
  37. type deviceTagRecord struct {
  38. DeviceID int `db:"device_id"`
  39. TagName string `db:"tag_name"`
  40. }
  41. type DeviceRepo struct {
  42. DBX *sqlx.DB
  43. }
  44. func (r *DeviceRepo) Find(ctx context.Context, id int) (*models.Device, error) {
  45. var device deviceRecord
  46. err := r.DBX.GetContext(ctx, &device, "SELECT * FROM device WHERE id = ?", id)
  47. if err != nil {
  48. return nil, dbErr(err)
  49. }
  50. return r.populateOne(ctx, device)
  51. }
  52. func (r *DeviceRepo) FetchByReference(ctx context.Context, kind models.ReferenceKind, value string) ([]models.Device, error) {
  53. q := sq.Select("device.*").From("device")
  54. var requiredTags []string
  55. var unwantedTags []string
  56. switch kind {
  57. case models.RKDeviceID:
  58. q = q.Where(sq.Eq{"id": strings.Split(value, ",")})
  59. case models.RKBridgeID:
  60. q = q.Where(sq.Eq{"bridge_id": strings.Split(value, ",")})
  61. case models.RKName:
  62. value = strings.ReplaceAll(value, "*", "%")
  63. q = q.Where(sq.Like{"name": value})
  64. case models.RKTag:
  65. allTags := strings.Split(strings.ReplaceAll(strings.ReplaceAll(value, "-", ",-"), "+", ",+"), ",")
  66. optionalTags := make([]string, 0, len(allTags))
  67. for _, tag := range allTags {
  68. if tag == "" || tag == "+" || tag == "-" {
  69. continue
  70. }
  71. if strings.HasPrefix(tag, "+") {
  72. requiredTags = append(requiredTags, tag[1:])
  73. } else if strings.HasPrefix(tag, "-") {
  74. unwantedTags = append(unwantedTags, tag[1:])
  75. } else {
  76. optionalTags = append(optionalTags, tag)
  77. }
  78. }
  79. q = q.Join("device_tag dt ON device.id=dt.device_id").Where(sq.Eq{"dt.tag_name": optionalTags})
  80. case models.RKAll:
  81. default:
  82. log.Println("Unknown reference kind used for device fetch:", kind)
  83. return []models.Device{}, nil
  84. }
  85. query, args, err := q.OrderBy("name", "id").ToSql()
  86. if err != nil {
  87. if err == sql.ErrNoRows {
  88. return []models.Device{}, nil
  89. }
  90. return nil, dbErr(err)
  91. }
  92. records := make([]deviceRecord, 0, 8)
  93. err = r.DBX.SelectContext(ctx, &records, query, args...)
  94. if err != nil {
  95. return nil, dbErr(err)
  96. }
  97. if len(requiredTags) > 0 || len(unwantedTags) > 0 {
  98. return r.populateFiltered(ctx, records, requiredTags, unwantedTags)
  99. }
  100. return r.populate(ctx, records)
  101. }
  102. func (r *DeviceRepo) SaveMany(ctx context.Context, mode models.SaveMode, devices []models.Device) error {
  103. tx, err := r.DBX.Beginx()
  104. if err != nil {
  105. return dbErr(err)
  106. }
  107. defer tx.Rollback()
  108. for i, device := range devices {
  109. scenesJSON, err := json.Marshal(device.SceneAssignments)
  110. if err != nil {
  111. return err
  112. }
  113. record := deviceRecord{
  114. ID: device.ID,
  115. BridgeID: device.BridgeID,
  116. InternalID: device.InternalID,
  117. SceneAssignmentJSON: scenesJSON,
  118. Icon: device.Icon,
  119. Name: device.Name,
  120. Capabilities: strings.Join(models.DeviceCapabilitiesToStrings(device.Capabilities), ","),
  121. ButtonNames: strings.Join(device.ButtonNames, ","),
  122. }
  123. if device.ID > 0 {
  124. _, err := tx.NamedExecContext(ctx, `
  125. UPDATE device SET
  126. internal_id = :internal_id,
  127. icon = :icon,
  128. name = :name,
  129. capabilities = :capabilities,
  130. button_names = :button_names,
  131. scene_assignments = :scene_assignments
  132. WHERE id=:id
  133. `, record)
  134. if err != nil {
  135. return dbErr(err)
  136. }
  137. // Let's just be lazy for now, optimize later if need be.
  138. if mode == 0 || mode&models.SMTags != 0 {
  139. _, err = tx.ExecContext(ctx, "DELETE FROM device_tag WHERE device_id=?", record.ID)
  140. if err != nil {
  141. return dbErr(err)
  142. }
  143. }
  144. if mode == 0 || mode&models.SMProperties != 0 {
  145. _, err = tx.ExecContext(ctx, "DELETE FROM device_property WHERE device_id=?", record.ID)
  146. if err != nil {
  147. return dbErr(err)
  148. }
  149. }
  150. } else {
  151. res, err := tx.NamedExecContext(ctx, `
  152. INSERT INTO device (bridge_id, internal_id, icon, name, capabilities, button_names)
  153. VALUES (:bridge_id, :internal_id, :icon, :name, :capabilities, :button_names)
  154. `, record)
  155. if err != nil {
  156. return dbErr(err)
  157. }
  158. lastID, err := res.LastInsertId()
  159. if err != nil {
  160. return dbErr(err)
  161. }
  162. record.ID = int(lastID)
  163. devices[i].ID = int(lastID)
  164. }
  165. if mode == 0 || mode&models.SMTags != 0 {
  166. for _, tag := range device.Tags {
  167. _, err := tx.ExecContext(ctx, "INSERT INTO device_tag (device_id, tag_name) VALUES (?, ?)", record.ID, tag)
  168. if err != nil {
  169. return dbErr(err)
  170. }
  171. }
  172. }
  173. if mode == 0 || mode&models.SMProperties != 0 {
  174. for key, value := range device.UserProperties {
  175. _, err := tx.ExecContext(ctx, "INSERT INTO device_property (device_id, prop_key, prop_value, is_user) VALUES (?, ?, ?, 1)",
  176. record.ID, key, value,
  177. )
  178. if err != nil {
  179. return dbErr(err)
  180. }
  181. }
  182. for key, value := range device.DriverProperties {
  183. j, err := json.Marshal(value)
  184. if err != nil {
  185. // Eh, it'll get filled by the driver anyway
  186. continue
  187. }
  188. _, err = tx.ExecContext(ctx, "INSERT INTO device_property (device_id, prop_key, prop_value, is_user) VALUES (?, ?, ?, 0)",
  189. record.ID, key, string(j),
  190. )
  191. if err != nil {
  192. // Return err here anyway, it might put the tx in a bad state to ignore it.
  193. return dbErr(err)
  194. }
  195. }
  196. }
  197. if mode == 0 || mode&models.SMState != 0 {
  198. _, err = tx.NamedExecContext(ctx, `
  199. REPLACE INTO device_state(device_id, hue, saturation, kelvin, power, intensity, temperature)
  200. VALUES (:device_id, :hue, :saturation, :kelvin, :power, :intensity, :temperature)
  201. `, deviceStateRecord{
  202. DeviceID: record.ID,
  203. Hue: device.State.Color.Hue,
  204. Saturation: device.State.Color.Saturation,
  205. Kelvin: device.State.Color.Kelvin,
  206. Power: device.State.Power,
  207. Intensity: device.State.Intensity,
  208. Temperature: device.State.Temperature,
  209. })
  210. if err != nil {
  211. return dbErr(err)
  212. }
  213. }
  214. }
  215. return tx.Commit()
  216. }
  217. func (r *DeviceRepo) Save(ctx context.Context, device *models.Device, mode models.SaveMode) error {
  218. devices := []models.Device{*device}
  219. err := r.SaveMany(ctx, mode, devices)
  220. if err != nil {
  221. return err
  222. }
  223. *device = devices[0]
  224. return nil
  225. }
  226. func (r *DeviceRepo) Delete(ctx context.Context, device *models.Device) error {
  227. _, err := r.DBX.ExecContext(ctx, "DELETE FROM device WHERE Id=?", device.ID)
  228. if err != nil {
  229. return dbErr(err)
  230. }
  231. return nil
  232. }
  233. func (r *DeviceRepo) populateOne(ctx context.Context, record deviceRecord) (*models.Device, error) {
  234. records, err := r.populate(ctx, []deviceRecord{record})
  235. if err != nil {
  236. return nil, err
  237. }
  238. return &records[0], nil
  239. }
  240. func (r *DeviceRepo) populateFiltered(ctx context.Context, records []deviceRecord, requiredTags []string, unwantedTags []string) ([]models.Device, error) {
  241. devices, err := r.populate(ctx, records)
  242. if err != nil {
  243. return nil, err
  244. }
  245. filteredDevices := make([]models.Device, 0, len(devices))
  246. for _, device := range devices {
  247. if device.HasTag(unwantedTags...) {
  248. continue
  249. }
  250. if len(requiredTags) == 0 || device.HasTag(requiredTags...) {
  251. filteredDevices = append(filteredDevices, device)
  252. }
  253. }
  254. return filteredDevices, nil
  255. }
  256. func (r *DeviceRepo) populate(ctx context.Context, records []deviceRecord) ([]models.Device, error) {
  257. if len(records) == 0 {
  258. return []models.Device{}, nil
  259. }
  260. ids := make([]int, 0, len(records))
  261. for _, record := range records {
  262. ids = append(ids, record.ID)
  263. }
  264. tagsQuery, tagsArgs, err := sq.Select("*").From("device_tag").Where(sq.Eq{"device_id": ids}).ToSql()
  265. if err != nil {
  266. return nil, dbErr(err)
  267. }
  268. propsQuery, propsArgs, err := sq.Select("*").From("device_property").Where(sq.Eq{"device_id": ids}).ToSql()
  269. if err != nil {
  270. return nil, dbErr(err)
  271. }
  272. stateQuery, stateArgs, err := sq.Select("*").From("device_state").Where(sq.Eq{"device_id": ids}).ToSql()
  273. if err != nil {
  274. return nil, dbErr(err)
  275. }
  276. states := make([]deviceStateRecord, 0, len(records))
  277. props := make([]devicePropertyRecord, 0, len(records)*8)
  278. tags := make([]deviceTagRecord, 0, len(records)*4)
  279. err = r.DBX.SelectContext(ctx, &states, stateQuery, stateArgs...)
  280. if err != nil {
  281. return nil, dbErr(err)
  282. }
  283. err = r.DBX.SelectContext(ctx, &props, propsQuery, propsArgs...)
  284. if err != nil {
  285. return nil, dbErr(err)
  286. }
  287. err = r.DBX.SelectContext(ctx, &tags, tagsQuery, tagsArgs...)
  288. if err != nil {
  289. return nil, dbErr(err)
  290. }
  291. hasAdded := make(map[int]bool, len(records))
  292. devices := make([]models.Device, 0, len(records))
  293. for _, record := range records {
  294. if hasAdded[record.ID] {
  295. continue
  296. }
  297. device := models.Device{
  298. ID: record.ID,
  299. BridgeID: record.BridgeID,
  300. InternalID: record.InternalID,
  301. Icon: record.Icon,
  302. Name: record.Name,
  303. SceneAssignments: make([]models.DeviceSceneAssignment, 0, 4),
  304. ButtonNames: strings.Split(record.ButtonNames, ","),
  305. DriverProperties: make(map[string]interface{}, 8),
  306. UserProperties: make(map[string]string, 8),
  307. Tags: make([]string, 0, 8),
  308. }
  309. _ = json.Unmarshal(record.SceneAssignmentJSON, &device.SceneAssignments)
  310. if device.ButtonNames[0] == "" {
  311. device.ButtonNames = device.ButtonNames[:0]
  312. }
  313. caps := make([]models.DeviceCapability, 0, 16)
  314. for _, capStr := range strings.Split(record.Capabilities, ",") {
  315. caps = append(caps, models.DeviceCapability(capStr))
  316. }
  317. device.Capabilities = caps
  318. for _, state := range states {
  319. if state.DeviceID == record.ID {
  320. device.State = models.DeviceState{
  321. Power: state.Power,
  322. Color: models.ColorValue{
  323. Hue: state.Hue,
  324. Saturation: state.Saturation,
  325. Kelvin: state.Kelvin,
  326. },
  327. Intensity: state.Intensity,
  328. Temperature: state.Temperature,
  329. }
  330. }
  331. }
  332. driverProps := make(map[string]json.RawMessage, 8)
  333. for _, prop := range props {
  334. if prop.DeviceID == record.ID {
  335. if prop.IsUser {
  336. device.UserProperties[prop.Key] = prop.Value
  337. } else {
  338. driverProps[prop.Key] = json.RawMessage(prop.Value)
  339. }
  340. }
  341. }
  342. if len(driverProps) > 0 {
  343. j, err := json.Marshal(driverProps)
  344. if err != nil {
  345. return nil, dbErr(err)
  346. }
  347. err = json.Unmarshal(j, &device.DriverProperties)
  348. if err != nil {
  349. return nil, dbErr(err)
  350. }
  351. }
  352. for _, tag := range tags {
  353. if tag.DeviceID == record.ID {
  354. device.Tags = append(device.Tags, tag.TagName)
  355. }
  356. }
  357. hasAdded[record.ID] = true
  358. devices = append(devices, device)
  359. }
  360. return devices, nil
  361. }