Skip to content

Commit

Permalink
Fix concurrent for resolvers
Browse files Browse the repository at this point in the history
  • Loading branch information
Вячеслав Крупянский committed Sep 16, 2024
1 parent 4157ef9 commit 259ea5a
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 6 deletions.
2 changes: 1 addition & 1 deletion codegen/args.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func (b *builder) buildArg(obj *Object, arg *ast.ArgumentDefinition) (*FieldArgu

argDirs, err := b.getDirectives(arg.Directives)
if err != nil {
return nil, err
return nil, fmt.Errorf("%s: %w", arg.Name, err)
}
newArg := FieldArgument{
ArgumentDefinition: arg,
Expand Down
50 changes: 50 additions & 0 deletions codegen/concurrent.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package codegen

import "github.com/vektah/gqlparser/v2/ast"

const concurrentDirectiveName = "concurrent"

func makeConcurrentObjectAndField(obj *Object, f *Field) {
var hasConcurrentDirective bool
for _, dir := range obj.Directives {
if dir.Name == concurrentDirectiveName {
hasConcurrentDirective = true
break
}
}

if !hasConcurrentDirective {
obj.Directives = append(obj.Directives, &Directive{
DirectiveDefinition: &ast.DirectiveDefinition{
Name: concurrentDirectiveName,
},
Name: concurrentDirectiveName,
Builtin: true,
})
obj.DisableConcurrency = false
}

if obj.Definition != nil && obj.Definition.Directives.ForName(concurrentDirectiveName) == nil {
obj.Definition.Directives = append(obj.Definition.Directives, &ast.Directive{
Name: concurrentDirectiveName,
Definition: &ast.DirectiveDefinition{
Name: concurrentDirectiveName,
},
})
}

if f.TypeReference != nil && f.TypeReference.Definition != nil {
for _, dir := range f.TypeReference.Definition.Directives {
if dir.Name == concurrentDirectiveName {
hasConcurrentDirective = true
break
}
}

if !hasConcurrentDirective {
f.TypeReference.Definition.Directives = append(f.TypeReference.Definition.Directives, &ast.Directive{
Name: concurrentDirectiveName,
})
}
}
}
4 changes: 4 additions & 0 deletions codegen/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,10 @@ func (c *Config) injectTypesFromSchema() error {
SkipRuntime: true,
}

c.Directives["concurrent"] = DirectiveConfig{
SkipRuntime: true,
}

for _, schemaType := range c.Schema.Types {
if c.IsRoot(schemaType) {
continue
Expand Down
42 changes: 42 additions & 0 deletions codegen/data.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package codegen

import (
"container/list"
"errors"
"fmt"
"os"
Expand Down Expand Up @@ -153,6 +154,8 @@ func BuildData(cfg *config.Config, plugins ...any) (*Data, error) {
return nil, err
}

handleConcurrent(s.Objects)

s.ReferencedTypes = b.buildTypes()

sort.Slice(s.Objects, func(i, j int) bool {
Expand Down Expand Up @@ -234,3 +237,42 @@ func (b *builder) injectIntrospectionRoots(s *Data) error {

return nil
}

func handleConcurrent(objects Objects) {
concurrentObjects := make([]*Object, 0)
for _, obj := range objects {
for _, dir := range obj.Directives {
if dir.Name == concurrentDirectiveName {
concurrentObjects = append(concurrentObjects, obj)
break
}
}
}

queue := list.New()
for _, obj := range concurrentObjects {
queue.PushBack(obj)
}

concurrentObjectsCache := make(map[string]struct{}, 0)

for queue.Len() > 0 {
v := queue.Front()
concurrentObject := v.Value.(*Object)
for _, obj := range objects {
if _, ok := concurrentObjectsCache[obj.Name]; ok {
continue
}

for _, f := range obj.Fields {
if f.TypeReference.Definition == concurrentObject.Definition {
makeConcurrentObjectAndField(obj, f)

queue.PushBack(obj)
concurrentObjectsCache[obj.Name] = struct{}{}
}
}
}
queue.Remove(v)
}
}
5 changes: 3 additions & 2 deletions codegen/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ type Field struct {
func (b *builder) buildField(obj *Object, field *ast.FieldDefinition) (*Field, error) {
dirs, err := b.getDirectives(field.Directives)
if err != nil {
return nil, err
return nil, fmt.Errorf("%s: %w", field.Name, err)
}

f := Field{
Expand Down Expand Up @@ -95,7 +95,7 @@ func (b *builder) bindField(obj *Object, f *Field) (errret error) {
if f.TypeReference != nil {
dirs, err := b.getDirectives(f.TypeReference.Definition.Directives)
if err != nil {
errret = err
errret = fmt.Errorf("%s: %w", f.Name, err)
}
for _, dir := range obj.Directives {
if dir.IsLocation(ast.LocationInputObject) {
Expand Down Expand Up @@ -137,6 +137,7 @@ func (b *builder) bindField(obj *Object, f *Field) (errret error) {
return nil
case b.Config.Models[obj.Name].Fields[f.Name].Resolver:
f.IsResolver = true
makeConcurrentObjectAndField(obj, f)
return nil
case obj.Type == config.MapType:
f.GoFieldType = GoFieldMap
Expand Down
7 changes: 4 additions & 3 deletions codegen/object.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,10 @@ func (b *builder) buildObject(typ *ast.Definition) (*Object, error) {
}
caser := cases.Title(language.English, cases.NoLower)
obj := &Object{
Definition: typ,
Root: b.Config.IsRoot(typ),
DisableConcurrency: typ == b.Schema.Mutation || typ.Directives.ForName("concurrent") == nil,
Definition: typ,
Root: b.Config.IsRoot(typ),
DisableConcurrency: typ == b.Schema.Mutation ||
typ.Directives.ForName(concurrentDirectiveName) == nil,
Stream: typ == b.Schema.Subscription,
Directives: dirs,
PointersInUnmarshalInput: b.Config.ReturnPointersInUnmarshalInput,
Expand Down

0 comments on commit 259ea5a

Please sign in to comment.