You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

393 lines
10 KiB

  1. package mysql
  2. import (
  3. "context"
  4. "database/sql"
  5. "encoding/json"
  6. "git.aiterp.net/lucifer/new-server/models"
  7. sq "github.com/Masterminds/squirrel"
  8. "github.com/jmoiron/sqlx"
  9. "log"
  10. "strings"
  11. )
  12. type deviceRecord struct {
  13. ID int `db:"id"`
  14. BridgeID int `db:"bridge_id"`
  15. InternalID string `db:"internal_id"`
  16. Icon string `db:"icon"`
  17. Name string `db:"name"`
  18. Capabilities string `db:"capabilities"`
  19. ButtonNames string `db:"button_names"`
  20. }
  21. type deviceStateRecord struct {
  22. DeviceID int `db:"device_id"`
  23. Hue float64 `db:"hue"`
  24. Saturation float64 `db:"saturation"`
  25. Kelvin int `db:"kelvin"`
  26. Power bool `db:"power"`
  27. Intensity float64 `db:"intensity"`
  28. }
  29. type devicePropertyRecord struct {
  30. DeviceID int `db:"device_id"`
  31. Key string `db:"prop_key"`
  32. Value string `db:"prop_value"`
  33. IsUser bool `db:"is_user"`
  34. }
  35. type deviceTagRecord struct {
  36. DeviceID int `db:"device_id"`
  37. TagName string `db:"tag_name"`
  38. }
  39. type DeviceRepo struct {
  40. DBX *sqlx.DB
  41. }
  42. func (r *DeviceRepo) Find(ctx context.Context, id int) (*models.Device, error) {
  43. var device deviceRecord
  44. err := r.DBX.GetContext(ctx, &device, "SELECT * FROM device WHERE id = ?", id)
  45. if err != nil {
  46. return nil, dbErr(err)
  47. }
  48. return r.populateOne(ctx, device)
  49. }
  50. func (r *DeviceRepo) FetchByReference(ctx context.Context, kind models.ReferenceKind, value string) ([]models.Device, error) {
  51. q := sq.Select("device.*").From("device")
  52. var requiredTags []string
  53. var unwantedTags []string
  54. switch kind {
  55. case models.RKDeviceID:
  56. q = q.Where(sq.Eq{"id": strings.Split(value, ",")})
  57. case models.RKBridgeID:
  58. q = q.Where(sq.Eq{"bridge_id": strings.Split(value, ",")})
  59. case models.RKName:
  60. value = strings.ReplaceAll(value, "*", "%")
  61. q = q.Where(sq.Like{"name": value})
  62. case models.RKTag:
  63. allTags := strings.Split(strings.ReplaceAll(strings.ReplaceAll(value, "-", ",-"), "+", ",+"), ",")
  64. optionalTags := make([]string, 0, len(allTags))
  65. for _, tag := range allTags {
  66. if tag == "" || tag == "+" || tag == "-" {
  67. continue
  68. }
  69. if strings.HasPrefix(tag, "+") {
  70. requiredTags = append(requiredTags, tag[1:])
  71. } else if strings.HasPrefix(tag, "-") {
  72. unwantedTags = append(unwantedTags, tag[1:])
  73. } else {
  74. optionalTags = append(optionalTags, tag)
  75. }
  76. }
  77. q = q.Join("device_tag dt ON device.id=dt.device_id").Where(sq.Eq{"dt.tag_name": optionalTags})
  78. case models.RKAll:
  79. default:
  80. log.Println("Unknown reference kind used for device fetch:", kind)
  81. return []models.Device{}, nil
  82. }
  83. query, args, err := q.OrderBy("name", "id").ToSql()
  84. if err != nil {
  85. if err == sql.ErrNoRows {
  86. return []models.Device{}, nil
  87. }
  88. return nil, dbErr(err)
  89. }
  90. records := make([]deviceRecord, 0, 8)
  91. err = r.DBX.SelectContext(ctx, &records, query, args...)
  92. if err != nil {
  93. return nil, dbErr(err)
  94. }
  95. if len(requiredTags) > 0 || len(unwantedTags) > 0 {
  96. return r.populateFiltered(ctx, records, requiredTags, unwantedTags)
  97. }
  98. return r.populate(ctx, records)
  99. }
  100. func (r *DeviceRepo) Save(ctx context.Context, device *models.Device, mode models.SaveMode) error {
  101. tx, err := r.DBX.Beginx()
  102. if err != nil {
  103. return dbErr(err)
  104. }
  105. defer tx.Rollback()
  106. record := deviceRecord{
  107. ID: device.ID,
  108. BridgeID: device.BridgeID,
  109. InternalID: device.InternalID,
  110. Icon: device.Icon,
  111. Name: device.Name,
  112. Capabilities: strings.Join(models.DeviceCapabilitiesToStrings(device.Capabilities), ","),
  113. ButtonNames: strings.Join(device.ButtonNames, ","),
  114. }
  115. if device.ID > 0 {
  116. _, err := tx.NamedExecContext(ctx, `
  117. UPDATE device SET
  118. internal_id = :internal_id,
  119. icon = :icon,
  120. name = :name,
  121. capabilities = :capabilities,
  122. button_names = :button_names
  123. WHERE id=:id
  124. `, record)
  125. if err != nil {
  126. return dbErr(err)
  127. }
  128. // Let's just be lazy for now, optimize later if need be.
  129. if mode == 0 || mode&models.SMTags != 0 {
  130. _, err = tx.ExecContext(ctx, "DELETE FROM device_tag WHERE device_id=?", record.ID)
  131. if err != nil {
  132. return dbErr(err)
  133. }
  134. }
  135. if mode == 0 || mode&models.SMProperties != 0 {
  136. _, err = tx.ExecContext(ctx, "DELETE FROM device_property WHERE device_id=?", record.ID)
  137. if err != nil {
  138. return dbErr(err)
  139. }
  140. }
  141. } else {
  142. res, err := tx.NamedExecContext(ctx, `
  143. INSERT INTO device (bridge_id, internal_id, icon, name, capabilities, button_names)
  144. VALUES (:bridge_id, :internal_id, :icon, :name, :capabilities, :button_names)
  145. `, record)
  146. if err != nil {
  147. return dbErr(err)
  148. }
  149. lastID, err := res.LastInsertId()
  150. if err != nil {
  151. return dbErr(err)
  152. }
  153. record.ID = int(lastID)
  154. device.ID = int(lastID)
  155. }
  156. if mode == 0 || mode&models.SMTags != 0 {
  157. for _, tag := range device.Tags {
  158. _, err := tx.ExecContext(ctx, "INSERT INTO device_tag (device_id, tag_name) VALUES (?, ?)", record.ID, tag)
  159. if err != nil {
  160. return dbErr(err)
  161. }
  162. }
  163. }
  164. if mode == 0 || mode&models.SMProperties != 0 {
  165. for key, value := range device.UserProperties {
  166. _, err := tx.ExecContext(ctx, "INSERT INTO device_property (device_id, prop_key, prop_value, is_user) VALUES (?, ?, ?, 1)",
  167. record.ID, key, value,
  168. )
  169. if err != nil {
  170. return dbErr(err)
  171. }
  172. }
  173. for key, value := range device.DriverProperties {
  174. j, err := json.Marshal(value)
  175. if err != nil {
  176. // Eh, it'll get filled by the driver anyway
  177. continue
  178. }
  179. _, err = tx.ExecContext(ctx, "INSERT INTO device_property (device_id, prop_key, prop_value, is_user) VALUES (?, ?, ?, 0)",
  180. record.ID, key, string(j),
  181. )
  182. if err != nil {
  183. // Return err here anyway, it might put the tx in a bad state to ignore it.
  184. return dbErr(err)
  185. }
  186. }
  187. }
  188. if mode == 0 || mode&models.SMState != 0 {
  189. _, err = tx.NamedExecContext(ctx, `
  190. REPLACE INTO device_state(device_id, hue, saturation, kelvin, power, intensity)
  191. VALUES (:device_id, :hue, :saturation, :kelvin, :power, :intensity)
  192. `, deviceStateRecord{
  193. DeviceID: record.ID,
  194. Hue: device.State.Color.Hue,
  195. Saturation: device.State.Color.Saturation,
  196. Kelvin: device.State.Color.Kelvin,
  197. Power: device.State.Power,
  198. Intensity: device.State.Intensity,
  199. })
  200. if err != nil {
  201. return dbErr(err)
  202. }
  203. }
  204. return tx.Commit()
  205. }
  206. func (r *DeviceRepo) Delete(ctx context.Context, device *models.Device) error {
  207. _, err := r.DBX.ExecContext(ctx, "DELETE FROM device WHERE Id=?", device.ID)
  208. if err != nil {
  209. return dbErr(err)
  210. }
  211. return nil
  212. }
  213. func (r *DeviceRepo) populateOne(ctx context.Context, record deviceRecord) (*models.Device, error) {
  214. records, err := r.populate(ctx, []deviceRecord{record})
  215. if err != nil {
  216. return nil, err
  217. }
  218. return &records[0], nil
  219. }
  220. func (r *DeviceRepo) populateFiltered(ctx context.Context, records []deviceRecord, requiredTags []string, unwantedTags []string) ([]models.Device, error) {
  221. devices, err := r.populate(ctx, records)
  222. if err != nil {
  223. return nil, err
  224. }
  225. filteredDevices := make([]models.Device, 0, len(devices))
  226. for _, device := range devices {
  227. if device.HasTag(unwantedTags...) {
  228. continue
  229. }
  230. if len(requiredTags) == 0 || device.HasTag(requiredTags...) {
  231. filteredDevices = append(filteredDevices, device)
  232. }
  233. }
  234. return filteredDevices, nil
  235. }
  236. func (r *DeviceRepo) populate(ctx context.Context, records []deviceRecord) ([]models.Device, error) {
  237. if len(records) == 0 {
  238. return []models.Device{}, nil
  239. }
  240. ids := make([]int, 0, len(records))
  241. for _, record := range records {
  242. ids = append(ids, record.ID)
  243. }
  244. tagsQuery, tagsArgs, err := sq.Select("*").From("device_tag").Where(sq.Eq{"device_id": ids}).ToSql()
  245. if err != nil {
  246. return nil, dbErr(err)
  247. }
  248. propsQuery, propsArgs, err := sq.Select("*").From("device_property").Where(sq.Eq{"device_id": ids}).ToSql()
  249. if err != nil {
  250. return nil, dbErr(err)
  251. }
  252. stateQuery, stateArgs, err := sq.Select("*").From("device_state").Where(sq.Eq{"device_id": ids}).ToSql()
  253. if err != nil {
  254. return nil, dbErr(err)
  255. }
  256. states := make([]deviceStateRecord, 0, len(records))
  257. props := make([]devicePropertyRecord, 0, len(records)*8)
  258. tags := make([]deviceTagRecord, 0, len(records)*4)
  259. err = r.DBX.SelectContext(ctx, &states, stateQuery, stateArgs...)
  260. if err != nil {
  261. return nil, dbErr(err)
  262. }
  263. err = r.DBX.SelectContext(ctx, &props, propsQuery, propsArgs...)
  264. if err != nil {
  265. return nil, dbErr(err)
  266. }
  267. err = r.DBX.SelectContext(ctx, &tags, tagsQuery, tagsArgs...)
  268. if err != nil {
  269. return nil, dbErr(err)
  270. }
  271. hasAdded := make(map[int]bool, len(records))
  272. devices := make([]models.Device, 0, len(records))
  273. for _, record := range records {
  274. if hasAdded[record.ID] {
  275. continue
  276. }
  277. device := models.Device{
  278. ID: record.ID,
  279. BridgeID: record.BridgeID,
  280. InternalID: record.InternalID,
  281. Icon: record.Icon,
  282. Name: record.Name,
  283. ButtonNames: strings.Split(record.ButtonNames, ","),
  284. DriverProperties: make(map[string]interface{}, 8),
  285. UserProperties: make(map[string]string, 8),
  286. Tags: make([]string, 0, 8),
  287. }
  288. if device.ButtonNames[0] == "" {
  289. device.ButtonNames = device.ButtonNames[:0]
  290. }
  291. caps := make([]models.DeviceCapability, 0, 16)
  292. for _, capStr := range strings.Split(record.Capabilities, ",") {
  293. caps = append(caps, models.DeviceCapability(capStr))
  294. }
  295. device.Capabilities = caps
  296. for _, state := range states {
  297. if state.DeviceID == record.ID {
  298. device.State = models.DeviceState{
  299. Power: state.Power,
  300. Color: models.ColorValue{
  301. Hue: state.Hue,
  302. Saturation: state.Saturation,
  303. Kelvin: state.Kelvin,
  304. },
  305. Intensity: state.Intensity,
  306. }
  307. }
  308. }
  309. driverProps := make(map[string]json.RawMessage, 8)
  310. for _, prop := range props {
  311. if prop.DeviceID == record.ID {
  312. if prop.IsUser {
  313. device.UserProperties[prop.Key] = prop.Value
  314. } else {
  315. driverProps[prop.Key] = json.RawMessage(prop.Value)
  316. }
  317. }
  318. }
  319. if len(driverProps) > 0 {
  320. j, err := json.Marshal(driverProps)
  321. if err != nil {
  322. return nil, dbErr(err)
  323. }
  324. err = json.Unmarshal(j, &device.DriverProperties)
  325. if err != nil {
  326. return nil, dbErr(err)
  327. }
  328. }
  329. for _, tag := range tags {
  330. if tag.DeviceID == record.ID {
  331. device.Tags = append(device.Tags, tag.TagName)
  332. }
  333. }
  334. hasAdded[record.ID] = true
  335. devices = append(devices, device)
  336. }
  337. return devices, nil
  338. }