diff --git a/graphql/loader/channel.go b/graphql/loader/channel.go index 3bb5bcb..5677e04 100644 --- a/graphql/loader/channel.go +++ b/graphql/loader/channel.go @@ -19,6 +19,8 @@ func (loader *Loader) Channel(key, value string) (channel.Channel, error) { return channel.Channel{}, errors.New("unsupported key") } + loader.loadPrimed(key) + thunk := loader.loaders[key].Load(loader.ctx, dataloader.StringKey(value)) res, err := thunk() if err != nil { @@ -33,6 +35,15 @@ func (loader *Loader) Channel(key, value string) (channel.Channel, error) { return channel, nil } +// PrimeChannels primes channels for loading along with the first one. +func (loader *Loader) PrimeChannels(key string, values ...string) { + if !strings.HasPrefix(key, "Channel.") { + key = "Channel." + key + } + + loader.prime(key, values) +} + func channelNameBatch(ctx context.Context, keys dataloader.Keys) []*dataloader.Result { var results []*dataloader.Result names := keys.Keys() diff --git a/graphql/loader/character.go b/graphql/loader/character.go index 172a68d..85ae51d 100644 --- a/graphql/loader/character.go +++ b/graphql/loader/character.go @@ -15,6 +15,8 @@ func (loader *Loader) Character(key, value string) (character.Character, error) key = "Character." + key } + loader.loadPrimed(key) + if loader.loaders[key] == nil { return character.Character{}, errors.New("unsupported key") } @@ -43,6 +45,8 @@ func (loader *Loader) Characters(key string, values ...string) ([]character.Char return nil, errors.New("unsupported key") } + loader.loadPrimed(key) + thunk := loader.loaders[key].LoadMany(loader.ctx, dataloader.NewKeysFromStrings(values)) res, errs := thunk() for _, err := range errs { @@ -65,6 +69,17 @@ func (loader *Loader) Characters(key string, values ...string) ([]character.Char return chars, nil } +// PrimeCharacters adds a set of characters to be loaded if, and only if, characters +// are going to be loaded. This will fill up the cache and speed up subsequent dataloader +// runs. +func (loader *Loader) PrimeCharacters(key string, values ...string) { + if !strings.HasPrefix(key, "Character.") { + key = "Character." + key + } + + loader.prime(key, values) +} + func characterIDBatch(ctx context.Context, keys dataloader.Keys) []*dataloader.Result { results := make([]*dataloader.Result, 0, len(keys)) ids := keys.Keys() diff --git a/graphql/loader/loader.go b/graphql/loader/loader.go index 9cb0946..61b7e1b 100644 --- a/graphql/loader/loader.go +++ b/graphql/loader/loader.go @@ -3,6 +3,7 @@ package loader import ( "context" "errors" + "sync" "time" "github.com/graph-gophers/dataloader" @@ -16,8 +17,11 @@ var ErrNotFound = errors.New("not found") // A Loader is a collection of data loaders and functions to act on them. It's supposed to be // request-scoped, and will thus keep things cached indefinitely. type Loader struct { + mutex sync.Mutex ctx context.Context loaders map[string]*dataloader.Loader + + primedKeys map[string]map[string]bool } // New initializes the loader. @@ -29,6 +33,7 @@ func New() *Loader { "Character.nick": dataloader.NewBatchedLoader(characterNickBatch, dataloader.WithWait(time.Millisecond)), "Channel.name": dataloader.NewBatchedLoader(channelNameBatch, dataloader.WithWait(time.Millisecond)), }, + primedKeys: make(map[string]map[string]bool), } } @@ -47,3 +52,29 @@ func (loader *Loader) ToContext(ctx context.Context) context.Context { loader.ctx = ctx return context.WithValue(ctx, &contextKey, loader) } + +func (loader *Loader) prime(key string, values []string) { + loader.mutex.Lock() + if loader.primedKeys[key] == nil { + loader.primedKeys[key] = make(map[string]bool) + } + + for _, value := range values { + loader.primedKeys[key][value] = true + } + loader.mutex.Unlock() +} + +func (loader *Loader) loadPrimed(key string) { + loader.mutex.Lock() + if len(loader.primedKeys[key]) > 0 { + primedKeys := make([]string, 0, len(loader.primedKeys[key])) + for key := range loader.primedKeys[key] { + primedKeys = append(primedKeys, key) + } + + loader.loaders[key].LoadMany(loader.ctx, dataloader.NewKeysFromStrings(primedKeys)) + loader.primedKeys[key] = nil + } + loader.mutex.Unlock() +} diff --git a/graphql/resolver/queries/logs.go b/graphql/resolver/queries/logs.go index 5dd9992..91a2e21 100644 --- a/graphql/resolver/queries/logs.go +++ b/graphql/resolver/queries/logs.go @@ -2,7 +2,9 @@ package queries import ( "context" + "errors" + "git.aiterp.net/rpdata/api/graphql/loader" "git.aiterp.net/rpdata/api/graphql/resolver/types" "git.aiterp.net/rpdata/api/model/log" ) @@ -70,6 +72,18 @@ func (r *QueryResolver) Logs(ctx context.Context, args *LogsArgs) ([]*types.LogR } } + if len(logs) >= 100 { + loader := loader.FromContext(ctx) + if loader == nil { + return nil, errors.New("no loader") + } + + for _, log := range logs { + loader.PrimeCharacters("id", log.CharacterIDs...) + loader.PrimeChannels("name", log.Channel) + } + } + resolvers := make([]*types.LogResolver, len(logs)) for i := range logs { resolvers[i] = &types.LogResolver{L: logs[i]}