Skip to content

Commit

Permalink
Merge branch 'main' into malicious-openai
Browse files Browse the repository at this point in the history
  • Loading branch information
rdimitrov authored Feb 4, 2025
2 parents 50634ce + ece8831 commit 048f8cb
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 44 deletions.
2 changes: 1 addition & 1 deletion api/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/AddProviderEndpointRequest"
"$ref": "#/components/schemas/ProviderEndpoint"
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/codegate/api/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ async def configure_auth_material(
)
async def update_provider_endpoint(
provider_id: UUID,
request: v1_models.AddProviderEndpointRequest,
request: v1_models.ProviderEndpoint,
) -> v1_models.ProviderEndpoint:
"""Update a provider endpoint by ID."""
try:
Expand Down
77 changes: 35 additions & 42 deletions src/codegate/providers/crud/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ async def add_endpoint(
return apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint)

async def update_endpoint(
self, endpoint: apimodelsv1.AddProviderEndpointRequest
self, endpoint: apimodelsv1.ProviderEndpoint
) -> apimodelsv1.ProviderEndpoint:
"""Update an endpoint."""

Expand All @@ -134,12 +134,40 @@ async def update_endpoint(
if founddbe is None:
raise ProviderNotFoundError("Provider not found")

models = []
if endpoint.auth_type == apimodelsv1.ProviderAuthType.api_key and not endpoint.api_key:
dbendpoint = await self._db_writer.update_provider_endpoint(endpoint.to_db_model())

return apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint)

async def configure_auth_material(
self, provider_id: UUID, config: apimodelsv1.ConfigureAuthMaterial
):
"""Add an API key."""
if config.auth_type == apimodelsv1.ProviderAuthType.api_key and not config.api_key:
raise ValueError("API key must be provided for API auth type")
if endpoint.auth_type != apimodelsv1.ProviderAuthType.passthrough:
elif config.auth_type != apimodelsv1.ProviderAuthType.api_key and config.api_key:
raise ValueError("API key provided for non-API auth type")

dbendpoint = await self._db_reader.get_provider_endpoint_by_id(str(provider_id))
if dbendpoint is None:
raise ProviderNotFoundError("Provider not found")

await self._db_writer.push_provider_auth_material(
dbmodels.ProviderAuthMaterial(
provider_endpoint_id=dbendpoint.id,
auth_type=config.auth_type,
auth_blob=config.api_key if config.api_key else "",
)
)

endpoint = apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint)
endpoint.auth_type = config.auth_type
provider_registry = get_provider_registry()
prov = endpoint.get_from_registry(provider_registry)

models = []
if config.auth_type != apimodelsv1.ProviderAuthType.passthrough:
try:
models = prov.models(endpoint=endpoint.endpoint, api_key=endpoint.api_key)
models = prov.models(endpoint=endpoint.endpoint, api_key=config.api_key)
except Exception as err:
raise ValueError("Unable to get models from provider: {}".format(str(err)))

Expand All @@ -154,56 +182,21 @@ async def update_endpoint(
for model in models_set - models_in_db_set:
await self._db_writer.add_provider_model(
dbmodels.ProviderModel(
provider_endpoint_id=founddbe.id,
provider_endpoint_id=dbendpoint.id,
name=model,
)
)

# Remove the models that are in the DB but not in the provider
for model in models_in_db_set - models_set:
await self._db_writer.delete_provider_model(
founddbe.id,
dbendpoint.id,
model,
)

dbendpoint = await self._db_writer.update_provider_endpoint(endpoint.to_db_model())

# If an API key was provided or we've changed the auth type, we update the auth material
if endpoint.auth_type != founddbe.auth_type or endpoint.api_key:
await self._db_writer.push_provider_auth_material(
dbmodels.ProviderAuthMaterial(
provider_endpoint_id=dbendpoint.id,
auth_type=endpoint.auth_type,
auth_blob=endpoint.api_key if endpoint.api_key else "",
)
)

# a model might have been deleted, let's repopulate the cache
await self._ws_crud.repopulate_mux_cache()

return apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint)

async def configure_auth_material(
self, provider_id: UUID, config: apimodelsv1.ConfigureAuthMaterial
):
"""Add an API key."""
if config.auth_type == apimodelsv1.ProviderAuthType.api_key and not config.api_key:
raise ValueError("API key must be provided for API auth type")
elif config.auth_type != apimodelsv1.ProviderAuthType.api_key and config.api_key:
raise ValueError("API key provided for non-API auth type")

dbendpoint = await self._db_reader.get_provider_endpoint_by_id(str(provider_id))
if dbendpoint is None:
raise ProviderNotFoundError("Provider not found")

await self._db_writer.push_provider_auth_material(
dbmodels.ProviderAuthMaterial(
provider_endpoint_id=dbendpoint.id,
auth_type=config.auth_type,
auth_blob=config.api_key if config.api_key else "",
)
)

async def delete_endpoint(self, provider_id: UUID):
"""Delete an endpoint."""

Expand Down

0 comments on commit 048f8cb

Please sign in to comment.