Skip to content

Commit

Permalink
Merge branch 'main' into fix/add-missing-callback
Browse files Browse the repository at this point in the history
  • Loading branch information
bhimrazy authored Jan 13, 2025
2 parents b4a2004 + 5aef366 commit e0787e7
Show file tree
Hide file tree
Showing 18 changed files with 787 additions and 502 deletions.
2 changes: 1 addition & 1 deletion .github/CODEOWNERS
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# These owners will be the default owners for everything in the repo. Unless a later match takes precedence,
# @global-owner1 and @global-owner2 will be requested for review when someone opens a pull request.
* @lantiga @aniketmaurya @ethanwharris @Andrei-Aksionov @borda
* @lantiga @aniketmaurya @ethanwharris @Andrei-Aksionov @borda @justusschock @tchaton

# CI/CD and configs
/.github/ @borda
Expand Down
34 changes: 19 additions & 15 deletions src/litserve/loops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,20 @@
import asyncio
import inspect
import logging
import signal
import sys
import time
from abc import ABC
from queue import Empty, Queue
from typing import Any, Dict, List, Optional, Tuple, Union

import zmq
from starlette.formparsers import MultiPartParser

from litserve import LitAPI
from litserve.callbacks import CallbackRunner
from litserve.specs.base import LitSpec
from litserve.utils import LitAPIStatus
from litserve.zmq_queue import Producer

logger = logging.getLogger(__name__)
# FastAPI writes form files to disk over 1MB by default, which prevents serialization by multiprocessing
Expand Down Expand Up @@ -129,9 +130,6 @@ def run(
"""

zmq_ctx: Optional[zmq.Context] = None
socket: Optional[zmq.Socket] = None

def pre_setup(self, lit_api: LitAPI, spec: Optional[LitSpec]):
pass

Expand Down Expand Up @@ -159,9 +157,7 @@ def __call__(
stream: bool,
workers_setup_status: Dict[int, str],
callback_runner: CallbackRunner,
socket: Optional[zmq.Socket],
):
self.socket = socket
if asyncio.iscoroutinefunction(self.run):
event_loop = asyncio.new_event_loop()

Expand Down Expand Up @@ -226,7 +222,9 @@ def run(

class LitLoop(_BaseLoop):
def __init__(self):
self.producer: Optional[Producer] = None
self._context = {}
self._setup_signal_handlers()

def get_batch_requests(self, lit_api: LitAPI, request_queue: Queue, max_batch_size: int, batch_timeout: float):
batches, timed_out_uids = collate_requests(
Expand All @@ -250,23 +248,29 @@ def populate_context(self, lit_spec: LitSpec, request: Any):
def put_response(
self, response_queues: List[Queue], response_queue_id: int, uid: str, response_data: Any, status: LitAPIStatus
) -> None:
if self.socket:
self.socket.send_pyobj((uid, (response_data, status)))
if self.producer:
self.producer.put((uid, (response_data, status)), consumer_id=response_queue_id)
else:
response_queues[response_queue_id].put((uid, (response_data, status)), block=False)

def put_error_response(
self, response_queues: List[Queue], response_queue_id: int, uid: str, error: Exception
) -> None:
if self.socket:
self.socket.send_pyobj((uid, (error, LitAPIStatus.ERROR)))
else:
response_queues[response_queue_id].put((uid, (error, LitAPIStatus.ERROR)), block=False)
self.put_response(response_queues, response_queue_id, uid, error, LitAPIStatus.ERROR)

def __del__(self):
if self.socket:
self.socket.close()
self.zmq_ctx.term()
if self.producer:
self.producer.close()

def _setup_signal_handlers(self):
def cleanup_handler(signum=None, frame=None):
logging.debug("Worker process received shutdown signal")
if self.producer:
self.producer.close()
sys.exit(0)

signal.signal(signal.SIGINT, cleanup_handler)
signal.signal(signal.SIGTERM, cleanup_handler)


class DefaultLoop(LitLoop):
Expand Down
17 changes: 4 additions & 13 deletions src/litserve/loops/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,14 @@
from queue import Queue
from typing import Dict, List, Optional, Union

import zmq
import zmq.asyncio

from litserve import LitAPI
from litserve.callbacks import CallbackRunner, EventTypes
from litserve.loops.base import _BaseLoop
from litserve.loops.simple_loops import BatchedLoop, SingleLoop
from litserve.loops.streaming_loops import BatchedStreamingLoop, StreamingLoop
from litserve.specs.base import LitSpec
from litserve.utils import WorkerSetupStatus
from litserve.zmq_queue import Producer

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -78,13 +76,10 @@ def inference_worker(
if loop == "auto":
loop = get_default_loop(stream, max_batch_size)

socket = None
if use_zmq:
ctx = zmq.Context()
socket = ctx.socket(zmq.PUB)
logger.debug(f"Inference worker binding to {zmq_addr}")
socket.bind(zmq_addr)
loop.zmq_context = ctx
producer = Producer(address=zmq_addr)
producer.wait_for_subscribers(timeout=5)
loop.producer = producer

loop(
lit_api,
Expand All @@ -98,8 +93,4 @@ def inference_worker(
stream,
workers_setup_status,
callback_runner,
socket,
)
if use_zmq:
socket.close()
loop.zmq_context.term()
Loading

0 comments on commit e0787e7

Please sign in to comment.