Skip to content

Commit

Permalink
Feat: Add support for response_format field in OpenAI Spec (#197)
Browse files Browse the repository at this point in the history
* feat: add response_format field to ChatCompletionRequest model

* feat: add openai_request_data_with_response_format fixture

* chore: Add TestAPIWithStructuredOutput class to openai_spec_example.py

* feat: Update litserve to use TestAPIWithToolCalls in default_openaispec_tools.py

* feat: Add default_openaispec_response_format.py to tests/e2e

* feat: Add test for openai_parity_with_response_format

* feat: Add test for openai_parity_with_response_format

* reverted change

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* chore: Update field name in JSONSchema model to adhere to naming convention

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Aniket Maurya <[email protected]>
  • Loading branch information
3 people authored Aug 14, 2024
1 parent f73d30d commit 8ff924a
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 5 deletions.
8 changes: 8 additions & 0 deletions src/litserve/examples/openai_spec_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@ def encode_response(self, output):
)


class TestAPIWithStructuredOutput(TestAPI):
def encode_response(self, output):
yield ChatMessage(
role="assistant",
content='{"name": "Science Fair", "date": "Friday", "participants": ["Alice", "Bob"]}',
)


class OpenAIBatchContext(ls.LitAPI):
def setup(self, device: str) -> None:
self.model = None
Expand Down
28 changes: 27 additions & 1 deletion src/litserve/specs/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import uuid
from collections import deque
from enum import Enum
from typing import AsyncGenerator, Dict, Iterator, List, Literal, Optional, Union
from typing import Annotated, AsyncGenerator, Dict, Iterator, List, Literal, Optional, Union

from fastapi import BackgroundTasks, HTTPException, Request, Response
from fastapi.responses import StreamingResponse
Expand Down Expand Up @@ -105,6 +105,31 @@ class ToolCall(BaseModel):
function: FunctionCall


class ResponseFormatText(BaseModel):
type: Literal["text"]


class ResponseFormatJSONObject(BaseModel):
type: Literal["json_object"]


class JSONSchema(BaseModel):
name: str
description: Optional[str] = None
schema_def: Optional[Dict[str, object]] = Field(None, alias="schema")
strict: Optional[bool] = False


class ResponseFormatJSONSchema(BaseModel):
json_schema: JSONSchema
type: Literal["json_schema"]


ResponseFormat = Annotated[
Union[ResponseFormatText, ResponseFormatJSONObject, ResponseFormatJSONSchema], "ResponseFormat"
]


class ChatMessage(BaseModel):
role: str
content: Union[str, List[Union[TextContent, ImageContent]]]
Expand Down Expand Up @@ -138,6 +163,7 @@ class ChatCompletionRequest(BaseModel):
user: Optional[str] = None
tools: Optional[List[Tool]] = None
tool_choice: Optional[ToolChoice] = ToolChoice.auto
response_format: Optional[ResponseFormat] = None


class ChatCompletionResponseChoice(BaseModel):
Expand Down
40 changes: 40 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,3 +224,43 @@ def openai_request_data_with_tools():
"frequency_penalty": 0,
"user": "string",
}


@pytest.fixture()
def openai_request_data_with_response_format():
return {
"model": "lit",
"messages": [
{
"role": "system",
"content": "Extract the event information.",
},
{"role": "user", "content": "Alice and Bob are going to a science fair on Friday."},
],
"response_format": {
"type": "json_schema",
"json_schema": {
"name": "calendar_event",
"schema": {
"type": "object",
"properties": {
"name": {"type": "string"},
"date": {"type": "string"},
"participants": {"type": "array", "items": {"type": "string"}},
},
"required": ["name", "date", "participants"],
"additionalProperties": "false",
},
"strict": "true",
},
},
"temperature": 0.7,
"top_p": 1,
"n": 1,
"max_tokens": 0,
"stop": "string",
"stream": False,
"presence_penalty": 0,
"frequency_penalty": 0,
"user": "string",
}
7 changes: 7 additions & 0 deletions tests/e2e/default_openaispec_response_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import litserve as ls
from litserve import OpenAISpec
from litserve.examples.openai_spec_example import TestAPIWithStructuredOutput

