Skip to content

Commit

Permalink
add test for litserve.server (#25)
Browse files Browse the repository at this point in the history
* add test for litserve.server

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update

* fix

* add tests

* test inference_worker

* fix dispose pipe

* test single loop

* clean

* change filename

* fix "import file mismatch"

* fix "import file mismatch"

* fix "import file mismatch"

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
aniketmaurya and pre-commit-ci[bot] authored Apr 10, 2024
1 parent edbeb43 commit 144b5ff
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 2 deletions.
2 changes: 2 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[pytest]
python_files = test_*.py
5 changes: 3 additions & 2 deletions src/litserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,9 @@ def new_pipe(self):
return pipe_s, pipe_r

def dispose_pipe(self, pipe_s, pipe_r):
if len(self.pipe_pool) > self.max_pool_size:
if len(self.pipe_pool) >= self.max_pool_size:
return
self.pipe_pool.append(pipe_s, pipe_r)
self.pipe_pool.append((pipe_s, pipe_r))

def device_identifiers(self, accelerator, device):
if isinstance(device, Sequence):
Expand Down Expand Up @@ -270,6 +270,7 @@ async def data_reader():
else:
data = await data_reader()

self.dispose_pipe(read, write)
if type(data) == HTTPException:
raise data

Expand Down
35 changes: 35 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from litserve.server import LitServer
import pytest
from litserve.api import LitAPI
from fastapi import Request, Response
from fastapi.testclient import TestClient


class SimpleLitAPI(LitAPI):
def setup(self, device):
self.model = lambda x: x**2

def decode_request(self, request: Request):
return request["input"]

def predict(self, x):
return self.model(x)

def encode_response(self, output) -> Response:
return {"output": output}


@pytest.fixture()
def simple_litapi():
return SimpleLitAPI()


@pytest.fixture()
def lit_server(simple_litapi):
return LitServer(simple_litapi, accelerator="cpu", devices=1, timeout=10)


@pytest.fixture()
def sync_testclient(lit_server):
with TestClient(lit_server.app) as client:
yield client
86 changes: 86 additions & 0 deletions tests/test_lit_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from multiprocessing import Pipe, Manager
from unittest.mock import patch, MagicMock
from litserve.server import inference_worker, run_single_loop
from litserve.server import LitServer

import pytest


def test_new_pipe(lit_server):
pool_size = lit_server.max_pool_size
for _ in range(pool_size):
lit_server.new_pipe()

assert len(lit_server.pipe_pool) == 0
assert len(lit_server.new_pipe()) == 2


def test_dispose_pipe(lit_server):
for i in range(lit_server.max_pool_size + 10):
lit_server.dispose_pipe(*Pipe())
assert len(lit_server.pipe_pool) == lit_server.max_pool_size


def test_index(sync_testclient):
assert sync_testclient.get("/").text == "litserve running"


@patch("litserve.server.lifespan")
def test_device_identifiers(lifespan_mock, simple_litapi):
server = LitServer(simple_litapi, accelerator="cpu", devices=1, timeout=10)
assert server.device_identifiers("cpu", 1) == ["cpu:1"]
assert server.device_identifiers("cpu", [1, 2]) == ["cpu:1", "cpu:2"]

server = LitServer(simple_litapi, accelerator="cpu", devices=1, timeout=10)
assert server.app.devices == ["cpu"]

server = LitServer(simple_litapi, accelerator="cuda", devices=1, timeout=10)
assert server.app.devices == [["cuda:0"]]

server = LitServer(simple_litapi, accelerator="cuda", devices=[1, 2], timeout=10)
# [["cuda:1"], ["cuda:2"]]
assert server.app.devices[0][0] == "cuda:1"
assert server.app.devices[1][0] == "cuda:2"


@patch("litserve.server.run_batched_loop")
@patch("litserve.server.run_single_loop")
def test_inference_worker(mock_single_loop, mock_batched_loop):
inference_worker(*[MagicMock()] * 5, max_batch_size=2, batch_timeout=0)
mock_batched_loop.assert_called_once()

inference_worker(*[MagicMock()] * 5, max_batch_size=1, batch_timeout=0)
mock_single_loop.assert_called_once()


@pytest.fixture()
def loop_args():
from multiprocessing import Manager, Queue, Pipe

requests_queue = Queue()
request_buffer = Manager().dict()
requests_queue.put(1)
requests_queue.put(2)
read, write = Pipe()
request_buffer[1] = {"input": 4.0}, write
request_buffer[2] = {"input": 5.0}, write

lit_api_mock = MagicMock()
lit_api_mock.decode_request = MagicMock(side_effect=lambda x: x["input"])
return lit_api_mock, requests_queue, request_buffer


class FakePipe:
def send(self, item):
raise StopIteration("exit loop")


def test_single_loop(simple_litapi, loop_args):
lit_api_mock, requests_queue, request_buffer = loop_args
lit_api_mock.unbatch.side_effect = None
request_buffer = Manager().dict()
request_buffer[1] = {"input": 4.0}, FakePipe()
request_buffer[2] = {"input": 5.0}, FakePipe()

with pytest.raises(StopIteration, match="exit loop"):
run_single_loop(lit_api_mock, requests_queue, request_buffer)

0 comments on commit 144b5ff

Please sign in to comment.