Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Config option for generating embedded structs for GraphQL interfaces #2220

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 25 additions & 20 deletions codegen/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,24 @@ import (
)

type Config struct {
SchemaFilename StringList `yaml:"schema,omitempty"`
Exec ExecConfig `yaml:"exec"`
Model PackageConfig `yaml:"model,omitempty"`
Federation PackageConfig `yaml:"federation,omitempty"`
Resolver ResolverConfig `yaml:"resolver,omitempty"`
AutoBind []string `yaml:"autobind"`
Models TypeMap `yaml:"models,omitempty"`
StructTag string `yaml:"struct_tag,omitempty"`
Directives map[string]DirectiveConfig `yaml:"directives,omitempty"`
OmitSliceElementPointers bool `yaml:"omit_slice_element_pointers,omitempty"`
SkipValidation bool `yaml:"skip_validation,omitempty"`
SkipModTidy bool `yaml:"skip_mod_tidy,omitempty"`
Sources []*ast.Source `yaml:"-"`
Packages *code.Packages `yaml:"-"`
Schema *ast.Schema `yaml:"-"`
SchemaFilename StringList `yaml:"schema,omitempty"`
Exec ExecConfig `yaml:"exec"`
Model PackageConfig `yaml:"model,omitempty"`
Federation PackageConfig `yaml:"federation,omitempty"`
Resolver ResolverConfig `yaml:"resolver,omitempty"`
AutoBind []string `yaml:"autobind"`
Models TypeMap `yaml:"models,omitempty"`
StructTag string `yaml:"struct_tag,omitempty"`
Directives map[string]DirectiveConfig `yaml:"directives,omitempty"`
OmitSliceElementPointers bool `yaml:"omit_slice_element_pointers,omitempty"`
StructFieldsAlwaysPointers bool `yaml:"struct_fields_always_pointers,omitempty"`
ResolversAlwaysReturnPointers bool `yaml:"resolvers_always_return_pointers,omitempty"`
GenerateEmbeddedStructsForInterfaces bool `yaml:"generate_embedded_structs_for_interfaces,omitempty"`
SkipValidation bool `yaml:"skip_validation,omitempty"`
SkipModTidy bool `yaml:"skip_mod_tidy,omitempty"`
Sources []*ast.Source `yaml:"-"`
Packages *code.Packages `yaml:"-"`
Schema *ast.Schema `yaml:"-"`

// Deprecated: use Federation instead. Will be removed next release
Federated bool `yaml:"federated,omitempty"`
Expand All @@ -41,11 +44,13 @@ var cfgFilenames = []string{".gqlgen.yml", "gqlgen.yml", "gqlgen.yaml"}
// DefaultConfig creates a copy of the default config
func DefaultConfig() *Config {
return &Config{
SchemaFilename: StringList{"schema.graphql"},
Model: PackageConfig{Filename: "models_gen.go"},
Exec: ExecConfig{Filename: "generated.go"},
Directives: map[string]DirectiveConfig{},
Models: TypeMap{},
SchemaFilename: StringList{"schema.graphql"},
Model: PackageConfig{Filename: "models_gen.go"},
Exec: ExecConfig{Filename: "generated.go"},
Directives: map[string]DirectiveConfig{},
Models: TypeMap{},
StructFieldsAlwaysPointers: true,
ResolversAlwaysReturnPointers: true,
}
}

Expand Down
16 changes: 9 additions & 7 deletions codegen/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,15 @@ func BuildData(cfg *config.Config) (*Data, error) {

for _, schemaType := range b.Schema.Types {
switch schemaType.Kind {
case ast.Object:
case ast.Union, ast.Interface, ast.Object:
if (schemaType.Kind == ast.Union || schemaType.Kind == ast.Interface) && !cfg.GenerateEmbeddedStructsForInterfaces {
s.Interfaces[schemaType.Name], err = b.buildInterface(schemaType)
if err != nil {
return nil, fmt.Errorf("unable to bind to interface: %w", err)
}

continue
}
obj, err := b.buildObject(schemaType)
if err != nil {
return nil, fmt.Errorf("unable to build object definition: %w", err)
Expand All @@ -123,12 +131,6 @@ func BuildData(cfg *config.Config) (*Data, error) {
}

s.Inputs = append(s.Inputs, input)

case ast.Union, ast.Interface:
s.Interfaces[schemaType.Name], err = b.buildInterface(schemaType)
if err != nil {
return nil, fmt.Errorf("unable to bind to interface: %w", err)
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion codegen/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func (b *builder) buildField(obj *Object, field *ast.FieldDefinition) (*Field, e
log.Println(err.Error())
}

if f.IsResolver && !f.TypeReference.IsPtr() && f.TypeReference.IsStruct() {
if f.IsResolver && b.Config.ResolversAlwaysReturnPointers && !f.TypeReference.IsPtr() && f.TypeReference.IsStruct() {
f.TypeReference = b.Binder.PointerTo(f.TypeReference)
}

Expand Down
12 changes: 11 additions & 1 deletion docs/content/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,16 @@ resolver:
# Optional: turn on to use []Thing instead of []*Thing
# omit_slice_element_pointers: false

# Optional: turn off to make struct-type struct fields not use pointers
# e.g. type Thing struct { FieldA OtherThing } instead of { FieldA *OtherThing }
# struct_fields_always_pointers: true

# Optional: turn off to make resolvers return values instead of pointers for structs
# resolvers_always_return_pointers: true

# Optional: turn on to generate getter/setter methods for accessing interface fields instead of exporting the fields
# generate_interface_getters_setters: false

# Optional: set to speed up generation time by not performing a final validation pass.
# skip_validation: true

Expand Down Expand Up @@ -116,7 +126,7 @@ type User @goModel(model: "github.com/my/app/models.User") {
}
```

The builtin directives `goField`, `goModel` and `goTag` are automatically registered to `skip_runtime`. Any directives registered as `skip_runtime` will not exposed during introspection and are used during code generation only.
The builtin directives `goField`, `goModel` and `goTag` are automatically registered to `skip_runtime`. Any directives registered as `skip_runtime` will not exposed during introspection and are used during code generation only.

If you have created a new code generation plugin using a directive which does not require runtime execution, the directive will need to be set to `skip_runtime`.

Expand Down
10 changes: 10 additions & 0 deletions init-templates/gqlgen.yml.gotmpl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,16 @@ resolver:
# Optional: turn on to use []Thing instead of []*Thing
# omit_slice_element_pointers: false

# Optional: turn off to make struct-type struct fields not use pointers
# e.g. type Thing struct { FieldA OtherThing } instead of { FieldA *OtherThing }
# struct_fields_always_pointers: true

# Optional: turn off to make resolvers return values instead of pointers for structs
# resolvers_always_return_pointers: true

# Optional: turn on to generate embedded structs when processing GraphQL interfaces
# generate_embedded_structs_for_interfaces: false

# Optional: set to speed up generation time by not performing a final validation pass.
# skip_validation: true

Expand Down
98 changes: 83 additions & 15 deletions plugin/modelgen/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,17 @@ func defaultBuildMutateHook(b *ModelBuild) *ModelBuild {
}

type ModelBuild struct {
PackageName string
Interfaces []*Interface
Models []*Object
Enums []*Enum
Scalars []string
Config *config.Config
Interfaces []*Interface
Models []*Object
Enums []*Enum
Scalars []string
}

type Interface struct {
Description string
Name string
Fields ast.FieldList
Implements []string
}

Expand Down Expand Up @@ -86,23 +87,30 @@ func (m *Plugin) MutateConfig(cfg *config.Config) error {
binder := cfg.NewBinder()

b := &ModelBuild{
PackageName: cfg.Model.Package,
Config: cfg,
}

for _, schemaType := range cfg.Schema.Types {
if cfg.Models.UserDefined(schemaType.Name) {
continue
}
switch schemaType.Kind {
case ast.Interface, ast.Union:
it := &Interface{
Description: schemaType.Description,
Name: schemaType.Name,
Implements: schemaType.Interfaces,
case ast.Interface, ast.Union, ast.Object, ast.InputObject:
if schemaType.Kind == ast.Interface || schemaType.Kind == ast.Union {
it := &Interface{
Description: schemaType.Description,
Name: schemaType.Name,
Implements: schemaType.Interfaces,
Fields: schemaType.Fields,
}

b.Interfaces = append(b.Interfaces, it)

if !cfg.GenerateEmbeddedStructsForInterfaces {
continue
}
}

b.Interfaces = append(b.Interfaces, it)
case ast.Object, ast.InputObject:
if schemaType == cfg.Schema.Query || schemaType == cfg.Schema.Mutation || schemaType == cfg.Schema.Subscription {
continue
}
Expand Down Expand Up @@ -185,8 +193,10 @@ func (m *Plugin) MutateConfig(cfg *config.Config) error {

typ = binder.CopyModifiersFromAst(field.Type, typ)

if isStruct(typ) && (fieldDef.Kind == ast.Object || fieldDef.Kind == ast.InputObject) {
typ = types.NewPointer(typ)
if cfg.StructFieldsAlwaysPointers {
if isStruct(typ) && (fieldDef.Kind == ast.Object || fieldDef.Kind == ast.InputObject) {
typ = types.NewPointer(typ)
}
}

f := &Field{
Expand Down Expand Up @@ -230,6 +240,12 @@ func (m *Plugin) MutateConfig(cfg *config.Config) error {
sort.Slice(b.Models, func(i, j int) bool { return b.Models[i].Name < b.Models[j].Name })
sort.Slice(b.Interfaces, func(i, j int) bool { return b.Interfaces[i].Name < b.Interfaces[j].Name })

// if we are not just turning all struct-type fields in generated structs into pointers, we need to at least
// check for cyclical relationships and recursive structs
if !cfg.StructFieldsAlwaysPointers {
findAndHandleCyclicalRelationships(b)
}

for _, it := range b.Enums {
cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name))
}
Expand Down Expand Up @@ -303,3 +319,55 @@ func isStruct(t types.Type) bool {
_, is := t.Underlying().(*types.Struct)
return is
}

// findAndHandleCyclicalRelationships checks for cyclical relationships between generated structs and replaces them
// with pointers. These relationships will produce compilation errors if they are not pointers.
// Also handles recursive structs.
func findAndHandleCyclicalRelationships(b *ModelBuild) {
for ii, structA := range b.Models {
for _, fieldA := range structA.Fields {
if strings.Contains(fieldA.Type.String(), "NotCyclicalA") {
fmt.Print()
}
if !isStruct(fieldA.Type) {
continue
}

// the field Type string will be in the form "github.com/99designs/gqlgen/codegen/testserver/followschema.LoopA"
// we only want the part after the last dot: "LoopA"
// this could lead to false positives, as we are only checking the name of the struct type, but these
// should be extremely rare, if it is even possible at all.
fieldAStructNameParts := strings.Split(fieldA.Type.String(), ".")
fieldAStructName := fieldAStructNameParts[len(fieldAStructNameParts)-1]

// find this struct type amongst the generated structs
for jj, structB := range b.Models {
if structB.Name != fieldAStructName {
continue
}

// check if structB contains a cyclical reference back to structA
var cyclicalReferenceFound bool
for _, fieldB := range structB.Fields {
if !isStruct(fieldB.Type) {
continue
}

fieldBStructNameParts := strings.Split(fieldB.Type.String(), ".")
fieldBStructName := fieldBStructNameParts[len(fieldBStructNameParts)-1]
if fieldBStructName == structA.Name {
cyclicalReferenceFound = true
fieldB.Type = types.NewPointer(fieldB.Type)
// keep looping in case this struct has additional fields of this type
}
}

// if this is a recursive struct (i.e. structA == structB), ensure that we only change this field to a pointer once
if cyclicalReferenceFound && ii != jj {
fieldA.Type = types.NewPointer(fieldA.Type)
break
}
}
}
}
}
61 changes: 48 additions & 13 deletions plugin/modelgen/models.gotpl
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,64 @@
{{ reserveImport "github.com/99designs/gqlgen/graphql" }}
{{ reserveImport "github.com/99designs/gqlgen/graphql/introspection" }}

{{- range $model := .Interfaces }}
{{ with .Description }} {{.|prefixLines "// "}} {{ end }}
type {{.Name|go }} interface {
{{- range $impl := .Implements }}
{{ $impl|go }}
{{- end }}
Is{{.Name|go }}()
}
{{- if not $.Config.GenerateEmbeddedStructsForInterfaces }}
{{- range $interface := .Interfaces }}
{{ with .Description }} {{.|prefixLines "// "}} {{ end }}
type {{.Name|go }} interface {
{{- range $impl := .Implements }}
{{ $impl|go }}
{{- end }}
Is{{.Name|go }}()
}
{{- end }}
{{- end }}

{{ range $model := .Models }}
{{with .Description }} {{.|prefixLines "// "}} {{end}}
type {{ .Name|go }} struct {
{{- range $impl := $model.Implements }}
{{ $impl|go }}
{{- end}}
{{- range $field := .Fields }}
{{- with .Description }}
{{.|prefixLines "// "}}
{{- /* If we are generating embedded structs for GraphQL interfaces,
we need to determine which of the struct's fields are for the purpose of implementing an interface
and ignore those in favor of simply embedding the interface's struct */ -}}
{{- $found := false }}
{{- if and $.Config.GenerateEmbeddedStructsForInterfaces $model.Implements}}
{{- range $impl := $model.Implements }}
{{- range $interface := $.Interfaces }}
{{- if eq $impl $interface.Name }}
{{- range $interfaceField := $interface.Fields }}
{{- if eq $interfaceField.Name $field.Name }}
{{- $found = true }}
{{- break}}
{{- end }}
{{- end }}
{{- end }}

{{- if $found }}
{{- break}}
{{- end }}
{{- end }}

{{- if $found }}
{{- break}}
{{- end }}
{{- end }}
{{- end }}
{{- if not $found }}
{{- with .Description }}
{{.|prefixLines "// "}}
{{- end}}
{{ $field.Name|go }} {{$field.Type | ref}} `{{$field.Tag}}`
{{- end}}
{{ $field.Name|go }} {{$field.Type | ref}} `{{$field.Tag}}`
{{- end }}
}

{{- range $iface := .Implements }}
func ({{ $model.Name|go }}) Is{{ $iface|go }}() {}
{{- if not $.Config.GenerateEmbeddedStructsForInterfaces }}
{{- range $iface := .Implements }}
func ({{ $model.Name|go }}) Is{{ $iface|go }}() {}
{{- end }}
{{- end }}
{{- end}}

Expand Down
Loading