Skip to content

Commit

Permalink
feat: add workspace preferred model (#217)
Browse files Browse the repository at this point in the history
* feat(mux): wip model overrides workspace update

* fetch providers from BE (currently mocked)

* render multiple overrides

* fix form submit

* feat: add/remove model ovveride row

* add sorting feature

* fix sortable partially

* fix sorting

* .

* fix scrolling

* actually save filter to zustand store

* fix label

* fix label layout

* remove console logs

* remove unnecessary div

* extract <SortableArea />

* use renderprop in <SortableArea />

* simplify drag behavior

* simplify folder structure

* fix: add key to array item

* test: add workspace model override

* fix: type

* fix: mock types

* chore: update api

* feat: add discard button

* refactor: configure preferred model

* hide preferred model section

* chore: update api

* refactor: notification message

---------

Co-authored-by: Daniel Kantor <[email protected]>
  • Loading branch information
peppescg and kantord authored Jan 29, 2025
1 parent aebe051 commit 90dfbe9
Show file tree
Hide file tree
Showing 13 changed files with 724 additions and 190 deletions.
98 changes: 61 additions & 37 deletions src/api/generated/@tanstack/react-query.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@ import {
healthCheckHealthGet,
v1ListProviderEndpoints,
v1AddProviderEndpoint,
v1ListAllModelsForAllProviders,
v1ListModelsByProvider,
v1GetProviderEndpoint,
v1UpdateProviderEndpoint,
v1DeleteProviderEndpoint,
v1ListModelsByProvider,
v1ListAllModelsForAllProviders,
v1ConfigureAuthMaterial,
v1ListWorkspaces,
v1CreateWorkspace,
v1ListActiveWorkspaces,
Expand All @@ -36,14 +37,17 @@ import type {
V1AddProviderEndpointData,
V1AddProviderEndpointError,
V1AddProviderEndpointResponse,
V1ListModelsByProviderData,
V1GetProviderEndpointData,
V1UpdateProviderEndpointData,
V1UpdateProviderEndpointError,
V1UpdateProviderEndpointResponse,
V1DeleteProviderEndpointData,
V1DeleteProviderEndpointError,
V1DeleteProviderEndpointResponse,
V1ListModelsByProviderData,
V1ConfigureAuthMaterialData,
V1ConfigureAuthMaterialError,
V1ConfigureAuthMaterialResponse,
V1CreateWorkspaceData,
V1CreateWorkspaceError,
V1CreateWorkspaceResponse,
Expand Down Expand Up @@ -190,6 +194,48 @@ export const v1AddProviderEndpointMutation = (
return mutationOptions;
};

export const v1ListAllModelsForAllProvidersQueryKey = (
options?: OptionsLegacyParser,
) => [createQueryKey("v1ListAllModelsForAllProviders", options)];

export const v1ListAllModelsForAllProvidersOptions = (
options?: OptionsLegacyParser,
) => {
return queryOptions({
queryFn: async ({ queryKey, signal }) => {
const { data } = await v1ListAllModelsForAllProviders({
...options,
...queryKey[0],
signal,
throwOnError: true,
});
return data;
},
queryKey: v1ListAllModelsForAllProvidersQueryKey(options),
});
};

export const v1ListModelsByProviderQueryKey = (
options: OptionsLegacyParser<V1ListModelsByProviderData>,
) => [createQueryKey("v1ListModelsByProvider", options)];

export const v1ListModelsByProviderOptions = (
options: OptionsLegacyParser<V1ListModelsByProviderData>,
) => {
return queryOptions({
queryFn: async ({ queryKey, signal }) => {
const { data } = await v1ListModelsByProvider({
...options,
...queryKey[0],
signal,
throwOnError: true,
});
return data;
},
queryKey: v1ListModelsByProviderQueryKey(options),
});
};

export const v1GetProviderEndpointQueryKey = (
options: OptionsLegacyParser<V1GetProviderEndpointData>,
) => [createQueryKey("v1GetProviderEndpoint", options)];
Expand Down Expand Up @@ -251,46 +297,24 @@ export const v1DeleteProviderEndpointMutation = (
return mutationOptions;
};

export const v1ListModelsByProviderQueryKey = (
options: OptionsLegacyParser<V1ListModelsByProviderData>,
) => [createQueryKey("v1ListModelsByProvider", options)];

export const v1ListModelsByProviderOptions = (
options: OptionsLegacyParser<V1ListModelsByProviderData>,
) => {
return queryOptions({
queryFn: async ({ queryKey, signal }) => {
const { data } = await v1ListModelsByProvider({
...options,
...queryKey[0],
signal,
throwOnError: true,
});
return data;
},
queryKey: v1ListModelsByProviderQueryKey(options),
});
};

export const v1ListAllModelsForAllProvidersQueryKey = (
options?: OptionsLegacyParser,
) => [createQueryKey("v1ListAllModelsForAllProviders", options)];

export const v1ListAllModelsForAllProvidersOptions = (
options?: OptionsLegacyParser,
export const v1ConfigureAuthMaterialMutation = (
options?: Partial<OptionsLegacyParser<V1ConfigureAuthMaterialData>>,
) => {
return queryOptions({
queryFn: async ({ queryKey, signal }) => {
const { data } = await v1ListAllModelsForAllProviders({
const mutationOptions: UseMutationOptions<
V1ConfigureAuthMaterialResponse,
V1ConfigureAuthMaterialError,
OptionsLegacyParser<V1ConfigureAuthMaterialData>
> = {
mutationFn: async (localOptions) => {
const { data } = await v1ConfigureAuthMaterial({
...options,
...queryKey[0],
signal,
...localOptions,
throwOnError: true,
});
return data;
},
queryKey: v1ListAllModelsForAllProvidersQueryKey(options),
});
};
return mutationOptions;
};

export const v1ListWorkspacesQueryKey = (options?: OptionsLegacyParser) => [
Expand Down
84 changes: 52 additions & 32 deletions src/api/generated/sdk.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ import type {
V1AddProviderEndpointData,
V1AddProviderEndpointError,
V1AddProviderEndpointResponse,
V1ListAllModelsForAllProvidersError,
V1ListAllModelsForAllProvidersResponse,
V1ListModelsByProviderData,
V1ListModelsByProviderError,
V1ListModelsByProviderResponse,
V1GetProviderEndpointData,
V1GetProviderEndpointError,
V1GetProviderEndpointResponse,
Expand All @@ -23,11 +28,9 @@ import type {
V1DeleteProviderEndpointData,
V1DeleteProviderEndpointError,
V1DeleteProviderEndpointResponse,
V1ListModelsByProviderData,
V1ListModelsByProviderError,
V1ListModelsByProviderResponse,
V1ListAllModelsForAllProvidersError,
V1ListAllModelsForAllProvidersResponse,
V1ConfigureAuthMaterialData,
V1ConfigureAuthMaterialError,
V1ConfigureAuthMaterialResponse,
V1ListWorkspacesError,
V1ListWorkspacesResponse,
V1CreateWorkspaceData,
Expand Down Expand Up @@ -131,6 +134,42 @@ export const v1AddProviderEndpoint = <ThrowOnError extends boolean = false>(
});
};

/**
* List All Models For All Providers
* List all models for all providers.
*/
export const v1ListAllModelsForAllProviders = <
ThrowOnError extends boolean = false,
>(
options?: OptionsLegacyParser<unknown, ThrowOnError>,
) => {
return (options?.client ?? client).get<
V1ListAllModelsForAllProvidersResponse,
V1ListAllModelsForAllProvidersError,
ThrowOnError
>({
...options,
url: "/api/v1/provider-endpoints/models",
});
};

/**
* List Models By Provider
* List models by provider.
*/
export const v1ListModelsByProvider = <ThrowOnError extends boolean = false>(
options: OptionsLegacyParser<V1ListModelsByProviderData, ThrowOnError>,
) => {
return (options?.client ?? client).get<
V1ListModelsByProviderResponse,
V1ListModelsByProviderError,
ThrowOnError
>({
...options,
url: "/api/v1/provider-endpoints/{provider_id}/models",
});
};

/**
* Get Provider Endpoint
* Get a provider endpoint by ID.
Expand Down Expand Up @@ -183,38 +222,19 @@ export const v1DeleteProviderEndpoint = <ThrowOnError extends boolean = false>(
};

/**
* List Models By Provider
* List models by provider.
*/
export const v1ListModelsByProvider = <ThrowOnError extends boolean = false>(
options: OptionsLegacyParser<V1ListModelsByProviderData, ThrowOnError>,
) => {
return (options?.client ?? client).get<
V1ListModelsByProviderResponse,
V1ListModelsByProviderError,
ThrowOnError
>({
...options,
url: "/api/v1/provider-endpoints/{provider_name}/models",
});
};

/**
* List All Models For All Providers
* List all models for all providers.
* Configure Auth Material
* Configure auth material for a provider.
*/
export const v1ListAllModelsForAllProviders = <
ThrowOnError extends boolean = false,
>(
options?: OptionsLegacyParser<unknown, ThrowOnError>,
export const v1ConfigureAuthMaterial = <ThrowOnError extends boolean = false>(
options: OptionsLegacyParser<V1ConfigureAuthMaterialData, ThrowOnError>,
) => {
return (options?.client ?? client).get<
V1ListAllModelsForAllProvidersResponse,
V1ListAllModelsForAllProvidersError,
return (options?.client ?? client).put<
V1ConfigureAuthMaterialResponse,
V1ConfigureAuthMaterialError,
ThrowOnError
>({
...options,
url: "/api/v1/provider-endpoints/models",
url: "/api/v1/provider-endpoints/{provider_id}/auth-material",
});
};

Expand Down
55 changes: 38 additions & 17 deletions src/api/generated/types.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@ export type CodeSnippet = {
libraries?: Array<string>;
};

/**
* Represents a request to configure auth material for a provider.
*/
export type ConfigureAuthMaterial = {
auth_type: ProviderAuthType;
api_key?: string | null;
};

/**
* Represents a conversation.
*/
Expand Down Expand Up @@ -84,7 +92,8 @@ export type ListWorkspacesResponse = {
*/
export type ModelByProvider = {
name: string;
provider: string;
provider_id: string;
provider_name: string;
};

/**
Expand All @@ -99,10 +108,10 @@ export enum MuxMatcherType {
* Represents a mux rule for a provider.
*/
export type MuxRule = {
provider: string;
provider_id: string;
model: string;
matcher_type: MuxMatcherType;
matcher: string | null;
matcher?: string | null;
};

/**
Expand All @@ -120,12 +129,12 @@ export enum ProviderAuthType {
* so we can use this for muxing messages.
*/
export type ProviderEndpoint = {
id: number;
id?: string | null;
name: string;
description?: string;
provider_type: ProviderType;
endpoint: string;
auth_type: ProviderAuthType;
auth_type?: ProviderAuthType | null;
};

/**
Expand All @@ -135,8 +144,9 @@ export enum ProviderType {
OPENAI = "openai",
ANTHROPIC = "anthropic",
VLLM = "vllm",
LLAMACPP = "llamacpp",
OLLAMA = "ollama",
LM_STUDIO = "lm_studio",
LLAMACPP = "llamacpp",
}

/**
Expand Down Expand Up @@ -216,9 +226,23 @@ export type V1AddProviderEndpointResponse = ProviderEndpoint;

export type V1AddProviderEndpointError = HTTPValidationError;

export type V1ListAllModelsForAllProvidersResponse = Array<ModelByProvider>;

export type V1ListAllModelsForAllProvidersError = unknown;

export type V1ListModelsByProviderData = {
path: {
provider_id: string;
};
};

export type V1ListModelsByProviderResponse = Array<ModelByProvider>;

export type V1ListModelsByProviderError = HTTPValidationError;

export type V1GetProviderEndpointData = {
path: {
provider_id: number;
provider_id: string;
};
};

Expand All @@ -229,7 +253,7 @@ export type V1GetProviderEndpointError = HTTPValidationError;
export type V1UpdateProviderEndpointData = {
body: ProviderEndpoint;
path: {
provider_id: number;
provider_id: string;
};
};

Expand All @@ -239,27 +263,24 @@ export type V1UpdateProviderEndpointError = HTTPValidationError;

export type V1DeleteProviderEndpointData = {
path: {
provider_id: number;
provider_id: string;
};
};

export type V1DeleteProviderEndpointResponse = unknown;

export type V1DeleteProviderEndpointError = HTTPValidationError;

export type V1ListModelsByProviderData = {
export type V1ConfigureAuthMaterialData = {
body: ConfigureAuthMaterial;
path: {
provider_name: string;
provider_id: string;
};
};

export type V1ListModelsByProviderResponse = Array<ModelByProvider>;

export type V1ListModelsByProviderError = HTTPValidationError;
export type V1ConfigureAuthMaterialResponse = void;

export type V1ListAllModelsForAllProvidersResponse = Array<ModelByProvider>;

export type V1ListAllModelsForAllProvidersError = unknown;
export type V1ConfigureAuthMaterialError = HTTPValidationError;

export type V1ListWorkspacesResponse = ListWorkspacesResponse;

Expand Down
Loading

0 comments on commit 90dfbe9

Please sign in to comment.