From d95e2362336af494f6bcb0162be16c158820a714 Mon Sep 17 00:00:00 2001 From: Safoora Yousefi Date: Fri, 18 Oct 2024 08:36:10 +0000 Subject: [PATCH 1/3] reorg +azure o1 model --- eureka_ml_insights/configs/model_configs.py | 65 +++++--- eureka_ml_insights/models/__init__.py | 54 +++---- eureka_ml_insights/models/models.py | 168 +++++++++++--------- 3 files changed, 161 insertions(+), 126 deletions(-) diff --git a/eureka_ml_insights/configs/model_configs.py b/eureka_ml_insights/configs/model_configs.py index 7b72988..f3588d8 100644 --- a/eureka_ml_insights/configs/model_configs.py +++ b/eureka_ml_insights/configs/model_configs.py @@ -3,14 +3,16 @@ You can also add your custom models here by following the same pattern as the existing configs. """ from eureka_ml_insights.models import ( - ClaudeModels, - GeminiModels, - LlamaServerlessAzureRestEndpointModels, - LLaVA, - LLaVAHuggingFaceMM, - MistralServerlessAzureRestEndpointModels, - OpenAIModelsOAI, - RestEndpointModels, + ClaudeModel, + GeminiModel, + LlamaServerlessAzureRestEndpointModel, + LLaVAModel, + LLaVAHuggingFaceModel, + MistralServerlessAzureRestEndpointModel, + DirectOpenAIModel, + AzureOpenAIO1Model, + DirectOpenAIO1Model, + RestEndpointModel, ) from .config import ModelConfig @@ -22,8 +24,25 @@ "key_vault_url": None, } +OAI_O1_PREVIEW_CONFIG = ModelConfig( + DirectOpenAIO1Model, + { + "model_name": "o1-preview", + "secret_key_params": OPENAI_SECRET_KEY_PARAMS, + }, +) + +OAI_O1_PREVIEW_AUZRE_CONFIG = ModelConfig( + AzureOpenAIO1Model, + { + "model_name": "o1-preview", + "url": "your/endpoint/url", + "api_version": "2024-08-01-preview", + } +) + OAI_GPT4_1106_PREVIEW_CONFIG = ModelConfig( - OpenAIModelsOAI, + DirectOpenAIModel, { "model_name": "gpt-4-1106-preview", "secret_key_params": OPENAI_SECRET_KEY_PARAMS, @@ -31,7 +50,7 @@ ) OAI_GPT4V_1106_VISION_PREVIEW_CONFIG = ModelConfig( - OpenAIModelsOAI, + DirectOpenAIModel, { "model_name": "gpt-4-1106-vision-preview", "secret_key_params": OPENAI_SECRET_KEY_PARAMS, @@ -39,7 +58,7 @@ ) OAI_GPT4V_TURBO_2024_04_09_CONFIG = ModelConfig( - OpenAIModelsOAI, + DirectOpenAIModel, { "model_name": "gpt-4-turbo-2024-04-09", "secret_key_params": OPENAI_SECRET_KEY_PARAMS, @@ -47,7 +66,7 @@ ) OAI_GPT4O_2024_05_13_CONFIG = ModelConfig( - OpenAIModelsOAI, + DirectOpenAIModel, { "model_name": "gpt-4o-2024-05-13", "secret_key_params": OPENAI_SECRET_KEY_PARAMS, @@ -63,7 +82,7 @@ } GEMINI_V15_PRO_CONFIG = ModelConfig( - GeminiModels, + GeminiModel, { "model_name": "gemini-1.5-pro", "secret_key_params": GEMINI_SECRET_KEY_PARAMS, @@ -71,7 +90,7 @@ ) GEMINI_V1_PRO_CONFIG = ModelConfig( - GeminiModels, + GeminiModel, { "model_name": "gemini-1.0-pro", "secret_key_params": GEMINI_SECRET_KEY_PARAMS, @@ -86,7 +105,7 @@ } CLAUDE_3_OPUS_CONFIG = ModelConfig( - ClaudeModels, + ClaudeModel, { "model_name": "claude-3-opus-20240229", "secret_key_params": CLAUDE_SECRET_KEY_PARAMS, @@ -94,7 +113,7 @@ ) CLAUDE_3_5_SONNET_CONFIG = ModelConfig( - ClaudeModels, + ClaudeModel, { "secret_key_params": CLAUDE_SECRET_KEY_PARAMS, "model_name": "claude-3-5-sonnet-20240620", @@ -103,29 +122,29 @@ # LLAVA models LLAVAHF_V16_34B_CONFIG = ModelConfig( - LLaVAHuggingFaceMM, + LLaVAHuggingFaceModel, {"model_name": "llava-hf/llava-v1.6-34b-hf", "use_flash_attn": True}, ) LLAVAHF_V15_7B_CONFIG = ModelConfig( - LLaVAHuggingFaceMM, + LLaVAHuggingFaceModel, {"model_name": "llava-hf/llava-1.5-7b-hf", "use_flash_attn": True}, ) LLAVA_V16_34B_CONFIG = ModelConfig( - LLaVA, + LLaVAModel, {"model_name": "liuhaotian/llava-v1.6-34b", "use_flash_attn": True}, ) LLAVA_V15_7B_CONFIG = ModelConfig( - LLaVA, + LLaVAModel, {"model_name": "liuhaotian/llava-v1.5-7b", "use_flash_attn": True}, ) # Llama models LLAMA3_1_70B_INSTRUCT_CONFIG = ModelConfig( - RestEndpointModels, + RestEndpointModel, { "url": "your/endpoint/url", "secret_key_params": { @@ -138,7 +157,7 @@ ) LLAMA3_1_405B_INSTRUCT_CONFIG = ModelConfig( - LlamaServerlessAzureRestEndpointModels, + LlamaServerlessAzureRestEndpointModel, { "url": "your/endpoint/url", "secret_key_params": { @@ -152,7 +171,7 @@ # Mistral Endpoints AIF_NT_MISTRAL_LARGE_2_2407_CONFIG = ModelConfig( - MistralServerlessAzureRestEndpointModels, + MistralServerlessAzureRestEndpointModel, { "url": "your/endpoint/url", "secret_key_params": { diff --git a/eureka_ml_insights/models/__init__.py b/eureka_ml_insights/models/__init__.py index c7a6926..a945901 100644 --- a/eureka_ml_insights/models/__init__.py +++ b/eureka_ml_insights/models/__init__.py @@ -1,32 +1,32 @@ from .models import ( - ClaudeModels, - GeminiModels, - HuggingFaceLM, - LlamaServerlessAzureRestEndpointModels, - LLaVA, - LLaVAHuggingFaceMM, - MistralServerlessAzureRestEndpointModels, - OpenAIModelsMixIn, - OpenAIModelsAzure, - OpenAIModelsOAI, - OpenAIO1Direct, - Phi3HF, - KeyBasedAuthentication, - EndpointModels, - RestEndpointModels + ClaudeModel, + GeminiModel, + HuggingFaceModel, + LlamaServerlessAzureRestEndpointModel, + LLaVAModel, + LLaVAHuggingFaceModel, + MistralServerlessAzureRestEndpointModel, + AzureOpenAIModel, + DirectOpenAIModel, + AzureOpenAIO1Model, + DirectOpenAIO1Model, + Phi3HFModel, + KeyBasedAuthMixIn, + EndpointModel, + RestEndpointModel ) __all__ = [ - OpenAIModelsMixIn, - OpenAIO1Direct, - HuggingFaceLM, - LLaVAHuggingFaceMM, - Phi3HF, - OpenAIModelsOAI, - OpenAIModelsAzure, - GeminiModels, - ClaudeModels, - MistralServerlessAzureRestEndpointModels, - LlamaServerlessAzureRestEndpointModels, - LLaVA, + AzureOpenAIO1Model, + DirectOpenAIO1Model, + HuggingFaceModel, + LLaVAHuggingFaceModel, + Phi3HFModel, + DirectOpenAIModel, + AzureOpenAIModel, + GeminiModel, + ClaudeModel, + MistralServerlessAzureRestEndpointModel, + LlamaServerlessAzureRestEndpointModel, + LLaVAModel, ] diff --git a/eureka_ml_insights/models/models.py b/eureka_ml_insights/models/models.py index c5ab5a6..ab6e187 100644 --- a/eureka_ml_insights/models/models.py +++ b/eureka_ml_insights/models/models.py @@ -62,7 +62,7 @@ def base64encode(self, query_images): @dataclass -class KeyBasedAuthentication: +class KeyBasedAuthMixIn: """This class is used to handle key-based authentication for models.""" api_key: str = None @@ -73,6 +73,7 @@ def __post_init__(self): raise ValueError("Either api_key or secret_key_params must be provided.") self.api_key = self.get_api_key() + def get_api_key(self): """ This method is used to get the api_key for the models that require key-based authentication. @@ -80,12 +81,12 @@ def get_api_key(self): if api_key is not directly provided, secret_key_params must be provided to get the api_key using GetKey method. """ if self.api_key is None: - self.api_key = GetKey(**self.secret_key_params) + self.api_key = GetKey(**self.secret_key_params) return self.api_key @dataclass -class EndpointModels(Model): +class EndpointModel(Model): """This class is used to interact with API-based models.""" num_retries: int = 3 @@ -149,7 +150,7 @@ def handle_request_error(self, e): @dataclass -class RestEndpointModels(EndpointModels, KeyBasedAuthentication): +class RestEndpointModel(EndpointModel, KeyBasedAuthMixIn): url: str = None model_name: str = None temperature: float = 0 @@ -206,14 +207,14 @@ def get_response(self, request): def handle_request_error(self, e): if isinstance(e, urllib.error.HTTPError): logging.info("The request failed with status code: " + str(e.code)) - # Print the headers - they include the requert ID and the timestamp, which are useful for debugging. + # Print the headers - they include the request ID and the timestamp, which are useful for debugging. logging.info(e.info()) logging.info(e.read().decode("utf8", "ignore")) return False @dataclass -class ServerlessAzureRestEndpointModels(EndpointModels, KeyBasedAuthentication): +class ServerlessAzureRestEndpointModel(EndpointModel, KeyBasedAuthMixIn): """This class can be used for serverless Azure model deployments.""" """https://learn.microsoft.com/en-us/azure/ai-studio/how-to/deploy-models-serverless?tabs=azure-ai-studio""" @@ -222,11 +223,19 @@ class ServerlessAzureRestEndpointModels(EndpointModels, KeyBasedAuthentication): stream: str = "false" def __post_init__(self): - super().__post_init__() - self.headers = { - "Content-Type": "application/json", - "Authorization": ("Bearer " + self.api_key), - } + try: + super().__post_init__() + self.headers = { + "Content-Type": "application/json", + "Authorization": ("Bearer " + self.api_key), + } + except ValueError: + self.bearer_token_provider = get_bearer_token_provider(AzureCliCredential(), "https://cognitiveservices.azure.com/.default") + headers = { + "Content-Type": "application/json", + "Authorization": ("Bearer " + self.bearer_token_provider()), + } + @abstractmethod def create_request(self, text_prompt, query_images=None, system_message=None): @@ -252,7 +261,7 @@ def handle_request_error(self, e): @dataclass -class LlamaServerlessAzureRestEndpointModels(ServerlessAzureRestEndpointModels, KeyBasedAuthentication): +class LlamaServerlessAzureRestEndpointModel(ServerlessAzureRestEndpointModel): """Tested for Llama 3.1 405B Instruct deployments.""" """See https://learn.microsoft.com/en-us/azure/ai-studio/how-to/deploy-models-llama?tabs=llama-three for the api reference.""" @@ -285,7 +294,7 @@ def create_request(self, text_prompt, *args): @dataclass -class MistralServerlessAzureRestEndpointModels(ServerlessAzureRestEndpointModels, KeyBasedAuthentication): +class MistralServerlessAzureRestEndpointModel(ServerlessAzureRestEndpointModel): """Tested for Mistral Large 2 2407 deployments.""" """See https://learn.microsoft.com/en-us/azure/ai-studio/how-to/deploy-models-mistral?tabs=mistral-large#mistral-chat-api for the api reference.""" @@ -317,26 +326,11 @@ def create_request(self, text_prompt, *args): @dataclass -class OpenAIModelsMixIn(EndpointModels): +class OpenAICommonRequestResponseMixIn(): """ - This class defines the request and response handling for OpenAI models. - This is an abstract class and should not be used directly. Child classes should implement the get_client - method and handle_request_error method. + This mixin class defines the request and response handling for most OpenAI models. """ - model_name: str = None - temperature: float = 0 - max_tokens: int = 2000 - top_p: float = 0.95 - frequency_penalty: float = 0 - presence_penalty: float = 0 - seed: int = 0 - api_version: str = "2023-06-01-preview" - - @abstractmethod - def get_client(self): - raise NotImplementedError - def create_request(self, prompt, query_images=None, system_message=None): messages = [] if system_message: @@ -373,20 +367,9 @@ def get_response(self, request): self.model_output = openai_response["choices"][0]["message"]["content"] self.response_time = end_time - start_time - @abstractmethod - def handle_request_error(self, e): - raise NotImplementedError - - -@dataclass -class OpenAIModelsAzure(OpenAIModelsMixIn): - """This class is used to interact with Azure OpenAI models.""" - - url: str = None - - def __post_init__(self): - self.client = self.get_client() +class AzureOpenAIClientMixIn(): + """This mixin provides some methods to interact with Azure OpenAI models.""" def get_client(self): from openai import AzureOpenAI @@ -406,14 +389,8 @@ def handle_request_error(self, e): return False -@dataclass -class OpenAIModelsOAI(OpenAIModelsMixIn, KeyBasedAuthentication): - """This class is used to interact with OpenAI models dirctly (not through Azure)""" - - def __post_init__(self): - super().__post_init__() - self.client = self.get_client() - +class DirectOpenAIClientMixIn(KeyBasedAuthMixIn): + """This mixin class provides some methods for using OpenAI models dirctly (not through Azure)""" def get_client(self): from openai import OpenAI @@ -425,32 +402,40 @@ def handle_request_error(self, e): logging.warning(e) return False - @dataclass -class OpenAIO1Direct(EndpointModels, KeyBasedAuthentication): +class AzureOpenAIModel(OpenAICommonRequestResponseMixIn, AzureOpenAIClientMixIn, EndpointModel): + """This class is used to interact with Azure OpenAI models.""" + url: str = None model_name: str = None - temperature: float = 1 - # Not used currently, because the API throws: - # "Completions.create() got an unexpected keyword argument 'max_completion_tokens'" - # although this argument is documented in the OpenAI API documentation. - max_completion_tokens: int = 2000 - top_p: float = 1 - seed: int = 0 + temperature: float = 0 + max_tokens: int = 2000 + top_p: float = 0.95 frequency_penalty: float = 0 presence_penalty: float = 0 + seed: int = 0 + api_version: str = "2023-06-01-preview" def __post_init__(self): - super().__post_init__() self.client = self.get_client() - def get_client(self): - from openai import OpenAI +@dataclass +class DirectOpenAIModel(OpenAICommonRequestResponseMixIn, DirectOpenAIClientMixIn, EndpointModel): + """This class is used to interact with OpenAI models dirctly (not through Azure)""" + model_name: str = None + temperature: float = 0 + max_tokens: int = 2000 + top_p: float = 0.95 + frequency_penalty: float = 0 + presence_penalty: float = 0 + seed: int = 0 + api_version: str = "2023-06-01-preview" - return OpenAI( - api_key=self.api_key, - ) + def __post_init__(self): + super().__post_init__() + self.client = self.get_client() - def create_request(self, prompt, *kwargs): +class OpenAIO1RequestResponseMixIn(): + def create_request(self, prompt, *args, **kwargs): messages = [{"role": "user", "content": prompt}] return {"messages": messages} @@ -470,13 +455,44 @@ def get_response(self, request): self.model_output = openai_response["choices"][0]["message"]["content"] self.response_time = end_time - start_time - def handle_request_error(self, e): - logging.warning(e) - return False +@dataclass +class DirectOpenAIO1Model(OpenAIO1RequestResponseMixIn, DirectOpenAIClientMixIn, EndpointModel): + model_name: str = None + temperature: float = 1 + # Not used currently, because the API throws: + # "Completions.create() got an unexpected keyword argument 'max_completion_tokens'" + # although this argument is documented in the OpenAI API documentation. + max_completion_tokens: int = 2000 + top_p: float = 1 + seed: int = 0 + frequency_penalty: float = 0 + presence_penalty: float = 0 + + def __post_init__(self): + self.client = super().get_client() + +@dataclass +class AzureOpenAIO1Model(OpenAIO1RequestResponseMixIn, AzureOpenAIClientMixIn, EndpointModel): + url: str = None + model_name: str = None + temperature: float = 1 + # Not used currently, because the API throws: + # "Completions.create() got an unexpected keyword argument 'max_completion_tokens'" + # although this argument is documented in the OpenAI API documentation. + max_completion_tokens: int = 2000 + top_p: float = 1 + seed: int = 0 + frequency_penalty: float = 0 + presence_penalty: float = 0 + api_version: str = "2023-06-01-preview" + + + def __post_init__(self): + self.client = super().get_client() @dataclass -class GeminiModels(EndpointModels, KeyBasedAuthentication): +class GeminiModel(EndpointModel, KeyBasedAuthMixIn): """This class is used to interact with Gemini models through the python api.""" timeout: int = 60 @@ -549,7 +565,7 @@ def handle_request_error(self, e): @dataclass -class HuggingFaceLM(Model): +class HuggingFaceModel(Model): """This class is used to run a self-hosted language model via HuggingFace apis.""" model_name: str = None @@ -647,7 +663,7 @@ def model_template_fn(self, text_prompt, system_message=None): @dataclass -class Phi3HF(HuggingFaceLM): +class Phi3HFModel(HuggingFaceModel): """This class is used to run a self-hosted PHI3 model via HuggingFace apis.""" def __post_init__(self): @@ -664,7 +680,7 @@ def model_template_fn(self, text_prompt, system_message=None): @dataclass -class LLaVAHuggingFaceMM(HuggingFaceLM): +class LLaVAHuggingFaceModel(HuggingFaceModel): """This class is used to run a self-hosted LLaVA model via HuggingFace apis.""" quantize: bool = False @@ -786,7 +802,7 @@ def model_template_fn(self, text_prompt, system_message=None): @dataclass -class LLaVA(LLaVAHuggingFaceMM): +class LLaVAModel(LLaVAHuggingFaceModel): """This class is used to run a self-hosted LLaVA model via the LLaVA package.""" model_base: str = None @@ -857,7 +873,7 @@ def _generate(self, text_prompt, query_images=None, system_message=None): @dataclass -class ClaudeModels(EndpointModels, KeyBasedAuthentication): +class ClaudeModel(EndpointModel, KeyBasedAuthMixIn): """This class is used to interact with Claude models through the python api.""" model_name: str = None From ac1d40c193ed6a89e803fb3e412fd5724798f68a Mon Sep 17 00:00:00 2001 From: Safoora Yousefi Date: Fri, 18 Oct 2024 17:53:33 +0000 Subject: [PATCH 2/3] bug fix --- eureka_ml_insights/models/models.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/eureka_ml_insights/models/models.py b/eureka_ml_insights/models/models.py index ab6e187..2c3a32b 100644 --- a/eureka_ml_insights/models/models.py +++ b/eureka_ml_insights/models/models.py @@ -431,7 +431,7 @@ class DirectOpenAIModel(OpenAICommonRequestResponseMixIn, DirectOpenAIClientMixI api_version: str = "2023-06-01-preview" def __post_init__(self): - super().__post_init__() + self.api_key = self.get_api_key() self.client = self.get_client() class OpenAIO1RequestResponseMixIn(): @@ -469,7 +469,8 @@ class DirectOpenAIO1Model(OpenAIO1RequestResponseMixIn, DirectOpenAIClientMixIn, presence_penalty: float = 0 def __post_init__(self): - self.client = super().get_client() + self.api_key = self.get_api_key() + self.client = self.get_client() @dataclass class AzureOpenAIO1Model(OpenAIO1RequestResponseMixIn, AzureOpenAIClientMixIn, EndpointModel): @@ -488,7 +489,7 @@ class AzureOpenAIO1Model(OpenAIO1RequestResponseMixIn, AzureOpenAIClientMixIn, E def __post_init__(self): - self.client = super().get_client() + self.client = self.get_client() @dataclass From e4be6a755306942e66b535986bb83ca0b58aefe5 Mon Sep 17 00:00:00 2001 From: Safoora Yousefi Date: Fri, 18 Oct 2024 18:09:12 +0000 Subject: [PATCH 3/3] formatting and tests --- eureka_ml_insights/configs/model_configs.py | 14 +++++++----- eureka_ml_insights/models/__init__.py | 15 ++++++------- eureka_ml_insights/models/models.py | 24 ++++++++++++++------- 3 files changed, 32 insertions(+), 21 deletions(-) diff --git a/eureka_ml_insights/configs/model_configs.py b/eureka_ml_insights/configs/model_configs.py index f3588d8..17ac457 100644 --- a/eureka_ml_insights/configs/model_configs.py +++ b/eureka_ml_insights/configs/model_configs.py @@ -3,20 +3,24 @@ You can also add your custom models here by following the same pattern as the existing configs. """ from eureka_ml_insights.models import ( + AzureOpenAIO1Model, ClaudeModel, + DirectOpenAIModel, + DirectOpenAIO1Model, GeminiModel, LlamaServerlessAzureRestEndpointModel, - LLaVAModel, LLaVAHuggingFaceModel, + LLaVAModel, MistralServerlessAzureRestEndpointModel, - DirectOpenAIModel, - AzureOpenAIO1Model, - DirectOpenAIO1Model, RestEndpointModel, ) from .config import ModelConfig +# For models that require secret keys, you can store the keys in a json file and provide the path to the file +# in the secret_key_params dictionary. OR you can provide the key name and key vault URL to fetch the key from Azure Key Vault. +# You don't need to provide both the key_vault_url and local_keys_path. You can provide one of them based on your setup. + # OpenAI models OPENAI_SECRET_KEY_PARAMS = { "key_name": "your_openai_secret_key_name", @@ -38,7 +42,7 @@ "model_name": "o1-preview", "url": "your/endpoint/url", "api_version": "2024-08-01-preview", - } + }, ) OAI_GPT4_1106_PREVIEW_CONFIG = ModelConfig( diff --git a/eureka_ml_insights/models/__init__.py b/eureka_ml_insights/models/__init__.py index a945901..fc68cc6 100644 --- a/eureka_ml_insights/models/__init__.py +++ b/eureka_ml_insights/models/__init__.py @@ -1,19 +1,17 @@ from .models import ( + AzureOpenAIModel, + AzureOpenAIO1Model, ClaudeModel, + DirectOpenAIModel, + DirectOpenAIO1Model, GeminiModel, HuggingFaceModel, LlamaServerlessAzureRestEndpointModel, - LLaVAModel, LLaVAHuggingFaceModel, + LLaVAModel, MistralServerlessAzureRestEndpointModel, - AzureOpenAIModel, - DirectOpenAIModel, - AzureOpenAIO1Model, - DirectOpenAIO1Model, Phi3HFModel, - KeyBasedAuthMixIn, - EndpointModel, - RestEndpointModel + RestEndpointModel, ) __all__ = [ @@ -29,4 +27,5 @@ MistralServerlessAzureRestEndpointModel, LlamaServerlessAzureRestEndpointModel, LLaVAModel, + RestEndpointModel, ] diff --git a/eureka_ml_insights/models/models.py b/eureka_ml_insights/models/models.py index 2c3a32b..5ecfd6d 100644 --- a/eureka_ml_insights/models/models.py +++ b/eureka_ml_insights/models/models.py @@ -73,7 +73,6 @@ def __post_init__(self): raise ValueError("Either api_key or secret_key_params must be provided.") self.api_key = self.get_api_key() - def get_api_key(self): """ This method is used to get the api_key for the models that require key-based authentication. @@ -81,7 +80,7 @@ def get_api_key(self): if api_key is not directly provided, secret_key_params must be provided to get the api_key using GetKey method. """ if self.api_key is None: - self.api_key = GetKey(**self.secret_key_params) + self.api_key = GetKey(**self.secret_key_params) return self.api_key @@ -230,12 +229,13 @@ def __post_init__(self): "Authorization": ("Bearer " + self.api_key), } except ValueError: - self.bearer_token_provider = get_bearer_token_provider(AzureCliCredential(), "https://cognitiveservices.azure.com/.default") + self.bearer_token_provider = get_bearer_token_provider( + AzureCliCredential(), "https://cognitiveservices.azure.com/.default" + ) headers = { "Content-Type": "application/json", "Authorization": ("Bearer " + self.bearer_token_provider()), } - @abstractmethod def create_request(self, text_prompt, query_images=None, system_message=None): @@ -326,7 +326,7 @@ def create_request(self, text_prompt, *args): @dataclass -class OpenAICommonRequestResponseMixIn(): +class OpenAICommonRequestResponseMixIn: """ This mixin class defines the request and response handling for most OpenAI models. """ @@ -368,8 +368,9 @@ def get_response(self, request): self.response_time = end_time - start_time -class AzureOpenAIClientMixIn(): +class AzureOpenAIClientMixIn: """This mixin provides some methods to interact with Azure OpenAI models.""" + def get_client(self): from openai import AzureOpenAI @@ -391,6 +392,7 @@ def handle_request_error(self, e): class DirectOpenAIClientMixIn(KeyBasedAuthMixIn): """This mixin class provides some methods for using OpenAI models dirctly (not through Azure)""" + def get_client(self): from openai import OpenAI @@ -402,9 +404,11 @@ def handle_request_error(self, e): logging.warning(e) return False + @dataclass class AzureOpenAIModel(OpenAICommonRequestResponseMixIn, AzureOpenAIClientMixIn, EndpointModel): """This class is used to interact with Azure OpenAI models.""" + url: str = None model_name: str = None temperature: float = 0 @@ -418,9 +422,11 @@ class AzureOpenAIModel(OpenAICommonRequestResponseMixIn, AzureOpenAIClientMixIn, def __post_init__(self): self.client = self.get_client() + @dataclass class DirectOpenAIModel(OpenAICommonRequestResponseMixIn, DirectOpenAIClientMixIn, EndpointModel): """This class is used to interact with OpenAI models dirctly (not through Azure)""" + model_name: str = None temperature: float = 0 max_tokens: int = 2000 @@ -434,7 +440,8 @@ def __post_init__(self): self.api_key = self.get_api_key() self.client = self.get_client() -class OpenAIO1RequestResponseMixIn(): + +class OpenAIO1RequestResponseMixIn: def create_request(self, prompt, *args, **kwargs): messages = [{"role": "user", "content": prompt}] return {"messages": messages} @@ -455,6 +462,7 @@ def get_response(self, request): self.model_output = openai_response["choices"][0]["message"]["content"] self.response_time = end_time - start_time + @dataclass class DirectOpenAIO1Model(OpenAIO1RequestResponseMixIn, DirectOpenAIClientMixIn, EndpointModel): model_name: str = None @@ -472,6 +480,7 @@ def __post_init__(self): self.api_key = self.get_api_key() self.client = self.get_client() + @dataclass class AzureOpenAIO1Model(OpenAIO1RequestResponseMixIn, AzureOpenAIClientMixIn, EndpointModel): url: str = None @@ -487,7 +496,6 @@ class AzureOpenAIO1Model(OpenAIO1RequestResponseMixIn, AzureOpenAIClientMixIn, E presence_penalty: float = 0 api_version: str = "2023-06-01-preview" - def __post_init__(self): self.client = self.get_client()