Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Add support for response_format field in OpenAI Spec #197

Merged
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
c24a220
feat: add response_format field to ChatCompletionRequest model
bhimrazy Aug 7, 2024
3d3a3db
feat: add openai_request_data_with_response_format fixture
bhimrazy Aug 8, 2024
f59f59e
chore: Add TestAPIWithStructuredOutput class to openai_spec_example.py
bhimrazy Aug 8, 2024
041c0b0
feat: Update litserve to use TestAPIWithToolCalls in default_openaisp…
bhimrazy Aug 8, 2024
e0f1a2f
feat: Add default_openaispec_response_format.py to tests/e2e
bhimrazy Aug 8, 2024
9001b0a
feat: Add test for openai_parity_with_response_format
bhimrazy Aug 8, 2024
a505e5c
feat: Add test for openai_parity_with_response_format
bhimrazy Aug 8, 2024
14fbf57
Merge branch 'main' into feat/add-response-format-support
bhimrazy Aug 8, 2024
ef36f96
reverted change
bhimrazy Aug 8, 2024
13e50cc
Merge branch 'feat/add-response-format-support' of github.com:bhimraz…
bhimrazy Aug 8, 2024
f041521
Merge branch 'main' into feat/add-response-format-support
bhimrazy Aug 10, 2024
c419ad8
Merge branch 'main' into feat/add-response-format-support
bhimrazy Aug 12, 2024
e95084e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 12, 2024
4da081d
chore: Update field name in JSONSchema model to adhere to naming conv…
bhimrazy Aug 12, 2024
f368522
Merge branch 'main' into feat/add-response-format-support
aniketmaurya Aug 12, 2024
d040953
Merge branch 'main' into feat/add-response-format-support
aniketmaurya Aug 12, 2024
3bb25b2
Merge branch 'main' into feat/add-response-format-support
bhimrazy Aug 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/litserve/examples/openai_spec_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@ def encode_response(self, output):
)


class TestAPIWithStructuredOutput(TestAPI):
def encode_response(self, output):
yield ChatMessage(
role="assistant",
content='{"name": "Science Fair", "date": "Friday", "participants": ["Alice", "Bob"]}',
)


class OpenAIBatchContext(ls.LitAPI):
def setup(self, device: str) -> None:
self.model = None
Expand Down
28 changes: 27 additions & 1 deletion src/litserve/specs/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import uuid
from collections import deque
from enum import Enum
from typing import AsyncGenerator, Dict, Iterator, List, Literal, Optional, Union
from typing import Annotated, AsyncGenerator, Dict, Iterator, List, Literal, Optional, Union

from fastapi import BackgroundTasks, HTTPException, Request, Response
from fastapi.responses import StreamingResponse
Expand Down Expand Up @@ -105,6 +105,31 @@ class ToolCall(BaseModel):
function: FunctionCall


class ResponseFormatText(BaseModel):
type: Literal["text"]


class ResponseFormatJSONObject(BaseModel):
type: Literal["json_object"]


class JSONSchema(BaseModel):
name: str
description: Optional[str] = None
schema: Optional[Dict[str, object]] = None
strict: Optional[bool] = False


class ResponseFormatJSONSchema(BaseModel):
json_schema: JSONSchema
type: Literal["json_schema"]


ResponseFormat = Annotated[
Union[ResponseFormatText, ResponseFormatJSONObject, ResponseFormatJSONSchema], "ResponseFormat"
]


class ChatMessage(BaseModel):
role: str
content: Union[str, List[Union[TextContent, ImageContent]]]
Expand Down Expand Up @@ -138,6 +163,7 @@ class ChatCompletionRequest(BaseModel):
user: Optional[str] = None
tools: Optional[List[Tool]] = None
tool_choice: Optional[ToolChoice] = ToolChoice.auto
response_format: Optional[ResponseFormat] = None


class ChatCompletionResponseChoice(BaseModel):
Expand Down
40 changes: 40 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,3 +236,43 @@ def openai_request_data_with_tools():
"frequency_penalty": 0,
"user": "string",
}


