Skip to content

Commit

Permalink
Add an OpenRouter provider (#921)
Browse files Browse the repository at this point in the history
* Move _get_base_url to the base provider

In order to properly support "muxing providers" like openrouter, we'll
have to tell litellm (or in future a native implementation), what server
do we want to proxy to. We were already doing that with Vllm, but since
are about to do the same for OpenRouter, let's move the `_get_base_url`
method to the base provider.

* Add an openrouter provider

OpenRouter is a "muxing provider" which itself provides access to
multiple models and providers. It speaks a dialect of the OpenAI protocol, but
for our purposes, we can say it's OpenAI.

There are some differences in handling the requests, though:
1) we need to know where to forward the request to, by default this is
   `https://openrouter.ai/api/v1`, this is done by setting the base_url
   parameter
2) we need to prefix the model with `openrouter/`. This is a
   lite-LLM-ism (see https://docs.litellm.ai/docs/providers/openrouter)
   which we'll be able to remove once we ditch litellm

Initially I was considering just exposing the OpenAI provider on an
additional route and handling the prefix based on the route, but I think
having an explicit provider class is better as it allows us to handle
any differences in OpenRouter dialect easily in the future.

Related: #878

* Add a special ProviderType for openrouter

We can later alias it to openai if we decide to merge them.

* Add tests for the openrouter provider

* ProviderType was reversed, thanks Alejandro

---------

Co-authored-by: Radoslav Dimitrov <[email protected]>
  • Loading branch information
jhrozek and rdimitrov authored Feb 6, 2025
1 parent 37168b5 commit 1be0bfe
Show file tree
Hide file tree
Showing 11 changed files with 170 additions and 5 deletions.
1 change: 1 addition & 0 deletions src/codegate/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# Default provider URLs
DEFAULT_PROVIDER_URLS = {
"openai": "https://api.openai.com/v1",
"openrouter": "https://openrouter.ai/api/v1",
"anthropic": "https://api.anthropic.com/v1",
"vllm": "http://localhost:8000", # Base URL without /v1 path
"ollama": "http://localhost:11434", # Default Ollama server URL
Expand Down
1 change: 1 addition & 0 deletions src/codegate/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class ProviderType(str, Enum):
ollama = "ollama"
lm_studio = "lm_studio"
llamacpp = "llamacpp"
openrouter = "openai"


class GetPromptWithOutputsRow(BaseModel):
Expand Down
2 changes: 2 additions & 0 deletions src/codegate/muxing/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ def __init__(self):
db_models.ProviderType.anthropic: self._format_antropic,
# Our Lllamacpp provider emits OpenAI chunks
db_models.ProviderType.llamacpp: self._format_openai,
# OpenRouter is a dialect of OpenAI
db_models.ProviderType.openrouter: self._format_openai,
}

def _format_ollama(self, chunk: str) -> str:
Expand Down
2 changes: 2 additions & 0 deletions src/codegate/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
from codegate.providers.base import BaseProvider
from codegate.providers.ollama.provider import OllamaProvider
from codegate.providers.openai.provider import OpenAIProvider
from codegate.providers.openrouter.provider import OpenRouterProvider
from codegate.providers.registry import ProviderRegistry
from codegate.providers.vllm.provider import VLLMProvider

__all__ = [
"BaseProvider",
"ProviderRegistry",
"OpenAIProvider",
"OpenRouterProvider",
"AnthropicProvider",
"VLLMProvider",
"OllamaProvider",
Expand Down
8 changes: 8 additions & 0 deletions src/codegate/providers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from codegate.clients.clients import ClientType
from codegate.codegate_logging import setup_logging
from codegate.config import Config
from codegate.db.connection import DbRecorder
from codegate.pipeline.base import (
PipelineContext,
Expand Down Expand Up @@ -88,6 +89,13 @@ async def process_request(
def provider_route_name(self) -> str:
pass

def _get_base_url(self) -> str:
"""
Get the base URL from config with proper formatting
"""
config = Config.get_config()
return config.provider_urls.get(self.provider_route_name) if config else ""

async def _run_output_stream_pipeline(
self,
input_context: PipelineContext,
Expand Down
3 changes: 3 additions & 0 deletions src/codegate/providers/openai/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from codegate.providers.openai.provider import OpenAIProvider

__all__ = ["OpenAIProvider"]
47 changes: 47 additions & 0 deletions src/codegate/providers/openrouter/provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import json

from fastapi import Header, HTTPException, Request

from codegate.clients.detector import DetectClient
from codegate.pipeline.factory import PipelineFactory
from codegate.providers.openai import OpenAIProvider


class OpenRouterProvider(OpenAIProvider):
def __init__(self, pipeline_factory: PipelineFactory):
super().__init__(pipeline_factory)

@property
def provider_route_name(self) -> str:
return "openrouter"

def _setup_routes(self):
@self.router.post(f"/{self.provider_route_name}/api/v1/chat/completions")
@self.router.post(f"/{self.provider_route_name}/chat/completions")
@DetectClient()
async def create_completion(
request: Request,
authorization: str = Header(..., description="Bearer token"),
):
if not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Invalid authorization header")

api_key = authorization.split(" ")[1]
body = await request.body()
data = json.loads(body)

base_url = self._get_base_url()
data["base_url"] = base_url

# litellm workaround - add openrouter/ prefix to model name to make it openai-compatible
# once we get rid of litellm, this can simply be removed
original_model = data.get("model", "")
if not original_model.startswith("openrouter/"):
data["model"] = f"openrouter/{original_model}"

return await self.process_request(
data,
api_key,
request.url.path,
request.state.detected_client,
)
4 changes: 1 addition & 3 deletions src/codegate/providers/vllm/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from codegate.clients.clients import ClientType
from codegate.clients.detector import DetectClient
from codegate.config import Config
from codegate.pipeline.factory import PipelineFactory
from codegate.providers.base import BaseProvider, ModelFetchError
from codegate.providers.litellmshim import LiteLLmShim, sse_stream_generator
Expand Down Expand Up @@ -39,8 +38,7 @@ def _get_base_url(self) -> str:
"""
Get the base URL from config with proper formatting
"""
config = Config.get_config()
base_url = config.provider_urls.get("vllm") if config else ""
base_url = super()._get_base_url()
if base_url:
base_url = base_url.rstrip("/")
# Add /v1 if not present
Expand Down
5 changes: 5 additions & 0 deletions src/codegate/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from codegate.providers.lm_studio.provider import LmStudioProvider
from codegate.providers.ollama.provider import OllamaProvider
from codegate.providers.openai.provider import OpenAIProvider
from codegate.providers.openrouter.provider import OpenRouterProvider
from codegate.providers.registry import ProviderRegistry, get_provider_registry
from codegate.providers.vllm.provider import VLLMProvider

Expand Down Expand Up @@ -75,6 +76,10 @@ async def log_user_agent(request: Request, call_next):
ProviderType.openai,
OpenAIProvider(pipeline_factory),
)
registry.add_provider(
ProviderType.openrouter,
OpenRouterProvider(pipeline_factory),
)
registry.add_provider(
ProviderType.anthropic,
AnthropicProvider(
Expand Down
98 changes: 98 additions & 0 deletions tests/providers/openrouter/test_openrouter_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import json
from unittest.mock import AsyncMock, MagicMock

import pytest
from fastapi import HTTPException
from fastapi.requests import Request

from codegate.config import DEFAULT_PROVIDER_URLS
from codegate.pipeline.factory import PipelineFactory
from codegate.providers.openrouter.provider import OpenRouterProvider


@pytest.fixture
def mock_factory():
return MagicMock(spec=PipelineFactory)


@pytest.fixture
def provider(mock_factory):
return OpenRouterProvider(mock_factory)


def test_get_base_url(provider):
"""Test that _get_base_url returns the correct OpenRouter API URL"""
assert provider._get_base_url() == DEFAULT_PROVIDER_URLS["openrouter"]


@pytest.mark.asyncio
async def test_model_prefix_added():
"""Test that model name gets prefixed with openrouter/ when not already present"""
mock_factory = MagicMock(spec=PipelineFactory)
provider = OpenRouterProvider(mock_factory)
provider.process_request = AsyncMock()

# Mock request
mock_request = MagicMock(spec=Request)
mock_request.body = AsyncMock(return_value=json.dumps({"model": "gpt-4"}).encode())
mock_request.url.path = "/openrouter/chat/completions"
mock_request.state.detected_client = "test-client"

# Get the route handler function
route_handlers = [
route for route in provider.router.routes if route.path == "/openrouter/chat/completions"
]
create_completion = route_handlers[0].endpoint

await create_completion(request=mock_request, authorization="Bearer test-token")

# Verify process_request was called with prefixed model
call_args = provider.process_request.call_args[0]
assert call_args[0]["model"] == "openrouter/gpt-4"


@pytest.mark.asyncio
async def test_model_prefix_preserved():
"""Test that model name is not modified when openrouter/ prefix is already present"""
mock_factory = MagicMock(spec=PipelineFactory)
provider = OpenRouterProvider(mock_factory)
provider.process_request = AsyncMock()

# Mock request
mock_request = MagicMock(spec=Request)
mock_request.body = AsyncMock(return_value=json.dumps({"model": "openrouter/gpt-4"}).encode())
mock_request.url.path = "/openrouter/chat/completions"
mock_request.state.detected_client = "test-client"

# Get the route handler function
route_handlers = [
route for route in provider.router.routes if route.path == "/openrouter/chat/completions"
]
create_completion = route_handlers[0].endpoint

await create_completion(request=mock_request, authorization="Bearer test-token")

# Verify process_request was called with unchanged model name
call_args = provider.process_request.call_args[0]
assert call_args[0]["model"] == "openrouter/gpt-4"


@pytest.mark.asyncio
async def test_invalid_auth_header():
"""Test that invalid authorization header format raises HTTPException"""
mock_factory = MagicMock(spec=PipelineFactory)
provider = OpenRouterProvider(mock_factory)

mock_request = MagicMock(spec=Request)

# Get the route handler function
route_handlers = [
route for route in provider.router.routes if route.path == "/openrouter/chat/completions"
]
create_completion = route_handlers[0].endpoint

with pytest.raises(HTTPException) as exc_info:
await create_completion(request=mock_request, authorization="InvalidToken")

assert exc_info.value.status_code == 401
assert exc_info.value.detail == "Invalid authorization header"
4 changes: 2 additions & 2 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ def test_provider_registration(mock_registry, mock_secrets_mgr, mock_pipeline_fa
# Verify all providers were registered
registry_instance = mock_registry.return_value
assert (
registry_instance.add_provider.call_count == 6
) # openai, anthropic, llamacpp, vllm, ollama, lm_studio
registry_instance.add_provider.call_count == 7
) # openai, anthropic, llamacpp, vllm, ollama, lm_studio, openrouter

# Verify specific providers were registered
provider_names = [call.args[0] for call in registry_instance.add_provider.call_args_list]
Expand Down

0 comments on commit 1be0bfe

Please sign in to comment.