diff --git a/eureka_ml_insights/configs/model_configs.py b/eureka_ml_insights/configs/model_configs.py index 7b72988..17ac457 100644 --- a/eureka_ml_insights/configs/model_configs.py +++ b/eureka_ml_insights/configs/model_configs.py @@ -3,18 +3,24 @@ You can also add your custom models here by following the same pattern as the existing configs. """ from eureka_ml_insights.models import ( - ClaudeModels, - GeminiModels, - LlamaServerlessAzureRestEndpointModels, - LLaVA, - LLaVAHuggingFaceMM, - MistralServerlessAzureRestEndpointModels, - OpenAIModelsOAI, - RestEndpointModels, + AzureOpenAIO1Model, + ClaudeModel, + DirectOpenAIModel, + DirectOpenAIO1Model, + GeminiModel, + LlamaServerlessAzureRestEndpointModel, + LLaVAHuggingFaceModel, + LLaVAModel, + MistralServerlessAzureRestEndpointModel, + RestEndpointModel, ) from .config import ModelConfig +# For models that require secret keys, you can store the keys in a json file and provide the path to the file +# in the secret_key_params dictionary. OR you can provide the key name and key vault URL to fetch the key from Azure Key Vault. +# You don't need to provide both the key_vault_url and local_keys_path. You can provide one of them based on your setup. + # OpenAI models OPENAI_SECRET_KEY_PARAMS = { "key_name": "your_openai_secret_key_name", @@ -22,8 +28,25 @@ "key_vault_url": None, } +OAI_O1_PREVIEW_CONFIG = ModelConfig( + DirectOpenAIO1Model, + { + "model_name": "o1-preview", + "secret_key_params": OPENAI_SECRET_KEY_PARAMS, + }, +) + +OAI_O1_PREVIEW_AUZRE_CONFIG = ModelConfig( + AzureOpenAIO1Model, + { + "model_name": "o1-preview", + "url": "your/endpoint/url", + "api_version": "2024-08-01-preview", + }, +) + OAI_GPT4_1106_PREVIEW_CONFIG = ModelConfig( - OpenAIModelsOAI, + DirectOpenAIModel, { "model_name": "gpt-4-1106-preview", "secret_key_params": OPENAI_SECRET_KEY_PARAMS, @@ -31,7 +54,7 @@ ) OAI_GPT4V_1106_VISION_PREVIEW_CONFIG = ModelConfig( - OpenAIModelsOAI, + DirectOpenAIModel, { "model_name": "gpt-4-1106-vision-preview", "secret_key_params": OPENAI_SECRET_KEY_PARAMS, @@ -39,7 +62,7 @@ ) OAI_GPT4V_TURBO_2024_04_09_CONFIG = ModelConfig( - OpenAIModelsOAI, + DirectOpenAIModel, { "model_name": "gpt-4-turbo-2024-04-09", "secret_key_params": OPENAI_SECRET_KEY_PARAMS, @@ -47,7 +70,7 @@ ) OAI_GPT4O_2024_05_13_CONFIG = ModelConfig( - OpenAIModelsOAI, + DirectOpenAIModel, { "model_name": "gpt-4o-2024-05-13", "secret_key_params": OPENAI_SECRET_KEY_PARAMS, @@ -63,7 +86,7 @@ } GEMINI_V15_PRO_CONFIG = ModelConfig( - GeminiModels, + GeminiModel, { "model_name": "gemini-1.5-pro", "secret_key_params": GEMINI_SECRET_KEY_PARAMS, @@ -71,7 +94,7 @@ ) GEMINI_V1_PRO_CONFIG = ModelConfig( - GeminiModels, + GeminiModel, { "model_name": "gemini-1.0-pro", "secret_key_params": GEMINI_SECRET_KEY_PARAMS, @@ -86,7 +109,7 @@ } CLAUDE_3_OPUS_CONFIG = ModelConfig( - ClaudeModels, + ClaudeModel, { "model_name": "claude-3-opus-20240229", "secret_key_params": CLAUDE_SECRET_KEY_PARAMS, @@ -94,7 +117,7 @@ ) CLAUDE_3_5_SONNET_CONFIG = ModelConfig( - ClaudeModels, + ClaudeModel, { "secret_key_params": CLAUDE_SECRET_KEY_PARAMS, "model_name": "claude-3-5-sonnet-20240620", @@ -103,29 +126,29 @@ # LLAVA models LLAVAHF_V16_34B_CONFIG = ModelConfig( - LLaVAHuggingFaceMM, + LLaVAHuggingFaceModel, {"model_name": "llava-hf/llava-v1.6-34b-hf", "use_flash_attn": True}, ) LLAVAHF_V15_7B_CONFIG = ModelConfig( - LLaVAHuggingFaceMM, + LLaVAHuggingFaceModel, {"model_name": "llava-hf/llava-1.5-7b-hf", "use_flash_attn": True}, ) LLAVA_V16_34B_CONFIG = ModelConfig( - LLaVA, + LLaVAModel, {"model_name": "liuhaotian/llava-v1.6-34b", "use_flash_attn": True}, ) LLAVA_V15_7B_CONFIG = ModelConfig( - LLaVA, + LLaVAModel, {"model_name": "liuhaotian/llava-v1.5-7b", "use_flash_attn": True}, ) # Llama models LLAMA3_1_70B_INSTRUCT_CONFIG = ModelConfig( - RestEndpointModels, + RestEndpointModel, { "url": "your/endpoint/url", "secret_key_params": { @@ -138,7 +161,7 @@ ) LLAMA3_1_405B_INSTRUCT_CONFIG = ModelConfig( - LlamaServerlessAzureRestEndpointModels, + LlamaServerlessAzureRestEndpointModel, { "url": "your/endpoint/url", "secret_key_params": { @@ -152,7 +175,7 @@ # Mistral Endpoints AIF_NT_MISTRAL_LARGE_2_2407_CONFIG = ModelConfig( - MistralServerlessAzureRestEndpointModels, + MistralServerlessAzureRestEndpointModel, { "url": "your/endpoint/url", "secret_key_params": { diff --git a/eureka_ml_insights/models/__init__.py b/eureka_ml_insights/models/__init__.py index c7a6926..fc68cc6 100644 --- a/eureka_ml_insights/models/__init__.py +++ b/eureka_ml_insights/models/__init__.py @@ -1,32 +1,31 @@ from .models import ( - ClaudeModels, - GeminiModels, - HuggingFaceLM, - LlamaServerlessAzureRestEndpointModels, - LLaVA, - LLaVAHuggingFaceMM, - MistralServerlessAzureRestEndpointModels, - OpenAIModelsMixIn, - OpenAIModelsAzure, - OpenAIModelsOAI, - OpenAIO1Direct, - Phi3HF, - KeyBasedAuthentication, - EndpointModels, - RestEndpointModels + AzureOpenAIModel, + AzureOpenAIO1Model, + ClaudeModel, + DirectOpenAIModel, + DirectOpenAIO1Model, + GeminiModel, + HuggingFaceModel, + LlamaServerlessAzureRestEndpointModel, + LLaVAHuggingFaceModel, + LLaVAModel, + MistralServerlessAzureRestEndpointModel, + Phi3HFModel, + RestEndpointModel, ) __all__ = [ - OpenAIModelsMixIn, - OpenAIO1Direct, - HuggingFaceLM, - LLaVAHuggingFaceMM, - Phi3HF, - OpenAIModelsOAI, - OpenAIModelsAzure, - GeminiModels, - ClaudeModels, - MistralServerlessAzureRestEndpointModels, - LlamaServerlessAzureRestEndpointModels, - LLaVA, + AzureOpenAIO1Model, + DirectOpenAIO1Model, + HuggingFaceModel, + LLaVAHuggingFaceModel, + Phi3HFModel, + DirectOpenAIModel, + AzureOpenAIModel, + GeminiModel, + ClaudeModel, + MistralServerlessAzureRestEndpointModel, + LlamaServerlessAzureRestEndpointModel, + LLaVAModel, + RestEndpointModel, ] diff --git a/eureka_ml_insights/models/models.py b/eureka_ml_insights/models/models.py index c5ab5a6..5ecfd6d 100644 --- a/eureka_ml_insights/models/models.py +++ b/eureka_ml_insights/models/models.py @@ -62,7 +62,7 @@ def base64encode(self, query_images): @dataclass -class KeyBasedAuthentication: +class KeyBasedAuthMixIn: """This class is used to handle key-based authentication for models.""" api_key: str = None @@ -85,7 +85,7 @@ def get_api_key(self): @dataclass -class EndpointModels(Model): +class EndpointModel(Model): """This class is used to interact with API-based models.""" num_retries: int = 3 @@ -149,7 +149,7 @@ def handle_request_error(self, e): @dataclass -class RestEndpointModels(EndpointModels, KeyBasedAuthentication): +class RestEndpointModel(EndpointModel, KeyBasedAuthMixIn): url: str = None model_name: str = None temperature: float = 0 @@ -206,14 +206,14 @@ def get_response(self, request): def handle_request_error(self, e): if isinstance(e, urllib.error.HTTPError): logging.info("The request failed with status code: " + str(e.code)) - # Print the headers - they include the requert ID and the timestamp, which are useful for debugging. + # Print the headers - they include the request ID and the timestamp, which are useful for debugging. logging.info(e.info()) logging.info(e.read().decode("utf8", "ignore")) return False @dataclass -class ServerlessAzureRestEndpointModels(EndpointModels, KeyBasedAuthentication): +class ServerlessAzureRestEndpointModel(EndpointModel, KeyBasedAuthMixIn): """This class can be used for serverless Azure model deployments.""" """https://learn.microsoft.com/en-us/azure/ai-studio/how-to/deploy-models-serverless?tabs=azure-ai-studio""" @@ -222,11 +222,20 @@ class ServerlessAzureRestEndpointModels(EndpointModels, KeyBasedAuthentication): stream: str = "false" def __post_init__(self): - super().__post_init__() - self.headers = { - "Content-Type": "application/json", - "Authorization": ("Bearer " + self.api_key), - } + try: + super().__post_init__() + self.headers = { + "Content-Type": "application/json", + "Authorization": ("Bearer " + self.api_key), + } + except ValueError: + self.bearer_token_provider = get_bearer_token_provider( + AzureCliCredential(), "https://cognitiveservices.azure.com/.default" + ) + headers = { + "Content-Type": "application/json", + "Authorization": ("Bearer " + self.bearer_token_provider()), + } @abstractmethod def create_request(self, text_prompt, query_images=None, system_message=None): @@ -252,7 +261,7 @@ def handle_request_error(self, e): @dataclass -class LlamaServerlessAzureRestEndpointModels(ServerlessAzureRestEndpointModels, KeyBasedAuthentication): +class LlamaServerlessAzureRestEndpointModel(ServerlessAzureRestEndpointModel): """Tested for Llama 3.1 405B Instruct deployments.""" """See https://learn.microsoft.com/en-us/azure/ai-studio/how-to/deploy-models-llama?tabs=llama-three for the api reference.""" @@ -285,7 +294,7 @@ def create_request(self, text_prompt, *args): @dataclass -class MistralServerlessAzureRestEndpointModels(ServerlessAzureRestEndpointModels, KeyBasedAuthentication): +class MistralServerlessAzureRestEndpointModel(ServerlessAzureRestEndpointModel): """Tested for Mistral Large 2 2407 deployments.""" """See https://learn.microsoft.com/en-us/azure/ai-studio/how-to/deploy-models-mistral?tabs=mistral-large#mistral-chat-api for the api reference.""" @@ -317,26 +326,11 @@ def create_request(self, text_prompt, *args): @dataclass -class OpenAIModelsMixIn(EndpointModels): +class OpenAICommonRequestResponseMixIn: """ - This class defines the request and response handling for OpenAI models. - This is an abstract class and should not be used directly. Child classes should implement the get_client - method and handle_request_error method. + This mixin class defines the request and response handling for most OpenAI models. """ - model_name: str = None - temperature: float = 0 - max_tokens: int = 2000 - top_p: float = 0.95 - frequency_penalty: float = 0 - presence_penalty: float = 0 - seed: int = 0 - api_version: str = "2023-06-01-preview" - - @abstractmethod - def get_client(self): - raise NotImplementedError - def create_request(self, prompt, query_images=None, system_message=None): messages = [] if system_message: @@ -373,19 +367,9 @@ def get_response(self, request): self.model_output = openai_response["choices"][0]["message"]["content"] self.response_time = end_time - start_time - @abstractmethod - def handle_request_error(self, e): - raise NotImplementedError - - -@dataclass -class OpenAIModelsAzure(OpenAIModelsMixIn): - """This class is used to interact with Azure OpenAI models.""" - - url: str = None - def __post_init__(self): - self.client = self.get_client() +class AzureOpenAIClientMixIn: + """This mixin provides some methods to interact with Azure OpenAI models.""" def get_client(self): from openai import AzureOpenAI @@ -406,13 +390,8 @@ def handle_request_error(self, e): return False -@dataclass -class OpenAIModelsOAI(OpenAIModelsMixIn, KeyBasedAuthentication): - """This class is used to interact with OpenAI models dirctly (not through Azure)""" - - def __post_init__(self): - super().__post_init__() - self.client = self.get_client() +class DirectOpenAIClientMixIn(KeyBasedAuthMixIn): + """This mixin class provides some methods for using OpenAI models dirctly (not through Azure)""" def get_client(self): from openai import OpenAI @@ -427,30 +406,43 @@ def handle_request_error(self, e): @dataclass -class OpenAIO1Direct(EndpointModels, KeyBasedAuthentication): +class AzureOpenAIModel(OpenAICommonRequestResponseMixIn, AzureOpenAIClientMixIn, EndpointModel): + """This class is used to interact with Azure OpenAI models.""" + + url: str = None model_name: str = None - temperature: float = 1 - # Not used currently, because the API throws: - # "Completions.create() got an unexpected keyword argument 'max_completion_tokens'" - # although this argument is documented in the OpenAI API documentation. - max_completion_tokens: int = 2000 - top_p: float = 1 - seed: int = 0 + temperature: float = 0 + max_tokens: int = 2000 + top_p: float = 0.95 frequency_penalty: float = 0 presence_penalty: float = 0 + seed: int = 0 + api_version: str = "2023-06-01-preview" def __post_init__(self): - super().__post_init__() self.client = self.get_client() - def get_client(self): - from openai import OpenAI - return OpenAI( - api_key=self.api_key, - ) +@dataclass +class DirectOpenAIModel(OpenAICommonRequestResponseMixIn, DirectOpenAIClientMixIn, EndpointModel): + """This class is used to interact with OpenAI models dirctly (not through Azure)""" + + model_name: str = None + temperature: float = 0 + max_tokens: int = 2000 + top_p: float = 0.95 + frequency_penalty: float = 0 + presence_penalty: float = 0 + seed: int = 0 + api_version: str = "2023-06-01-preview" - def create_request(self, prompt, *kwargs): + def __post_init__(self): + self.api_key = self.get_api_key() + self.client = self.get_client() + + +class OpenAIO1RequestResponseMixIn: + def create_request(self, prompt, *args, **kwargs): messages = [{"role": "user", "content": prompt}] return {"messages": messages} @@ -470,13 +462,46 @@ def get_response(self, request): self.model_output = openai_response["choices"][0]["message"]["content"] self.response_time = end_time - start_time - def handle_request_error(self, e): - logging.warning(e) - return False + +@dataclass +class DirectOpenAIO1Model(OpenAIO1RequestResponseMixIn, DirectOpenAIClientMixIn, EndpointModel): + model_name: str = None + temperature: float = 1 + # Not used currently, because the API throws: + # "Completions.create() got an unexpected keyword argument 'max_completion_tokens'" + # although this argument is documented in the OpenAI API documentation. + max_completion_tokens: int = 2000 + top_p: float = 1 + seed: int = 0 + frequency_penalty: float = 0 + presence_penalty: float = 0 + + def __post_init__(self): + self.api_key = self.get_api_key() + self.client = self.get_client() + + +@dataclass +class AzureOpenAIO1Model(OpenAIO1RequestResponseMixIn, AzureOpenAIClientMixIn, EndpointModel): + url: str = None + model_name: str = None + temperature: float = 1 + # Not used currently, because the API throws: + # "Completions.create() got an unexpected keyword argument 'max_completion_tokens'" + # although this argument is documented in the OpenAI API documentation. + max_completion_tokens: int = 2000 + top_p: float = 1 + seed: int = 0 + frequency_penalty: float = 0 + presence_penalty: float = 0 + api_version: str = "2023-06-01-preview" + + def __post_init__(self): + self.client = self.get_client() @dataclass -class GeminiModels(EndpointModels, KeyBasedAuthentication): +class GeminiModel(EndpointModel, KeyBasedAuthMixIn): """This class is used to interact with Gemini models through the python api.""" timeout: int = 60 @@ -549,7 +574,7 @@ def handle_request_error(self, e): @dataclass -class HuggingFaceLM(Model): +class HuggingFaceModel(Model): """This class is used to run a self-hosted language model via HuggingFace apis.""" model_name: str = None @@ -647,7 +672,7 @@ def model_template_fn(self, text_prompt, system_message=None): @dataclass -class Phi3HF(HuggingFaceLM): +class Phi3HFModel(HuggingFaceModel): """This class is used to run a self-hosted PHI3 model via HuggingFace apis.""" def __post_init__(self): @@ -664,7 +689,7 @@ def model_template_fn(self, text_prompt, system_message=None): @dataclass -class LLaVAHuggingFaceMM(HuggingFaceLM): +class LLaVAHuggingFaceModel(HuggingFaceModel): """This class is used to run a self-hosted LLaVA model via HuggingFace apis.""" quantize: bool = False @@ -786,7 +811,7 @@ def model_template_fn(self, text_prompt, system_message=None): @dataclass -class LLaVA(LLaVAHuggingFaceMM): +class LLaVAModel(LLaVAHuggingFaceModel): """This class is used to run a self-hosted LLaVA model via the LLaVA package.""" model_base: str = None @@ -857,7 +882,7 @@ def _generate(self, text_prompt, query_images=None, system_message=None): @dataclass -class ClaudeModels(EndpointModels, KeyBasedAuthentication): +class ClaudeModel(EndpointModel, KeyBasedAuthMixIn): """This class is used to interact with Claude models through the python api.""" model_name: str = None