Skip to content

Commit

Permalink
improve tests (#418)
Browse files Browse the repository at this point in the history
* improve tests

* improve tests

* format

* fix

* split tests

* add license

* improve loop timeout tests

* bump version
  • Loading branch information
aniketmaurya authored Jan 17, 2025
1 parent cd39be5 commit 61d7596
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 111 deletions.
2 changes: 1 addition & 1 deletion src/litserve/__about__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.2.6"
__version__ = "0.2.7.dev0"
__author__ = "Lightning-AI et al."
__author_email__ = "[email protected]"
__license__ = "Apache-2.0"
Expand Down
13 changes: 13 additions & 0 deletions src/litserve/zmq_queue.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import multiprocessing
Expand Down
62 changes: 40 additions & 22 deletions tests/test_lit_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,39 +228,58 @@ def test_start_server(mock_uvicon):
assert server.lit_spec.response_queue_id is not None, "response_queue_id must be generated"


@pytest.fixture
def server_for_api_worker_test(simple_litapi):
server = ls.LitServer(simple_litapi, devices=1)
server.verify_worker_status = MagicMock()
server.launch_inference_worker = MagicMock(return_value=[MagicMock(), [MagicMock()]])
server._start_server = MagicMock()
return server


@pytest.mark.skipif(sys.platform == "win32", reason="Test is only for Unix")
@patch("litserve.server.uvicorn")
def test_server_run_with_api_server_worker_type(mock_uvicorn):
api = ls.test_examples.SimpleLitAPI()
server = ls.LitServer(api, devices=1)
server.verify_worker_status = MagicMock()
with pytest.raises(ValueError, match=r"Must be 'process' or 'thread'"):
server.run(api_server_worker_type="invalid")
def test_server_run_with_api_server_worker_type(mock_uvicorn, server_for_api_worker_test):
server = server_for_api_worker_test

with pytest.raises(ValueError, match=r"must be greater than 0"):
server.run(num_api_servers=0)
server.run(api_server_worker_type="process", num_api_servers=10)
server.launch_inference_worker.assert_called_with(10)

server.launch_inference_worker = MagicMock(return_value=[MagicMock(), [MagicMock()]])
server._start_server = MagicMock()

# Running the method to test
server.run(api_server_worker_type=None)
server.launch_inference_worker.assert_called_with(1)
@pytest.mark.skipif(sys.platform == "win32", reason="Test is only for Unix")
@pytest.mark.parametrize(("api_server_worker_type", "num_api_workers"), [(None, 1), ("process", 1)])
@patch("litserve.server.uvicorn")
def test_server_run_with_process_api_worker(
mock_uvicorn, api_server_worker_type, num_api_workers, server_for_api_worker_test
):
server = server_for_api_worker_test

server.run(api_server_worker_type=api_server_worker_type, num_api_workers=num_api_workers)
server.launch_inference_worker.assert_called_with(num_api_workers)
actual = server._start_server.call_args
assert actual[0][4] == "process", "Server should run in process mode"
mock_uvicorn.Config.assert_called()


@pytest.mark.skipif(sys.platform == "win32", reason="Test is only for Unix")
@patch("litserve.server.uvicorn")
def test_server_run_with_thread_api_worker(mock_uvicorn, server_for_api_worker_test):
server = server_for_api_worker_test
server.run(api_server_worker_type="thread")
server.launch_inference_worker.assert_called_with(1)
actual = server._start_server.call_args
assert actual[0][4] == "thread", "Server should run in thread mode"
assert server._start_server.call_args[0][4] == "thread", "Server should run in thread mode"
mock_uvicorn.Config.assert_called()

server.run(api_server_worker_type="process")
server.launch_inference_worker.assert_called_with(1)
actual = server._start_server.call_args
assert actual[0][4] == "process", "Server should run in process mode"

server.run(api_server_worker_type="process", num_api_servers=10)
server.launch_inference_worker.assert_called_with(10)
@pytest.mark.skipif(sys.platform == "win32", reason="Test is only for Unix")
def test_server_run_with_invalid_api_worker(simple_litapi):
server = ls.LitServer(simple_litapi, devices=1)
server.verify_worker_status = MagicMock()
with pytest.raises(ValueError, match=r"Must be 'process' or 'thread'"):
server.run(api_server_worker_type="invalid")

with pytest.raises(ValueError, match=r"must be greater than 0"):
server.run(num_api_servers=0)


@pytest.mark.skipif(sys.platform != "win32", reason="Test is only for Windows")
Expand All @@ -272,7 +291,6 @@ def test_server_run_windows(mock_uvicorn):
server.launch_inference_worker = MagicMock(return_value=[MagicMock(), [MagicMock()]])
server._start_server = MagicMock()

# Running the method to test
server.run(api_server_worker_type=None)
actual = server._start_server.call_args
assert actual[0][4] == "thread", "Windows only supports thread mode"
Expand Down
73 changes: 39 additions & 34 deletions tests/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,99 +247,104 @@ def test_run_single_loop():
def test_run_single_loop_timeout():
stream = io.StringIO()
ls.configure_logging(stream=stream)

lit_api = ls.test_examples.SimpleLitAPI()
lit_api.setup(None)
lit_api.request_timeout = 0.0001

request_queue = Queue()
request = (0, "UUID-001", time.monotonic(), {"input": 4.0})
time.sleep(0.1)
request_queue.put(request)
response_queues = [Queue()]
old_request = (0, "UUID-001", time.monotonic(), {"input": 4.0})
time.sleep(0.1) # Age the request
request_queue.put(old_request)

# Run the loop in a separate thread to allow it to be stopped
lit_loop = SingleLoop()
loop_thread = threading.Thread(
target=lit_loop.run_single_loop, args=(lit_api, None, request_queue, response_queues, NOOP_CB_RUNNER)
)
loop_thread.start()

response_queue = response_queues[0]
_, (response, status) = response_queue.get()
assert isinstance(response, HTTPException)
assert response.status_code == 504
assert "Request UUID-001 was waiting in the queue for too long" in stream.getvalue()

request_queue.put((None, None, None, None))
loop_thread.join()
assert "Request UUID-001 was waiting in the queue for too long" in stream.getvalue()
assert isinstance(response_queues[0].get()[1][0], HTTPException), "Timeout should return an HTTPException"


def test_run_batched_loop():
lit_api = ls.test_examples.SimpleBatchedAPI()
lit_api.setup(None)
lit_api.pre_setup(2, None)
assert lit_api.model is not None, "Setup must initialize the model"
lit_api.request_timeout = 1

request_queue = Queue()
# response_queue_id, uid, timestamp, x_enc
request_queue.put((0, "UUID-001", time.monotonic(), {"input": 4.0}))
request_queue.put((0, "UUID-002", time.monotonic(), {"input": 5.0}))
response_queues = [Queue()]

# Run the loop in a separate thread to allow it to be stopped
requests = [(0, "UUID-001", time.monotonic(), {"input": 4.0}), (0, "UUID-002", time.monotonic(), {"input": 5.0})]
for req in requests:
request_queue.put(req)

lit_loop = BatchedLoop()
loop_thread = threading.Thread(
target=lit_loop.run_batched_loop,
args=(lit_api, None, request_queue, response_queues, 2, 1, NOOP_CB_RUNNER),
)
loop_thread.start()

# Allow some time for the loop to process
time.sleep(1)
expected_responses = [
("UUID-001", ({"output": 16.0}, LitAPIStatus.OK)),
("UUID-002", ({"output": 25.0}, LitAPIStatus.OK)),
]

for expected in expected_responses:
actual = response_queues[0].get(timeout=10)
assert actual == expected, f"Expected {expected}, got {actual}"

# Stop the loop by putting a sentinel value in the queue
request_queue.put((None, None, None, None))
loop_thread.join()

response_1 = response_queues[0].get(timeout=10)
response_2 = response_queues[0].get(timeout=10)
assert response_1 == ("UUID-001", ({"output": 16.0}, LitAPIStatus.OK))
assert response_2 == ("UUID-002", ({"output": 25.0}, LitAPIStatus.OK))


def test_run_batched_loop_timeout():
stream = io.StringIO()
ls.configure_logging(stream=stream)

lit_api = ls.test_examples.SimpleBatchedAPI()
lit_api.setup(None)
lit_api.pre_setup(2, None)
assert lit_api.model is not None, "Setup must initialize the model"
lit_api.request_timeout = 0.1

request_queue = Queue()
# response_queue_id, uid, timestamp, x_enc
r1 = (0, "UUID-001", time.monotonic(), {"input": 4.0})
time.sleep(0.1)
request_queue.put(r1)
r2 = (0, "UUID-002", time.monotonic(), {"input": 5.0})
request_queue.put(r2)
response_queues = [Queue()]

# Run the loop in a separate thread to allow it to be stopped
# First request will time out, second will succeed
requests = [
(0, "UUID-001", time.monotonic() - 0.2, {"input": 4.0}), # Old request
(0, "UUID-002", time.monotonic(), {"input": 5.0}), # Fresh request
]
for req in requests:
request_queue.put(req)

lit_loop = BatchedLoop()
loop_thread = threading.Thread(
target=lit_loop.run_batched_loop,
args=(lit_api, None, request_queue, response_queues, 2, 0.001, NOOP_CB_RUNNER),
)
loop_thread.start()

# Allow some time for the loop to process
time.sleep(1)
response_queue = response_queues[0]

# First response should be timeout error
_, (response1, _) = response_queue.get(timeout=10)
assert isinstance(response1, HTTPException)
assert "Request UUID-001 was waiting in the queue for too long" in stream.getvalue()
resp1 = response_queues[0].get(timeout=10)[1]
resp2 = response_queues[0].get(timeout=10)[1]
assert isinstance(resp1[0], HTTPException), "First request was timed out"
assert resp2[0] == {"output": 25.0}, "Second request wasn't timed out"

# Stop the loop by putting a sentinel value in the queue
# Second response should succeed
_, (response2, _) = response_queue.get(timeout=10)
assert response2 == {"output": 25.0}

request_queue.put((None, None, None, None))
loop_thread.join()

Expand Down
13 changes: 13 additions & 0 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import io
import os
Expand Down
92 changes: 38 additions & 54 deletions tests/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import asyncio
import time
from concurrent.futures import ThreadPoolExecutor
from contextlib import ExitStack

import numpy as np
import pytest
Expand Down Expand Up @@ -159,14 +160,6 @@ async def test_timeout(use_zmq):
async with LifespanManager(server.app) as manager, AsyncClient(
transport=ASGITransport(app=manager.app), base_url="http://test"
) as ac:
# Poll until the server is ready
for _ in range(10): # retry 10 times
try:
await ac.get("/health")
break
except Exception:
await asyncio.sleep(0.2)

response1 = asyncio.create_task(ac.post("/predict", json={"input": 4.0}))
await asyncio.sleep(0.0001)
response2 = asyncio.create_task(ac.post("/predict", json={"input": 5.0}))
Expand All @@ -182,8 +175,7 @@ async def test_timeout(use_zmq):
@pytest.mark.flaky(retries=3)
@pytest.mark.parametrize("use_zmq", [True, False])
@pytest.mark.asyncio
async def test_batch_timeout(use_zmq):
# Scenario: first 2 requests finish as a batch and third request times out in queue
async def test_batch_timeout_with_concurrent_requests(use_zmq):
server = LitServer(
SlowBatchAPI(),
accelerator="cpu",
Expand All @@ -196,58 +188,50 @@ async def test_batch_timeout(use_zmq):
async with LifespanManager(server.app) as manager, AsyncClient(
transport=ASGITransport(app=manager.app), base_url="http://test"
) as ac:
# wait for the server to be ready
for _ in range(10):
try:
await ac.get("/health")
break
except Exception:
await asyncio.sleep(0.2)

response1 = asyncio.create_task(ac.post("/predict", json={"input": 4.0}))
response2 = asyncio.create_task(ac.post("/predict", json={"input": 5.0}))
await asyncio.sleep(0.0001)
response3 = asyncio.create_task(ac.post("/predict", json={"input": 6.0}))
responses = await asyncio.gather(response1, response2, response3, return_exceptions=True)
assert responses[0].status_code == 200, (
"Batch: First request should complete since it's popped from the request queue."
)
assert responses[1].status_code == 200, (
"Batch: Second request should complete since it's popped from the request queue."
)
assert responses[2].status_code == 504, "Batch: Third request was delayed and should fail"

server1 = LitServer(SlowLitAPI(), accelerator="cpu", devices=1, timeout=-1)
server2 = LitServer(SlowLitAPI(), accelerator="cpu", devices=1, timeout=False)
server3 = LitServer(
SlowBatchAPI(),
accelerator="cpu",
devices=1,
timeout=False,
max_batch_size=2,
batch_timeout=2,
fast_queue=use_zmq,
)
server4 = LitServer(
SlowBatchAPI(), accelerator="cpu", devices=1, timeout=-1, max_batch_size=2, batch_timeout=2, fast_queue=use_zmq
)
assert responses[0].status_code == 200, "First request in batch should complete"
assert responses[1].status_code == 200, "Second request in batch should complete"
assert responses[2].status_code == 504, "Third request should timeout"

with wrap_litserve_start(server1) as server1, wrap_litserve_start(server2) as server2, wrap_litserve_start(
server3
) as server3, wrap_litserve_start(server4) as server4, TestClient(server1.app) as client1, TestClient(
server2.app
) as client2, TestClient(server3.app) as client3, TestClient(server4.app) as client4:
response1 = client1.post("/predict", json={"input": 4.0})
assert response1.status_code == 200, "Expected slow server to respond since timeout was disabled"

response2 = client2.post("/predict", json={"input": 4.0})
assert response2.status_code == 200, "Expected slow server to respond since timeout was disabled"

response3 = client3.post("/predict", json={"input": 4.0})
assert response3.status_code == 200, "Expected slow batch server to respond since timeout was disabled"

response4 = client4.post("/predict", json={"input": 4.0})
assert response4.status_code == 200, "Expected slow batch server to respond since timeout was disabled"
@pytest.mark.parametrize("use_zmq", [True, False])
def test_server_with_disabled_timeout(use_zmq):
servers = [
LitServer(SlowLitAPI(), accelerator="cpu", devices=1, timeout=-1),
LitServer(SlowLitAPI(), accelerator="cpu", devices=1, timeout=False),
LitServer(
SlowBatchAPI(),
accelerator="cpu",
devices=1,
timeout=False,
max_batch_size=2,
batch_timeout=2,
fast_queue=use_zmq,
),
LitServer(
SlowBatchAPI(),
accelerator="cpu",
devices=1,
timeout=-1,
max_batch_size=2,
batch_timeout=2,
fast_queue=use_zmq,
),
]

with ExitStack() as stack:
clients = [
stack.enter_context(TestClient(stack.enter_context(wrap_litserve_start(server)).app)) for server in servers
]

for i, client in enumerate(clients, 1):
response = client.post("/predict", json={"input": 4.0})
assert response.status_code == 200, f"Server {i} should complete request with disabled timeout"


def test_concurrent_requests(lit_server):
Expand Down

0 comments on commit 61d7596

Please sign in to comment.