Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
aniketmaurya committed Jan 9, 2025
1 parent b41042f commit 74771fc
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 26 deletions.
2 changes: 1 addition & 1 deletion src/litserve/loops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def __call__(
stream: bool,
workers_setup_status: Dict[int, str],
callback_runner: CallbackRunner,
socket: zmq.Socket,
socket: Optional[zmq.Socket],
):
self.socket = socket
if asyncio.iscoroutinefunction(self.run):
Expand Down
8 changes: 7 additions & 1 deletion src/litserve/loops/simple_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,13 @@ def run_single_loop(
"has been timed out. "
"You can adjust the timeout by providing the `timeout` argument to LitServe(..., timeout=30)."
)
response_queues[response_queue_id].put((uid, (HTTPException(504, "Request timed out"), LitAPIStatus.ERROR)))
if socket:
socket.send_pyobj((uid, (HTTPException(504, "Request timed out"), LitAPIStatus.ERROR)))
else:
response_queues[response_queue_id].put((
uid,
(HTTPException(504, "Request timed out"), LitAPIStatus.ERROR),
))
continue
try:
context = {}
Expand Down
35 changes: 27 additions & 8 deletions src/litserve/loops/streaming_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def run_batched_streaming_loop(
max_batch_size: int,
batch_timeout: float,
callback_runner: CallbackRunner,
socket: Optional[zmq.Socket],
):
while True:
batches, timed_out_uids = collate_requests(
Expand All @@ -136,7 +137,13 @@ def run_batched_streaming_loop(
"has been timed out. "
"You can adjust the timeout by providing the `timeout` argument to LitServe(..., timeout=30)."
)
response_queues[response_queue_id].put((uid, (HTTPException(504, "Request timed out"), LitAPIStatus.ERROR)))
if socket:
socket.send_pyobj((uid, (HTTPException(504, "Request timed out"), LitAPIStatus.ERROR)))
else:
response_queues[response_queue_id].put((
uid,
(HTTPException(504, "Request timed out"), LitAPIStatus.ERROR),
))

if not batches:
continue
Expand Down Expand Up @@ -174,23 +181,35 @@ def run_batched_streaming_loop(
for y_batch in y_enc_iter:
for response_queue_id, y_enc, uid in zip(response_queue_ids, y_batch, uids):
y_enc = lit_api.format_encoded_response(y_enc)
response_queues[response_queue_id].put((uid, (y_enc, LitAPIStatus.OK)))
if socket:
socket.send_pyobj((uid, (y_enc, LitAPIStatus.OK)))
else:
response_queues[response_queue_id].put((uid, (y_enc, LitAPIStatus.OK)))

for response_queue_id, uid in zip(response_queue_ids, uids):
response_queues[response_queue_id].put((uid, ("", LitAPIStatus.FINISH_STREAMING)))
if socket:
socket.send_pyobj((uid, ("", LitAPIStatus.FINISH_STREAMING)))
else:
response_queues[response_queue_id].put((uid, ("", LitAPIStatus.FINISH_STREAMING)))

except HTTPException as e:
response_queues[response_queue_id].put((
uid,
(PickleableHTTPException.from_exception(e), LitAPIStatus.ERROR),
))
if socket:
socket.send_pyobj((uid, (PickleableHTTPException.from_exception(e), LitAPIStatus.ERROR)))
else:
response_queues[response_queue_id].put((
uid,
(PickleableHTTPException.from_exception(e), LitAPIStatus.ERROR),
))

except Exception as e:
logger.exception(
"LitAPI ran into an error while processing the streaming batched request.\n"
"Please check the error trace for more details."
)
response_queues[response_queue_id].put((uid, (e, LitAPIStatus.ERROR)))
if socket:
socket.send_pyobj((uid, (e, LitAPIStatus.ERROR)))
else:
response_queues[response_queue_id].put((uid, (e, LitAPIStatus.ERROR)))


class StreamingLoop(DefaultLoop):
Expand Down
1 change: 1 addition & 0 deletions tests/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def test_batched_loop():
max_batch_size=2,
batch_timeout=4,
callback_runner=NOOP_CB_RUNNER,
socket=None,
)

lit_api_mock.batch.assert_called_once()
Expand Down
69 changes: 53 additions & 16 deletions tests/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@
from unittest.mock import MagicMock, patch

import pytest
import zmq
from fastapi import HTTPException
from fastapi.testclient import TestClient

import litserve as ls
from litserve import LitAPI
from litserve.callbacks import CallbackRunner
from litserve.loops import LitLoop, Output, _BaseLoop, inference_worker
from litserve.loops import LitLoop, Output, inference_worker
from litserve.loops.base import DefaultLoop
from litserve.loops.continuous_batching_loop import (
ContinuousBatchingLoop,
Expand Down Expand Up @@ -69,7 +70,9 @@ def test_single_loop(loop_args):
response_queues = [FakeResponseQueue()]

with pytest.raises(StopIteration, match="exit loop"):
run_single_loop(lit_api_mock, None, requests_queue, response_queues, callback_runner=NOOP_CB_RUNNER)
run_single_loop(
lit_api_mock, None, requests_queue, response_queues, callback_runner=NOOP_CB_RUNNER, socket=None
)


class FakeStreamResponseQueue:
Expand Down Expand Up @@ -111,7 +114,12 @@ def fake_encode(output):

with pytest.raises(StopIteration, match="exit loop"):
run_streaming_loop(
fake_stream_api, fake_stream_api, requests_queue, response_queues, callback_runner=NOOP_CB_RUNNER
fake_stream_api,
fake_stream_api,
requests_queue,
response_queues,
callback_runner=NOOP_CB_RUNNER,
socket=None,
)

fake_stream_api.predict.assert_called_once_with("Hello")
Expand Down Expand Up @@ -177,6 +185,7 @@ def fake_encode(output_iter):
max_batch_size=2,
batch_timeout=2,
callback_runner=NOOP_CB_RUNNER,
socket=None,
)
fake_stream_api.predict.assert_called_once_with(["Hello", "World"])
fake_stream_api.encode_response.assert_called_once()
Expand All @@ -193,6 +202,7 @@ def test_inference_worker(mock_single_loop, mock_batched_loop):
workers_setup_status={},
callback_runner=NOOP_CB_RUNNER,
loop="auto",
use_zmq=False,
)
mock_batched_loop.assert_called_once()