if __name__ == "__main__":
server = ls.LitServer(TestAPIWithStructuredOutput(), spec=OpenAISpec())
server.run()
51 changes: 51 additions & 0 deletions tests/e2e/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,3 +244,54 @@ def test_e2e_openai_with_batching(openai_request_data):
assert response.choices[0].message.content == (
"Hi! It's nice to meet you. Is there something I can " "help you with or would you like to chat? "
), f"Server didn't return expected output OpenAI client output: {response}"


@e2e_from_file("tests/e2e/default_openaispec_response_format.py")
def test_openai_parity_with_response_format():
client = OpenAI(base_url="http://127.0.0.1:8000/v1", api_key="lit")
messages = [
{
"role": "system",
"content": "Extract the event information.",
},
{"role": "user", "content": "Alice and Bob are going to a science fair on Friday."},
]
response_format = {
"type": "json_schema",
"json_schema": {
"name": "calendar_event",
"schema": {
"type": "object",
"properties": {
"name": {"type": "string"},
"date": {"type": "string"},
"participants": {"type": "array", "items": {"type": "string"}},
},
"required": ["name", "date", "participants"],
"additionalProperties": "false",
},
"strict": "true",
},
}
output = '{"name": "Science Fair", "date": "Friday", "participants": ["Alice", "Bob"]}'
response = client.chat.completions.create(
model="lit",
messages=messages,
response_format=response_format,
)
assert response.choices[0].message.content == output, (
f"Server didn't return expected output" f"\nOpenAI client output: {response}"
)

response = client.chat.completions.create(
model="lit",
messages=messages,
response_format=response_format,
stream=True,
)

expected_outputs = [output, None]
for r, expected_out in zip(response, expected_outputs):
assert r.choices[0].delta.content == expected_out, (
f"Server didn't return expected output.\n" f"OpenAI client output: {r}"
)
23 changes: 19 additions & 4 deletions tests/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import litserve as ls
import pytest
from asgi_lifespan import LifespanManager
from fastapi import HTTPException
from httpx import AsyncClient
from litserve.examples.openai_spec_example import (
OpenAIBatchingWithUsage,
OpenAIWithUsage,
OpenAIWithUsageEncodeResponse,
TestAPI,
TestAPIWithCustomEncode,
TestAPIWithStructuredOutput,
TestAPIWithToolCalls,
OpenAIWithUsage,
OpenAIBatchingWithUsage,
OpenAIWithUsageEncodeResponse,
)
from litserve.utils import wrap_litserve_start
from litserve.specs.openai import OpenAISpec, ChatMessage
import litserve as ls


@pytest.mark.asyncio()
Expand Down Expand Up @@ -117,6 +118,20 @@ async def test_openai_spec_with_tools(openai_request_data_with_tools):
], "LitAPI predict response should match with the generated output"


@pytest.mark.asyncio()
async def test_openai_spec_with_response_format(openai_request_data_with_response_format):
spec = OpenAISpec()
server = ls.LitServer(TestAPIWithStructuredOutput(), spec=spec)
with wrap_litserve_start(server) as server:
async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac:
resp = await ac.post("/v1/chat/completions", json=openai_request_data_with_response_format, timeout=10)
assert resp.status_code == 200, "Status code should be 200"
assert (
resp.json()["choices"][0]["message"]["content"]
== '{"name": "Science Fair", "date": "Friday", "participants": ["Alice", "Bob"]}'
), "LitAPI predict response should match with the generated output"


class IncorrectAPI1(ls.LitAPI):
def setup(self, device):
self.model = None
Expand Down

0 comments on commit 8ff924a

Please sign in to comment.