Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Stream node constructor for sub-classing #442 #445

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions streamz/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,22 @@ def __str__(self):

class APIRegisterMixin(object):

def _new_node(self, cls, args, kwargs):
""" Constructor for downstream nodes.

Examples
--------
To provide inheritance through nodes :

>>> class MyStream(Stream):
>>>
>>> def _new_node(self, cls, args, kwargs):
>>> if not issubclass(cls, MyStream):
>>> cls = type(cls.__name__, (cls, MyStream), dict(cls.__dict__))
>>> return cls(*args, **kwargs)
"""
return cls(*args, **kwargs)

@classmethod
def register_api(cls, modifier=identity, attribute_name=None):
""" Add callable to Stream API
Expand Down Expand Up @@ -158,6 +174,10 @@ def register_api(cls, modifier=identity, attribute_name=None):
def _(func):
@functools.wraps(func)
def wrapped(*args, **kwargs):
if identity is not staticmethod and args:
self = args[0]
if isinstance(self, APIRegisterMixin):
return self._new_node(func, args, kwargs)
return func(*args, **kwargs)
name = attribute_name if attribute_name else func.__name__
setattr(cls, name, modifier(wrapped))
Expand Down
29 changes: 29 additions & 0 deletions streamz/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1367,6 +1367,35 @@ class foo(NewStream):
assert not hasattr(Stream(), 'foo')


def test_subclass_node():

def add(x) : return x + 1

class MyStream(Stream):
def _new_node(self, cls, args, kwargs):
if not issubclass(cls, MyStream):
cls = type(cls.__name__, (cls, MyStream), dict(cls.__dict__))
return cls(*args, **kwargs)

@MyStream.register_api()
class foo(sz.sinks.sink):
pass

stream = MyStream()
martindurant marked this conversation as resolved.
Show resolved Hide resolved
lst = list()

node = stream.map(add)
assert isinstance(node, sz.core.map)
assert isinstance(node, MyStream)

node = node.foo(lst.append)
assert isinstance(node, sz.sinks.sink)
assert isinstance(node, MyStream)

stream.emit(100)
assert lst == [ 101 ]


@gen_test()
def test_latest():
source = Stream(asynchronous=True)
Expand Down
1 change: 1 addition & 0 deletions streamz/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def test_buffer_sync(loop): # noqa: F811


@pytest.mark.xfail(reason='')
@pytest.mark.asyncio
async def test_stream_shares_client_loop(loop): # noqa: F811
with cluster() as (s, [a, b]):
with Client(s['address'], loop=loop) as client: # noqa: F841
Expand Down
42 changes: 25 additions & 17 deletions streamz/tests/test_kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,9 @@ def split(messages):
return parsed


@pytest.mark.asyncio
@flaky(max_runs=3, min_passes=1)
@gen_test(timeout=60)
def test_from_kafka():
async def test_from_kafka():
j = random.randint(0, 10000)
ARGS = {'bootstrap.servers': 'localhost:9092',
'group.id': 'streamz-test%i' % j}
Expand All @@ -117,30 +117,29 @@ def test_from_kafka():
stream = Stream.from_kafka([TOPIC], ARGS, asynchronous=True)
out = stream.sink_to_list()
stream.start()
yield gen.sleep(1.1) # for loop to run
await asyncio.sleep(1.1) # for loop to run
for i in range(10):
yield gen.sleep(0.1) # small pause ensures correct ordering
await asyncio.sleep(0.1) # small pause ensures correct ordering
kafka.produce(TOPIC, b'value-%d' % i)
kafka.flush()
# it takes some time for messages to come back out of kafka
wait_for(lambda: len(out) == 10, 10, period=0.1)
await await_for(lambda: len(out) == 10, 10, period=0.1)
assert out[-1] == b'value-9'

kafka.produce(TOPIC, b'final message')
kafka.flush()
wait_for(lambda: out[-1] == b'final message', 10, period=0.1)
await await_for(lambda: out[-1] == b'final message', 10, period=0.1)

stream._close_consumer()
kafka.produce(TOPIC, b'lost message')
kafka.flush()
# absolute sleep here, since we expect output list *not* to change
yield gen.sleep(1)
await asyncio.sleep(1)
assert out[-1] == b'final message'
stream._close_consumer()


@flaky(max_runs=3, min_passes=1)
@gen_test(timeout=60)
def test_to_kafka():
ARGS = {'bootstrap.servers': 'localhost:9092'}
with kafka_service() as kafka:
Expand All @@ -150,43 +149,52 @@ def test_to_kafka():
out = kafka.sink_to_list()

for i in range(10):
yield source.emit(b'value-%d' % i)
source.emit(b'value-%d' % i)

source.emit('final message')
kafka.flush()
wait_for(lambda: len(out) == 11, 10, period=0.1)
assert out[-1] == b'final message'


@pytest.mark.asyncio
@flaky(max_runs=3, min_passes=1)
@gen_test(timeout=60)
def test_from_kafka_thread():
async def test_from_kafka_thread():
j = random.randint(0, 10000)
ARGS = {'bootstrap.servers': 'localhost:9092',
'group.id': 'streamz-test%i' % j}
print(".")
with kafka_service() as kafka:
kafka, TOPIC = kafka
stream = Stream.from_kafka([TOPIC], ARGS)
stream = Stream.from_kafka([TOPIC], ARGS, asynchronous=True)
print(".")
out = stream.sink_to_list()
stream.start()
yield gen.sleep(1.1)
await asyncio.sleep(1.1)
print(".")
for i in range(10):
yield gen.sleep(0.1)
await asyncio.sleep(0.1)
kafka.produce(TOPIC, b'value-%d' % i)
kafka.flush()
print(".")
# it takes some time for messages to come back out of kafka
yield await_for(lambda: len(out) == 10, 10, period=0.1)
await await_for(lambda: len(out) == 10, 10, period=0.1)
print(".")

assert out[-1] == b'value-9'
kafka.produce(TOPIC, b'final message')
kafka.flush()
yield await_for(lambda: out[-1] == b'final message', 10, period=0.1)
print(".")
await await_for(lambda: out[-1] == b'final message', 10, period=0.1)
print(".")

stream._close_consumer()
kafka.produce(TOPIC, b'lost message')
kafka.flush()
# absolute sleep here, since we expect output list *not* to change
yield gen.sleep(1)
print(".")
await asyncio.sleep(1)
print(".")
assert out[-1] == b'final message'
stream._close_consumer()

Expand Down