Skip to content

Commit

Permalink
warn users when predict/unbatch output length is not same as #requests (
Browse files Browse the repository at this point in the history
#408)

* add warning

* fix

* fix

* update error msg

* clean up

* bump version

* dev3
  • Loading branch information
aniketmaurya authored Jan 10, 2025
1 parent 43692d4 commit 6f1d3df
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 61 deletions.
2 changes: 1 addition & 1 deletion src/litserve/__about__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.2.6.dev2"
__version__ = "0.2.6.dev3"
__author__ = "Lightning-AI et al."
__author_email__ = "[email protected]"
__license__ = "Apache-2.0"
Expand Down
65 changes: 11 additions & 54 deletions src/litserve/loops/simple_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,72 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import logging
import sys
import time
from queue import Empty, Queue
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional

import zmq
from fastapi import HTTPException
from starlette.formparsers import MultiPartParser

from litserve import LitAPI
from litserve.callbacks import CallbackRunner, EventTypes
from litserve.loops.base import DefaultLoop
from litserve.loops.base import DefaultLoop, _inject_context, collate_requests
from litserve.specs.base import LitSpec
from litserve.utils import LitAPIStatus, PickleableHTTPException

logger = logging.getLogger(__name__)
# FastAPI writes form files to disk over 1MB by default, which prevents serialization by multiprocessing
MultiPartParser.max_file_size = sys.maxsize


def _inject_context(context: Union[List[dict], dict], func, *args, **kwargs):
sig = inspect.signature(func)
if "context" in sig.parameters:
return func(*args, **kwargs, context=context)
return func(*args, **kwargs)


def collate_requests(
lit_api: LitAPI, request_queue: Queue, max_batch_size: int, batch_timeout: float
) -> Tuple[List, List]:
payloads = []
timed_out_uids = []
entered_at = time.monotonic()
end_time = entered_at + batch_timeout
apply_timeout = lit_api.request_timeout not in (-1, False)

if batch_timeout == 0:
while len(payloads) < max_batch_size:
try:
response_queue_id, uid, timestamp, x_enc = request_queue.get_nowait()
if apply_timeout and time.monotonic() - timestamp > lit_api.request_timeout:
timed_out_uids.append((response_queue_id, uid))
else:
payloads.append((response_queue_id, uid, x_enc))
except Empty:
break
return payloads, timed_out_uids

while time.monotonic() < end_time and len(payloads) < max_batch_size:
remaining_time = end_time - time.monotonic()
if remaining_time <= 0:
break

try:
response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=min(remaining_time, 0.001))
if apply_timeout and time.monotonic() - timestamp > lit_api.request_timeout:
timed_out_uids.append((response_queue_id, uid))
else:
payloads.append((response_queue_id, uid, x_enc))

except Empty:
continue

return payloads, timed_out_uids


def run_single_loop(
Expand Down Expand Up @@ -199,8 +148,9 @@ def run_batched_loop(
continue
logger.debug(f"{len(batches)} batched requests received")
response_queue_ids, uids, inputs = zip(*batches)
num_inputs = len(inputs)
try:
contexts = [{}] * len(inputs)
contexts = [{}] * num_inputs
if hasattr(lit_spec, "populate_context"):
for input, context in zip(inputs, contexts):
lit_spec.populate_context(context, input)
Expand All @@ -224,6 +174,13 @@ def run_batched_loop(

outputs = lit_api.unbatch(y)

if len(outputs) != num_inputs:
logger.error(
"LitAPI.predict/unbatch returned {len(outputs)} outputs, but expected {num_inputs}. "
"Please check the predict/unbatch method of the LitAPI implementation."
)
raise HTTPException(500, "Batch size mismatch")

callback_runner.trigger_event(EventTypes.BEFORE_ENCODE_RESPONSE, lit_api=lit_api)
y_enc_list = []
for response_queue_id, y, uid, context in zip(response_queue_ids, outputs, uids, contexts):
Expand Down
15 changes: 9 additions & 6 deletions src/litserve/loops/streaming_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,9 @@ def run_batched_streaming_loop(
if not batches:
continue
response_queue_ids, uids, inputs = zip(*batches)
num_inputs = len(inputs)
try:
contexts = [{}] * len(inputs)
contexts = [{}] * num_inputs
if hasattr(lit_spec, "populate_context"):
for input, context in zip(inputs, contexts):
lit_spec.populate_context(context, input)
Expand Down Expand Up @@ -196,10 +197,11 @@ def run_batched_streaming_loop(
if socket:
socket.send_pyobj((uid, (PickleableHTTPException.from_exception(e), LitAPIStatus.ERROR)))
else:
response_queues[response_queue_id].put((
uid,
(PickleableHTTPException.from_exception(e), LitAPIStatus.ERROR),
))
for response_queue_id, uid in zip(response_queue_ids, uids):
response_queues[response_queue_id].put((
uid,
(PickleableHTTPException.from_exception(e), LitAPIStatus.ERROR),
))

except Exception as e:
logger.exception(
Expand All @@ -209,7 +211,8 @@ def run_batched_streaming_loop(
if socket:
socket.send_pyobj((uid, (e, LitAPIStatus.ERROR)))
else:
response_queues[response_queue_id].put((uid, (e, LitAPIStatus.ERROR)))
for response_queue_id, uid in zip(response_queue_ids, uids):
response_queues[response_queue_id].put((uid, (e, LitAPIStatus.ERROR)))


class StreamingLoop(DefaultLoop):
Expand Down
27 changes: 27 additions & 0 deletions tests/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,3 +226,30 @@ def test_collate_requests(batch_timeout, batch_size):
)
assert len(payloads) == batch_size, f"Should have {batch_size} payloads, got {len(payloads)}"
assert len(timed_out_uids) == 0, "No timed out uids"


class BatchSizeMismatchAPI(SimpleBatchLitAPI):
def predict(self, x):
assert len(x) == 2, "Expected two concurrent inputs to be batched"
return self.model(x) # returns a list of length same as len(x)

def unbatch(self, output):
return [output] # returns a list of length 1


@pytest.mark.asyncio
async def test_batch_size_mismatch():
api = BatchSizeMismatchAPI()
server = LitServer(api, accelerator="cpu", devices=1, timeout=10, max_batch_size=2, batch_timeout=4)

with wrap_litserve_start(server) as server:
async with LifespanManager(server.app) as manager, AsyncClient(
transport=ASGITransport(app=manager.app), base_url="http://test"
) as ac:
response1 = ac.post("/predict", json={"input": 4.0})
response2 = ac.post("/predict", json={"input": 5.0})
response1, response2 = await asyncio.gather(response1, response2)
assert response1.status_code == 500
assert response2.status_code == 500
assert response1.json() == {"detail": "Batch size mismatch"}, "unbatch a list of length 1 when batch size is 2"
assert response2.json() == {"detail": "Batch size mismatch"}, "unbatch a list of length 1 when batch size is 2"

0 comments on commit 6f1d3df

Please sign in to comment.