Skip to content

Commit

Permalink
fix close pipes (#108)
Browse files Browse the repository at this point in the history
  • Loading branch information
aniketmaurya authored May 23, 2024
1 parent 5268206 commit 41a256d
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 35 deletions.
28 changes: 10 additions & 18 deletions src/litserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,10 +291,7 @@ def __init__(
self.max_batch_size = max_batch_size
self.timeout = timeout
self.batch_timeout = batch_timeout
initial_pool_size = 100
self.max_pool_size = 1000
self.stream = stream
self.pipe_pool = [Pipe() for _ in range(initial_pool_size)]
self._connector = _Connector(accelerator=accelerator, devices=devices)

specs = spec if spec is not None else []
Expand Down Expand Up @@ -379,17 +376,12 @@ async def lifespan(self, app: FastAPI):
logging.info(f"terminating worker worker_id={worker_id}")
process.terminate()

def new_pipe(self):
try:
pipe_s, pipe_r = self.pipe_pool.pop()
except IndexError:
pipe_s, pipe_r = Pipe()
return pipe_s, pipe_r
def new_pipe(self) -> tuple:
return Pipe()

def dispose_pipe(self, pipe_s, pipe_r):
if len(self.pipe_pool) >= self.max_pool_size:
return
self.pipe_pool.append((pipe_s, pipe_r))
def close_pipe(self, pipe_s, pipe_r):
pipe_s.close()
pipe_r.close()

def device_identifiers(self, accelerator, device):
if isinstance(device, Sequence):
Expand Down Expand Up @@ -418,10 +410,10 @@ async def stream_from_pipe(self, read, write):
if read.poll(LONG_TIMEOUT):
response, status = read.recv()
if status == LitAPIStatus.FINISH_STREAMING:
self.dispose_pipe(read, write)
self.close_pipe(read, write)
return
elif status == LitAPIStatus.ERROR:
self.dispose_pipe(read, write)
self.close_pipe(read, write)
logger.error(
"Error occurred while streaming outputs from the inference worker. "
"Please check the above traceback."
Expand All @@ -442,10 +434,10 @@ async def data_streamer(self, read, write):
if read.poll():
response, status = read.recv()
if status == LitAPIStatus.FINISH_STREAMING:
self.dispose_pipe(read, write)
self.close_pipe(read, write)
return
if status == LitAPIStatus.ERROR:
self.dispose_pipe(read, write)
self.close_pipe(read, write)
logger.error(
"Error occurred while streaming outputs from the inference worker. "
"Please check the above traceback."
Expand Down Expand Up @@ -487,7 +479,7 @@ async def predict(request: self.request_type, background_tasks: BackgroundTasks)
)
else:
data = await wait_for_queue_timeout(self.data_reader(read), self.timeout, uid, self.request_buffer)
self.dispose_pipe(read, write)
self.close_pipe(read, write)

response, status = data
if status == LitAPIStatus.ERROR:
Expand Down
2 changes: 1 addition & 1 deletion src/litserve/specs/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ async def chat_completion(
responses = await self.get_from_pipe(uids, pipes)

for read, write in pipes:
self._server.dispose_pipe(read, write)
self._server.close_pipe(read, write)

usage = UsageInfo()

Expand Down
17 changes: 1 addition & 16 deletions tests/test_lit_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,21 +39,6 @@
from litserve.server import LitServer


def test_new_pipe(lit_server):
pool_size = lit_server.max_pool_size
for _ in range(pool_size):
lit_server.new_pipe()

assert len(lit_server.pipe_pool) == 0, "All available pipes from the pipe_pool were used up, which makes it empty"
assert len(lit_server.new_pipe()) == 2, "lit_server.new_pipe() always must return a tuple of read and write pipes"


def test_dispose_pipe(lit_server):
for i in range(lit_server.max_pool_size + 10):
lit_server.dispose_pipe(*Pipe())
assert len(lit_server.pipe_pool) == lit_server.max_pool_size, "pipe_pool size must be less than max_pool_size"


def test_index(sync_testclient):
assert sync_testclient.get("/").text == "litserve running"

Expand Down Expand Up @@ -88,7 +73,7 @@ def test_inference_worker(mock_single_loop, mock_batched_loop):

@pytest.fixture()
def loop_args():
from multiprocessing import Manager, Queue, Pipe
from multiprocessing import Manager, Queue

requests_queue = Queue()
request_buffer = Manager().dict()
Expand Down

0 comments on commit 41a256d

Please sign in to comment.