diff --git a/streamz/core.py b/streamz/core.py index f381b632..6e9f4209 100644 --- a/streamz/core.py +++ b/streamz/core.py @@ -119,6 +119,22 @@ def __str__(self): class APIRegisterMixin(object): + def _new_node(self, cls, args, kwargs): + """ Constructor for downstream nodes. + + Examples + -------- + To provide inheritance through nodes : + + >>> class MyStream(Stream): + >>> + >>> def _new_node(self, cls, args, kwargs): + >>> if not issubclass(cls, MyStream): + >>> cls = type(cls.__name__, (cls, MyStream), dict(cls.__dict__)) + >>> return cls(*args, **kwargs) + """ + return cls(*args, **kwargs) + @classmethod def register_api(cls, modifier=identity, attribute_name=None): """ Add callable to Stream API @@ -158,6 +174,10 @@ def register_api(cls, modifier=identity, attribute_name=None): def _(func): @functools.wraps(func) def wrapped(*args, **kwargs): + if identity is not staticmethod and args: + self = args[0] + if isinstance(self, APIRegisterMixin): + return self._new_node(func, args, kwargs) return func(*args, **kwargs) name = attribute_name if attribute_name else func.__name__ setattr(cls, name, modifier(wrapped)) diff --git a/streamz/tests/test_core.py b/streamz/tests/test_core.py index 56a661d7..336d8500 100644 --- a/streamz/tests/test_core.py +++ b/streamz/tests/test_core.py @@ -1367,6 +1367,35 @@ class foo(NewStream): assert not hasattr(Stream(), 'foo') +def test_subclass_node(): + + def add(x) : return x + 1 + + class MyStream(Stream): + def _new_node(self, cls, args, kwargs): + if not issubclass(cls, MyStream): + cls = type(cls.__name__, (cls, MyStream), dict(cls.__dict__)) + return cls(*args, **kwargs) + + @MyStream.register_api() + class foo(sz.sinks.sink): + pass + + stream = MyStream() + lst = list() + + node = stream.map(add) + assert isinstance(node, sz.core.map) + assert isinstance(node, MyStream) + + node = node.foo(lst.append) + assert isinstance(node, sz.sinks.sink) + assert isinstance(node, MyStream) + + stream.emit(100) + assert lst == [ 101 ] + + @gen_test() def test_latest(): source = Stream(asynchronous=True) diff --git a/streamz/tests/test_dask.py b/streamz/tests/test_dask.py index e99b1722..d4a67f91 100644 --- a/streamz/tests/test_dask.py +++ b/streamz/tests/test_dask.py @@ -242,6 +242,7 @@ def test_buffer_sync(loop): # noqa: F811 @pytest.mark.xfail(reason='') +@pytest.mark.asyncio async def test_stream_shares_client_loop(loop): # noqa: F811 with cluster() as (s, [a, b]): with Client(s['address'], loop=loop) as client: # noqa: F841 diff --git a/streamz/tests/test_kafka.py b/streamz/tests/test_kafka.py index 27755655..edeaac69 100644 --- a/streamz/tests/test_kafka.py +++ b/streamz/tests/test_kafka.py @@ -70,7 +70,7 @@ def predicate(): return b'kafka entered RUNNING state' in out except subprocess.CalledProcessError: pass - wait_for(predicate, 10, period=0.1) + wait_for(predicate, 30, period=0.1) return cid @@ -106,9 +106,9 @@ def split(messages): return parsed +@pytest.mark.asyncio @flaky(max_runs=3, min_passes=1) -@gen_test(timeout=60) -def test_from_kafka(): +async def test_from_kafka(): j = random.randint(0, 10000) ARGS = {'bootstrap.servers': 'localhost:9092', 'group.id': 'streamz-test%i' % j} @@ -117,30 +117,29 @@ def test_from_kafka(): stream = Stream.from_kafka([TOPIC], ARGS, asynchronous=True) out = stream.sink_to_list() stream.start() - yield gen.sleep(1.1) # for loop to run + await asyncio.sleep(1.1) # for loop to run for i in range(10): - yield gen.sleep(0.1) # small pause ensures correct ordering + await asyncio.sleep(0.1) # small pause ensures correct ordering kafka.produce(TOPIC, b'value-%d' % i) kafka.flush() # it takes some time for messages to come back out of kafka - wait_for(lambda: len(out) == 10, 10, period=0.1) + await await_for(lambda: len(out) == 10, 10, period=0.1) assert out[-1] == b'value-9' kafka.produce(TOPIC, b'final message') kafka.flush() - wait_for(lambda: out[-1] == b'final message', 10, period=0.1) + await await_for(lambda: out[-1] == b'final message', 10, period=0.1) stream._close_consumer() kafka.produce(TOPIC, b'lost message') kafka.flush() # absolute sleep here, since we expect output list *not* to change - yield gen.sleep(1) + await asyncio.sleep(1) assert out[-1] == b'final message' stream._close_consumer() @flaky(max_runs=3, min_passes=1) -@gen_test(timeout=60) def test_to_kafka(): ARGS = {'bootstrap.servers': 'localhost:9092'} with kafka_service() as kafka: @@ -150,7 +149,7 @@ def test_to_kafka(): out = kafka.sink_to_list() for i in range(10): - yield source.emit(b'value-%d' % i) + source.emit(b'value-%d' % i) source.emit('final message') kafka.flush() @@ -158,35 +157,44 @@ def test_to_kafka(): assert out[-1] == b'final message' +@pytest.mark.asyncio @flaky(max_runs=3, min_passes=1) -@gen_test(timeout=60) -def test_from_kafka_thread(): +async def test_from_kafka_thread(): j = random.randint(0, 10000) ARGS = {'bootstrap.servers': 'localhost:9092', 'group.id': 'streamz-test%i' % j} + print(".") with kafka_service() as kafka: kafka, TOPIC = kafka - stream = Stream.from_kafka([TOPIC], ARGS) + stream = Stream.from_kafka([TOPIC], ARGS, asynchronous=True) + print(".") out = stream.sink_to_list() stream.start() - yield gen.sleep(1.1) + await asyncio.sleep(1.1) + print(".") for i in range(10): - yield gen.sleep(0.1) + await asyncio.sleep(0.1) kafka.produce(TOPIC, b'value-%d' % i) kafka.flush() + print(".") # it takes some time for messages to come back out of kafka - yield await_for(lambda: len(out) == 10, 10, period=0.1) + await await_for(lambda: len(out) == 10, 10, period=0.1) + print(".") assert out[-1] == b'value-9' kafka.produce(TOPIC, b'final message') kafka.flush() - yield await_for(lambda: out[-1] == b'final message', 10, period=0.1) + print(".") + await await_for(lambda: out[-1] == b'final message', 10, period=0.1) + print(".") stream._close_consumer() kafka.produce(TOPIC, b'lost message') kafka.flush() # absolute sleep here, since we expect output list *not* to change - yield gen.sleep(1) + print(".") + await asyncio.sleep(1) + print(".") assert out[-1] == b'final message' stream._close_consumer()