Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add ConnectionType(ctx) for called methods to use #118

Merged
merged 2 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 27 additions & 8 deletions rpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ import (
)

func init() {
if err := logging.SetLogLevel("rpc", "DEBUG"); err != nil {
panic(err)
if _, exists := os.LookupEnv("GOLOG_LOG_LEVEL"); !exists {
if err := logging.SetLogLevel("rpc", "DEBUG"); err != nil {
panic(err)
}
}

debugTrace = true
Expand Down Expand Up @@ -497,15 +499,17 @@ func TestParallelRPC(t *testing.T) {
type CtxHandler struct {
lk sync.Mutex

cancelled bool
i int
cancelled bool
i int
connectionType ConnectionType
}

func (h *CtxHandler) Test(ctx context.Context) {
h.lk.Lock()
defer h.lk.Unlock()
timeout := time.After(300 * time.Millisecond)
h.i++
h.connectionType = GetConnectionType(ctx)

select {
case <-timeout:
Expand Down Expand Up @@ -543,6 +547,9 @@ func TestCtx(t *testing.T) {
if !serverHandler.cancelled {
t.Error("expected cancellation on the server side")
}
if serverHandler.connectionType != ConnectionTypeWS {
t.Error("wrong connection type")
rvagg marked this conversation as resolved.
Show resolved Hide resolved
}

serverHandler.cancelled = false

Expand All @@ -564,6 +571,9 @@ func TestCtx(t *testing.T) {
if serverHandler.cancelled || serverHandler.i != 2 {
t.Error("wrong serverHandler state")
}
if serverHandler.connectionType != ConnectionTypeWS {
t.Error("wrong connection type")
}

serverHandler.lk.Unlock()
closer()
Expand Down Expand Up @@ -598,6 +608,9 @@ func TestCtxHttp(t *testing.T) {
if !serverHandler.cancelled {
t.Error("expected cancellation on the server side")
}
if serverHandler.connectionType != ConnectionTypeHTTP {
t.Error("wrong connection type")
}

serverHandler.cancelled = false

Expand All @@ -619,6 +632,10 @@ func TestCtxHttp(t *testing.T) {
if serverHandler.cancelled || serverHandler.i != 2 {
t.Error("wrong serverHandler state")
}
// connection type should have switched to WS
if serverHandler.connectionType != ConnectionTypeWS {
t.Error("wrong connection type")
}

serverHandler.lk.Unlock()
closer()
Expand Down Expand Up @@ -1007,10 +1024,12 @@ func TestChanClientReceiveAll(t *testing.T) {
}

func TestControlChanDeadlock(t *testing.T) {
_ = logging.SetLogLevel("rpc", "error")
defer func() {
_ = logging.SetLogLevel("rpc", "debug")
}()
if _, exists := os.LookupEnv("GOLOG_LOG_LEVEL"); !exists {
_ = logging.SetLogLevel("rpc", "error")
defer func() {
_ = logging.SetLogLevel("rpc", "DEBUG")
}()
}

for r := 0; r < 20; r++ {
testControlChanDeadlock(t)
Expand Down
27 changes: 27 additions & 0 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,31 @@ const (
rpcInvalidParams = -32602
)

// ConnectionType indicates the type of connection, this is set in the context and can be retrieved
// with GetConnectionType.
type ConnectionType string

const (
// ConnectionTypeUnknown indicates that the connection type cannot be determined, likely because
// it hasn't passed through an RPCServer.
ConnectionTypeUnknown ConnectionType = "unknown"
// ConnectionTypeHTTP indicates that the connection is an HTTP connection.
ConnectionTypeHTTP ConnectionType = "http"
// ConnectionTypeWS indicates that the connection is a WebSockets connection.
ConnectionTypeWS ConnectionType = "websockets"
)

var connectionTypeCtxKey = &struct{ name string }{"jsonrpc-connection-type"}

// GetConnectionType returns the connection type of the request if it was set by an RPCServer.
// A connection type of ConnectionTypeUnknown means that the connection type was not set.
func GetConnectionType(ctx context.Context) ConnectionType {
if v := ctx.Value(connectionTypeCtxKey); v != nil {
return v.(ConnectionType)
}
return ConnectionTypeUnknown
}

// RPCServer provides a jsonrpc 2.0 http server handler
type RPCServer struct {
*handler
Expand Down Expand Up @@ -97,10 +122,12 @@ func (s *RPCServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {

h := strings.ToLower(r.Header.Get("Connection"))
if strings.Contains(h, "upgrade") {
ctx = context.WithValue(ctx, connectionTypeCtxKey, ConnectionTypeWS)
s.handleWS(ctx, w, r)
return
}

ctx = context.WithValue(ctx, connectionTypeCtxKey, ConnectionTypeHTTP)
s.handleReader(ctx, r.Body, w, rpcError)
}

Expand Down