Skip to content

Commit

Permalink
add support for embedded structs when a graphql type implements an in…
Browse files Browse the repository at this point in the history
…terface
  • Loading branch information
Adrian Lungu committed Sep 5, 2024
1 parent 4c4be0a commit 384d4a0
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 58 deletions.
4 changes: 4 additions & 0 deletions codegen/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ type Config struct {
AutoBind []string `yaml:"autobind"`
Models TypeMap `yaml:"models,omitempty"`
StructTag string `yaml:"struct_tag,omitempty"`
EmbeddedStructsPrefix string `yaml:"embedded_structs_prefix,omitempty"`
Directives map[string]DirectiveConfig `yaml:"directives,omitempty"`
GoBuildTags StringList `yaml:"go_build_tags,omitempty"`
GoInitialisms GoInitialismsConfig `yaml:"go_initialisms,omitempty"`
Expand All @@ -42,6 +43,7 @@ type Config struct {
OmitRootModels bool `yaml:"omit_root_models,omitempty"`
OmitResolverFields bool `yaml:"omit_resolver_fields,omitempty"`
OmitPanicHandler bool `yaml:"omit_panic_handler,omitempty"`
OmitEmbeddedStructs bool `yaml:"omit_embedded_structs,omitempty"`
// If this is set to true, argument directives that
// decorate a field with a null value will still be called.
//
Expand Down Expand Up @@ -77,6 +79,8 @@ func DefaultConfig() *Config {
ReturnPointersInUnmarshalInput: false,
ResolversAlwaysReturnPointers: true,
NullableInputOmittable: false,
OmitEmbeddedStructs: true,
EmbeddedStructsPrefix: "Base",
}
}

Expand Down
7 changes: 7 additions & 0 deletions docs/content/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@ resolver:
# Optional: turn on to exclude resolver fields from the generated models file.
# omit_resolver_fields: false

# Optional: turn off to generate the models using embedding where a base struct is created that implements an interface
# and subsequent graphql types that implement that particular interface will embed the base struct.
# omit_embedded_structs: true

# Optional: turn on to set a different prefix to the generated base structs used for embedding.
# embedded_structs_prefix: "Base"

# Optional: turn off to make struct-type struct fields not use pointers
# e.g. type Thing struct { FieldA OtherThing } instead of { FieldA *OtherThing }
# struct_fields_always_pointers: true
Expand Down
200 changes: 142 additions & 58 deletions plugin/modelgen/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,75 +104,34 @@ func (m *Plugin) MutateConfig(cfg *config.Config) error {
PackageName: cfg.Model.Package,
}

// We need to generate the models in a more deterministic order, therefor, we iterate through `interfaces` and `unions
// first and afterward, the rest of the types.
for _, schemaType := range cfg.Schema.Types {
if cfg.Models.UserDefined(schemaType.Name) {
continue
}
switch schemaType.Kind {
case ast.Interface, ast.Union:
var fields []*Field
var err error
if !cfg.OmitGetters {
fields, err = m.generateFields(cfg, schemaType)
if err != nil {
return err
}
}

it := &Interface{
Description: schemaType.Description,
Name: schemaType.Name,
Implements: schemaType.Interfaces,
Fields: fields,
OmitCheck: cfg.OmitInterfaceChecks,
}

// if the interface has a key directive as an entity interface, allow it to implement _Entity
if schemaType.Directives.ForName("key") != nil {
it.Implements = append(it.Implements, "_Entity")
it, err := m.getInterface(cfg, schemaType, b)
if err != nil {
return err
}

b.Interfaces = append(b.Interfaces, it)
case ast.Object, ast.InputObject:
if cfg.IsRoot(schemaType) {
if !cfg.OmitRootModels {
b.Models = append(b.Models, &Object{
Description: schemaType.Description,
Name: schemaType.Name,
})
}
continue
}
}
}

fields, err := m.generateFields(cfg, schemaType)
for _, schemaType := range cfg.Schema.Types {
if cfg.Models.UserDefined(schemaType.Name) {
continue
}
switch schemaType.Kind {
case ast.Object, ast.InputObject:
it, err := m.getObject(cfg, schemaType, b)
if err != nil {
return err
}

it := &Object{
Description: schemaType.Description,
Name: schemaType.Name,
Fields: fields,
}

// If Interface A implements interface B, and Interface C also implements interface B
// then both A and C have methods of B.
// The reason for checking unique is to prevent the same method B from being generated twice.
uniqueMap := map[string]bool{}
for _, implementor := range cfg.Schema.GetImplements(schemaType) {
if !uniqueMap[implementor.Name] {
it.Implements = append(it.Implements, implementor.Name)
uniqueMap[implementor.Name] = true
}
// for interface implements
for _, iface := range implementor.Interfaces {
if !uniqueMap[iface] {
it.Implements = append(it.Implements, iface)
uniqueMap[iface] = true
}
}
}

b.Models = append(b.Models, it)
case ast.Enum:
it := &Enum{
Expand All @@ -192,6 +151,7 @@ func (m *Plugin) MutateConfig(cfg *config.Config) error {
b.Scalars = append(b.Scalars, schemaType.Name)
}
}

sort.Slice(b.Enums, func(i, j int) bool { return b.Enums[i].Name < b.Enums[j].Name })
sort.Slice(b.Models, func(i, j int) bool { return b.Models[i].Name < b.Models[j].Name })
sort.Slice(b.Interfaces, func(i, j int) bool { return b.Interfaces[i].Name < b.Interfaces[j].Name })
Expand Down Expand Up @@ -278,7 +238,7 @@ func (m *Plugin) MutateConfig(cfg *config.Config) error {
getter += fmt.Sprintf("\tif this.%s == nil { return nil }\n", field.GoName)
getter += fmt.Sprintf("\tinterfaceSlice := make(%s, 0, len(this.%s))\n", goType, field.GoName)
getter += fmt.Sprintf("\tfor _, concrete := range this.%s { interfaceSlice = append(interfaceSlice, ", field.GoName)
if interfaceFieldTypeIsPointer && !structFieldTypeIsPointer {
if interfaceFieldTypeIsPointer && !structFieldTypeIsPointer && cfg.OmitEmbeddedStructs {
getter += "&"
} else if !interfaceFieldTypeIsPointer && structFieldTypeIsPointer {
getter += "*"
Expand All @@ -290,7 +250,7 @@ func (m *Plugin) MutateConfig(cfg *config.Config) error {
}
getter := fmt.Sprintf("func (this %s) Get%s() %s { return ", templates.ToGo(model.Name), field.GoName, goType)

if interfaceFieldTypeIsPointer && !structFieldTypeIsPointer {
if interfaceFieldTypeIsPointer && !structFieldTypeIsPointer && cfg.OmitEmbeddedStructs {
getter += "&"
} else if !interfaceFieldTypeIsPointer && structFieldTypeIsPointer {
getter += "*"
Expand Down Expand Up @@ -328,11 +288,52 @@ func (m *Plugin) MutateConfig(cfg *config.Config) error {
return nil
}

func (m *Plugin) generateFields(cfg *config.Config, schemaType *ast.Definition) ([]*Field, error) {
func (m *Plugin) generateFields(cfg *config.Config, schemaType *ast.Definition, model *ModelBuild) ([]*Field, error) {
// func (m *Plugin) generateFields(cfg *config.Config, schemaType *ast.Definition) ([]*Field, error) {
binder := cfg.NewBinder()
fields := make([]*Field, 0)
embeddedFields := map[string]*Field{}

for _, field := range schemaType.Fields {
if model != nil && !cfg.OmitEmbeddedStructs && schemaType.Kind == ast.Object && len(schemaType.Interfaces) > 0 {
interfaceHasField := false
interfaceName := ""

for _, iface := range schemaType.Interfaces {
// We skip the node interface as it should be present in all interfaces implementing it, and it
// creates an issue with ambiguity when referencing `ID`.
if iface == "Node" {
continue
}

for _, modelInterface := range model.Interfaces {
if modelInterface.Name == iface {
for _, mField := range modelInterface.Fields {
if field.Name == mField.Name {
interfaceHasField = true
interfaceName = iface
break
}
}
}
}
}

if interfaceHasField {
embeddedField := &Field{
Type: types.NewNamed(
types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(fmt.Sprintf("%s%s", cfg.EmbeddedStructsPrefix, interfaceName)), nil),
types.NewStruct(nil, nil),
nil,
),
}

embeddedFields[interfaceName] = embeddedField

continue
}
}

f, err := m.generateField(cfg, binder, schemaType, field)
if err != nil {
return nil, err
Expand All @@ -345,6 +346,10 @@ func (m *Plugin) generateFields(cfg *config.Config, schemaType *ast.Definition)
fields = append(fields, f)
}

for _, field := range embeddedFields {
fields = append(fields, field)
}

fields = append(fields, getExtraFields(cfg, schemaType.Name)...)

return fields, nil
Expand Down Expand Up @@ -723,3 +728,82 @@ func readModelTemplate(customModelTemplate string) string {
}
return string(contentBytes)
}

func (m *Plugin) getInterface(cfg *config.Config, schemaType *ast.Definition, b *ModelBuild) (*Interface, error) {
var fields []*Field
var err error
if !cfg.OmitGetters {
fields, err = m.generateFields(cfg, schemaType, nil)
if err != nil {
return nil, err
}
}

if !cfg.OmitEmbeddedStructs {
it := &Object{
Description: schemaType.Description,
// To not conflict with the interface name, we prefix the struct name with "Base"
Name: fmt.Sprintf("%s%s", cfg.EmbeddedStructsPrefix, schemaType.Name),
Fields: fields,
}

b.Models = append(b.Models, it)
}

it := &Interface{
Description: schemaType.Description,
Name: schemaType.Name,
Implements: schemaType.Interfaces,
Fields: fields,
OmitCheck: cfg.OmitInterfaceChecks,
}

// if the interface has a key directive as an entity interface, allow it to implement _Entity
if schemaType.Directives.ForName("key") != nil {
it.Implements = append(it.Implements, "_Entity")
}

return it, nil
}

func (m *Plugin) getObject(cfg *config.Config, schemaType *ast.Definition, b *ModelBuild) (*Object, error) {
if cfg.IsRoot(schemaType) {
if !cfg.OmitRootModels {
return &Object{
Description: schemaType.Description,
Name: schemaType.Name,
}, nil
}
}

fields, err := m.generateFields(cfg, schemaType, b)
if err != nil {
return nil, err
}

it := &Object{
Description: schemaType.Description,
Name: schemaType.Name,
Fields: fields,
}

// If Interface A implements interface B, and Interface C also implements interface B
// then both A and C have methods of B.
// The reason for checking unique is to prevent the same method B from being generated twice.
uniqueMap := map[string]bool{}
for _, implementor := range cfg.Schema.GetImplements(schemaType) {
if !uniqueMap[implementor.Name] {
it.Implements = append(it.Implements, implementor.Name)
uniqueMap[implementor.Name] = true
}
// for interface implements
for _, iface := range implementor.Interfaces {
if !uniqueMap[iface] {
it.Implements = append(it.Implements, iface)
uniqueMap[iface] = true
}
}
}

return it, nil
}

0 comments on commit 384d4a0

Please sign in to comment.