Expand All @@ -204,6 +214,7 @@ def test_inference_worker(mock_single_loop, mock_batched_loop):
workers_setup_status={},
callback_runner=NOOP_CB_RUNNER,
loop="auto",
use_zmq=False,
)
mock_single_loop.assert_called_once()

Expand All @@ -219,7 +230,7 @@ def test_run_single_loop():

# Run the loop in a separate thread to allow it to be stopped
loop_thread = threading.Thread(
target=run_single_loop, args=(lit_api, None, request_queue, response_queues, NOOP_CB_RUNNER)
target=run_single_loop, args=(lit_api, None, request_queue, response_queues, NOOP_CB_RUNNER, None)
)
loop_thread.start()

Expand Down Expand Up @@ -249,7 +260,7 @@ def test_run_single_loop_timeout():

# Run the loop in a separate thread to allow it to be stopped
loop_thread = threading.Thread(
target=run_single_loop, args=(lit_api, None, request_queue, response_queues, NOOP_CB_RUNNER)
target=run_single_loop, args=(lit_api, None, request_queue, response_queues, NOOP_CB_RUNNER, None)
)
loop_thread.start()

Expand All @@ -274,7 +285,7 @@ def test_run_batched_loop():

# Run the loop in a separate thread to allow it to be stopped
loop_thread = threading.Thread(
target=run_batched_loop, args=(lit_api, None, request_queue, response_queues, 2, 1, NOOP_CB_RUNNER)
target=run_batched_loop, args=(lit_api, None, request_queue, response_queues, 2, 1, NOOP_CB_RUNNER, None)
)
loop_thread.start()

