diff --git a/transport/grpc/client.go b/transport/grpc/client.go index 42ad5d724f3..8b1daf5f0c3 100644 --- a/transport/grpc/client.go +++ b/transport/grpc/client.go @@ -62,6 +62,13 @@ func WithMiddleware(m ...middleware.Middleware) ClientOption { } } +// WithStreamMiddleware with client stream middleware. +func WithStreamMiddleware(m ...middleware.Middleware) ClientOption { + return func(s *clientOptions) { + s.streamMiddleware = m + } +} + // WithDiscovery with client discovery. func WithDiscovery(d registry.Discovery) ClientOption { return func(o *clientOptions) { diff --git a/transport/grpc/client_test.go b/transport/grpc/client_test.go index cdbee600709..029baa63af2 100644 --- a/transport/grpc/client_test.go +++ b/transport/grpc/client_test.go @@ -42,6 +42,17 @@ func TestWithMiddleware(t *testing.T) { } } +func TestWithStreamMiddleware(t *testing.T) { + o := &clientOptions{} + v := []middleware.Middleware{ + func(middleware.Handler) middleware.Handler { return nil }, + } + WithStreamMiddleware(v...)(o) + if !reflect.DeepEqual(v, o.streamMiddleware) { + t.Errorf("expect %v but got %v", v, o.streamMiddleware) + } +} + type mockRegistry struct{} func (m *mockRegistry) GetService(_ context.Context, _ string) ([]*registry.ServiceInstance, error) {