Skip to content

Commit

Permalink
move built in loops inside classes (#409)
Browse files Browse the repository at this point in the history
* move implementation inside the classes

* fix tests

* fix tests
  • Loading branch information
aniketmaurya authored Jan 13, 2025
1 parent 6f1d3df commit 23bcb30
Show file tree
Hide file tree
Showing 5 changed files with 388 additions and 271 deletions.
2 changes: 1 addition & 1 deletion src/litserve/loops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def put_error_response(

def __del__(self):
if self.socket:
self.socket.close()
self.socket.close(linger=0)
self.zmq_ctx.term()


Expand Down
1 change: 1 addition & 0 deletions src/litserve/loops/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def inference_worker(
socket = ctx.socket(zmq.PUB)
logger.debug(f"Inference worker binding to {zmq_addr}")
socket.bind(zmq_addr)
loop.socket = socket
loop.zmq_context = ctx

loop(
Expand Down
279 changes: 193 additions & 86 deletions src/litserve/loops/simple_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,90 +28,6 @@
logger = logging.getLogger(__name__)


def run_single_loop(
lit_api: LitAPI,
lit_spec: LitSpec,
request_queue: Queue,
response_queues: List[Queue],
callback_runner: CallbackRunner,
socket: Optional[zmq.Socket],
):
while True:
try:
response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=1.0)
except (Empty, ValueError):
continue

if (lit_api.request_timeout and lit_api.request_timeout != -1) and (
time.monotonic() - timestamp > lit_api.request_timeout
):
logger.error(
f"Request {uid} was waiting in the queue for too long ({lit_api.request_timeout} seconds) and "
"has been timed out. "
"You can adjust the timeout by providing the `timeout` argument to LitServe(..., timeout=30)."
)
if socket:
socket.send_pyobj((uid, (HTTPException(504, "Request timed out"), LitAPIStatus.ERROR)))
else:
response_queues[response_queue_id].put((
uid,
(HTTPException(504, "Request timed out"), LitAPIStatus.ERROR),
))
continue
try:
context = {}
if hasattr(lit_spec, "populate_context"):
lit_spec.populate_context(context, x_enc)

callback_runner.trigger_event(EventTypes.BEFORE_DECODE_REQUEST, lit_api=lit_api)
x = _inject_context(
context,
lit_api.decode_request,
x_enc,
)
callback_runner.trigger_event(EventTypes.AFTER_DECODE_REQUEST, lit_api=lit_api)

callback_runner.trigger_event(EventTypes.BEFORE_PREDICT, lit_api=lit_api)
y = _inject_context(
context,
lit_api.predict,
x,
)
callback_runner.trigger_event(EventTypes.AFTER_PREDICT, lit_api=lit_api)

callback_runner.trigger_event(EventTypes.BEFORE_ENCODE_RESPONSE, lit_api=lit_api)
y_enc = _inject_context(
context,
lit_api.encode_response,
y,
)
callback_runner.trigger_event(EventTypes.AFTER_ENCODE_RESPONSE, lit_api=lit_api)
if socket:
socket.send_pyobj((uid, (y_enc, LitAPIStatus.OK)))
else:
response_queues[response_queue_id].put((uid, (y_enc, LitAPIStatus.OK)))

except HTTPException as e:
if socket:
socket.send_pyobj((uid, (PickleableHTTPException.from_exception(e), LitAPIStatus.ERROR)))
else:
response_queues[response_queue_id].put((
uid,
(PickleableHTTPException.from_exception(e), LitAPIStatus.ERROR),
))

except Exception as e:
logger.exception(
"LitAPI ran into an error while processing the request uid=%s.\n"
"Please check the error trace for more details.",
uid,
)
if socket:
socket.send_pyobj((uid, (e, LitAPIStatus.ERROR)))
else:
response_queues[response_queue_id].put((uid, (e, LitAPIStatus.ERROR)))


def run_batched_loop(
lit_api: LitAPI,
lit_spec: LitSpec,
Expand Down Expand Up @@ -214,6 +130,96 @@ def run_batched_loop(


class SingleLoop(DefaultLoop):
def run_single_loop(
self,
lit_api: LitAPI,
lit_spec: LitSpec,
request_queue: Queue,
response_queues: List[Queue],
callback_runner: CallbackRunner,
socket: Optional[zmq.Socket],
):
while True:
try:
response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=1.0)
except (Empty, ValueError):
continue

if (lit_api.request_timeout and lit_api.request_timeout != -1) and (
time.monotonic() - timestamp > lit_api.request_timeout
):
logger.error(
f"Request {uid} was waiting in the queue for too long ({lit_api.request_timeout} seconds) and "
"has been timed out. "
"You can adjust the timeout by providing the `timeout` argument to LitServe(..., timeout=30)."
)
self.put_response(
response_queues=response_queues,
response_queue_id=response_queue_id,
uid=uid,
response_data=(HTTPException(504, "Request timed out")),
status=LitAPIStatus.ERROR,
)
continue
try:
context = {}
if hasattr(lit_spec, "populate_context"):
lit_spec.populate_context(context, x_enc)

callback_runner.trigger_event(EventTypes.BEFORE_DECODE_REQUEST, lit_api=lit_api)
x = _inject_context(
context,
lit_api.decode_request,
x_enc,
)
callback_runner.trigger_event(EventTypes.AFTER_DECODE_REQUEST, lit_api=lit_api)

callback_runner.trigger_event(EventTypes.BEFORE_PREDICT, lit_api=lit_api)
y = _inject_context(
context,
lit_api.predict,
x,
)
callback_runner.trigger_event(EventTypes.AFTER_PREDICT, lit_api=lit_api)

callback_runner.trigger_event(EventTypes.BEFORE_ENCODE_RESPONSE, lit_api=lit_api)
y_enc = _inject_context(
context,
lit_api.encode_response,
y,
)
callback_runner.trigger_event(EventTypes.AFTER_ENCODE_RESPONSE, lit_api=lit_api)
self.put_response(
response_queues=response_queues,
response_queue_id=response_queue_id,
uid=uid,
response_data=y_enc,
status=LitAPIStatus.OK,
)

except HTTPException as e:
self.put_response(
response_queues=response_queues,
response_queue_id=response_queue_id,
uid=uid,
response_data=PickleableHTTPException.from_exception(e),
status=LitAPIStatus.ERROR,
)

except Exception as e:
logger.exception(
"LitAPI ran into an error while processing the request uid=%s.\n"
"Please check the error trace for more details.",
uid,
)
self.put_response(
response_queues=response_queues,
response_queue_id=response_queue_id,
uid=uid,
response_data=e,
status=LitAPIStatus.ERROR,
)

def __call__(
self,
lit_api: LitAPI,
Expand All @@ -229,10 +235,111 @@ def __call__(
callback_runner: CallbackRunner,
socket: Optional[zmq.Socket],
):
run_single_loop(lit_api, lit_spec, request_queue, response_queues, callback_runner, socket)
self.run_single_loop(lit_api, lit_spec, request_queue, response_queues, callback_runner, socket)


class BatchedLoop(DefaultLoop):
def run_batched_loop(
self,
lit_api: LitAPI,
lit_spec: LitSpec,
request_queue: Queue,
response_queues: List[Queue],
max_batch_size: int,
batch_timeout: float,
callback_runner: CallbackRunner,
socket: Optional[zmq.Socket],
):
while True:
batches, timed_out_uids = collate_requests(
lit_api,
request_queue,
max_batch_size,
batch_timeout,
)

for response_queue_id, uid in timed_out_uids:
logger.error(
f"Request {uid} was waiting in the queue for too long ({lit_api.request_timeout} seconds) and "
"has been timed out. "
"You can adjust the timeout by providing the `timeout` argument to LitServe(..., timeout=30)."
)
if socket:
socket.send_pyobj((uid, (HTTPException(504, "Request timed out"), LitAPIStatus.ERROR)))
else:
response_queues[response_queue_id].put((
uid,
(HTTPException(504, "Request timed out"), LitAPIStatus.ERROR),
))

if not batches:
continue
logger.debug(f"{len(batches)} batched requests received")
response_queue_ids, uids, inputs = zip(*batches)
num_inputs = len(inputs)
try:
contexts = [{}] * num_inputs
if hasattr(lit_spec, "populate_context"):
for input, context in zip(inputs, contexts):
lit_spec.populate_context(context, input)

callback_runner.trigger_event(EventTypes.BEFORE_DECODE_REQUEST, lit_api=lit_api)
x = [
_inject_context(
context,
lit_api.decode_request,
input,
)
for input, context in zip(inputs, contexts)
]
callback_runner.trigger_event(EventTypes.AFTER_DECODE_REQUEST, lit_api=lit_api)

x = lit_api.batch(x)

callback_runner.trigger_event(EventTypes.BEFORE_PREDICT, lit_api=lit_api)
y = _inject_context(contexts, lit_api.predict, x)
callback_runner.trigger_event(EventTypes.AFTER_PREDICT, lit_api=lit_api)

outputs = lit_api.unbatch(y)

if len(outputs) != num_inputs:
logger.error(
"LitAPI.predict/unbatch returned {len(outputs)} outputs, but expected {num_inputs}. "
"Please check the predict/unbatch method of the LitAPI implementation."
)
raise HTTPException(500, "Batch size mismatch")

callback_runner.trigger_event(EventTypes.BEFORE_ENCODE_RESPONSE, lit_api=lit_api)
y_enc_list = []
for response_queue_id, y, uid, context in zip(response_queue_ids, outputs, uids, contexts):
y_enc = _inject_context(context, lit_api.encode_response, y)
y_enc_list.append((response_queue_id, uid, y_enc))
callback_runner.trigger_event(EventTypes.AFTER_ENCODE_RESPONSE, lit_api=lit_api)

for response_queue_id, uid, y_enc in y_enc_list:
response_queues[response_queue_id].put((uid, (y_enc, LitAPIStatus.OK)))

except HTTPException as e:
for response_queue_id, uid in zip(response_queue_ids, uids):
if socket:
socket.send_pyobj((uid, (PickleableHTTPException.from_exception(e), LitAPIStatus.ERROR)))
else:
response_queues[response_queue_id].put((
uid,
(PickleableHTTPException.from_exception(e), LitAPIStatus.ERROR),
))

except Exception as e:
logger.exception(
"LitAPI ran into an error while processing the batched request.\n"
"Please check the error trace for more details."
)
for response_queue_id, uid in zip(response_queue_ids, uids):
if socket:
socket.send_pyobj((uid, (e, LitAPIStatus.ERROR)))
else:
response_queues[response_queue_id].put((uid, (e, LitAPIStatus.ERROR)))

def __call__(
self,
lit_api: LitAPI,
Expand All @@ -248,7 +355,7 @@ def __call__(
callback_runner: CallbackRunner,
socket: Optional[zmq.Socket],
):
run_batched_loop(
self.run_batched_loop(
lit_api,
lit_spec,
request_queue,
Expand Down
Loading

0 comments on commit 23bcb30

Please sign in to comment.