Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add reuseport functionality to websocket protocol (v2) #2719

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion p2p/transport/websocket/addrs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ func TestConvertWebsocketMultiaddrToNetAddr(t *testing.T) {
}

func TestListeningOnDNSAddr(t *testing.T) {
ln, err := newListener(ma.StringCast("/dns/localhost/tcp/0/ws"), nil)
wt := &WebsocketTransport{}
ln, err := wt.newListener(ma.StringCast("/dns/localhost/tcp/0/ws"), nil)
require.NoError(t, err)
addr := ln.Multiaddr()
first, rest := ma.SplitFirst(addr)
Expand Down
22 changes: 14 additions & 8 deletions p2p/transport/websocket/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
)

type listener struct {
nl net.Listener
nl manet.Listener
server http.Server
// The Go standard library sets the http.Server.TLSConfig no matter if this is a WS or WSS,
// so we can't rely on checking if server.TLSConfig is set.
Expand All @@ -40,7 +40,7 @@ func (pwma *parsedWebsocketMultiaddr) toMultiaddr() ma.Multiaddr {

// newListener creates a new listener from a raw net.Listener.
// tlsConf may be nil (for unencrypted websockets).
func newListener(a ma.Multiaddr, tlsConf *tls.Config) (*listener, error) {
func (t *WebsocketTransport) newListener(a ma.Multiaddr, tlsConf *tls.Config) (*listener, error) {
parsed, err := parseWebsocketMultiaddr(a)
if err != nil {
return nil, err
Expand All @@ -50,11 +50,16 @@ func newListener(a ma.Multiaddr, tlsConf *tls.Config) (*listener, error) {
return nil, fmt.Errorf("cannot listen on wss address %s without a tls.Config", a)
}

lnet, lnaddr, err := manet.DialArgs(parsed.restMultiaddr)
if err != nil {
return nil, err
var nl manet.Listener
if !t.UseReuseport() {
nl, err = manet.Listen(a)
} else {
nl, err = t.reuse.Listen(a)
// Fallback to regular listener in case of an error.
if err != nil {
nl, err = manet.Listen(a)
}
}
nl, err := net.Listen(lnet, lnaddr)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -88,10 +93,11 @@ func newListener(a ma.Multiaddr, tlsConf *tls.Config) (*listener, error) {

func (l *listener) serve() {
defer close(l.closed)
list := manet.NetListener(l.nl)
if !l.isWss {
l.server.Serve(l.nl)
l.server.Serve(list)
} else {
l.server.ServeTLS(l.nl, "", "")
l.server.ServeTLS(list, "", "")
}
}

Expand Down
9 changes: 9 additions & 0 deletions p2p/transport/websocket/reuseport.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package websocket

import (
"github.com/libp2p/go-reuseport"
)

func reuseportIsAvailable() bool {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now this is just overhead - I added it for parity with the TCP implementation that includes checking of a LIBP2P_TCP_REUSEPORT environment variable to change the setting.

return reuseport.Available()
}
65 changes: 62 additions & 3 deletions p2p/transport/websocket/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/transport"
"github.com/libp2p/go-libp2p/p2p/net/reuseport"

ma "github.com/multiformats/go-multiaddr"
mafmt "github.com/multiformats/go-multiaddr-fmt"
Expand Down Expand Up @@ -80,13 +81,23 @@ func WithTLSConfig(conf *tls.Config) Option {
}
}

func EnableReuseport() Option {
return func(t *WebsocketTransport) error {
t.enableReuseport = true
return nil
}
}
Comment on lines +84 to +89
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the opposite option from the TCP implementation - there we have it enabled by default, here it needs to be explicitly enabled. I figured it might be better to go the more conservative route.


// WebsocketTransport is the actual go-libp2p transport
type WebsocketTransport struct {
upgrader transport.Upgrader
rcmgr network.ResourceManager

tlsClientConf *tls.Config
tlsConf *tls.Config

enableReuseport bool // Explicitly enable reuseport.
reuse reuseport.Transport
}

var _ transport.Transport = (*WebsocketTransport)(nil)
Expand Down Expand Up @@ -188,6 +199,32 @@ func (t *WebsocketTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (ma
}
isWss := wsurl.Scheme == "wss"
dialer := ws.Dialer{HandshakeTimeout: 30 * time.Second}
dialer.NetDialContext = func(ctx context.Context, network string, address string) (net.Conn, error) {

tcpAddr, err := net.ResolveTCPAddr(network, address)
if err != nil {
return nil, err
}

maddr, err := manet.FromNetAddr(tcpAddr)
if err != nil {
return nil, err
}

var conn manet.Conn
if t.UseReuseport() {
conn, err = t.reuse.DialContext(ctx, maddr)
} else {
var d manet.Dialer
conn, err = d.DialContext(ctx, maddr)
}
if err != nil {
return nil, err
}

return conn, nil
}

if isWss {
sni := ""
sni, err = raddr.ValueForProtocol(ma.P_SNI)
Expand All @@ -202,12 +239,29 @@ func (t *WebsocketTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (ma
ipAddr := wsurl.Host
// Setting the NetDial because we already have the resolved IP address, so we don't want to do another resolution.
// We set the `.Host` to the sni field so that the host header gets properly set.
dialer.NetDial = func(network, address string) (net.Conn, error) {
dialer.NetDialContext = func(ctx context.Context, network, address string) (net.Conn, error) {
tcpAddr, err := net.ResolveTCPAddr(network, ipAddr)
if err != nil {
return nil, err
}
return net.DialTCP("tcp", nil, tcpAddr)

maddr, err := manet.FromNetAddr(tcpAddr)
if err != nil {
return nil, err
}

var conn manet.Conn
if t.UseReuseport() {
conn, err = t.reuse.DialContext(ctx, maddr)
} else {
var d manet.Dialer
conn, err = d.DialContext(ctx, maddr)
}
if err != nil {
return nil, err
}

return conn, nil
}
wsurl.Host = sni + ":" + wsurl.Port()
} else {
Expand All @@ -229,7 +283,7 @@ func (t *WebsocketTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (ma
}

func (t *WebsocketTransport) maListen(a ma.Multiaddr) (manet.Listener, error) {
l, err := newListener(a, t.tlsConf)
l, err := t.newListener(a, t.tlsConf)
if err != nil {
return nil, err
}
Expand All @@ -244,3 +298,8 @@ func (t *WebsocketTransport) Listen(a ma.Multiaddr) (transport.Listener, error)
}
return &transportListener{Listener: t.upgrader.UpgradeListener(t, malist)}, nil
}

// UseReuseport returns true if reuseport is enabled and available.
func (t *WebsocketTransport) UseReuseport() bool {
return t.enableReuseport && reuseportIsAvailable()
}
151 changes: 151 additions & 0 deletions p2p/transport/websocket/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ import (
"math/big"
"net"
"net/http"
"runtime"
"strings"
"sync"
"testing"
"time"

Expand All @@ -32,6 +34,7 @@ import (
ttransport "github.com/libp2p/go-libp2p/p2p/transport/testsuite"

ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -549,3 +552,151 @@ func TestResolveMultiaddr(t *testing.T) {
})
}
}

func TestReusePortOnDial(t *testing.T) {

// Create an endpoint that will accept connections.
// We'll use this to verify that the party initiating the connection reused port.
serverID, cu := newUpgrader(t)
transport, err := New(cu, &network.NullResourceManager{})
require.NoError(t, err)

server, err := transport.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws"))
require.NoError(t, err)
defer server.Close()

// Create an endpoint that will initiate connection.
_, u := newUpgrader(t)
tpt, err := New(u, &network.NullResourceManager{}, EnableReuseport())
require.NoError(t, err)

// Start listening.
listener, err := tpt.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws"))
require.NoError(t, err)
defer listener.Close()

// Take a note of the multiaddress on which we listen. This should be the address from which we dial too.
expectedAddr := listener.Multiaddr()

done := make(chan struct{})
go func() {
defer close(done)

conn, err := server.Accept()
require.NoError(t, err)
defer conn.Close()

// The meat of this test - verify that the connection was received from the same port as the listen port recorded above.
remote := conn.RemoteMultiaddr()
require.Equal(t, expectedAddr, remote)
}()

conn, err := tpt.Dial(context.Background(), server.Multiaddr(), serverID)
require.NoError(t, err)
defer conn.Close()

<-done
}

func TestReusePortOnListen(t *testing.T) {

const (
// how many connections we try to establish.
connectionCount = 20
)

// Create an endpoint that will accept connections.
// We'll use this to verify that the party initiating the connection reused port.
_, cu := newUpgrader(t)
tpt, err := New(cu, &network.NullResourceManager{}, EnableReuseport())
require.NoError(t, err)

listener1, err := tpt.maListen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws"))
require.NoError(t, err)

// Get the port on which we should start the second listener
addr, ok := listener1.Addr().(*net.TCPAddr)
require.True(t, ok)

port := addr.Port
listener2, err := tpt.maListen(ma.StringCast(fmt.Sprintf("/ip4/127.0.0.1/tcp/%v/ws", port)))
require.NoError(t, err)

listeners := []manet.Listener{listener1, listener2}

// Record which listener accepted how many connections.
requestCount := make(map[int]int)
var lock sync.Mutex

var connsHandled sync.WaitGroup
connsHandled.Add(connectionCount)
// For both listeners spin up goroutines to accept incoming connections.
for i, listener := range listeners {
for j := 0; j < connectionCount; j++ {
go func(index int, listener manet.Listener) {

conn, err := listener.Accept()
if err != nil {
// Stop condition - this happens when the listener is closed.
require.ErrorIs(t, err, transport.ErrListenerClosed)
return
}
Comment on lines +638 to +643
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's possible to just start listeners above and defer Close() calls to it. Then when the test function completes and execution "falls off the end" the listener.Close() call will unblock this and Accept() call will fail with an error.

Still, I think I prefer that one can see what will happen during the test just by reading it. To me it's better than "something will happen after the test is done too". This is a bit ugly but explicit. I'm open to any criticism :)

defer conn.Close()
connsHandled.Done()

// Record which listener accepted the connection.
lock.Lock()
defer lock.Unlock()
requestCount[index]++
}(i, listener)
}
}

// Create a different transport as you cannot self-dial using reuseport.
tpt2, err := New(cu, &network.NullResourceManager{})
require.NoError(t, err)

var dialers sync.WaitGroup
dialers.Add(connectionCount)

for i := 0; i < connectionCount; i++ {
go func() {
defer dialers.Done()
conn, err := tpt2.maDial(context.Background(), listener1.Multiaddr())
require.NoError(t, err)
defer conn.Close()
}()
}

// Wait for all dialers to complete.
dialers.Wait()

// Wait for listeners to complete their part.
connsHandled.Wait()

// Cancel listeners to unblock any further pending accepts.
listener1.Close()
listener2.Close()

// For Windows we can't make any assumptions with regards to connection distribution:
// "Once the second socket has successfully bound, the behavior for all sockets bound to that port is indeterminate.
// For example, if all of the sockets on the same port provide TCP service, any incoming TCP connection requests over
// the port cannot be guaranteed to be handled by the correct socket — the behavior is non-deterministic."
// => https://learn.microsoft.com/en-us/windows/win32/winsock/using-so-reuseaddr-and-so-exclusiveaddruse

// For MacOS (FreeBSD) it's the last socket to bind that receives the connections. Anegdotal evidence but:
// "Ironically it's the BSD semantics which support seamless server restarts. In my tests OS X's behavior (which I presume
// is identical to FreeBSD and other BSDs) is that the last socket to bind is the only one to receive new connections."
// => https://lwn.net/Articles/542629/
// On FreeBSD it's the SO_REUSEPORT_LB variant that provides load balancing.

// For Linux only - verify that both listeners handled some connections.
if runtime.GOOS == "linux" {
// We're not trying to verify an even distribution as it's not a perfect world.
require.NotZero(t, requestCount[0], "first listener accepted no connections")
require.NotZero(t, requestCount[1], "second listener accepted no connections")
}

total := requestCount[0] + requestCount[1]
require.Equal(t, connectionCount, total, "not all requests were handled")
}