Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for embedded structs #3242

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions codegen/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ type Config struct {
AutoBind []string `yaml:"autobind"`
Models TypeMap `yaml:"models,omitempty"`
StructTag string `yaml:"struct_tag,omitempty"`
EmbeddedStructsPrefix string `yaml:"embedded_structs_prefix,omitempty"`
Directives map[string]DirectiveConfig `yaml:"directives,omitempty"`
GoBuildTags StringList `yaml:"go_build_tags,omitempty"`
GoInitialisms GoInitialismsConfig `yaml:"go_initialisms,omitempty"`
Expand All @@ -42,6 +43,7 @@ type Config struct {
OmitRootModels bool `yaml:"omit_root_models,omitempty"`
OmitResolverFields bool `yaml:"omit_resolver_fields,omitempty"`
OmitPanicHandler bool `yaml:"omit_panic_handler,omitempty"`
OmitEmbeddedStructs bool `yaml:"omit_embedded_structs,omitempty"`
// If this is set to true, argument directives that
// decorate a field with a null value will still be called.
//
Expand Down Expand Up @@ -77,6 +79,8 @@ func DefaultConfig() *Config {
ReturnPointersInUnmarshalInput: false,
ResolversAlwaysReturnPointers: true,
NullableInputOmittable: false,
OmitEmbeddedStructs: true,
EmbeddedStructsPrefix: "Base",
}
}

Expand Down
7 changes: 7 additions & 0 deletions docs/content/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@ resolver:
# Optional: turn on to exclude resolver fields from the generated models file.
# omit_resolver_fields: false

# Optional: turn off to generate the models using embedding where a base struct is created that implements an interface
# and subsequent graphql types that implement that particular interface will embed the base struct.
# omit_embedded_structs: true

# Optional: turn on to set a different prefix to the generated base structs used for embedding.
# embedded_structs_prefix: "Base"

# Optional: turn off to make struct-type struct fields not use pointers
# e.g. type Thing struct { FieldA OtherThing } instead of { FieldA *OtherThing }
# struct_fields_always_pointers: true
Expand Down
197 changes: 144 additions & 53 deletions plugin/modelgen/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,73 +104,48 @@ func (m *Plugin) MutateConfig(cfg *config.Config) error {
PackageName: cfg.Model.Package,
}

// We need to generate the models in a more deterministic order, therefor, we iterate through `interfaces` and `unions
// first and afterward, the rest of the types.
for _, schemaType := range cfg.Schema.Types {
if cfg.Models.UserDefined(schemaType.Name) {
continue
}
switch schemaType.Kind {
case ast.Interface, ast.Union:
var fields []*Field
var err error
if !cfg.OmitGetters {
fields, err = m.generateFields(cfg, schemaType)
it, err := m.getInterface(cfg, schemaType, b)
if err != nil {
return err
}

if !cfg.OmitEmbeddedStructs {
ob, err := m.getObject(cfg, schemaType, b)
if err != nil {
return err
}
}

it := &Interface{
Description: schemaType.Description,
Name: schemaType.Name,
Implements: schemaType.Interfaces,
Fields: fields,
OmitCheck: cfg.OmitInterfaceChecks,
}
if ob == nil {
continue
}

// if the interface has a key directive as an entity interface, allow it to implement _Entity
if schemaType.Directives.ForName("key") != nil {
it.Implements = append(it.Implements, "_Entity")
ob.Name = fmt.Sprintf("%s%s", cfg.EmbeddedStructsPrefix, ob.Name)
b.Models = append(b.Models, ob)
}

b.Interfaces = append(b.Interfaces, it)
case ast.Object, ast.InputObject:
if cfg.IsRoot(schemaType) {
if !cfg.OmitRootModels {
b.Models = append(b.Models, &Object{
Description: schemaType.Description,
Name: schemaType.Name,
})
}
continue
}
}
}

fields, err := m.generateFields(cfg, schemaType)
for _, schemaType := range cfg.Schema.Types {
if cfg.Models.UserDefined(schemaType.Name) {
continue
}
switch schemaType.Kind {
case ast.Object, ast.InputObject:
it, err := m.getObject(cfg, schemaType, b)
if err != nil {
return err
}

it := &Object{
Description: schemaType.Description,
Name: schemaType.Name,
Fields: fields,
}

// If Interface A implements interface B, and Interface C also implements interface B
// then both A and C have methods of B.
// The reason for checking unique is to prevent the same method B from being generated twice.
uniqueMap := map[string]bool{}
for _, implementor := range cfg.Schema.GetImplements(schemaType) {
if !uniqueMap[implementor.Name] {
it.Implements = append(it.Implements, implementor.Name)
uniqueMap[implementor.Name] = true
}
// for interface implements
for _, iface := range implementor.Interfaces {
if !uniqueMap[iface] {
it.Implements = append(it.Implements, iface)
uniqueMap[iface] = true
}
}
if it == nil {
continue
}

b.Models = append(b.Models, it)
Expand All @@ -192,6 +167,7 @@ func (m *Plugin) MutateConfig(cfg *config.Config) error {
b.Scalars = append(b.Scalars, schemaType.Name)
}
}

sort.Slice(b.Enums, func(i, j int) bool { return b.Enums[i].Name < b.Enums[j].Name })
sort.Slice(b.Models, func(i, j int) bool { return b.Models[i].Name < b.Models[j].Name })
sort.Slice(b.Interfaces, func(i, j int) bool { return b.Interfaces[i].Name < b.Interfaces[j].Name })
Expand Down Expand Up @@ -278,7 +254,7 @@ func (m *Plugin) MutateConfig(cfg *config.Config) error {
getter += fmt.Sprintf("\tif this.%s == nil { return nil }\n", field.GoName)
getter += fmt.Sprintf("\tinterfaceSlice := make(%s, 0, len(this.%s))\n", goType, field.GoName)
getter += fmt.Sprintf("\tfor _, concrete := range this.%s { interfaceSlice = append(interfaceSlice, ", field.GoName)
if interfaceFieldTypeIsPointer && !structFieldTypeIsPointer {
if interfaceFieldTypeIsPointer && !structFieldTypeIsPointer && cfg.OmitEmbeddedStructs {
getter += "&"
} else if !interfaceFieldTypeIsPointer && structFieldTypeIsPointer {
getter += "*"
Expand All @@ -290,7 +266,7 @@ func (m *Plugin) MutateConfig(cfg *config.Config) error {
}
getter := fmt.Sprintf("func (this %s) Get%s() %s { return ", templates.ToGo(model.Name), field.GoName, goType)

if interfaceFieldTypeIsPointer && !structFieldTypeIsPointer {
if interfaceFieldTypeIsPointer && !structFieldTypeIsPointer && cfg.OmitEmbeddedStructs {
getter += "&"
} else if !interfaceFieldTypeIsPointer && structFieldTypeIsPointer {
getter += "*"
Expand Down Expand Up @@ -328,11 +304,52 @@ func (m *Plugin) MutateConfig(cfg *config.Config) error {
return nil
}

func (m *Plugin) generateFields(cfg *config.Config, schemaType *ast.Definition) ([]*Field, error) {
func (m *Plugin) generateFields(cfg *config.Config, schemaType *ast.Definition, model *ModelBuild) ([]*Field, error) {
// func (m *Plugin) generateFields(cfg *config.Config, schemaType *ast.Definition) ([]*Field, error) {
binder := cfg.NewBinder()
fields := make([]*Field, 0)
embeddedFields := map[string]*Field{}

for _, field := range schemaType.Fields {
if model != nil && !cfg.OmitEmbeddedStructs && schemaType.Kind == ast.Object && len(schemaType.Interfaces) > 0 {
interfaceHasField := false
interfaceName := ""

for _, iface := range schemaType.Interfaces {
// We skip the node interface as it should be present in all interfaces implementing it, and it
// creates an issue with ambiguity when referencing `ID`.
if iface == "Node" {
continue
}

for _, modelInterface := range model.Interfaces {
if modelInterface.Name == iface {
for _, mField := range modelInterface.Fields {
if field.Name == mField.Name {
interfaceHasField = true
interfaceName = iface
break
}
}
}
}
}

if interfaceHasField {
embeddedField := &Field{
Type: types.NewNamed(
types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(fmt.Sprintf("%s%s", cfg.EmbeddedStructsPrefix, interfaceName)), nil),
types.NewStruct(nil, nil),
nil,
),
}

embeddedFields[interfaceName] = embeddedField

continue
}
}

f, err := m.generateField(cfg, binder, schemaType, field)
if err != nil {
return nil, err
Expand All @@ -345,6 +362,10 @@ func (m *Plugin) generateFields(cfg *config.Config, schemaType *ast.Definition)
fields = append(fields, f)
}

for _, field := range embeddedFields {
fields = append(fields, field)
}

fields = append(fields, getExtraFields(cfg, schemaType.Name)...)

return fields, nil
Expand Down Expand Up @@ -723,3 +744,73 @@ func readModelTemplate(customModelTemplate string) string {
}
return string(contentBytes)
}

func (m *Plugin) getInterface(cfg *config.Config, schemaType *ast.Definition, b *ModelBuild) (*Interface, error) {
var fields []*Field
var err error
if !cfg.OmitGetters {
fields, err = m.generateFields(cfg, schemaType, nil)
if err != nil {
return nil, err
}
}

it := &Interface{
Description: schemaType.Description,
Name: schemaType.Name,
Implements: schemaType.Interfaces,
Fields: fields,
OmitCheck: cfg.OmitInterfaceChecks,
}

// if the interface has a key directive as an entity interface, allow it to implement _Entity
if schemaType.Directives.ForName("key") != nil {
it.Implements = append(it.Implements, "_Entity")
}

return it, nil
}

func (m *Plugin) getObject(cfg *config.Config, schemaType *ast.Definition, b *ModelBuild) (*Object, error) {
if cfg.IsRoot(schemaType) {
if !cfg.OmitRootModels {
return &Object{
Description: schemaType.Description,
Name: schemaType.Name,
}, nil
}

return nil, nil
}

fields, err := m.generateFields(cfg, schemaType, b)
if err != nil {
return nil, err
}

it := &Object{
Description: schemaType.Description,
Name: schemaType.Name,
Fields: fields,
}

// If Interface A implements interface B, and Interface C also implements interface B
// then both A and C have methods of B.
// The reason for checking unique is to prevent the same method B from being generated twice.
uniqueMap := map[string]bool{}
for _, implementor := range cfg.Schema.GetImplements(schemaType) {
if !uniqueMap[implementor.Name] {
it.Implements = append(it.Implements, implementor.Name)
uniqueMap[implementor.Name] = true
}
// for interface implements
for _, iface := range implementor.Interfaces {
if !uniqueMap[iface] {
it.Implements = append(it.Implements, iface)
uniqueMap[iface] = true
}
}
}

return it, nil
}
2 changes: 1 addition & 1 deletion plugin/modelgen/models.gotpl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
{{- with .Description }}
{{.|prefixLines "// "}}
{{- end}}
{{ $field.GoName }} {{$field.Type | ref}} `{{$field.Tag}}`
{{ $field.GoName }} {{$field.Type | ref}} {{if $field.Tag}}`{{end}}{{$field.Tag}}{{if $field.Tag}}`{{end}}
{{- end }}
}

Expand Down
41 changes: 41 additions & 0 deletions plugin/modelgen/models_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,32 @@ func TestModelGenerationOmitRootModels(t *testing.T) {
require.NotContains(t, string(generated), "type Subscription struct")
}

func TestModelGenerationDontOmitEmbeddedStructs(t *testing.T) {
cfg, err := config.LoadConfig("testdata/gqlgen_embedded_structs_models.yml")
require.NoError(t, err)
require.NoError(t, cfg.Init())
p := Plugin{
FieldHook: DefaultFieldMutateHook,
}
require.NoError(t, p.MutateConfig(cfg))
require.NoError(t, goBuild(t, "./out_embedded_structs_models/"))
generated, err := os.ReadFile("./out_embedded_structs_models/generated_embedded_structs_models.go")
require.NoError(t, err)
require.Contains(t, string(generated), "type BaseElement")

carbonStr := getStringInBetween(string(generated), "type Carbon struct {", "}")
require.NotEqual(t, "", carbonStr)
require.Contains(t, carbonStr, "BaseElement")

magnesiumStr := getStringInBetween(string(generated), "type Magnesium struct {", "}")
require.NotEqual(t, "", magnesiumStr)
require.Contains(t, magnesiumStr, "BaseElement")

potassiumStr := getStringInBetween(string(generated), "type Potassium struct {", "}")
require.NotEqual(t, "", potassiumStr)
require.Contains(t, potassiumStr, "BaseElement")
}

func TestModelGenerationOmitResolverFields(t *testing.T) {
cfg, err := config.LoadConfig("testdata/gqlgen_omit_resolver_fields.yml")
require.NoError(t, err)
Expand Down Expand Up @@ -699,3 +725,18 @@ func TestCustomTemplate(t *testing.T) {
}
require.NoError(t, p.MutateConfig(cfg))
}

func getStringInBetween(str, start, end string) string {
startIndex := strings.Index(str, start)
if startIndex == -1 {
return ""
}

newStr := str[startIndex+len(start):]
e := strings.Index(newStr, end)
if e == -1 {
return ""
}

return newStr[:e]
}
Loading
Loading