Skip to content

Commit

Permalink
fix custom exceptions (#425)
Browse files Browse the repository at this point in the history
* always pickle exceptions

* fix

* fix

* fix
  • Loading branch information
aniketmaurya authored Feb 11, 2025
1 parent 38d7169 commit ccc322a
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 7 deletions.
2 changes: 2 additions & 0 deletions src/litserve/loops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import asyncio
import inspect
import logging
import pickle
import signal
import sys
import time
Expand Down Expand Up @@ -256,6 +257,7 @@ def put_response(
def put_error_response(
self, response_queues: List[Queue], response_queue_id: int, uid: str, error: Exception
) -> None:
error = pickle.dumps(error)
self.put_response(response_queues, response_queue_id, uid, error, LitAPIStatus.ERROR)

def __del__(self):
Expand Down
7 changes: 3 additions & 4 deletions src/litserve/loops/simple_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,11 @@ def run_single_loop(
"Please check the error trace for more details.",
uid,
)
self.put_response(
self.put_error_response(
response_queues=response_queues,
response_queue_id=response_queue_id,
uid=uid,
response_data=e,
status=LitAPIStatus.ERROR,
error=e,
)

def __call__(
Expand Down Expand Up @@ -226,7 +225,7 @@ def run_batched_loop(
"Please check the error trace for more details."
)
for response_queue_id, uid in zip(response_queue_ids, uids):
self.put_response(response_queues, response_queue_id, uid, e, LitAPIStatus.ERROR)
self.put_error_response(response_queues, response_queue_id, uid, e)

def __call__(
self,
Expand Down
4 changes: 2 additions & 2 deletions src/litserve/loops/streaming_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def run_streaming_loop(
"Please check the error trace for more details.",
uid,
)
self.put_response(response_queues, response_queue_id, uid, e, LitAPIStatus.ERROR)
self.put_error_response(response_queues, response_queue_id, uid, e)

def __call__(
self,
Expand Down Expand Up @@ -207,7 +207,7 @@ def run_batched_streaming_loop(
"Please check the error trace for more details."
)
for response_queue_id, uid in zip(response_queue_ids, uids):
self.put_response(response_queues, response_queue_id, uid, e, LitAPIStatus.ERROR)
self.put_error_response(response_queues, response_queue_id, uid, e)

def __call__(
self,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def put(self, item, block=False, timeout=None):
response, status = args
if status == LitAPIStatus.FINISH_STREAMING:
raise StopIteration("interrupt iteration")
if status == LitAPIStatus.ERROR and isinstance(response, StopIteration):
if status == LitAPIStatus.ERROR:
assert self.count // 2 == self.num_streamed_outputs, (
f"Loop count must have incremented for {self.num_streamed_outputs} times."
)
Expand Down
18 changes: 18 additions & 0 deletions tests/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,3 +244,21 @@ def test_concurrent_requests(lit_server):
assert response.json() == {"output": i**2}, "Server returns square of the input number"
count += 1
assert count == n_requests


class CustomError(Exception):
def __init__(self, arg1, arg2, arg3):
super().__init__("Test exception")


class ExceptionAPI(SimpleLitAPI):
def predict(self, x):
raise CustomError("This", "is", "a test")


def test_exception():
server = LitServer(ExceptionAPI(), accelerator="cpu", devices=1)
with wrap_litserve_start(server) as server, TestClient(server.app) as client:
response = client.post("/predict", json={"input": 4.0})
assert response.status_code == 500
assert response.json() == {"detail": "Internal Server Error"}

0 comments on commit ccc322a

Please sign in to comment.