diff --git a/cmd/dex/config.go b/cmd/dex/config.go index eba919fa91..3e445fe682 100644 --- a/cmd/dex/config.go +++ b/cmd/dex/config.go @@ -35,6 +35,9 @@ type Config struct { // Write operations, like updating a connector, will fail. StaticConnectors []Connector `json:"connectors"` + // StaticMiddleware are global middleware specified in the ConfigMap. + StaticMiddleware []Middleware `json:"middleware"` + // StaticClients cause the server to use this list of clients rather than // querying the storage. Write operations, like creating a client, will fail. StaticClients []storage.Client `json:"staticClients"` @@ -218,6 +221,8 @@ type Connector struct { ID string `json:"id"` Config server.ConnectorConfig `json:"config"` + + Middleware []Middleware `json:"middleware"` } // UnmarshalJSON allows Connector to implement the unmarshaler interface to @@ -229,6 +234,8 @@ func (c *Connector) UnmarshalJSON(b []byte) error { ID string `json:"id"` Config json.RawMessage `json:"config"` + + Middleware []Middleware `json:"middleware"` } if err := json.Unmarshal(b, &conn); err != nil { return fmt.Errorf("parse connector: %v", err) @@ -246,10 +253,11 @@ func (c *Connector) UnmarshalJSON(b []byte) error { } } *c = Connector{ - Type: conn.Type, - Name: conn.Name, - ID: conn.ID, - Config: connConfig, + Type: conn.Type, + Name: conn.Name, + ID: conn.ID, + Config: connConfig, + Middleware: conn.Middleware, } return nil } @@ -261,10 +269,70 @@ func ToStorageConnector(c Connector) (storage.Connector, error) { return storage.Connector{}, fmt.Errorf("failed to marshal connector config: %v", err) } + mwares := make([]storage.Middleware, len(c.Middleware)) + for n, mware := range c.Middleware { + mwares[n], err = ToStorageMiddleware(mware) + if err != nil { + return storage.Connector{}, fmt.Errorf("failed to marshal connector middleware: %v", err) + } + } + return storage.Connector{ - ID: c.ID, - Type: c.Type, - Name: c.Name, + ID: c.ID, + Type: c.Type, + Name: c.Name, + Config: data, + Middleware: mwares, + }, nil +} + +// Middleware is another magic type, like Connector, that can unmarshal YAML +// dynamically. The Type field determines the middleware type, which is then +// customized for Config. +type Middleware struct { + Type string `json:"type"` + + Config server.MiddlewareConfig `json:"config"` +} + +// UnmarshalJSON allows Connector to implement the unmarshaler interface to +// dynamically determine the type of the middleware config. +func (m *Middleware) UnmarshalJSON(b []byte) error { + var mware struct { + Type string `json:"type"` + + Config json.RawMessage `json:"config"` + } + if err := json.Unmarshal(b, &mware); err != nil { + return fmt.Errorf("parse middleware: %v", err) + } + f, ok := server.MiddlewaresConfig[mware.Type] + if !ok { + return fmt.Errorf("unknown middleware type %q", mware.Type) + } + + mwareConfig := f() + if len(mware.Config) != 0 { + if err := json.Unmarshal(mware.Config, mwareConfig); err != nil { + return fmt.Errorf("parse middleware config: %v", err) + } + } + *m = Middleware{ + Type: mware.Type, + Config: mwareConfig, + } + return nil +} + +// ToStorageMiddleware converts an object to storage middleware type. +func ToStorageMiddleware(m Middleware) (storage.Middleware, error) { + data, err := json.Marshal(m.Config) + if err != nil { + return storage.Middleware{}, fmt.Errorf("failed to marshal middleware config: %v", err) + } + + return storage.Middleware{ + Type: m.Type, Config: data, }, nil } diff --git a/cmd/dex/serve.go b/cmd/dex/serve.go index ca74059366..6431c88ed4 100644 --- a/cmd/dex/serve.go +++ b/cmd/dex/serve.go @@ -216,6 +216,28 @@ func serve(cmd *cobra.Command, args []string) error { s = storage.WithStaticConnectors(s, storageConnectors) + storageMiddleware := make([]storage.Middleware, len(c.StaticMiddleware)) + for i, m := range c.StaticMiddleware { + if m.Type == "" { + return fmt.Errorf("invalid config: Type field is required for a middleware") + } + if m.Config == nil { + return fmt.Errorf("invalid config: no config field for middleware %d (%s)", i, m.Type) + } + logger.Infof("config middleware: %d (%s)", i, m.Type) + + // convert to a storage middleware object + mware, err := ToStorageMiddleware(m) + if err != nil { + return fmt.Errorf("failed to initialize storage middleware: %v", err) + } + storageMiddleware[i] = mware + } + + if len(storageMiddleware) > 0 { + s = storage.WithStaticMiddleware(s, storageMiddleware) + } + if len(c.OAuth2.ResponseTypes) > 0 { logger.Infof("config response types accepted: %s", c.OAuth2.ResponseTypes) } diff --git a/middleware/groups/groups.go b/middleware/groups/groups.go new file mode 100644 index 0000000000..1ddcf80deb --- /dev/null +++ b/middleware/groups/groups.go @@ -0,0 +1,186 @@ +// Package groups implements support for manipulating groups claims. +package groups + +import ( + "context" + "fmt" + "regexp" + "sort" + "strings" + + "github.com/dexidp/dex/connector" + "github.com/dexidp/dex/middleware" + + "github.com/dexidp/dex/pkg/log" +) + +// Config holds the configuration parameters for the groups middleware. +// The groups middleware provides the ability to filter and prefix groups +// returned by things further up the chain. +// +// An example config: +// +// type: groups +// config: +// actions: +// - discard: "^admin$" +// - replace: +// pattern: "\s+" +// with: " " +// - stripPrefix: "foo/" +// - addPrefix: "ldap/" +// inject: +// - DexUsers +// - Users +// sorted: true +// unique: true +// +type Config struct { + // A list of actions to perform on each group name + Actions []Action `json:"actions,omitempty"` + + // Additional groups to inject + Inject []string `json:"inject,omitempty"` + + // If true, sort the resulting list of groups + Sorted bool `json:"sorted,omitempty"` + + // If true, ensure that each group is listed at most once + Unique bool `json:"unique,omitempty"` +} + +// Replaces matches of the regular expression Pattern with With +type ReplaceAction struct { + Pattern string `json:"pattern"` + With string `json:"with"` +} + +// An action +type Action struct { + // Discards groups whose names match a regexp + Discard string `json:"discard,omitempty"` + + // Replace regexp matches in a group name + Replace *ReplaceAction `json:"replace,omitempty"` + + // Remove a prefix string from a group name + StripPrefix string `json:"stripPrefix,omitempty"` + + // Add a prefix string to a group name + AddPrefix string `json:"addPrefix,omitempty"` +} + +// The actual Middleware object uses these instead, so that we can pre-compile +// the regular expressions +type replaceAction struct { + Regexp *regexp.Regexp + With string +} + +type action struct { + Discard *regexp.Regexp + Replace *replaceAction + StripPrefix string + AddPrefix string +} + +// Open returns a groups Middleware +func (c *Config) Open(logger log.Logger) (middleware.Middleware, error) { + // Compile the regular expressions + actions := make([]action, len(c.Actions)) + for n, a := range c.Actions { + actions[n] = action{ + StripPrefix: a.StripPrefix, + AddPrefix: a.AddPrefix, + } + + if a.Discard != "" { + re, err := regexp.Compile(a.Discard) + if err != nil { + return nil, fmt.Errorf("groups: unable to compile discard regexp %q: %v", a.Discard, err) + } + actions[n].Discard = re + } + + if a.Replace != nil { + re, err := regexp.Compile(a.Replace.Pattern) + if err != nil { + return nil, fmt.Errorf("groups: unable to compile replace regexp %q: %v", a.Replace.Pattern, err) + } + actions[n].Replace = &replaceAction{ + Regexp: re, + With: a.Replace.With, + } + } + } + + return &groupsMiddleware{Config: *c, CompiledActions: actions}, nil +} + +type groupsMiddleware struct { + Config + + CompiledActions []action +} + +// Apply the actions to the groups in the incoming identity +func (g *groupsMiddleware) Process(ctx context.Context, identity connector.Identity) (connector.Identity, error) { + groupSet := map[string]struct{}{} + exists := struct{}{} + newGroups := []string{} + + if g.Unique && g.Inject != nil { + for _, group := range g.Inject { + groupSet[group] = exists + } + } + + for _, group := range identity.Groups { + discard := false + + for _, action := range g.CompiledActions { + if action.Discard != nil { + if action.Discard.MatchString(group) { + discard = true + break + } + } + + if action.Replace != nil { + group = action.Replace.Regexp.ReplaceAllString(group, + action.Replace.With) + } + + if action.StripPrefix != "" { + group = strings.TrimPrefix(group, action.StripPrefix) + } + + if action.AddPrefix != "" { + group = action.AddPrefix + group + } + } + + if !discard && g.Unique { + _, discard = groupSet[group] + if !discard { + groupSet[group] = exists + } + } + + if !discard { + newGroups = append(newGroups, group) + } + } + + if g.Inject != nil { + newGroups = append(newGroups, g.Inject...) + } + + if g.Sorted { + sort.Strings(newGroups) + } + + identity.Groups = newGroups + + return identity, nil +} diff --git a/middleware/groups/groups_test.go b/middleware/groups/groups_test.go new file mode 100644 index 0000000000..d1f8737bcc --- /dev/null +++ b/middleware/groups/groups_test.go @@ -0,0 +1,832 @@ +package groups + +import ( + "context" + "io/ioutil" + "testing" + + "github.com/kylelemons/godebug/pretty" + "github.com/sirupsen/logrus" + + "github.com/dexidp/dex/connector" +) + +type subtest struct { + // Name of the sub-test + name string + + // Input identity + input connector.Identity + + // Output + wantErr bool + want connector.Identity +} + +func TestDiscard(t *testing.T) { + c := &Config{ + Actions: []Action{ + { + Discard: `\bd[a-z]+d\b`, + }, + }, + } + + tests := []subtest{ + { + name: "nogroups", + input: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{}, + }, + want: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{}, + }, + }, + { + name: "nodiscard", + input: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{"test"}, + }, + want: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{"test"}, + }, + }, + { + name: "discard1", + input: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{"should discard this"}, + }, + want: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{}, + }, + }, + { + name: "discard1mid", + input: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{"test", "should discard this", "test2"}, + }, + want: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{"test", "test2"}, + }, + }, + { + name: "discard1start", + input: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{"should discard this", "test", "test2"}, + }, + want: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{"test", "test2"}, + }, + }, + { + name: "discard1end", + input: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{"test", "test2", "should discard this"}, + }, + want: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{"test", "test2"}, + }, + }, + } + + runTests(t, c, tests) +} + +func TestReplace(t *testing.T) { + c := &Config{ + Actions: []Action{ + { + Replace: &ReplaceAction{ + Pattern: "(dog|hound)", + With: "cat", + }, + }, + }, + } + + tests := []subtest{ + { + name: "nogroups", + input: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{}, + }, + want: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{}, + }, + }, + { + name: "noreplace", + input: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{"test"}, + }, + want: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{"test"}, + }, + }, + { + name: "replace1", + input: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{"dog lovers"}, + }, + want: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{"cat lovers"}, + }, + }, + { + name: "replace2", + input: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{"hound enthusiasts", "dog lovers"}, + }, + want: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{"cat enthusiasts", "cat lovers"}, + }, + }, + } + + runTests(t, c, tests) +} + +func TestStripPrefix(t *testing.T) { + c := &Config{ + Actions: []Action{ + { + StripPrefix: "fun/", + }, + }, + } + + tests := []subtest{ + { + name: "nogroups", + input: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{}, + }, + want: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{}, + }, + }, + { + name: "one", + input: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{"fun/test"}, + }, + want: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{"test"}, + }, + }, + { + name: "two", + input: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{"fun/dogs", "fun/cats"}, + }, + want: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{"dogs", "cats"}, + }, + }, + } + + runTests(t, c, tests) +} + +func TestAddPrefix(t *testing.T) { + c := &Config{ + Actions: []Action{ + { + AddPrefix: "fun/", + }, + }, + } + + tests := []subtest{ + { + name: "nogroups", + input: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{}, + }, + want: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{}, + }, + }, + { + name: "one", + input: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{"test"}, + }, + want: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{"fun/test"}, + }, + }, + { + name: "two", + input: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{"dogs", "cats"}, + }, + want: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{"fun/dogs", "fun/cats"}, + }, + }, + } + + runTests(t, c, tests) +} + +func TestSorting(t *testing.T) { + c := &Config{ + Sorted: true, + } + + tests := []subtest{ + { + name: "nogroups", + input: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{}, + }, + want: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{}, + }, + }, + { + name: "one", + input: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{"test"}, + }, + want: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{"test"}, + }, + }, + { + name: "two", + input: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{"dogs", "cats"}, + }, + want: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{"cats", "dogs"}, + }, + }, + { + name: "three", + input: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{"dogs", "rabbits", "cats"}, + }, + want: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{"cats", "dogs", "rabbits"}, + }, + }, + } + + runTests(t, c, tests) +} + +func TestUniquing(t *testing.T) { + c := &Config{ + Unique: true, + } + + tests := []subtest{ + { + name: "nogroups", + input: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{}, + }, + want: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{}, + }, + }, + { + name: "one", + input: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{"test"}, + }, + want: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{"test"}, + }, + }, + { + name: "two", + input: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{"dogs", "dogs", "cats"}, + }, + want: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{"dogs", "cats"}, + }, + }, + { + name: "three", + input: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{"cats", "dogs", "rabbits", "rabbits", "cats"}, + }, + want: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{"cats", "dogs", "rabbits"}, + }, + }, + } + + runTests(t, c, tests) +} + +func TestInject(t *testing.T) { + c := &Config{ + Inject: []string{ + "birds", "bees", "butterflies", + }, + } + + tests := []subtest{ + { + name: "nogroups", + input: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{}, + }, + want: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{"birds", "bees", "butterflies"}, + }, + }, + { + name: "one", + input: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{"test"}, + }, + want: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{"test", "birds", "bees", "butterflies"}, + }, + }, + { + name: "two", + input: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{"dogs", "cats"}, + }, + want: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{ + "dogs", "cats", "birds", "bees", "butterflies", + }, + }, + }, + } + + runTests(t, c, tests) +} + +func TestInjectUnique(t *testing.T) { + c := &Config{ + Inject: []string{ + "birds", "bees", "butterflies", + }, + Unique: true, + } + + tests := []subtest{ + { + name: "nogroups", + input: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{}, + }, + want: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{"birds", "bees", "butterflies"}, + }, + }, + { + name: "one", + input: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{"bees"}, + }, + want: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{"birds", "bees", "butterflies"}, + }, + }, + { + name: "two", + input: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{"birds", "bats", "butterflies"}, + }, + want: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{ + "bats", "birds", "bees", "butterflies", + }, + }, + }, + } + + runTests(t, c, tests) +} + +func TestActionUnique(t *testing.T) { + c := &Config{ + Actions: []Action{ + { + Replace: &ReplaceAction{ + Pattern: `\b(cats|dogs|rabbits)\b`, + With: "birds", + }, + }, + }, + Unique: true, + } + + tests := []subtest{ + { + name: "nogroups", + input: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{}, + }, + want: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{}, + }, + }, + { + name: "one", + input: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{"cats", "birds"}, + }, + want: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{"birds"}, + }, + }, + { + name: "two", + input: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{"dogs", "rabbits", "birds"}, + }, + want: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{"birds"}, + }, + }, + } + + runTests(t, c, tests) +} + +func TestMulti(t *testing.T) { + c := &Config{ + Actions: []Action{ + { + Discard: "^admin$", + }, + { + StripPrefix: "foobar/", + }, + { + Replace: &ReplaceAction{ + Pattern: `\b(cats|dogs|rabbits)\b`, + With: "birds", + }, + }, + { + AddPrefix: "foo/", + }, + }, + Sorted: true, + Unique: true, + } + + tests := []subtest{ + { + name: "nogroups", + input: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{}, + }, + want: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{}, + }, + }, + { + name: "multi", + input: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{ + "alternative", + "admin", + "verbose", + "foobar/frobble", + "cats", + "rabbits", + "foobar/dogs", + "angry cats", + }, + }, + want: connector.Identity{ + UserID: "test", + Username: "test", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{ + "foo/alternative", + "foo/angry birds", + "foo/birds", + "foo/frobble", + "foo/verbose", + }, + }, + }, + } + + runTests(t, c, tests) +} + +func runTests(t *testing.T, config *Config, tests []subtest) { + l := &logrus.Logger{Out: ioutil.Discard, Formatter: &logrus.TextFormatter{}} + ctx := context.Background() + + mware, err := config.Open(l) + if err != nil { + t.Errorf("open middleware: %v", err) + } + + for _, test := range tests { + if test.name == "" { + t.Fatal("subtest has no name") + } + + t.Run(test.name, func(t *testing.T) { + got, err := mware.Process(ctx, test.input) + if err != nil { + if !test.wantErr { + t.Fatalf("middleware failed: %v", err) + } + return + } + if test.wantErr { + t.Fatal("middleware should have failed") + } + + if diff := pretty.Compare(test.want, got); diff != "" { + t.Error(diff) + return + } + }) + } +} diff --git a/middleware/middleware.go b/middleware/middleware.go new file mode 100644 index 0000000000..bd655870d2 --- /dev/null +++ b/middleware/middleware.go @@ -0,0 +1,25 @@ +// Package middleware defines interfaces for pluggable identity middleware. +package middleware + +import ( + "context" + + "github.com/dexidp/dex/connector" +) + +// Middleware is a mechanism for allowing customisation of responses returned +// by a remote identity service. +// +// Each configured connector can have a stack of Middleware components; when +// the connector returns successfully with an Identity, this will be passed to +// the Middleware at the top of the stack, which can inspect the identity and +// take any required actions. Assuming that Middleware component succeeds +// and returns an Identity, the returned identity will be passed to the next +// Middleware in the chain until all middleware modules have processed the +// identity. +// +// Once the connector specific middleware has finished executing, there is also +// a global middleware chain that runs on the results. +type Middleware interface { + Process(ctx context.Context, identity connector.Identity) (connector.Identity, error) +} diff --git a/server/handlers.go b/server/handlers.go index 5a7244faaa..282a4651f4 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -397,6 +397,12 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { } return } + identity, err = s.RunMiddleware(r.Context(), conn, identity) + if err != nil { + s.logger.Errorf("Failed to run middleware for login: %v", err) + s.renderError(r, w, http.StatusInternalServerError, fmt.Sprintf("Login error: %v", err)) + return + } redirectURL, err := s.finalizeLogin(identity, authReq, conn.Connector) if err != nil { s.logger.Errorf("Failed to finalize login: %v", err) @@ -480,6 +486,13 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request) return } + identity, err = s.RunMiddleware(r.Context(), conn, identity) + if err != nil { + s.logger.Error("Failed to run middleware for login: %v", err) + s.renderError(r, w, http.StatusInternalServerError, "Login error.") + return + } + redirectURL, err := s.finalizeLogin(identity, authReq, conn.Connector) if err != nil { s.logger.Errorf("Failed to finalize login: %v", err) @@ -1116,7 +1129,14 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) return } - ident = newIdent + + // If we did a Refresh, run middleware on the result + ident, err = s.RunMiddleware(r.Context(), conn, newIdent) + if err != nil { + s.logger.Errorf("failed to run middleware for refresh: %v", err) + s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) + return + } } claims := storage.Claims{ @@ -1309,6 +1329,12 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli return } + identity, err = s.RunMiddleware(r.Context(), conn, identity) + if err != nil { + s.tokenErrHelper(w, errServerError, fmt.Sprintf("Login error: %v", err), http.StatusInternalServerError) + return + } + // Build the claims to send the id token claims := storage.Claims{ UserID: identity.UserID, diff --git a/server/middleware.go b/server/middleware.go new file mode 100644 index 0000000000..98fc11d4f5 --- /dev/null +++ b/server/middleware.go @@ -0,0 +1,102 @@ +package server + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/dexidp/dex/storage" + + "github.com/dexidp/dex/middleware" + "github.com/dexidp/dex/middleware/groups" + + "github.com/dexidp/dex/connector" + "github.com/dexidp/dex/pkg/log" +) + +// Middleware is a middleware with resource version metadata. +type Middleware struct { + ResourceVersion string + Middleware middleware.Middleware +} + +// MiddlewareConfig is a configuration that can open a middleware. +type MiddlewareConfig interface { + Open(logger log.Logger) (middleware.Middleware, error) +} + +// MiddlewaresConfig variable provides an easy way to return a config struct +// depending on the middleware type. +var MiddlewaresConfig = map[string]func() MiddlewareConfig{ + "groups": func() MiddlewareConfig { return new(groups.Config) }, +} + +// openMiddleware will parse the middleware config and open the middleware. +func openMiddleware(logger log.Logger, mware storage.Middleware) (middleware.Middleware, error) { + var m middleware.Middleware + + f, ok := MiddlewaresConfig[mware.Type] + if !ok { + return m, fmt.Errorf("unknown middleware type %q", mware.Type) + } + + mwareConfig := f() + if len(mware.Config) != 0 { + if err := json.Unmarshal(mware.Config, mwareConfig); err != nil { + return m, fmt.Errorf("parse middleware config: %v", err) + } + } + + m, err := mwareConfig.Open(logger) + if err != nil { + return m, fmt.Errorf("failed to create middleware %q: %v", mware.Type, err) + } + + return m, nil +} + +func (s *Server) OpenMiddleware(mware storage.Middleware) (Middleware, error) { + m, err := openMiddleware(s.logger, mware) + if err != nil { + return Middleware{}, fmt.Errorf("failed to open middleware: %v", err) + } + + middleware := Middleware{ + ResourceVersion: mware.ResourceVersion, + Middleware: m, + } + + s.mu.Lock() + s.middleware = append(s.middleware, middleware) + s.mu.Unlock() + + return middleware, nil +} + +// RunMiddleware executes the middleware for the specified connector, followed +// by the global middleware. +func (s *Server) RunMiddleware(ctx context.Context, conn Connector, identity connector.Identity) (connector.Identity, error) { + var err error + + // First, run the connector middleware + for _, mware := range conn.Middleware { + identity, err = mware.Process(ctx, identity) + if err != nil { + return identity, err + } + } + + // Grab a copy of the global middleware + s.mu.Lock() + middleware := s.middleware + s.mu.Unlock() + + for _, mware := range middleware { + identity, err = mware.Middleware.Process(ctx, identity) + if err != nil { + return identity, err + } + } + + return identity, err +} diff --git a/server/server.go b/server/server.go index c37b4fdc5e..bd1d9ca987 100644 --- a/server/server.go +++ b/server/server.go @@ -37,6 +37,9 @@ import ( "github.com/dexidp/dex/connector/oidc" "github.com/dexidp/dex/connector/openshift" "github.com/dexidp/dex/connector/saml" + + "github.com/dexidp/dex/middleware" + "github.com/dexidp/dex/pkg/log" "github.com/dexidp/dex/storage" ) @@ -49,6 +52,7 @@ const LocalConnector = "local" type Connector struct { ResourceVersion string Connector connector.Connector + Middleware []middleware.Middleware } // Config holds the server's configuration options. @@ -139,6 +143,9 @@ type Server struct { // Map of connector IDs to connectors. connectors map[string]Connector + // Global middleware + middleware []Middleware + storage storage.Storage mux http.Handler @@ -225,6 +232,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) s := &Server{ issuerURL: *issuerURL, connectors: make(map[string]Connector), + middleware: []Middleware{}, storage: newKeyCacher(c.Storage, now), supportedResponseTypes: supported, idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour), @@ -550,7 +558,17 @@ func (s *Server) OpenConnector(conn storage.Connector) (Connector, error) { connector := Connector{ ResourceVersion: conn.ResourceVersion, Connector: c, + Middleware: make([]middleware.Middleware, len(conn.Middleware)), } + + for n, mware := range conn.Middleware { + var err error + connector.Middleware[n], err = openMiddleware(s.logger, mware) + if err != nil { + return Connector{}, fmt.Errorf("failed to open connector middleware: %v", err) + } + } + s.mu.Lock() s.connectors[conn.ID] = connector s.mu.Unlock() diff --git a/storage/conformance/conformance.go b/storage/conformance/conformance.go index dd2083ae86..99718392c7 100644 --- a/storage/conformance/conformance.go +++ b/storage/conformance/conformance.go @@ -47,6 +47,7 @@ func RunTests(t *testing.T, newStorage func() storage.Storage) { {"KeysCRUD", testKeysCRUD}, {"OfflineSessionCRUD", testOfflineSessionCRUD}, {"ConnectorCRUD", testConnectorCRUD}, + {"MiddlewareCRUD", testMiddlewareCRUD}, {"GarbageCollection", testGC}, {"TimezoneSupport", testTimezones}, {"DeviceRequestCRUD", testDeviceRequestCRUD}, @@ -80,6 +81,15 @@ func mustBeErrAlreadyExists(t *testing.T, kind string, err error) { } } +func mustBeErrOutOfRange(t *testing.T, kind string, err error) { + switch { + case err == nil: + t.Errorf("expected ErrOutOfRange, got nil. (kind %q)", kind) + case err != storage.ErrOutOfRange: + t.Errorf("expected ErrOutOfRange, got %v. (kind %q)", kind, err) + } +} + func testAuthRequestCRUD(t *testing.T, s storage.Storage) { codeChallenge := storage.PKCE{ CodeChallenge: "code_challenge_test", @@ -687,6 +697,95 @@ func testConnectorCRUD(t *testing.T, s storage.Storage) { mustBeErrNotFound(t, "connector", err) } +func testMiddlewareCRUD(t *testing.T, s storage.Storage) { + emptyConfig := []byte(`{}`) + m1 := storage.Middleware{ + Type: "groups", + Config: emptyConfig, + } + + if err := s.InsertMiddleware(0, m1); err != nil { + t.Fatalf("insert middleware at start: %v", err) + } + + m2 := storage.Middleware{ + Type: "groups", + Config: emptyConfig, + } + + err := s.InsertMiddleware(5, m2) + mustBeErrOutOfRange(t, "middleware", err) + if err := s.InsertMiddleware(-1, m2); err != nil { + t.Fatalf("insert middleware at end: %v", err) + } + + m3 := storage.Middleware{ + Type: "groups", + Config: emptyConfig, + } + + if err := s.InsertMiddleware(1, m3); err != nil { + t.Fatalf("insert middleware in middle: %v", err) + } + + getAndCompare := func(ndx int, want storage.Middleware) { + gr, err := s.GetMiddleware(ndx) + if err != nil { + t.Errorf("get middleware: %v", err) + return + } + // ignore resource version comparison + gr.ResourceVersion = "" + if diff := pretty.Compare(want, gr); diff != "" { + t.Errorf("middleware retrieved from storage did not match: %s", diff) + } + } + + getAndCompare(0, m1) + + if err := s.UpdateMiddleware(0, func(old storage.Middleware) (storage.Middleware, error) { + old.Config = []byte(`{"inject": ["Test"]}`) + return old, nil + }); err != nil { + t.Fatalf("failed to update Middleware: %v", err) + } + + m1.Config = []byte(`{"inject": ["Test"]}`) + getAndCompare(0, m1) + + middlewareList := []storage.Middleware{m1, m3, m2} + listAndCompare := func(want []storage.Middleware) { + middlewares, err := s.ListMiddleware() + if err != nil { + t.Errorf("list middlewares: %v", err) + return + } + // ignore resource version comparison + for i := range middlewares { + middlewares[i].ResourceVersion = "" + } + if diff := pretty.Compare(want, middlewares); diff != "" { + t.Errorf("middleware list retrieved from storage did not match: %s", diff) + } + } + listAndCompare(middlewareList) + + if err := s.DeleteMiddleware(0); err != nil { + t.Fatalf("failed to delete middleware: %v", err) + } + + if err := s.DeleteMiddleware(1); err != nil { + t.Fatalf("failed to delete middleware: %v", err) + } + + if err := s.DeleteMiddleware(0); err != nil { + t.Fatalf("failed to delete middleware: %v", err) + } + + _, err = s.GetMiddleware(0) + mustBeErrOutOfRange(t, "middleware", err) +} + func testKeysCRUD(t *testing.T, s storage.Storage) { updateAndCompare := func(k storage.Keys) { err := s.UpdateKeys(func(oldKeys storage.Keys) (storage.Keys, error) { diff --git a/storage/etcd/etcd.go b/storage/etcd/etcd.go index e8abe3d08f..e3b1295718 100644 --- a/storage/etcd/etcd.go +++ b/storage/etcd/etcd.go @@ -24,6 +24,7 @@ const ( keysName = "openid-connect-keys" deviceRequestPrefix = "device_req/" deviceTokenPrefix = "device_token/" + middlewareName = "middleware" // defaultStorageTimeout will be applied to all storage's operations. defaultStorageTimeout = 5 * time.Second @@ -637,3 +638,131 @@ func (c *conn) UpdateDeviceToken(deviceCode string, updater func(old storage.Dev return json.Marshal(fromStorageDeviceToken(updated)) }) } + +func (c *conn) InsertMiddleware(ndx int, m storage.Middleware) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) + defer cancel() + return c.txnUpdate(ctx, middlewareName, func(currentValue []byte) ([]byte, error) { + var middleware []storage.Middleware + if len(currentValue) > 0 { + if err := json.Unmarshal(currentValue, &middleware); err != nil { + return nil, err + } + } + + if ndx == -1 { + ndx = len(middleware) + } + + if ndx < 0 || ndx > len(middleware) { + return currentValue, storage.ErrOutOfRange + } + + if ndx == len(middleware) { + middleware = append(middleware, m) + } else { + last := len(middleware) - 1 + middleware = append(middleware, middleware[last]) + copy(middleware[ndx+1:], middleware[ndx:last]) + middleware[ndx] = m + } + + return json.Marshal(middleware) + }) +} + +func (c *conn) GetMiddleware(ndx int) (storage.Middleware, error) { + ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) + defer cancel() + var middleware []storage.Middleware + res, err := c.db.Get(ctx, middlewareName) + if err != nil { + return storage.Middleware{}, err + } + if res.Count == 0 || len(res.Kvs) == 0 { + return storage.Middleware{}, storage.ErrOutOfRange + } + + err = json.Unmarshal(res.Kvs[0].Value, &middleware) + if err != nil { + return storage.Middleware{}, err + } + + if ndx < 0 || ndx >= len(middleware) { + return storage.Middleware{}, storage.ErrOutOfRange + } + + return middleware[ndx], nil +} + +func (c *conn) ListMiddleware() ([]storage.Middleware, error) { + ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) + defer cancel() + var middleware []storage.Middleware + res, err := c.db.Get(ctx, middlewareName) + if err != nil { + return []storage.Middleware{}, err + } + if res.Count == 0 || len(res.Kvs) == 0 { + return []storage.Middleware{}, nil + } + + err = json.Unmarshal(res.Kvs[0].Value, &middleware) + + return middleware, err +} + +func (c *conn) DeleteMiddleware(ndx int) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) + defer cancel() + return c.txnUpdate(ctx, middlewareName, func(currentValue []byte) ([]byte, error) { + var middleware []storage.Middleware + if len(currentValue) > 0 { + if err := json.Unmarshal(currentValue, &middleware); err != nil { + return nil, err + } + } + + if ndx == -1 { + ndx = len(middleware) + } + + if ndx < 0 || ndx >= len(middleware) { + return currentValue, storage.ErrOutOfRange + } + + copy(middleware[ndx:], middleware[ndx+1:]) + middleware = middleware[:len(middleware)-1] + + return json.Marshal(middleware) + }) +} + +func (c *conn) UpdateMiddleware(ndx int, updater func(m storage.Middleware) (storage.Middleware, error)) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) + defer cancel() + return c.txnUpdate(ctx, middlewareName, func(currentValue []byte) ([]byte, error) { + var middleware []storage.Middleware + if len(currentValue) > 0 { + if err := json.Unmarshal(currentValue, &middleware); err != nil { + return nil, err + } + } + + if ndx == -1 { + ndx = len(middleware) + } + + if ndx < 0 || ndx >= len(middleware) { + return currentValue, storage.ErrOutOfRange + } + + newMware, err := updater(middleware[ndx]) + if err != nil { + return currentValue, err + } + + middleware[ndx] = newMware + return json.Marshal(middleware) + }) +} diff --git a/storage/kubernetes/storage.go b/storage/kubernetes/storage.go index c0d6eb91f7..10fbc0902a 100644 --- a/storage/kubernetes/storage.go +++ b/storage/kubernetes/storage.go @@ -24,6 +24,7 @@ const ( kindConnector = "Connector" kindDeviceRequest = "DeviceRequest" kindDeviceToken = "DeviceToken" + kindMiddlewareList = "MiddlewareList" ) const ( @@ -37,6 +38,7 @@ const ( resourceConnector = "connectors" resourceDeviceRequest = "devicerequests" resourceDeviceToken = "devicetokens" + resourceMiddlewareList = "middlewarelists" ) // Config values for the Kubernetes storage type. @@ -248,6 +250,54 @@ func (cli *client) CreateConnector(c storage.Connector) error { return cli.post(resourceConnector, cli.fromStorageConnector(c)) } +func (cli *client) InsertMiddleware(ndx int, m storage.Middleware) error { + firstUpdate := false + var mwareList MiddlewareList + if err := cli.get(resourceMiddlewareList, middlewareName, &mwareList); err != nil { + if err != storage.ErrNotFound { + return err + } + firstUpdate = true + mwareList = MiddlewareList{ + TypeMeta: k8sapi.TypeMeta{ + Kind: kindMiddlewareList, + APIVersion: cli.apiVersion, + }, + ObjectMeta: k8sapi.ObjectMeta{ + Name: middlewareName, + Namespace: cli.namespace, + }, + Middleware: []Middleware{}, + } + } + + if ndx == -1 { + ndx = len(mwareList.Middleware) + } + + if ndx < 0 || ndx > len(mwareList.Middleware) { + return storage.ErrOutOfRange + } + + if ndx == len(mwareList.Middleware) { + mwareList.Middleware = append(mwareList.Middleware, + cli.fromStorageMiddleware(m)) + } else { + last := len(mwareList.Middleware) - 1 + mwareList.Middleware = append(mwareList.Middleware, + mwareList.Middleware[last]) + copy(mwareList.Middleware[ndx+1:], + mwareList.Middleware[ndx:last]) + mwareList.Middleware[ndx] = cli.fromStorageMiddleware(m) + } + + if firstUpdate { + return cli.post(resourceMiddlewareList, mwareList) + } + + return cli.put(resourceMiddlewareList, middlewareName, mwareList) +} + func (cli *client) GetAuthRequest(id string) (storage.AuthRequest, error) { var req AuthRequest if err := cli.get(resourceAuthRequest, id, &req); err != nil { @@ -354,6 +404,23 @@ func (cli *client) GetConnector(id string) (storage.Connector, error) { return toStorageConnector(c), nil } +func (cli *client) GetMiddleware(ndx int) (storage.Middleware, error) { + var mwareList MiddlewareList + if err := cli.get(resourceMiddlewareList, middlewareName, &mwareList); err != nil { + if err == storage.ErrNotFound { + err = storage.ErrOutOfRange + } + return storage.Middleware{}, err + } + + if ndx < 0 || ndx >= len(mwareList.Middleware) { + return storage.Middleware{}, storage.ErrOutOfRange + } + + return toStorageMiddleware(mwareList.ObjectMeta.ResourceVersion, + mwareList.Middleware[ndx]), nil +} + func (cli *client) ListClients() ([]storage.Client, error) { return nil, errors.New("not implemented") } @@ -395,6 +462,18 @@ func (cli *client) ListConnectors() (connectors []storage.Connector, err error) return } +func (cli *client) ListMiddleware() (middleware []storage.Middleware, err error) { + var middlewareList MiddlewareList + if err = cli.get(resourceMiddlewareList, middlewareName, &middlewareList); err != nil { + if err == storage.ErrNotFound { + return []storage.Middleware{}, nil + } + return []storage.Middleware{}, err + } + return toStorageMiddlewares(middlewareList.ObjectMeta.ResourceVersion, + middlewareList.Middleware), nil +} + func (cli *client) DeleteAuthRequest(id string) error { return cli.delete(resourceAuthRequest, id) } @@ -438,6 +517,25 @@ func (cli *client) DeleteConnector(id string) error { return cli.delete(resourceConnector, id) } +func (cli *client) DeleteMiddleware(ndx int) error { + var mwareList MiddlewareList + if err := cli.get(resourceMiddlewareList, middlewareName, &mwareList); err != nil { + if err == storage.ErrNotFound { + return storage.ErrOutOfRange + } + return err + } + + if ndx < 0 || ndx >= len(mwareList.Middleware) { + return storage.ErrOutOfRange + } + + copy(mwareList.Middleware[ndx:], mwareList.Middleware[ndx+1:]) + mwareList.Middleware = mwareList.Middleware[:len(mwareList.Middleware)-1] + + return cli.put(resourceMiddlewareList, middlewareName, mwareList) +} + func (cli *client) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error { r, err := cli.getRefreshToken(id) if err != nil { @@ -585,6 +683,30 @@ func (cli *client) UpdateConnector(id string, updater func(a storage.Connector) return cli.put(resourceConnector, id, newConn) } +func (cli *client) UpdateMiddleware(ndx int, updater func(m storage.Middleware) (storage.Middleware, error)) (err error) { + var mwareList MiddlewareList + if err := cli.get(resourceMiddlewareList, middlewareName, &mwareList); err != nil { + if err == storage.ErrNotFound { + return storage.ErrOutOfRange + } + return err + } + + if ndx < 0 || ndx >= len(mwareList.Middleware) { + return storage.ErrOutOfRange + } + + updated, err := updater(toStorageMiddleware(mwareList.ObjectMeta.ResourceVersion, + mwareList.Middleware[ndx])) + if err != nil { + return err + } + + mwareList.Middleware[ndx] = cli.fromStorageMiddleware(updated) + + return cli.put(resourceMiddlewareList, middlewareName, mwareList) +} + func (cli *client) GarbageCollect(now time.Time) (result storage.GCResult, err error) { var authRequests AuthRequestList if err := cli.list(resourceAuthRequest, &authRequests); err != nil { diff --git a/storage/kubernetes/types.go b/storage/kubernetes/types.go index 41b14f3731..eab10a8578 100644 --- a/storage/kubernetes/types.go +++ b/storage/kubernetes/types.go @@ -173,12 +173,30 @@ var customResourceDefinitions = []k8sapi.CustomResourceDefinition{ }, }, }, + { + ObjectMeta: k8sapi.ObjectMeta{ + Name: "middlewarelists.dex.coreos.com", + }, + TypeMeta: crdMeta, + Spec: k8sapi.CustomResourceDefinitionSpec{ + Group: apiGroup, + Version: "v1", + Names: k8sapi.CustomResourceDefinitionNames{ + Plural: "middlewarelists", + Singular: "middlewarelist", + Kind: "MiddlewareList", + }, + }, + }, } // There will only ever be a single keys resource. Maintain this by setting a // common name. const keysName = "openid-connect-keys" +// There will only ever be a single middleware resource. +const middlewareName = "middleware" + // Client is a mirrored struct from storage with JSON struct tags and // Kubernetes type metadata. type Client struct { @@ -648,6 +666,9 @@ type Connector struct { Name string `json:"name,omitempty"` // Config holds connector specific configuration information Config []byte `json:"config,omitempty"` + + // The list of middleware configured for this connector + Middleware []Middleware `json:"middleware,omitempty"` } func (cli *client) fromStorageConnector(c storage.Connector) Connector { @@ -660,10 +681,11 @@ func (cli *client) fromStorageConnector(c storage.Connector) Connector { Name: c.ID, Namespace: cli.namespace, }, - ID: c.ID, - Type: c.Type, - Name: c.Name, - Config: c.Config, + ID: c.ID, + Type: c.Type, + Name: c.Name, + Config: c.Config, + Middleware: cli.fromStorageMiddlewares(c.Middleware), } } @@ -674,7 +696,45 @@ func toStorageConnector(c Connector) storage.Connector { Name: c.Name, ResourceVersion: c.ObjectMeta.ResourceVersion, Config: c.Config, + Middleware: toStorageMiddlewares(c.ObjectMeta.ResourceVersion, c.Middleware), + } +} + +// Middleware is a mirrored struct from storage with JSON struct tags +type Middleware struct { + Type string `json:"type,omitempty"` + Config []byte `json:"config,omitempty"` +} + +func (cli *client) fromStorageMiddleware(m storage.Middleware) Middleware { + return Middleware{ + Type: m.Type, + Config: m.Config, + } +} + +func (cli *client) fromStorageMiddlewares(ms []storage.Middleware) []Middleware { + result := make([]Middleware, len(ms)) + for n, m := range ms { + result[n] = cli.fromStorageMiddleware(m) + } + return result +} + +func toStorageMiddleware(version string, m Middleware) storage.Middleware { + return storage.Middleware{ + Type: m.Type, + ResourceVersion: version, + Config: m.Config, + } +} + +func toStorageMiddlewares(version string, ms []Middleware) []storage.Middleware { + result := make([]storage.Middleware, len(ms)) + for n, m := range ms { + result[n] = toStorageMiddleware(version, m) } + return result } // ConnectorList is a list of Connectors. @@ -684,6 +744,13 @@ type ConnectorList struct { Connectors []Connector `json:"items"` } +// MiddlewareList is a list of Middlewares. +type MiddlewareList struct { + k8sapi.TypeMeta `json:",inline"` + k8sapi.ObjectMeta `json:"metadata,omitempty"` + Middleware []Middleware `json:"items"` +} + // DeviceRequest is a mirrored struct from storage with JSON struct tags and // Kubernetes type metadata. type DeviceRequest struct { diff --git a/storage/memory/memory.go b/storage/memory/memory.go index 82264205e7..19733a6a8a 100644 --- a/storage/memory/memory.go +++ b/storage/memory/memory.go @@ -22,6 +22,7 @@ func New(logger log.Logger) storage.Storage { connectors: make(map[string]storage.Connector), deviceRequests: make(map[string]storage.DeviceRequest), deviceTokens: make(map[string]storage.DeviceToken), + middleware: []storage.Middleware{}, logger: logger, } } @@ -50,6 +51,7 @@ type memStorage struct { connectors map[string]storage.Connector deviceRequests map[string]storage.DeviceRequest deviceTokens map[string]storage.DeviceToken + middleware []storage.Middleware keys storage.Keys @@ -181,6 +183,29 @@ func (s *memStorage) CreateConnector(connector storage.Connector) (err error) { return } +func (s *memStorage) InsertMiddleware(ndx int, middleware storage.Middleware) (err error) { + s.tx(func() { + if ndx == -1 { + ndx = len(s.middleware) + } + + if ndx < 0 || ndx > len(s.middleware) { + err = storage.ErrOutOfRange + return + } + + if ndx == len(s.middleware) { + s.middleware = append(s.middleware, middleware) + } else { + last := len(s.middleware) - 1 + s.middleware = append(s.middleware, s.middleware[last]) + copy(s.middleware[ndx+1:], s.middleware[ndx:last]) + s.middleware[ndx] = middleware + } + }) + return +} + func (s *memStorage) GetAuthCode(id string) (c storage.AuthCode, err error) { s.tx(func() { var ok bool @@ -265,6 +290,17 @@ func (s *memStorage) GetConnector(id string) (connector storage.Connector, err e return } +func (s *memStorage) GetMiddleware(ndx int) (middleware storage.Middleware, err error) { + s.tx(func() { + if ndx < 0 || ndx >= len(s.middleware) { + err = storage.ErrOutOfRange + } else { + middleware = s.middleware[ndx] + } + }) + return +} + func (s *memStorage) ListClients() (clients []storage.Client, err error) { s.tx(func() { for _, client := range s.clients { @@ -301,6 +337,13 @@ func (s *memStorage) ListConnectors() (conns []storage.Connector, err error) { return } +func (s *memStorage) ListMiddleware() (mware []storage.Middleware, err error) { + s.tx(func() { + mware = s.middleware + }) + return +} + func (s *memStorage) DeletePassword(email string) (err error) { email = strings.ToLower(email) s.tx(func() { @@ -383,6 +426,18 @@ func (s *memStorage) DeleteConnector(id string) (err error) { return } +func (s *memStorage) DeleteMiddleware(ndx int) (err error) { + s.tx(func() { + if ndx < 0 || ndx >= len(s.middleware) { + err = storage.ErrOutOfRange + } else { + copy(s.middleware[ndx:], s.middleware[ndx+1:]) + s.middleware = s.middleware[:len(s.middleware)-1] + } + }) + return +} + func (s *memStorage) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) (err error) { s.tx(func() { client, ok := s.clients[id] @@ -482,6 +537,20 @@ func (s *memStorage) UpdateConnector(id string, updater func(c storage.Connector return } +func (s *memStorage) UpdateMiddleware(ndx int, updater func(m storage.Middleware) (storage.Middleware, error)) (err error) { + s.tx(func() { + if ndx < 0 || ndx >= len(s.middleware) { + err = storage.ErrOutOfRange + } else { + r := s.middleware[ndx] + if r, err = updater(r); err == nil { + s.middleware[ndx] = r + } + } + }) + return +} + func (s *memStorage) CreateDeviceRequest(d storage.DeviceRequest) (err error) { s.tx(func() { if _, ok := s.deviceRequests[d.UserCode]; ok { diff --git a/storage/sql/crud.go b/storage/sql/crud.go index dedfd2a805..a497f6dba3 100644 --- a/storage/sql/crud.go +++ b/storage/sql/crud.go @@ -18,6 +18,10 @@ import ( // keysRowID is the ID of the only row we expect to populate the "keys" table. const keysRowID = "keys" +// orderStep says how far apart to space the order values in the global +// Middleware list +const orderStep = int64(1024) + // encoder wraps the underlying value in a JSON marshaler which is automatically // called by the database/sql package. // @@ -76,6 +80,7 @@ func (j jsonDecoder) Scan(dest interface{}) error { // Abstract conn vs trans. type querier interface { + Query(query string, args ...interface{}) (*sql.Rows, error) QueryRow(query string, args ...interface{}) *sql.Row } @@ -379,6 +384,8 @@ func (c *conn) ListRefreshTokens() ([]storage.RefreshToken, error) { if err != nil { return nil, fmt.Errorf("query: %v", err) } + defer rows.Close() + var tokens []storage.RefreshToken for rows.Next() { r, err := scanRefresh(rows) @@ -557,6 +564,8 @@ func (c *conn) ListClients() ([]storage.Client, error) { if err != nil { return nil, err } + defer rows.Close() + var clients []storage.Client for rows.Next() { cli, err := scanClient(rows) @@ -653,6 +662,7 @@ func (c *conn) ListPasswords() ([]storage.Password, error) { if err != nil { return nil, err } + defer rows.Close() var passwords []storage.Password for rows.Next() { @@ -755,7 +765,8 @@ func scanOfflineSessions(s scanner) (o storage.OfflineSessions, err error) { } func (c *conn) CreateConnector(connector storage.Connector) error { - _, err := c.Exec(` + return c.ExecTx(func(tx *trans) error { + _, err := tx.Exec(` insert into connector ( id, type, name, resource_version, config ) @@ -763,15 +774,34 @@ func (c *conn) CreateConnector(connector storage.Connector) error { $1, $2, $3, $4, $5 ); `, - connector.ID, connector.Type, connector.Name, connector.ResourceVersion, connector.Config, - ) - if err != nil { - if c.alreadyExistsCheck(err) { - return storage.ErrAlreadyExists + connector.ID, connector.Type, connector.Name, connector.ResourceVersion, connector.Config, + ) + if err != nil { + if c.alreadyExistsCheck(err) { + return storage.ErrAlreadyExists + } + return fmt.Errorf("insert connector: %v", err) } - return fmt.Errorf("insert connector: %v", err) - } - return nil + + for n, mware := range connector.Middleware { + _, err := tx.Exec(` + insert into middleware ( + conn_id, mw_order, type, resource_version, config + ) + values ( + $1, $2, $3, $4, $5 + ); + `, + connector.ID, n, + mware.Type, mware.ResourceVersion, mware.Config, + ) + if err != nil { + return fmt.Errorf("create connector insert middleware: %v", err) + } + } + + return nil + }) } func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (storage.Connector, error)) error { @@ -799,21 +829,111 @@ func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (sto if err != nil { return fmt.Errorf("update connector: %v", err) } + + // Delete extra middleware entries + _, err = tx.Exec(` + delete from middleware where conn_id = $1 and mw_order >= $2; + `, + id, len(newConn.Middleware), + ) + if err != nil { + return fmt.Errorf("update connector delete middleware: %v", err) + } + + for n, mware := range newConn.Middleware { + if n < len(connector.Middleware) { + // Update the existing record + _, err := tx.Exec(` + update middleware + set + type = $1, + resource_version = $2, + config = $3 + where conn_id = $4 and mw_order = $5; + `, + mware.Type, + mware.ResourceVersion, + mware.Config, + id, + n) + if err != nil { + return fmt.Errorf("update connector middleware: %v", err) + } + } else { + // Insert a new record + _, err := tx.Exec(` + insert into middleware ( + conn_id, mw_order, type, resource_version, config + ) + values ( + $1, $2, $3, $4, $5 + ); + `, + id, n, mware.Type, mware.ResourceVersion, mware.Config, + ) + if err != nil { + return fmt.Errorf("update connector add middleware: %v", err) + } + } + } + return nil }) } -func (c *conn) GetConnector(id string) (storage.Connector, error) { - return getConnector(c, id) +func (c *conn) GetConnector(id string) (conn storage.Connector, err error) { + err = c.ExecTx(func(tx *trans) error { + conn, err = getConnector(tx, id) + return err + }) + return conn, err } -func getConnector(q querier, id string) (storage.Connector, error) { - return scanConnector(q.QueryRow(` +func getConnector(tx *trans, id string) (storage.Connector, error) { + conn, err := scanConnector(tx.QueryRow(` select id, type, name, resource_version, config from connector where id = $1; `, id)) + if err != nil { + return storage.Connector{}, err + } + + conn.Middleware, err = getConnectorMiddleware(tx, id) + if err != nil { + return storage.Connector{}, err + } + + return conn, nil +} + +func getConnectorMiddleware(q querier, id string) ([]storage.Middleware, error) { + rows, err := q.Query(` + select + type, resource_version, config + from middleware + where conn_id = $1 + order by mw_order asc; + `, id) + if err != nil { + return []storage.Middleware{}, err + } + + middleware := []storage.Middleware{} + for rows.Next() { + mware, err := scanMiddleware(rows) + if err != nil { + return []storage.Middleware{}, err + } + + middleware = append(middleware, mware) + } + if err := rows.Err(); err != nil { + return []storage.Middleware{}, err + } + + return middleware, nil } func scanConnector(s scanner) (c storage.Connector, err error) { @@ -830,26 +950,391 @@ func scanConnector(s scanner) (c storage.Connector, err error) { } func (c *conn) ListConnectors() ([]storage.Connector, error) { + var connectors []storage.Connector + + err := c.ExecTx(func(tx *trans) error { + rows, err := tx.Query(` + select + id, type, name, resource_version, config + from connector; + `) + if err != nil { + return err + } + for rows.Next() { + conn, err := scanConnector(rows) + if err != nil { + return err + } + + connectors = append(connectors, conn) + } + if err := rows.Err(); err != nil { + return err + } + + // Have to do this after reading all the rows (otherwise we'd need to + // use a cursor). + for _, conn := range connectors { + conn.Middleware, err = getConnectorMiddleware(tx, conn.ID) + if err != nil { + return err + } + } + + return nil + }) + if err != nil { + return nil, err + } + + return connectors, nil +} + +func (c *conn) DeleteConnector(id string) error { + return c.ExecTx(func(tx *trans) error { + _, err := tx.Exec(` + delete from middleware where conn_id = $1; + `, + id) + if err != nil { + return fmt.Errorf("delete connector middleware: %v", err) + } + + result, err := tx.Exec(` + delete from connector where id = $1; + `, + id) + if err != nil { + return fmt.Errorf("delete connector: %v", err) + } + + // For now mandate that the driver implements RowsAffected. If we ever need to support + // a driver that doesn't implement this, we can run this in a transaction with a get beforehand. + n, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("rows affected: %v", err) + } + if n < 1 { + return storage.ErrNotFound + } + return nil + }) +} + +func (c *conn) InsertMiddleware(ndx int, mware storage.Middleware) error { + return c.ExecTx(func(tx *trans) error { + var order int64 + + switch ndx { + case 0: + var maybeOrder sql.NullInt64 + err := tx.QueryRow(` + select MIN(mw_order) from middleware where conn_id = ''; + `).Scan(&maybeOrder) + if err != nil { + return fmt.Errorf("insert middleware find min: %v", err) + } + + if maybeOrder.Valid { + order = maybeOrder.Int64 + + // We rely on this in the renumber routine; without it, + // order might not be unique when doing a renumber + if order == 0 { + order, err = renumberMiddleware(0, tx) + if err != nil { + return err + } + } + + order /= 2 + } else { + order = orderStep + } + case -1: + var maybeOrder sql.NullInt64 + err := tx.QueryRow(` + select MAX(mw_order) from middleware where conn_id = ''; + `).Scan(&maybeOrder) + if err != nil { + return fmt.Errorf("insert middleware find max: %v", err) + } + + if maybeOrder.Valid { + order = maybeOrder.Int64 + orderStep + } else { + order = orderStep + } + default: + rows, err := tx.Query(` + select + mw_order + from middleware + where conn_id = '' + order by mw_order asc + limit 2 offset $1 + `, ndx) + if err != nil { + return fmt.Errorf("insert middleware find low/high: %v", err) + } + defer rows.Close() + + if !rows.Next() { + return storage.ErrOutOfRange + } + + var low int64 + err = rows.Scan(&low) + if err != nil { + return fmt.Errorf("insert middleware scan low: %v", err) + } + + if !rows.Next() { + order = low + 1024 + } else { + var high int64 + err = rows.Scan(&high) + if err != nil { + return fmt.Errorf("insert middleware scan high: %v", err) + } + + if high == low+1 { + order, err = renumberMiddleware(ndx, tx) + if err != nil { + return err + } + order += orderStep / 2 + } else { + order = low + (high-low)/2 + } + } + + if err := rows.Close(); err != nil { + return fmt.Errorf("insert middleware find low/high close: %v", err) + } + + if err := rows.Err(); err != nil { + return fmt.Errorf("insert middleware find low/high err: %v", err) + } + } + + _, err := tx.Exec(` + insert into middleware ( + conn_id, type, mw_order, resource_version, config + ) + values ( + '', $1, $2, $3, $4 + ); + `, + mware.Type, order, mware.ResourceVersion, mware.Config, + ) + + if err != nil { + return fmt.Errorf("insert middleware: %v", err) + } + + return nil + }) +} + +func renumberMiddleware(ndx int, tx *trans) (int64, error) { + rows, err := tx.Query(` + select mw_order from middleware where conn_id = '' order by mw_order asc; + `) + if err != nil { + return 0, fmt.Errorf("renumber middleware: %v", err) + } + + // The idea here is that since we know the minimum possible order is 0, + // and since the smallest it can increment by is 1, we can renumber the + // middleware entries as 0, 1, 2, 3, 4, 5, ... without having to worry + // that one of these numbers is in use already. + result := int64(0) + n := int64(0) + for rows.Next() { + var order int64 + err = rows.Scan(&order) + if err != nil { + return result, fmt.Errorf("renumber middleware scan: %v", err) + } + + _, err = tx.Exec(` + update middleware set mw_order = $1 where conn_id = '' and mw_order = $3; + `, + n, + order) + if err != nil { + return result, fmt.Errorf("renumber middleware update: %v", err) + } + + n++ + } + if err := rows.Err(); err != nil { + return result, fmt.Errorf("renumber middleware: %v", err) + } + + // Once we've done that step, update the order to go up in the step size + // increments rather than by ones. + _, err = tx.Exec(` + update middleware set mw_order = (mw_order + 1) * $1 where conn_id = ''; + `, + orderStep, + ) + if err != nil { + return result, fmt.Errorf("renumber middleware multiply: %v", err) + } + + return int64(ndx) * orderStep, nil +} + +func (c *conn) UpdateMiddleware(ndx int, updater func(m storage.Middleware) (storage.Middleware, error)) error { + return c.ExecTx(func(tx *trans) error { + var order int64 + + // Find the order value for this index + err := tx.QueryRow(` + select + mw_order + from middleware + where conn_id = '' + order by mw_order asc + limit 1 offset $1; + `, + ndx, + ).Scan(&order) + if err != nil { + if err == sql.ErrNoRows { + return storage.ErrOutOfRange + } + return fmt.Errorf("update middleware find index: %v", err) + } + + middleware, err := scanMiddleware(tx.QueryRow(` + select + type, resource_version, config + from middleware + where conn_id = '' and mw_order = $1; + `, order)) + if err != nil { + return err + } + + newMware, err := updater(middleware) + if err != nil { + return err + } + + _, err = tx.Exec(` + update middleware + set + type = $1, + resource_version = $2, + config = $3 + where conn_id = '' and mw_order = $4; + `, + newMware.Type, + newMware.ResourceVersion, + newMware.Config, + order, + ) + if err != nil { + return fmt.Errorf("update middleware: %v", err) + } + + return nil + }) +} + +func (c *conn) GetMiddleware(ndx int) (storage.Middleware, error) { + row := c.QueryRow(` + select + type, resource_version, config + from middleware + where conn_id = '' + order by mw_order asc + limit 1 offset $1; + `, ndx) + + mware, err := scanMiddleware(row) + if err != nil { + return storage.Middleware{}, err + } + + return mware, nil +} + +func (c *conn) ListMiddleware() ([]storage.Middleware, error) { + var middleware []storage.Middleware + rows, err := c.Query(` select - id, type, name, resource_version, config - from connector; + type, resource_version, config + from middleware + where conn_id = '' + order by mw_order asc; `) if err != nil { - return nil, err + return []storage.Middleware{}, err } - var connectors []storage.Connector for rows.Next() { - conn, err := scanConnector(rows) + mware, err := scanMiddleware(rows) if err != nil { - return nil, err + return []storage.Middleware{}, err } - connectors = append(connectors, conn) + + middleware = append(middleware, mware) } if err := rows.Err(); err != nil { - return nil, err + return []storage.Middleware{}, err } - return connectors, nil + + return middleware, nil +} + +func scanMiddleware(s scanner) (m storage.Middleware, err error) { + err = s.Scan( + &m.Type, &m.ResourceVersion, &m.Config, + ) + if err != nil { + if err == sql.ErrNoRows { + return m, storage.ErrOutOfRange + } + return m, fmt.Errorf("select middleware: %v", err) + } + return m, nil +} + +func (c *conn) DeleteMiddleware(ndx int) error { + return c.ExecTx(func(tx *trans) error { + var maybeOrder sql.NullInt64 + err := tx.QueryRow(` + select mw_order from middleware where conn_id = '' + order by mw_order asc + limit 1 offset $1; + `, ndx).Scan(&maybeOrder) + if err != nil { + return fmt.Errorf("delete middleware find order: %v", err) + } + + if !maybeOrder.Valid { + return storage.ErrOutOfRange + } + + order := maybeOrder.Int64 + + _, err = tx.Exec(` + delete from middleware + where conn_id = '' and mw_order = $1; + `, + order, + ) + if err != nil { + return fmt.Errorf("delete middleware: %v", err) + } + + return nil + }) } func (c *conn) DeleteAuthRequest(id string) error { return c.delete("auth_request", "id", id) } @@ -859,8 +1344,6 @@ func (c *conn) DeleteRefresh(id string) error { return c.delete("refresh_tok func (c *conn) DeletePassword(email string) error { return c.delete("password", "email", strings.ToLower(email)) } -func (c *conn) DeleteConnector(id string) error { return c.delete("connector", "id", id) } - func (c *conn) DeleteOfflineSessions(userID string, connID string) error { result, err := c.Exec(`delete from offline_session where user_id = $1 AND conn_id = $2`, userID, connID) if err != nil { diff --git a/storage/sql/migrate.go b/storage/sql/migrate.go index 8201d443d4..5e24f2d41d 100644 --- a/storage/sql/migrate.go +++ b/storage/sql/migrate.go @@ -265,4 +265,16 @@ var migrations = []migration{ add column code_challenge_method text not null default '';`, }, }, + { + stmts: []string{` + create table middleware ( + conn_id text not null, + mw_order integer not null, + type text not null, + resource_version text not null, + config bytea, + PRIMARY KEY (conn_id, mw_order) + );`, + }, + }, } diff --git a/storage/static.go b/storage/static.go index 806b61f9cd..0cc1558c02 100644 --- a/storage/static.go +++ b/storage/static.go @@ -230,3 +230,40 @@ func (s staticConnectorsStorage) UpdateConnector(id string, updater func(old Con } return s.Storage.UpdateConnector(id, updater) } + +// staticMiddlewareStorage represents a storage with a read-only set of middleware. +type staticMiddlewareStorage struct { + Storage + + // A read-only set of middleware + middleware []Middleware +} + +// WithStaticMiddleware returns a storage with a read-only set of Middleware. Write actions, +// such as updating existing Middleware, will fail. +func WithStaticMiddleware(s Storage, staticMiddleware []Middleware) Storage { + return staticMiddlewareStorage{s, staticMiddleware} +} + +func (s staticMiddlewareStorage) InsertMiddleware(ndx int, m Middleware) error { + return errors.New("static middleware: read-only cannot create middleware") +} + +func (s staticMiddlewareStorage) GetMiddleware(ndx int) (Middleware, error) { + if ndx < 0 || ndx >= len(s.middleware) { + return Middleware{}, ErrOutOfRange + } + return s.middleware[ndx], nil +} + +func (s staticMiddlewareStorage) ListMiddleware() ([]Middleware, error) { + return s.middleware, nil +} + +func (s staticMiddlewareStorage) DeleteMiddleware(ndx int) error { + return errors.New("static middleware: read-only cannot delete middleware") +} + +func (s staticMiddlewareStorage) UpdateMiddleware(ndx int, updater func(m Middleware) (Middleware, error)) error { + return errors.New("static middleware: read-only cannot update middleware") +} diff --git a/storage/storage.go b/storage/storage.go index 06f718e10f..3849fe8817 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -18,6 +18,9 @@ var ( // ErrAlreadyExists is the error returned by storages if a resource ID is taken during a create. ErrAlreadyExists = errors.New("ID already exists") + + // ErrOutOfRange is the error returned by storages if you attempt to insert at an out-of-range location. + ErrOutOfRange = errors.New("out of range") ) // Kubernetes only allows lower case letters for names. @@ -72,6 +75,13 @@ type Storage interface { CreateDeviceRequest(d DeviceRequest) error CreateDeviceToken(d DeviceToken) error + // Insert a new Middleware into the global middleware list + // + // ndx = 0 means insert at start + // ndx = -1 means insert at end + // + InsertMiddleware(ndx int, m Middleware) error + // TODO(ericchiang): return (T, bool, error) so we can indicate not found // requests that way instead of using ErrNotFound. GetAuthRequest(id string) (AuthRequest, error) @@ -84,11 +94,13 @@ type Storage interface { GetConnector(id string) (Connector, error) GetDeviceRequest(userCode string) (DeviceRequest, error) GetDeviceToken(deviceCode string) (DeviceToken, error) + GetMiddleware(ndx int) (Middleware, error) ListClients() ([]Client, error) ListRefreshTokens() ([]RefreshToken, error) ListPasswords() ([]Password, error) ListConnectors() ([]Connector, error) + ListMiddleware() ([]Middleware, error) // Delete methods MUST be atomic. DeleteAuthRequest(id string) error @@ -98,6 +110,7 @@ type Storage interface { DeletePassword(email string) error DeleteOfflineSessions(userID string, connID string) error DeleteConnector(id string) error + DeleteMiddleware(ndx int) error // Update methods take a function for updating an object then performs that update within // a transaction. "updater" functions may be called multiple times by a single update call. @@ -120,6 +133,7 @@ type Storage interface { UpdatePassword(email string, updater func(p Password) (Password, error)) error UpdateOfflineSessions(userID string, connID string, updater func(s OfflineSessions) (OfflineSessions, error)) error UpdateConnector(id string, updater func(c Connector) (Connector, error)) error + UpdateMiddleware(ndx int, updater func(m Middleware) (Middleware, error)) error UpdateDeviceToken(deviceCode string, updater func(t DeviceToken) (DeviceToken, error)) error // GarbageCollect deletes all expired AuthCodes, @@ -350,7 +364,23 @@ type Connector struct { ResourceVersion string `json:"resourceVersion"` // Config holds all the configuration information specific to the connector type. Since there // no generic struct we can use for this purpose, it is stored as a byte stream. - Config []byte `json:"email"` + Config []byte `json:"config"` + + // Middleware configured for this connector + Middleware []Middleware `json:"middleware"` +} + +// Middleware is an object that contains the metadata about middleware +type Middleware struct { + // The type of the middleware, e.g. "groups" + Type string `json:"type"` + + // ResourceVersion is the static versioning used to keep track of dynamic configuration + // changes to the middleware object made by the API calls. + ResourceVersion string `json:"resourceVersion"` + + // Config holds middleware-specific configuration, as a byte stream. + Config []byte `json:"config"` } // VerificationKey is a rotated signing key which can still be used to verify