Skip to content

Commit

Permalink
[Streaming][bugfix] handle TLS signalisation when TLS is disabled on …
Browse files Browse the repository at this point in the history
…client side

Tnis is an alternative to hashicorp#9494
  • Loading branch information
pierresouchay committed Jan 7, 2021
1 parent c817c29 commit ca2c3eb
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 10 deletions.
3 changes: 3 additions & 0 deletions .changelog/9512.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:bug
client: properly set GRPC over RPC magic numbers when encryption was not set or partially set in the cluster with streaming enabled
```
9 changes: 5 additions & 4 deletions agent/grpc/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@ type TLSWrapper func(dc string, conn net.Conn) (net.Conn, error)

type dialer func(context.Context, string) (net.Conn, error)

func NewClientConnPool(servers ServerLocator, tls TLSWrapper) *ClientConnPool {
// NewClientConnPool create new GRPC client pool to connect to servers using GRPC over RPC
func NewClientConnPool(servers ServerLocator, tls TLSWrapper, useTLSForDC func(dc string) bool) *ClientConnPool {
return &ClientConnPool{
dialer: newDialer(servers, tls),
dialer: newDialer(servers, tls, useTLSForDC),
servers: servers,
conns: make(map[string]*grpc.ClientConn),
}
Expand Down Expand Up @@ -74,7 +75,7 @@ func (c *ClientConnPool) ClientConn(datacenter string) (*grpc.ClientConn, error)

// newDialer returns a gRPC dialer function that conditionally wraps the connection
// with TLS based on the Server.useTLS value.
func newDialer(servers ServerLocator, wrapper TLSWrapper) func(context.Context, string) (net.Conn, error) {
func newDialer(servers ServerLocator, wrapper TLSWrapper, useTLSForDC func(dc string) bool) func(context.Context, string) (net.Conn, error) {
return func(ctx context.Context, addr string) (net.Conn, error) {
d := net.Dialer{}
conn, err := d.DialContext(ctx, "tcp", addr)
Expand All @@ -88,7 +89,7 @@ func newDialer(servers ServerLocator, wrapper TLSWrapper) func(context.Context,
return nil, err
}

if server.UseTLS {
if server.UseTLS && useTLSForDC(server.Datacenter) {
if wrapper == nil {
conn.Close()
return nil, fmt.Errorf("TLS enabled but got nil TLS wrapper")
Expand Down
15 changes: 10 additions & 5 deletions agent/grpc/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ import (
"github.com/hashicorp/consul/tlsutil"
)

// useTLSForDcAlwaysTrue tell GRPC to always return the TLS is enabled
func useTLSForDcAlwaysTrue(_ string) bool {
return true
}

func TestNewDialer_WithTLSWrapper(t *testing.T) {
lis, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
Expand All @@ -37,7 +42,7 @@ func TestNewDialer_WithTLSWrapper(t *testing.T) {
called = true
return conn, nil
}
dial := newDialer(builder, wrapper)
dial := newDialer(builder, wrapper, useTLSForDcAlwaysTrue)
ctx := context.Background()
conn, err := dial(ctx, lis.Addr().String())
require.NoError(t, err)
Expand All @@ -63,7 +68,7 @@ func TestNewDialer_IntegrationWithTLSEnabledHandler(t *testing.T) {
res.AddServer(srv.Metadata())
t.Cleanup(srv.shutdown)

pool := NewClientConnPool(res, TLSWrapper(tlsConf.OutgoingRPCWrapper()))
pool := NewClientConnPool(res, TLSWrapper(tlsConf.OutgoingRPCWrapper()), tlsConf.UseTLS)

conn, err := pool.ClientConn("dc1")
require.NoError(t, err)
Expand All @@ -82,7 +87,7 @@ func TestClientConnPool_IntegrationWithGRPCResolver_Failover(t *testing.T) {
count := 4
res := resolver.NewServerResolverBuilder(newConfig(t))
registerWithGRPC(t, res)
pool := NewClientConnPool(res, nil)
pool := NewClientConnPool(res, nil, useTLSForDcAlwaysTrue)

for i := 0; i < count; i++ {
name := fmt.Sprintf("server-%d", i)
Expand Down Expand Up @@ -119,7 +124,7 @@ func TestClientConnPool_IntegrationWithGRPCResolver_Rebalance(t *testing.T) {
count := 5
res := resolver.NewServerResolverBuilder(newConfig(t))
registerWithGRPC(t, res)
pool := NewClientConnPool(res, nil)
pool := NewClientConnPool(res, nil, useTLSForDcAlwaysTrue)

for i := 0; i < count; i++ {
name := fmt.Sprintf("server-%d", i)
Expand Down Expand Up @@ -168,7 +173,7 @@ func TestClientConnPool_IntegrationWithGRPCResolver_MultiDC(t *testing.T) {

res := resolver.NewServerResolverBuilder(newConfig(t))
registerWithGRPC(t, res)
pool := NewClientConnPool(res, nil)
pool := NewClientConnPool(res, nil, useTLSForDcAlwaysTrue)

for _, dc := range dcs {
name := "server-0-" + dc
Expand Down
2 changes: 1 addition & 1 deletion agent/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ func NewBaseDeps(configLoader ConfigLoader, logOut io.Writer) (BaseDeps, error)

builder := resolver.NewServerResolverBuilder(resolver.Config{})
registerWithGRPC(builder)
d.GRPCConnPool = grpc.NewClientConnPool(builder, grpc.TLSWrapper(d.TLSConfigurator.OutgoingRPCWrapper()))
d.GRPCConnPool = grpc.NewClientConnPool(builder, grpc.TLSWrapper(d.TLSConfigurator.OutgoingRPCWrapper()), d.TLSConfigurator.UseTLS)

d.Router = router.NewRouter(d.Logger, cfg.Datacenter, fmt.Sprintf("%s.%s", cfg.NodeName, cfg.Datacenter), builder)

Expand Down

0 comments on commit ca2c3eb

Please sign in to comment.