Expand Down Expand Up @@ -311,7 +322,7 @@ def test_run_batched_loop_timeout():

# Run the loop in a separate thread to allow it to be stopped
loop_thread = threading.Thread(
target=run_batched_loop, args=(lit_api, None, request_queue, response_queues, 2, 0.001, NOOP_CB_RUNNER)
target=run_batched_loop, args=(lit_api, None, request_queue, response_queues, 2, 0.001, NOOP_CB_RUNNER, None)
)
loop_thread.start()

Expand Down Expand Up @@ -340,7 +351,7 @@ def test_run_streaming_loop():

# Run the loop in a separate thread to allow it to be stopped
loop_thread = threading.Thread(
target=run_streaming_loop, args=(lit_api, None, request_queue, response_queues, NOOP_CB_RUNNER)
target=run_streaming_loop, args=(lit_api, None, request_queue, response_queues, NOOP_CB_RUNNER, None)
)
loop_thread.start()

Expand Down Expand Up @@ -370,7 +381,7 @@ def test_run_streaming_loop_timeout():

# Run the loop in a separate thread to allow it to be stopped
loop_thread = threading.Thread(
target=run_streaming_loop, args=(lit_api, None, request_queue, response_queues, NOOP_CB_RUNNER)
target=run_streaming_loop, args=(lit_api, None, request_queue, response_queues, NOOP_CB_RUNNER, None)
)
loop_thread.start()

Expand Down Expand Up @@ -419,10 +430,36 @@ def off_test_run_batched_streaming_loop(openai_request_data):
assert response[0] == {"role": "assistant", "content": "10 + 6 is equal to 16."}


class TestLoop(_BaseLoop):
def __call__(self, *args, **kwargs):
class TestLoop(LitLoop):
def __call__(
self,
lit_api: LitAPI,
lit_spec: Optional[LitSpec],
device: str,
worker_id: int,
request_queue: Queue,
response_queues: List[Queue],
max_batch_size: int,
batch_timeout: float,
stream: bool,
workers_setup_status: Dict[int, str],
callback_runner: CallbackRunner,
socket: Optional[zmq.Socket],
):
try:
self.run(*args, **kwargs)
self.run(
lit_api,
lit_spec,
device,
worker_id,
request_queue,
response_queues,
max_batch_size,
batch_timeout,
stream,
workers_setup_status,
callback_runner,
)
except StopIteration as e:
return e

Expand Down Expand Up @@ -462,7 +499,7 @@ def test_custom_loop():
response_queues = [Queue()]
request_queue.put((0, "UUID-001", time.monotonic(), {"input": 4.0}))

loop(lit_api, None, "cpu", 0, request_queue, response_queues, 2, 1, False, {}, NOOP_CB_RUNNER)
loop(lit_api, None, "cpu", 0, request_queue, response_queues, 2, 1, False, {}, NOOP_CB_RUNNER, None)
response = response_queues[0].get()
assert response[0] == "UUID-001"
assert response[1][0] == {"output": 16.0}
Expand All @@ -477,11 +514,11 @@ def load_cache(self, x):

def test_loop_with_server():
loop = TestLoop()

lit_api = TestLitAPI()
server = ls.LitServer(lit_api, loop=loop, max_batch_size=1, batch_timeout=0.1)
server = ls.LitServer(lit_api, loop=loop)

with wrap_litserve_start(server) as server, TestClient(server.app) as client:
response = client.post("/predict", json={"input": 4.0})
response = client.post("/predict", json={"input": 4.0}, timeout=1)
assert response.json() == {"output": 1600.0} # use LitAPI.load_cache to multiply the input by 10


Expand Down

0 comments on commit 74771fc

Please sign in to comment.