From 6f1d3df35be9b6680a7e1d7c5f0238c40a8338b0 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Fri, 10 Jan 2025 14:38:35 +0000 Subject: [PATCH] warn users when predict/unbatch output length is not same as #requests (#408) * add warning * fix * fix * update error msg * clean up * bump version * dev3 --- src/litserve/__about__.py | 2 +- src/litserve/loops/simple_loops.py | 65 +++++---------------------- src/litserve/loops/streaming_loops.py | 15 ++++--- tests/test_batch.py | 27 +++++++++++ 4 files changed, 48 insertions(+), 61 deletions(-) diff --git a/src/litserve/__about__.py b/src/litserve/__about__.py index cdc5141a..a4b80811 100644 --- a/src/litserve/__about__.py +++ b/src/litserve/__about__.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.2.6.dev2" +__version__ = "0.2.6.dev3" __author__ = "Lightning-AI et al." __author_email__ = "community@lightning.ai" __license__ = "Apache-2.0" diff --git a/src/litserve/loops/simple_loops.py b/src/litserve/loops/simple_loops.py index d7677e88..54576d40 100644 --- a/src/litserve/loops/simple_loops.py +++ b/src/litserve/loops/simple_loops.py @@ -11,72 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import inspect import logging -import sys import time from queue import Empty, Queue -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional import zmq from fastapi import HTTPException -from starlette.formparsers import MultiPartParser from litserve import LitAPI from litserve.callbacks import CallbackRunner, EventTypes -from litserve.loops.base import DefaultLoop +from litserve.loops.base import DefaultLoop, _inject_context, collate_requests from litserve.specs.base import LitSpec from litserve.utils import LitAPIStatus, PickleableHTTPException logger = logging.getLogger(__name__) -# FastAPI writes form files to disk over 1MB by default, which prevents serialization by multiprocessing -MultiPartParser.max_file_size = sys.maxsize - - -def _inject_context(context: Union[List[dict], dict], func, *args, **kwargs): - sig = inspect.signature(func) - if "context" in sig.parameters: - return func(*args, **kwargs, context=context) - return func(*args, **kwargs) - - -def collate_requests( - lit_api: LitAPI, request_queue: Queue, max_batch_size: int, batch_timeout: float -) -> Tuple[List, List]: - payloads = [] - timed_out_uids = [] - entered_at = time.monotonic() - end_time = entered_at + batch_timeout - apply_timeout = lit_api.request_timeout not in (-1, False) - - if batch_timeout == 0: - while len(payloads) < max_batch_size: - try: - response_queue_id, uid, timestamp, x_enc = request_queue.get_nowait() - if apply_timeout and time.monotonic() - timestamp > lit_api.request_timeout: - timed_out_uids.append((response_queue_id, uid)) - else: - payloads.append((response_queue_id, uid, x_enc)) - except Empty: - break - return payloads, timed_out_uids - - while time.monotonic() < end_time and len(payloads) < max_batch_size: - remaining_time = end_time - time.monotonic() - if remaining_time <= 0: - break - - try: - response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=min(remaining_time, 0.001)) - if apply_timeout and time.monotonic() - timestamp > lit_api.request_timeout: - timed_out_uids.append((response_queue_id, uid)) - else: - payloads.append((response_queue_id, uid, x_enc)) - - except Empty: - continue - - return payloads, timed_out_uids def run_single_loop( @@ -199,8 +148,9 @@ def run_batched_loop( continue logger.debug(f"{len(batches)} batched requests received") response_queue_ids, uids, inputs = zip(*batches) + num_inputs = len(inputs) try: - contexts = [{}] * len(inputs) + contexts = [{}] * num_inputs if hasattr(lit_spec, "populate_context"): for input, context in zip(inputs, contexts): lit_spec.populate_context(context, input) @@ -224,6 +174,13 @@ def run_batched_loop( outputs = lit_api.unbatch(y) + if len(outputs) != num_inputs: + logger.error( + "LitAPI.predict/unbatch returned {len(outputs)} outputs, but expected {num_inputs}. " + "Please check the predict/unbatch method of the LitAPI implementation." + ) + raise HTTPException(500, "Batch size mismatch") + callback_runner.trigger_event(EventTypes.BEFORE_ENCODE_RESPONSE, lit_api=lit_api) y_enc_list = [] for response_queue_id, y, uid, context in zip(response_queue_ids, outputs, uids, contexts): diff --git a/src/litserve/loops/streaming_loops.py b/src/litserve/loops/streaming_loops.py index 62cabf09..8e4e05af 100644 --- a/src/litserve/loops/streaming_loops.py +++ b/src/litserve/loops/streaming_loops.py @@ -148,8 +148,9 @@ def run_batched_streaming_loop( if not batches: continue response_queue_ids, uids, inputs = zip(*batches) + num_inputs = len(inputs) try: - contexts = [{}] * len(inputs) + contexts = [{}] * num_inputs if hasattr(lit_spec, "populate_context"): for input, context in zip(inputs, contexts): lit_spec.populate_context(context, input) @@ -196,10 +197,11 @@ def run_batched_streaming_loop( if socket: socket.send_pyobj((uid, (PickleableHTTPException.from_exception(e), LitAPIStatus.ERROR))) else: - response_queues[response_queue_id].put(( - uid, - (PickleableHTTPException.from_exception(e), LitAPIStatus.ERROR), - )) + for response_queue_id, uid in zip(response_queue_ids, uids): + response_queues[response_queue_id].put(( + uid, + (PickleableHTTPException.from_exception(e), LitAPIStatus.ERROR), + )) except Exception as e: logger.exception( @@ -209,7 +211,8 @@ def run_batched_streaming_loop( if socket: socket.send_pyobj((uid, (e, LitAPIStatus.ERROR))) else: - response_queues[response_queue_id].put((uid, (e, LitAPIStatus.ERROR))) + for response_queue_id, uid in zip(response_queue_ids, uids): + response_queues[response_queue_id].put((uid, (e, LitAPIStatus.ERROR))) class StreamingLoop(DefaultLoop): diff --git a/tests/test_batch.py b/tests/test_batch.py index faad1ed1..7c576efa 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -226,3 +226,30 @@ def test_collate_requests(batch_timeout, batch_size): ) assert len(payloads) == batch_size, f"Should have {batch_size} payloads, got {len(payloads)}" assert len(timed_out_uids) == 0, "No timed out uids" + + +class BatchSizeMismatchAPI(SimpleBatchLitAPI): + def predict(self, x): + assert len(x) == 2, "Expected two concurrent inputs to be batched" + return self.model(x) # returns a list of length same as len(x) + + def unbatch(self, output): + return [output] # returns a list of length 1 + + +@pytest.mark.asyncio +async def test_batch_size_mismatch(): + api = BatchSizeMismatchAPI() + server = LitServer(api, accelerator="cpu", devices=1, timeout=10, max_batch_size=2, batch_timeout=4) + + with wrap_litserve_start(server) as server: + async with LifespanManager(server.app) as manager, AsyncClient( + transport=ASGITransport(app=manager.app), base_url="http://test" + ) as ac: + response1 = ac.post("/predict", json={"input": 4.0}) + response2 = ac.post("/predict", json={"input": 5.0}) + response1, response2 = await asyncio.gather(response1, response2) + assert response1.status_code == 500 + assert response2.status_code == 500 + assert response1.json() == {"detail": "Batch size mismatch"}, "unbatch a list of length 1 when batch size is 2" + assert response2.json() == {"detail": "Batch size mismatch"}, "unbatch a list of length 1 when batch size is 2"