You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

422 lines
11 KiB

3 years ago
3 years ago
3 years ago
3 years ago
  1. package mysql
  2. import (
  3. "context"
  4. "database/sql"
  5. "encoding/json"
  6. "git.aiterp.net/lucifer/new-server/internal/color"
  7. "git.aiterp.net/lucifer/new-server/models"
  8. sq "github.com/Masterminds/squirrel"
  9. "github.com/jmoiron/sqlx"
  10. "log"
  11. "strings"
  12. )
  13. type deviceRecord struct {
  14. ID int `db:"id"`
  15. BridgeID int `db:"bridge_id"`
  16. InternalID string `db:"internal_id"`
  17. Icon string `db:"icon"`
  18. Name string `db:"name"`
  19. Capabilities string `db:"capabilities"`
  20. ButtonNames string `db:"button_names"`
  21. SceneAssignmentJSON json.RawMessage `db:"scene_assignments"`
  22. }
  23. type deviceStateRecord struct {
  24. DeviceID int `db:"device_id"`
  25. Hue float64 `db:"hue"`
  26. Saturation float64 `db:"saturation"`
  27. Kelvin int `db:"kelvin"`
  28. Power bool `db:"power"`
  29. Intensity float64 `db:"intensity"`
  30. Temperature int `db:"temperature"`
  31. Color string `db:"color"`
  32. }
  33. type devicePropertyRecord struct {
  34. DeviceID int `db:"device_id"`
  35. Key string `db:"prop_key"`
  36. Value string `db:"prop_value"`
  37. IsUser bool `db:"is_user"`
  38. }
  39. type deviceTagRecord struct {
  40. DeviceID int `db:"device_id"`
  41. TagName string `db:"tag_name"`
  42. }
  43. type DeviceRepo struct {
  44. DBX *sqlx.DB
  45. }
  46. func (r *DeviceRepo) Find(ctx context.Context, id int) (*models.Device, error) {
  47. var device deviceRecord
  48. err := r.DBX.GetContext(ctx, &device, "SELECT * FROM device WHERE id = ?", id)
  49. if err != nil {
  50. return nil, dbErr(err)
  51. }
  52. return r.populateOne(ctx, device)
  53. }
  54. func (r *DeviceRepo) FetchByReference(ctx context.Context, kind models.ReferenceKind, value string) ([]models.Device, error) {
  55. q := sq.Select("device.*").From("device")
  56. var requiredTags []string
  57. var unwantedTags []string
  58. switch kind {
  59. case models.RKDeviceID:
  60. q = q.Where(sq.Eq{"id": strings.Split(value, ",")})
  61. case models.RKBridgeID:
  62. q = q.Where(sq.Eq{"bridge_id": strings.Split(value, ",")})
  63. case models.RKName:
  64. value = strings.ReplaceAll(value, "*", "%")
  65. q = q.Where(sq.Like{"name": value})
  66. case models.RKTag:
  67. allTags := strings.Split(strings.ReplaceAll(strings.ReplaceAll(value, "-", ",-"), "+", ",+"), ",")
  68. optionalTags := make([]string, 0, len(allTags))
  69. for _, tag := range allTags {
  70. if tag == "" || tag == "+" || tag == "-" {
  71. continue
  72. }
  73. if strings.HasPrefix(tag, "+") {
  74. requiredTags = append(requiredTags, tag[1:])
  75. } else if strings.HasPrefix(tag, "-") {
  76. unwantedTags = append(unwantedTags, tag[1:])
  77. } else {
  78. optionalTags = append(optionalTags, tag)
  79. }
  80. }
  81. q = q.Join("device_tag dt ON device.id=dt.device_id").Where(sq.Eq{"dt.tag_name": optionalTags})
  82. case models.RKAll:
  83. default:
  84. log.Println("Unknown reference kind used for device fetch:", kind)
  85. return []models.Device{}, nil
  86. }
  87. query, args, err := q.OrderBy("name", "id").ToSql()
  88. if err != nil {
  89. if err == sql.ErrNoRows {
  90. return []models.Device{}, nil
  91. }
  92. return nil, dbErr(err)
  93. }
  94. records := make([]deviceRecord, 0, 8)
  95. err = r.DBX.SelectContext(ctx, &records, query, args...)
  96. if err != nil {
  97. return nil, dbErr(err)
  98. }
  99. if len(requiredTags) > 0 || len(unwantedTags) > 0 {
  100. return r.populateFiltered(ctx, records, requiredTags, unwantedTags)
  101. }
  102. return r.populate(ctx, records)
  103. }
  104. func (r *DeviceRepo) SaveMany(ctx context.Context, mode models.SaveMode, devices []models.Device) error {
  105. tx, err := r.DBX.Beginx()
  106. if err != nil {
  107. return dbErr(err)
  108. }
  109. defer tx.Rollback()
  110. for i, device := range devices {
  111. scenesJSON, err := json.Marshal(device.SceneAssignments)
  112. if err != nil {
  113. return err
  114. }
  115. record := deviceRecord{
  116. ID: device.ID,
  117. BridgeID: device.BridgeID,
  118. InternalID: device.InternalID,
  119. SceneAssignmentJSON: scenesJSON,
  120. Icon: device.Icon,
  121. Name: device.Name,
  122. Capabilities: strings.Join(models.DeviceCapabilitiesToStrings(device.Capabilities), ","),
  123. ButtonNames: strings.Join(device.ButtonNames, ","),
  124. }
  125. if device.ID > 0 {
  126. _, err := tx.NamedExecContext(ctx, `
  127. UPDATE device SET
  128. internal_id = :internal_id,
  129. icon = :icon,
  130. name = :name,
  131. capabilities = :capabilities,
  132. button_names = :button_names,
  133. scene_assignments = :scene_assignments
  134. WHERE id=:id
  135. `, record)
  136. if err != nil {
  137. return dbErr(err)
  138. }
  139. // Let's just be lazy for now, optimize later if need be.
  140. if mode == 0 || mode&models.SMTags != 0 {
  141. _, err = tx.ExecContext(ctx, "DELETE FROM device_tag WHERE device_id=?", record.ID)
  142. if err != nil {
  143. return dbErr(err)
  144. }
  145. }
  146. if mode == 0 || mode&models.SMProperties != 0 {
  147. _, err = tx.ExecContext(ctx, "DELETE FROM device_property WHERE device_id=?", record.ID)
  148. if err != nil {
  149. return dbErr(err)
  150. }
  151. }
  152. } else {
  153. res, err := tx.NamedExecContext(ctx, `
  154. INSERT INTO device (bridge_id, internal_id, icon, name, capabilities, button_names)
  155. VALUES (:bridge_id, :internal_id, :icon, :name, :capabilities, :button_names)
  156. `, record)
  157. if err != nil {
  158. return dbErr(err)
  159. }
  160. lastID, err := res.LastInsertId()
  161. if err != nil {
  162. return dbErr(err)
  163. }
  164. record.ID = int(lastID)
  165. devices[i].ID = int(lastID)
  166. }
  167. if mode == 0 || mode&models.SMTags != 0 {
  168. for _, tag := range device.Tags {
  169. _, err := tx.ExecContext(ctx, "INSERT INTO device_tag (device_id, tag_name) VALUES (?, ?)", record.ID, tag)
  170. if err != nil {
  171. return dbErr(err)
  172. }
  173. }
  174. }
  175. if mode == 0 || mode&models.SMProperties != 0 {
  176. for key, value := range device.UserProperties {
  177. _, err := tx.ExecContext(ctx, "INSERT INTO device_property (device_id, prop_key, prop_value, is_user) VALUES (?, ?, ?, 1)",
  178. record.ID, key, value,
  179. )
  180. if err != nil {
  181. return dbErr(err)
  182. }
  183. }
  184. for key, value := range device.DriverProperties {
  185. j, err := json.Marshal(value)
  186. if err != nil {
  187. // Eh, it'll get filled by the driver anyway
  188. continue
  189. }
  190. _, err = tx.ExecContext(ctx, "INSERT INTO device_property (device_id, prop_key, prop_value, is_user) VALUES (?, ?, ?, 0)",
  191. record.ID, key, string(j),
  192. )
  193. if err != nil {
  194. // Return err here anyway, it might put the tx in a bad state to ignore it.
  195. return dbErr(err)
  196. }
  197. }
  198. }
  199. if mode == 0 || mode&models.SMState != 0 {
  200. _, err = tx.NamedExecContext(ctx, `
  201. REPLACE INTO device_state(device_id, hue, saturation, kelvin, power, intensity, color, temperature)
  202. VALUES (:device_id, :hue, :saturation, :kelvin, :power, :intensity, :color, :temperature)
  203. `, deviceStateRecord{
  204. DeviceID: record.ID,
  205. Hue: 40,
  206. Saturation: 0,
  207. Kelvin: 0,
  208. Color: device.State.Color.String(),
  209. Power: device.State.Power,
  210. Intensity: device.State.Intensity,
  211. Temperature: device.State.Temperature,
  212. })
  213. if err != nil {
  214. return dbErr(err)
  215. }
  216. }
  217. }
  218. return tx.Commit()
  219. }
  220. func (r *DeviceRepo) Save(ctx context.Context, device *models.Device, mode models.SaveMode) error {
  221. devices := []models.Device{*device}
  222. err := r.SaveMany(ctx, mode, devices)
  223. if err != nil {
  224. return err
  225. }
  226. *device = devices[0]
  227. return nil
  228. }
  229. func (r *DeviceRepo) Delete(ctx context.Context, device *models.Device) error {
  230. _, err := r.DBX.ExecContext(ctx, "DELETE FROM device WHERE Id=?", device.ID)
  231. if err != nil {
  232. return dbErr(err)
  233. }
  234. return nil
  235. }
  236. func (r *DeviceRepo) populateOne(ctx context.Context, record deviceRecord) (*models.Device, error) {
  237. records, err := r.populate(ctx, []deviceRecord{record})
  238. if err != nil {
  239. return nil, err
  240. }
  241. return &records[0], nil
  242. }
  243. func (r *DeviceRepo) populateFiltered(ctx context.Context, records []deviceRecord, requiredTags []string, unwantedTags []string) ([]models.Device, error) {
  244. devices, err := r.populate(ctx, records)
  245. if err != nil {
  246. return nil, err
  247. }
  248. filteredDevices := make([]models.Device, 0, len(devices))
  249. for _, device := range devices {
  250. if device.HasTag(unwantedTags...) {
  251. continue
  252. }
  253. if len(requiredTags) == 0 || device.HasTag(requiredTags...) {
  254. filteredDevices = append(filteredDevices, device)
  255. }
  256. }
  257. return filteredDevices, nil
  258. }
  259. func (r *DeviceRepo) populate(ctx context.Context, records []deviceRecord) ([]models.Device, error) {
  260. if len(records) == 0 {
  261. return []models.Device{}, nil
  262. }
  263. ids := make([]int, 0, len(records))
  264. for _, record := range records {
  265. ids = append(ids, record.ID)
  266. }
  267. tagsQuery, tagsArgs, err := sq.Select("*").From("device_tag").Where(sq.Eq{"device_id": ids}).ToSql()
  268. if err != nil {
  269. return nil, dbErr(err)
  270. }
  271. propsQuery, propsArgs, err := sq.Select("*").From("device_property").Where(sq.Eq{"device_id": ids}).ToSql()
  272. if err != nil {
  273. return nil, dbErr(err)
  274. }
  275. stateQuery, stateArgs, err := sq.Select("*").From("device_state").Where(sq.Eq{"device_id": ids}).ToSql()
  276. if err != nil {
  277. return nil, dbErr(err)
  278. }
  279. states := make([]deviceStateRecord, 0, len(records))
  280. props := make([]devicePropertyRecord, 0, len(records)*8)
  281. tags := make([]deviceTagRecord, 0, len(records)*4)
  282. err = r.DBX.SelectContext(ctx, &states, stateQuery, stateArgs...)
  283. if err != nil {
  284. return nil, dbErr(err)
  285. }
  286. err = r.DBX.SelectContext(ctx, &props, propsQuery, propsArgs...)
  287. if err != nil {
  288. return nil, dbErr(err)
  289. }
  290. err = r.DBX.SelectContext(ctx, &tags, tagsQuery, tagsArgs...)
  291. if err != nil {
  292. return nil, dbErr(err)
  293. }
  294. hasAdded := make(map[int]bool, len(records))
  295. devices := make([]models.Device, 0, len(records))
  296. for _, record := range records {
  297. if hasAdded[record.ID] {
  298. continue
  299. }
  300. device := models.Device{
  301. ID: record.ID,
  302. BridgeID: record.BridgeID,
  303. InternalID: record.InternalID,
  304. Icon: record.Icon,
  305. Name: record.Name,
  306. SceneAssignments: make([]models.DeviceSceneAssignment, 0, 4),
  307. ButtonNames: strings.Split(record.ButtonNames, ","),
  308. DriverProperties: make(map[string]interface{}, 8),
  309. UserProperties: make(map[string]string, 8),
  310. Tags: make([]string, 0, 8),
  311. }
  312. _ = json.Unmarshal(record.SceneAssignmentJSON, &device.SceneAssignments)
  313. if device.ButtonNames[0] == "" {
  314. device.ButtonNames = device.ButtonNames[:0]
  315. }
  316. caps := make([]models.DeviceCapability, 0, 16)
  317. for _, capStr := range strings.Split(record.Capabilities, ",") {
  318. caps = append(caps, models.DeviceCapability(capStr))
  319. }
  320. device.Capabilities = caps
  321. for _, state := range states {
  322. if state.DeviceID == record.ID {
  323. color, _ := color.Parse(state.Color)
  324. device.State = models.DeviceState{
  325. Power: state.Power,
  326. Color: color,
  327. Intensity: state.Intensity,
  328. Temperature: state.Temperature,
  329. }
  330. }
  331. }
  332. driverProps := make(map[string]json.RawMessage, 8)
  333. for _, prop := range props {
  334. if prop.DeviceID == record.ID {
  335. if prop.IsUser {
  336. device.UserProperties[prop.Key] = prop.Value
  337. } else {
  338. driverProps[prop.Key] = json.RawMessage(prop.Value)
  339. }
  340. }
  341. }
  342. if len(driverProps) > 0 {
  343. j, err := json.Marshal(driverProps)
  344. if err != nil {
  345. return nil, dbErr(err)
  346. }
  347. err = json.Unmarshal(j, &device.DriverProperties)
  348. if err != nil {
  349. return nil, dbErr(err)
  350. }
  351. }
  352. for _, tag := range tags {
  353. if tag.DeviceID == record.ID {
  354. device.Tags = append(device.Tags, tag.TagName)
  355. }
  356. }
  357. hasAdded[record.ID] = true
  358. devices = append(devices, device)
  359. }
  360. return devices, nil
  361. }