From ccc322a1aa468464e35514f086ba9d9706cebfce Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Tue, 11 Feb 2025 11:02:49 +0000 Subject: [PATCH] fix custom exceptions (#425) * always pickle exceptions * fix * fix * fix --- src/litserve/loops/base.py | 2 ++ src/litserve/loops/simple_loops.py | 7 +++---- src/litserve/loops/streaming_loops.py | 4 ++-- tests/test_loops.py | 2 +- tests/test_simple.py | 18 ++++++++++++++++++ 5 files changed, 26 insertions(+), 7 deletions(-) diff --git a/src/litserve/loops/base.py b/src/litserve/loops/base.py index cd747000..67172e86 100644 --- a/src/litserve/loops/base.py +++ b/src/litserve/loops/base.py @@ -14,6 +14,7 @@ import asyncio import inspect import logging +import pickle import signal import sys import time @@ -256,6 +257,7 @@ def put_response( def put_error_response( self, response_queues: List[Queue], response_queue_id: int, uid: str, error: Exception ) -> None: + error = pickle.dumps(error) self.put_response(response_queues, response_queue_id, uid, error, LitAPIStatus.ERROR) def __del__(self): diff --git a/src/litserve/loops/simple_loops.py b/src/litserve/loops/simple_loops.py index 9419b463..5c0cdea5 100644 --- a/src/litserve/loops/simple_loops.py +++ b/src/litserve/loops/simple_loops.py @@ -109,12 +109,11 @@ def run_single_loop( "Please check the error trace for more details.", uid, ) - self.put_response( + self.put_error_response( response_queues=response_queues, response_queue_id=response_queue_id, uid=uid, - response_data=e, - status=LitAPIStatus.ERROR, + error=e, ) def __call__( @@ -226,7 +225,7 @@ def run_batched_loop( "Please check the error trace for more details." ) for response_queue_id, uid in zip(response_queue_ids, uids): - self.put_response(response_queues, response_queue_id, uid, e, LitAPIStatus.ERROR) + self.put_error_response(response_queues, response_queue_id, uid, e) def __call__( self, diff --git a/src/litserve/loops/streaming_loops.py b/src/litserve/loops/streaming_loops.py index 52997e0b..767cef0d 100644 --- a/src/litserve/loops/streaming_loops.py +++ b/src/litserve/loops/streaming_loops.py @@ -102,7 +102,7 @@ def run_streaming_loop( "Please check the error trace for more details.", uid, ) - self.put_response(response_queues, response_queue_id, uid, e, LitAPIStatus.ERROR) + self.put_error_response(response_queues, response_queue_id, uid, e) def __call__( self, @@ -207,7 +207,7 @@ def run_batched_streaming_loop( "Please check the error trace for more details." ) for response_queue_id, uid in zip(response_queue_ids, uids): - self.put_response(response_queues, response_queue_id, uid, e, LitAPIStatus.ERROR) + self.put_error_response(response_queues, response_queue_id, uid, e) def __call__( self, diff --git a/tests/test_loops.py b/tests/test_loops.py index 3376dd77..d24b44ac 100644 --- a/tests/test_loops.py +++ b/tests/test_loops.py @@ -132,7 +132,7 @@ def put(self, item, block=False, timeout=None): response, status = args if status == LitAPIStatus.FINISH_STREAMING: raise StopIteration("interrupt iteration") - if status == LitAPIStatus.ERROR and isinstance(response, StopIteration): + if status == LitAPIStatus.ERROR: assert self.count // 2 == self.num_streamed_outputs, ( f"Loop count must have incremented for {self.num_streamed_outputs} times." ) diff --git a/tests/test_simple.py b/tests/test_simple.py index 432db1e6..c7d19060 100644 --- a/tests/test_simple.py +++ b/tests/test_simple.py @@ -244,3 +244,21 @@ def test_concurrent_requests(lit_server): assert response.json() == {"output": i**2}, "Server returns square of the input number" count += 1 assert count == n_requests + + +class CustomError(Exception): + def __init__(self, arg1, arg2, arg3): + super().__init__("Test exception") + + +class ExceptionAPI(SimpleLitAPI): + def predict(self, x): + raise CustomError("This", "is", "a test") + + +def test_exception(): + server = LitServer(ExceptionAPI(), accelerator="cpu", devices=1) + with wrap_litserve_start(server) as server, TestClient(server.app) as client: + response = client.post("/predict", json={"input": 4.0}) + assert response.status_code == 500 + assert response.json() == {"detail": "Internal Server Error"}