Skip to content

Commit

Permalink
integrate zmq (#403)
Browse files Browse the repository at this point in the history
* integrate zmq

* backward compatilbity

* fix typing

* fix tests

* update tests

* update

* delete socket

* terminate context

* fix random port

* add todo

* disable zmq

* use ipc

* fix windows CI

* clean up

* add tests

* omit windows
  • Loading branch information
aniketmaurya authored Jan 10, 2025
1 parent 501639b commit 43692d4
Show file tree
Hide file tree
Showing 12 changed files with 292 additions and 68 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
fastapi >=0.100
uvicorn[standard] >=0.29.0
pyzmq >=22.0.0
21 changes: 19 additions & 2 deletions src/litserve/loops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from queue import Empty, Queue
from typing import Any, Dict, List, Optional, Tuple, Union

import zmq
from starlette.formparsers import MultiPartParser

from litserve import LitAPI
Expand Down Expand Up @@ -128,6 +129,9 @@ def run(
"""

zmq_ctx: Optional[zmq.Context] = None
socket: Optional[zmq.Socket] = None

def pre_setup(self, lit_api: LitAPI, spec: Optional[LitSpec]):
pass

Expand Down Expand Up @@ -155,7 +159,9 @@ def __call__(
stream: bool,
workers_setup_status: Dict[int, str],
callback_runner: CallbackRunner,
socket: Optional[zmq.Socket],
):
self.socket = socket
if asyncio.iscoroutinefunction(self.run):
event_loop = asyncio.new_event_loop()

Expand Down Expand Up @@ -244,12 +250,23 @@ def populate_context(self, lit_spec: LitSpec, request: Any):
def put_response(
self, response_queues: List[Queue], response_queue_id: int, uid: str, response_data: Any, status: LitAPIStatus
) -> None:
response_queues[response_queue_id].put((uid, (response_data, status)), block=False)
if self.socket:
self.socket.send_pyobj((uid, (response_data, status)))
else:
response_queues[response_queue_id].put((uid, (response_data, status)), block=False)

def put_error_response(
self, response_queues: List[Queue], response_queue_id: int, uid: str, error: Exception
) -> None:
response_queues[response_queue_id].put((uid, (error, LitAPIStatus.ERROR)), block=False)
if self.socket:
self.socket.send_pyobj((uid, (error, LitAPIStatus.ERROR)))
else:
response_queues[response_queue_id].put((uid, (error, LitAPIStatus.ERROR)), block=False)

def __del__(self):
if self.socket:
self.socket.close()
self.zmq_ctx.term()


class DefaultLoop(LitLoop):
Expand Down
17 changes: 17 additions & 0 deletions src/litserve/loops/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from queue import Queue
from typing import Dict, List, Optional, Union

import zmq
import zmq.asyncio

from litserve import LitAPI
from litserve.callbacks import CallbackRunner, EventTypes
from litserve.loops.base import _BaseLoop
Expand Down Expand Up @@ -51,6 +54,8 @@ def inference_worker(
workers_setup_status: Dict[int, str],
callback_runner: CallbackRunner,
loop: Union[str, _BaseLoop],
use_zmq: bool,
zmq_addr: Optional[str],
):
callback_runner.trigger_event(EventTypes.BEFORE_SETUP, lit_api=lit_api)
try:
Expand All @@ -73,6 +78,14 @@ def inference_worker(
if loop == "auto":
loop = get_default_loop(stream, max_batch_size)

socket = None
if use_zmq:
ctx = zmq.Context()
socket = ctx.socket(zmq.PUB)
logger.debug(f"Inference worker binding to {zmq_addr}")
socket.bind(zmq_addr)
loop.zmq_context = ctx

loop(
lit_api,
lit_spec,
Expand All @@ -85,4 +98,8 @@ def inference_worker(
stream,
workers_setup_status,
callback_runner,
socket,
)
if use_zmq:
socket.close()
loop.zmq_context.term()
62 changes: 47 additions & 15 deletions src/litserve/loops/simple_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from queue import Empty, Queue
from typing import Dict, List, Optional, Tuple, Union

import zmq
from fastapi import HTTPException
from starlette.formparsers import MultiPartParser

Expand Down Expand Up @@ -84,6 +85,7 @@ def run_single_loop(
request_queue: Queue,
response_queues: List[Queue],
callback_runner: CallbackRunner,
socket: Optional[zmq.Socket],
):
while True:
try:
Expand All @@ -99,7 +101,13 @@ def run_single_loop(
"has been timed out. "
"You can adjust the timeout by providing the `timeout` argument to LitServe(..., timeout=30)."
)
response_queues[response_queue_id].put((uid, (HTTPException(504, "Request timed out"), LitAPIStatus.ERROR)))
if socket:
socket.send_pyobj((uid, (HTTPException(504, "Request timed out"), LitAPIStatus.ERROR)))
else:
response_queues[response_queue_id].put((
uid,
(HTTPException(504, "Request timed out"), LitAPIStatus.ERROR),
))
continue
try:
context = {}
Expand Down Expand Up @@ -129,22 +137,30 @@ def run_single_loop(
y,
)
callback_runner.trigger_event(EventTypes.AFTER_ENCODE_RESPONSE, lit_api=lit_api)

response_queues[response_queue_id].put((uid, (y_enc, LitAPIStatus.OK)))
if socket:
socket.send_pyobj((uid, (y_enc, LitAPIStatus.OK)))
else:
response_queues[response_queue_id].put((uid, (y_enc, LitAPIStatus.OK)))

except HTTPException as e:
response_queues[response_queue_id].put((
uid,
(PickleableHTTPException.from_exception(e), LitAPIStatus.ERROR),
))
if socket:
socket.send_pyobj((uid, (PickleableHTTPException.from_exception(e), LitAPIStatus.ERROR)))
else:
response_queues[response_queue_id].put((
uid,
(PickleableHTTPException.from_exception(e), LitAPIStatus.ERROR),
))

except Exception as e:
logger.exception(
"LitAPI ran into an error while processing the request uid=%s.\n"
"Please check the error trace for more details.",
uid,
)
response_queues[response_queue_id].put((uid, (e, LitAPIStatus.ERROR)))
if socket:
socket.send_pyobj((uid, (e, LitAPIStatus.ERROR)))
else:
response_queues[response_queue_id].put((uid, (e, LitAPIStatus.ERROR)))


def run_batched_loop(
Expand All @@ -155,6 +171,7 @@ def run_batched_loop(
max_batch_size: int,
batch_timeout: float,
callback_runner: CallbackRunner,
socket: Optional[zmq.Socket],
):
while True:
batches, timed_out_uids = collate_requests(
Expand All @@ -170,7 +187,13 @@ def run_batched_loop(
"has been timed out. "
"You can adjust the timeout by providing the `timeout` argument to LitServe(..., timeout=30)."
)
response_queues[response_queue_id].put((uid, (HTTPException(504, "Request timed out"), LitAPIStatus.ERROR)))
if socket:
socket.send_pyobj((uid, (HTTPException(504, "Request timed out"), LitAPIStatus.ERROR)))
else:
response_queues[response_queue_id].put((
uid,
(HTTPException(504, "Request timed out"), LitAPIStatus.ERROR),
))

if not batches:
continue
Expand Down Expand Up @@ -213,18 +236,24 @@ def run_batched_loop(

except HTTPException as e:
for response_queue_id, uid in zip(response_queue_ids, uids):
response_queues[response_queue_id].put((
uid,
(PickleableHTTPException.from_exception(e), LitAPIStatus.ERROR),
))
if socket:
socket.send_pyobj((uid, (PickleableHTTPException.from_exception(e), LitAPIStatus.ERROR)))
else:
response_queues[response_queue_id].put((
uid,
(PickleableHTTPException.from_exception(e), LitAPIStatus.ERROR),
))

except Exception as e:
logger.exception(
"LitAPI ran into an error while processing the batched request.\n"
"Please check the error trace for more details."
)
for response_queue_id, uid in zip(response_queue_ids, uids):
response_queues[response_queue_id].put((uid, (e, LitAPIStatus.ERROR)))
if socket:
socket.send_pyobj((uid, (e, LitAPIStatus.ERROR)))
else:
response_queues[response_queue_id].put((uid, (e, LitAPIStatus.ERROR)))


class SingleLoop(DefaultLoop):
Expand All @@ -241,8 +270,9 @@ def __call__(
stream: bool,
workers_setup_status: Dict[int, str],
callback_runner: CallbackRunner,
socket: Optional[zmq.Socket],
):
run_single_loop(lit_api, lit_spec, request_queue, response_queues, callback_runner)
run_single_loop(lit_api, lit_spec, request_queue, response_queues, callback_runner, socket)


class BatchedLoop(DefaultLoop):
Expand All @@ -259,6 +289,7 @@ def __call__(
stream: bool,
workers_setup_status: Dict[int, str],
callback_runner: CallbackRunner,
socket: Optional[zmq.Socket],
):
run_batched_loop(
lit_api,
Expand All @@ -268,4 +299,5 @@ def __call__(
max_batch_size,
batch_timeout,
callback_runner,
socket,
)
Loading

0 comments on commit 43692d4

Please sign in to comment.