@pytest.fixture()
def openai_request_data_with_response_format():
return {
"model": "lit",
"messages": [
{
"role": "system",
"content": "Extract the event information.",
},
{"role": "user", "content": "Alice and Bob are going to a science fair on Friday."},
],
"response_format": {
"type": "json_schema",
"json_schema": {
"name": "calendar_event",
"schema": {
"type": "object",
"properties": {
"name": {"type": "string"},
"date": {"type": "string"},
"participants": {"type": "array", "items": {"type": "string"}},
},
"required": ["name", "date", "participants"],
"additionalProperties": "false",
},
"strict": "true",
},
},
"temperature": 0.7,
"top_p": 1,
"n": 1,
"max_tokens": 0,
"stop": "string",
"stream": False,
"presence_penalty": 0,
"frequency_penalty": 0,
"user": "string",
}
7 changes: 7 additions & 0 deletions tests/e2e/default_openaispec_response_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import litserve as ls
from litserve import OpenAISpec
from litserve.examples.openai_spec_example import TestAPIWithStructuredOutput

if __name__ == "__main__":
server = ls.LitServer(TestAPIWithStructuredOutput(), spec=OpenAISpec())
server.run()
51 changes: 51 additions & 0 deletions tests/e2e/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,3 +244,54 @@ def test_e2e_openai_with_batching(openai_request_data):
assert response.choices[0].message.content == (
"Hi! It's nice to meet you. Is there something I can " "help you with or would you like to chat? "
), f"Server didn't return expected output OpenAI client output: {response}"


@e2e_from_file("tests/e2e/default_openaispec_response_format.py")
def test_openai_parity_with_response_format():
client = OpenAI(base_url="http://127.0.0.1:8000/v1", api_key="lit")
messages = [
{
"role": "system",
"content": "Extract the event information.",
},
{"role": "user", "content": "Alice and Bob are going to a science fair on Friday."},
]
response_format = {
"type": "json_schema",
"json_schema": {
"name": "calendar_event",
"schema": {
"type": "object",
"properties": {
"name": {"type": "string"},
"date": {"type": "string"},
"participants": {"type": "array", "items": {"type": "string"}},
},
"required": ["name", "date", "participants"],
"additionalProperties": "false",
},
"strict": "true",
},
}
output = '{"name": "Science Fair", "date": "Friday", "participants": ["Alice", "Bob"]}'
response = client.chat.completions.create(
model="lit",
messages=messages,
response_format=response_format,
)
assert response.choices[0].message.content == output, (
f"Server didn't return expected output" f"\nOpenAI client output: {response}"
)

response = client.chat.completions.create(
model="lit",
messages=messages,
response_format=response_format,
stream=True,
)

expected_outputs = [output, None]
for r, expected_out in zip(response, expected_outputs):
assert r.choices[0].delta.content == expected_out, (
f"Server didn't return expected output.\n" f"OpenAI client output: {r}"
)
26 changes: 21 additions & 5 deletions tests/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import litserve as ls
import pytest
from asgi_lifespan import LifespanManager
from fastapi import HTTPException
from httpx import AsyncClient
from litserve.examples.openai_spec_example import (
OpenAIBatchingWithUsage,
OpenAIWithUsage,
OpenAIWithUsageEncodeResponse,
TestAPI,
TestAPIWithCustomEncode,
TestAPIWithStructuredOutput,
TestAPIWithToolCalls,
OpenAIWithUsage,
OpenAIBatchingWithUsage,
OpenAIWithUsageEncodeResponse,
)
from litserve.specs.openai import ChatMessage, OpenAISpec

from tests.conftest import wrap_litserve_start
from litserve.specs.openai import OpenAISpec, ChatMessage
import litserve as ls


@pytest.mark.asyncio()
Expand Down Expand Up @@ -117,6 +119,20 @@ async def test_openai_spec_with_tools(openai_request_data_with_tools):
], "LitAPI predict response should match with the generated output"


@pytest.mark.asyncio()
async def test_openai_spec_with_response_format(openai_request_data_with_response_format):
spec = OpenAISpec()
server = ls.LitServer(TestAPIWithStructuredOutput(), spec=spec)
with wrap_litserve_start(server) as server:
async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac:
resp = await ac.post("/v1/chat/completions", json=openai_request_data_with_response_format, timeout=10)
assert resp.status_code == 200, "Status code should be 200"
assert (
resp.json()["choices"][0]["message"]["content"]
== '{"name": "Science Fair", "date": "Friday", "participants": ["Alice", "Bob"]}'
), "LitAPI predict response should match with the generated output"


class IncorrectAPI1(ls.LitAPI):
def setup(self, device):
self.model = None
Expand Down
Loading