Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

integrate zmq #403

Merged
merged 17 commits into from
Jan 10, 2025
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
fastapi >=0.100
uvicorn[standard] >=0.29.0
pyzmq >=22.0.0
21 changes: 19 additions & 2 deletions src/litserve/loops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from queue import Empty, Queue
from typing import Any, Dict, List, Optional, Tuple, Union

import zmq
from starlette.formparsers import MultiPartParser

from litserve import LitAPI
Expand Down Expand Up @@ -128,6 +129,9 @@ def run(

"""

zmq_ctx: Optional[zmq.Context] = None
socket: Optional[zmq.Socket] = None

def pre_setup(self, lit_api: LitAPI, spec: Optional[LitSpec]):
pass

Expand Down Expand Up @@ -155,7 +159,9 @@ def __call__(
stream: bool,
workers_setup_status: Dict[int, str],
callback_runner: CallbackRunner,
socket: Optional[zmq.Socket],
):
self.socket = socket
if asyncio.iscoroutinefunction(self.run):
event_loop = asyncio.new_event_loop()

Expand Down Expand Up @@ -244,12 +250,23 @@ def populate_context(self, lit_spec: LitSpec, request: Any):
def put_response(
self, response_queues: List[Queue], response_queue_id: int, uid: str, response_data: Any, status: LitAPIStatus
) -> None:
response_queues[response_queue_id].put((uid, (response_data, status)), block=False)
if self.socket:
self.socket.send_pyobj((uid, (response_data, status)))
else:
response_queues[response_queue_id].put((uid, (response_data, status)), block=False)

def put_error_response(
self, response_queues: List[Queue], response_queue_id: int, uid: str, error: Exception
) -> None:
response_queues[response_queue_id].put((uid, (error, LitAPIStatus.ERROR)), block=False)
if self.socket:
self.socket.send_pyobj((uid, (error, LitAPIStatus.ERROR)))
else:
response_queues[response_queue_id].put((uid, (error, LitAPIStatus.ERROR)), block=False)

def __del__(self):
if self.socket:
self.socket.close()
self.zmq_ctx.term()


class DefaultLoop(LitLoop):
Expand Down
17 changes: 17 additions & 0 deletions src/litserve/loops/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from queue import Queue
from typing import Dict, List, Optional, Union

import zmq
import zmq.asyncio

from litserve import LitAPI
from litserve.callbacks import CallbackRunner, EventTypes
from litserve.loops.base import _BaseLoop
Expand Down Expand Up @@ -51,6 +54,8 @@ def inference_worker(
workers_setup_status: Dict[int, str],
callback_runner: CallbackRunner,
loop: Union[str, _BaseLoop],
use_zmq: bool,
zmq_addr: Optional[str],
):
callback_runner.trigger_event(EventTypes.BEFORE_SETUP, lit_api=lit_api)
try:
Expand All @@ -73,6 +78,14 @@ def inference_worker(
if loop == "auto":
loop = get_default_loop(stream, max_batch_size)

socket = None
if use_zmq:
ctx = zmq.Context()
socket = ctx.socket(zmq.PUB)
logger.debug(f"Inference worker binding to {zmq_addr}")
socket.bind(zmq_addr)
loop.zmq_context = ctx

loop(
lit_api,
lit_spec,
Expand All @@ -85,4 +98,8 @@ def inference_worker(
stream,
workers_setup_status,
callback_runner,
socket,
)
if use_zmq:
socket.close()
loop.zmq_context.term()
62 changes: 47 additions & 15 deletions src/litserve/loops/simple_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from queue import Empty, Queue
from typing import Dict, List, Optional, Tuple, Union

import zmq
from fastapi import HTTPException
from starlette.formparsers import MultiPartParser

Expand Down Expand Up @@ -84,6 +85,7 @@ def run_single_loop(
request_queue: Queue,
response_queues: List[Queue],
callback_runner: CallbackRunner,
socket: Optional[zmq.Socket],
):
while True:
try:
Expand All @@ -99,7 +101,13 @@ def run_single_loop(
"has been timed out. "
"You can adjust the timeout by providing the `timeout` argument to LitServe(..., timeout=30)."
)
response_queues[response_queue_id].put((uid, (HTTPException(504, "Request timed out"), LitAPIStatus.ERROR)))
if socket:
socket.send_pyobj((uid, (HTTPException(504, "Request timed out"), LitAPIStatus.ERROR)))
else:
response_queues[response_queue_id].put((
uid,
(HTTPException(504, "Request timed out"), LitAPIStatus.ERROR),
))
continue
try:
context = {}
Expand Down Expand Up @@ -129,22 +137,30 @@ def run_single_loop(
y,
)
callback_runner.trigger_event(EventTypes.AFTER_ENCODE_RESPONSE, lit_api=lit_api)

response_queues[response_queue_id].put((uid, (y_enc, LitAPIStatus.OK)))
if socket:
socket.send_pyobj((uid, (y_enc, LitAPIStatus.OK)))
else:
response_queues[response_queue_id].put((uid, (y_enc, LitAPIStatus.OK)))

except HTTPException as e:
response_queues[response_queue_id].put((
uid,
(PickleableHTTPException.from_exception(e), LitAPIStatus.ERROR),
))
if socket:
socket.send_pyobj((uid, (PickleableHTTPException.from_exception(e), LitAPIStatus.ERROR)))
else:
response_queues[response_queue_id].put((
uid,
(PickleableHTTPException.from_exception(e), LitAPIStatus.ERROR),
))

except Exception as e:
logger.exception(
"LitAPI ran into an error while processing the request uid=%s.\n"
"Please check the error trace for more details.",
uid,
)
response_queues[response_queue_id].put((uid, (e, LitAPIStatus.ERROR)))
if socket:
socket.send_pyobj((uid, (e, LitAPIStatus.ERROR)))
else:
response_queues[response_queue_id].put((uid, (e, LitAPIStatus.ERROR)))


def run_batched_loop(
Expand All @@ -155,6 +171,7 @@ def run_batched_loop(
max_batch_size: int,
batch_timeout: float,
callback_runner: CallbackRunner,
socket: Optional[zmq.Socket],
):
while True:
batches, timed_out_uids = collate_requests(
Expand All @@ -170,7 +187,13 @@ def run_batched_loop(
"has been timed out. "
"You can adjust the timeout by providing the `timeout` argument to LitServe(..., timeout=30)."
)
response_queues[response_queue_id].put((uid, (HTTPException(504, "Request timed out"), LitAPIStatus.ERROR)))
if socket:
socket.send_pyobj((uid, (HTTPException(504, "Request timed out"), LitAPIStatus.ERROR)))
else:
response_queues[response_queue_id].put((
uid,
(HTTPException(504, "Request timed out"), LitAPIStatus.ERROR),
))

if not batches:
continue
Expand Down Expand Up @@ -213,18 +236,24 @@ def run_batched_loop(

except HTTPException as e:
for response_queue_id, uid in zip(response_queue_ids, uids):
response_queues[response_queue_id].put((
uid,
(PickleableHTTPException.from_exception(e), LitAPIStatus.ERROR),
))
if socket:
socket.send_pyobj((uid, (PickleableHTTPException.from_exception(e), LitAPIStatus.ERROR)))
else:
response_queues[response_queue_id].put((
uid,
(PickleableHTTPException.from_exception(e), LitAPIStatus.ERROR),
))

except Exception as e:
logger.exception(
"LitAPI ran into an error while processing the batched request.\n"
"Please check the error trace for more details."
)
for response_queue_id, uid in zip(response_queue_ids, uids):
response_queues[response_queue_id].put((uid, (e, LitAPIStatus.ERROR)))
if socket:
socket.send_pyobj((uid, (e, LitAPIStatus.ERROR)))
else:
response_queues[response_queue_id].put((uid, (e, LitAPIStatus.ERROR)))


class SingleLoop(DefaultLoop):
Expand All @@ -241,8 +270,9 @@ def __call__(
stream: bool,
workers_setup_status: Dict[int, str],
callback_runner: CallbackRunner,
socket: Optional[zmq.Socket],
):
run_single_loop(lit_api, lit_spec, request_queue, response_queues, callback_runner)
run_single_loop(lit_api, lit_spec, request_queue, response_queues, callback_runner, socket)


class BatchedLoop(DefaultLoop):
Expand All @@ -259,6 +289,7 @@ def __call__(
stream: bool,
workers_setup_status: Dict[int, str],
callback_runner: CallbackRunner,
socket: Optional[zmq.Socket],
):
run_batched_loop(
lit_api,
Expand All @@ -268,4 +299,5 @@ def __call__(
max_batch_size,
batch_timeout,
callback_runner,
socket,
)
Loading
Loading