-
Notifications
You must be signed in to change notification settings - Fork 50
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Move _get_base_url to the base provider In order to properly support "muxing providers" like openrouter, we'll have to tell litellm (or in future a native implementation), what server do we want to proxy to. We were already doing that with Vllm, but since are about to do the same for OpenRouter, let's move the `_get_base_url` method to the base provider. * Add an openrouter provider OpenRouter is a "muxing provider" which itself provides access to multiple models and providers. It speaks a dialect of the OpenAI protocol, but for our purposes, we can say it's OpenAI. There are some differences in handling the requests, though: 1) we need to know where to forward the request to, by default this is `https://openrouter.ai/api/v1`, this is done by setting the base_url parameter 2) we need to prefix the model with `openrouter/`. This is a lite-LLM-ism (see https://docs.litellm.ai/docs/providers/openrouter) which we'll be able to remove once we ditch litellm Initially I was considering just exposing the OpenAI provider on an additional route and handling the prefix based on the route, but I think having an explicit provider class is better as it allows us to handle any differences in OpenRouter dialect easily in the future. Related: #878 * Add a special ProviderType for openrouter We can later alias it to openai if we decide to merge them. * Add tests for the openrouter provider * ProviderType was reversed, thanks Alejandro --------- Co-authored-by: Radoslav Dimitrov <[email protected]>
- Loading branch information
Showing
11 changed files
with
170 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from codegate.providers.openai.provider import OpenAIProvider | ||
|
||
__all__ = ["OpenAIProvider"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import json | ||
|
||
from fastapi import Header, HTTPException, Request | ||
|
||
from codegate.clients.detector import DetectClient | ||
from codegate.pipeline.factory import PipelineFactory | ||
from codegate.providers.openai import OpenAIProvider | ||
|
||
|
||
class OpenRouterProvider(OpenAIProvider): | ||
def __init__(self, pipeline_factory: PipelineFactory): | ||
super().__init__(pipeline_factory) | ||
|
||
@property | ||
def provider_route_name(self) -> str: | ||
return "openrouter" | ||
|
||
def _setup_routes(self): | ||
@self.router.post(f"/{self.provider_route_name}/api/v1/chat/completions") | ||
@self.router.post(f"/{self.provider_route_name}/chat/completions") | ||
@DetectClient() | ||
async def create_completion( | ||
request: Request, | ||
authorization: str = Header(..., description="Bearer token"), | ||
): | ||
if not authorization.startswith("Bearer "): | ||
raise HTTPException(status_code=401, detail="Invalid authorization header") | ||
|
||
api_key = authorization.split(" ")[1] | ||
body = await request.body() | ||
data = json.loads(body) | ||
|
||
base_url = self._get_base_url() | ||
data["base_url"] = base_url | ||
|
||
# litellm workaround - add openrouter/ prefix to model name to make it openai-compatible | ||
# once we get rid of litellm, this can simply be removed | ||
original_model = data.get("model", "") | ||
if not original_model.startswith("openrouter/"): | ||
data["model"] = f"openrouter/{original_model}" | ||
|
||
return await self.process_request( | ||
data, | ||
api_key, | ||
request.url.path, | ||
request.state.detected_client, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
import json | ||
from unittest.mock import AsyncMock, MagicMock | ||
|
||
import pytest | ||
from fastapi import HTTPException | ||
from fastapi.requests import Request | ||
|
||
from codegate.config import DEFAULT_PROVIDER_URLS | ||
from codegate.pipeline.factory import PipelineFactory | ||
from codegate.providers.openrouter.provider import OpenRouterProvider | ||
|
||
|
||
@pytest.fixture | ||
def mock_factory(): | ||
return MagicMock(spec=PipelineFactory) | ||
|
||
|
||
@pytest.fixture | ||
def provider(mock_factory): | ||
return OpenRouterProvider(mock_factory) | ||
|
||
|
||
def test_get_base_url(provider): | ||
"""Test that _get_base_url returns the correct OpenRouter API URL""" | ||
assert provider._get_base_url() == DEFAULT_PROVIDER_URLS["openrouter"] | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_model_prefix_added(): | ||
"""Test that model name gets prefixed with openrouter/ when not already present""" | ||
mock_factory = MagicMock(spec=PipelineFactory) | ||
provider = OpenRouterProvider(mock_factory) | ||
provider.process_request = AsyncMock() | ||
|
||
# Mock request | ||
mock_request = MagicMock(spec=Request) | ||
mock_request.body = AsyncMock(return_value=json.dumps({"model": "gpt-4"}).encode()) | ||
mock_request.url.path = "/openrouter/chat/completions" | ||
mock_request.state.detected_client = "test-client" | ||
|
||
# Get the route handler function | ||
route_handlers = [ | ||
route for route in provider.router.routes if route.path == "/openrouter/chat/completions" | ||
] | ||
create_completion = route_handlers[0].endpoint | ||
|
||
await create_completion(request=mock_request, authorization="Bearer test-token") | ||
|
||
# Verify process_request was called with prefixed model | ||
call_args = provider.process_request.call_args[0] | ||
assert call_args[0]["model"] == "openrouter/gpt-4" | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_model_prefix_preserved(): | ||
"""Test that model name is not modified when openrouter/ prefix is already present""" | ||
mock_factory = MagicMock(spec=PipelineFactory) | ||
provider = OpenRouterProvider(mock_factory) | ||
provider.process_request = AsyncMock() | ||
|
||
# Mock request | ||
mock_request = MagicMock(spec=Request) | ||
mock_request.body = AsyncMock(return_value=json.dumps({"model": "openrouter/gpt-4"}).encode()) | ||
mock_request.url.path = "/openrouter/chat/completions" | ||
mock_request.state.detected_client = "test-client" | ||
|
||
# Get the route handler function | ||
route_handlers = [ | ||
route for route in provider.router.routes if route.path == "/openrouter/chat/completions" | ||
] | ||
create_completion = route_handlers[0].endpoint | ||
|
||
await create_completion(request=mock_request, authorization="Bearer test-token") | ||
|
||
# Verify process_request was called with unchanged model name | ||
call_args = provider.process_request.call_args[0] | ||
assert call_args[0]["model"] == "openrouter/gpt-4" | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_invalid_auth_header(): | ||
"""Test that invalid authorization header format raises HTTPException""" | ||
mock_factory = MagicMock(spec=PipelineFactory) | ||
provider = OpenRouterProvider(mock_factory) | ||
|
||
mock_request = MagicMock(spec=Request) | ||
|
||
# Get the route handler function | ||
route_handlers = [ | ||
route for route in provider.router.routes if route.path == "/openrouter/chat/completions" | ||
] | ||
create_completion = route_handlers[0].endpoint | ||
|
||
with pytest.raises(HTTPException) as exc_info: | ||
await create_completion(request=mock_request, authorization="InvalidToken") | ||
|
||
assert exc_info.value.status_code == 401 | ||
assert exc_info.value.detail == "Invalid authorization header" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters