diff --git a/eureka_ml_insights/configs/model_configs.py b/eureka_ml_insights/configs/model_configs.py index 1b83a7b..fd0b152 100644 --- a/eureka_ml_insights/configs/model_configs.py +++ b/eureka_ml_insights/configs/model_configs.py @@ -16,6 +16,8 @@ TestModel, ) +from azure.identity import DefaultAzureCredential + from .config import ModelConfig # For models that require secret keys, you can store the keys in a json file and provide the path to the file @@ -32,6 +34,7 @@ "key_name": "your_openai_secret_key_name", "local_keys_path": "keys/keys.json", "key_vault_url": None, + "credential_func": DefaultAzureCredential, } OAI_O1_PREVIEW_CONFIG = ModelConfig( @@ -96,6 +99,7 @@ "key_name": "your_gemini_secret_key_name", "local_keys_path": "keys/keys.json", "key_vault_url": None, + "credential_func": DefaultAzureCredential, } GEMINI_V15_PRO_CONFIG = ModelConfig( @@ -119,6 +123,7 @@ "key_name": "your_claude_secret_key_name", "local_keys_path": "keys/keys.json", "key_vault_url": None, + "credential_func": DefaultAzureCredential, } CLAUDE_3_OPUS_CONFIG = ModelConfig( @@ -168,6 +173,7 @@ "key_name": "your_llama_secret_key_name", "local_keys_path": "keys/keys.json", "key_vault_url": None, + "credential_func": DefaultAzureCredential, }, "model_name": "meta-llama-3-1-70b-instruct", }, @@ -181,6 +187,7 @@ "key_name": "your_llama_secret_key_name", "local_keys_path": "keys/keys.json", "key_vault_url": None, + "credential_func": DefaultAzureCredential, }, "model_name": "Meta-Llama-3-1-405B-Instruct", }, @@ -195,6 +202,7 @@ "key_name": "your_mistral_secret_key_name", "local_keys_path": "keys/keys.json", "key_vault_url": None, + "credential_func": DefaultAzureCredential, }, "model_name": "Mistral-large-2407", }, diff --git a/eureka_ml_insights/data_utils/data.py b/eureka_ml_insights/data_utils/data.py index 55fd556..d954135 100644 --- a/eureka_ml_insights/data_utils/data.py +++ b/eureka_ml_insights/data_utils/data.py @@ -8,7 +8,6 @@ import jsonlines import pandas as pd -from azure.identity import DefaultAzureCredential from azure.storage.blob import BlobClient, ContainerClient from datasets import load_dataset from PIL import Image @@ -234,6 +233,7 @@ def __init__( path, account_url, blob_container, + credential_func:callable=lambda _: None, total_lines=None, image_column_names=None, image_column_search_regex="image", @@ -260,7 +260,7 @@ def __init__( self.container_client = ContainerClient( account_url=self.account_url, container_name=self.blob_container, - credential=DefaultAzureCredential(), + credential=credential_func(), logger=self.logger, ) @@ -312,13 +312,13 @@ def read(self): class AzureBlobReader: """Reads an Azure storage blob from a full URL to a str""" - def read_azure_blob(self, blob_url) -> str: + def read_azure_blob(self, blob_url, credential_func:callable=lambda _: None) -> str: """ Reads an Azure storage blob.. args: blob_url: str, The Azure storage blob full URL. """ - blob_client = BlobClient.from_blob_url(blob_url, credential=DefaultAzureCredential(), logger=self.logger) + blob_client = BlobClient.from_blob_url(blob_url, credential=credential_func(), logger=self.logger) # real all the bytes from the blob file = blob_client.download_blob().readall() file = file.decode("utf-8") @@ -336,6 +336,7 @@ def __init__( account_url: str, blob_container: str, blob_name: str, + credential_func:callable=lambda _: None, ): """ Initializes an AzureJsonReader. @@ -346,10 +347,11 @@ def __init__( """ self.blob_url = f"{account_url}/{blob_container}/{blob_name}" super().__init__(self.blob_url) + self.credential_func = credential_func self.logger = AzureStorageLogger().get_logger() def read(self) -> dict: - file = super().read_azure_blob(self.blob_url) + file = super().read_azure_blob(self.blob_url, credential_func=self.credential_func) if self.format == ".json": data = json.loads(file) elif self.format == ".jsonl": @@ -464,6 +466,7 @@ def __init__( account_url: str, blob_container: str, blob_name: str, + credential_func:callable = lambda _: None, format: str = None, transform: Optional[DFTransformBase] = None, **kwargs, @@ -480,10 +483,11 @@ def __init__( """ self.blob_url = f"{account_url}/{blob_container}/{blob_name}" super().__init__(self.blob_url, format, transform, **kwargs) + self.credential_func = credential_func self.logger = AzureStorageLogger().get_logger() def _load_dataset(self) -> pd.DataFrame: - file = super().read_azure_blob(self.blob_url) + file = super().read_azure_blob(self.blob_url, credential_func=self.credential_func) if self.format == ".jsonl": jlr = jsonlines.Reader(file.splitlines()) df = pd.DataFrame(jlr.iter(skip_empty=True, skip_invalid=True)) diff --git a/eureka_ml_insights/metrics/kitab_metrics.py b/eureka_ml_insights/metrics/kitab_metrics.py index c11aafb..d8f0a94 100644 --- a/eureka_ml_insights/metrics/kitab_metrics.py +++ b/eureka_ml_insights/metrics/kitab_metrics.py @@ -17,7 +17,6 @@ ServiceRequestError, ServiceResponseError, ) -from azure.identity import DefaultAzureCredential from fuzzywuzzy import fuzz from eureka_ml_insights.metrics import CompositeMetric @@ -38,12 +37,14 @@ def __init__(self, temp_path_names, azure_lang_service_config): "https://huggingface.co/datasets/microsoft/kitab/raw/main/code/utils/gpt_4_name_data_processed.csv", temp_path_names, ) + self.credential_func = azure_lang_service_config["secret_key_params"].get("credential_func", lambda _: None) # requires an Azure Cognitive Services Endpoint # (ref: https://learn.microsoft.com/en-us/azure/ai-services/language-service/) self.key = get_secret( key_name=azure_lang_service_config["secret_key_params"].get("key_name", None), local_keys_path=azure_lang_service_config["secret_key_params"].get("local_keys_path", None), key_vault_url=azure_lang_service_config["secret_key_params"].get("key_vault_url", None), + credential_func=self.credential_func ) self.endpoint = azure_lang_service_config["url"] self.text_analytics_credential = self.get_verified_credential() @@ -58,11 +59,11 @@ def get_verified_credential(self): logging.info(f"Failed to create the TextAnalyticsClient using AzureKeyCredential") logging.info("The error is caused by: {}".format(e)) try: - text_analytics_client = TextAnalyticsClient(endpoint=self.endpoint, credential=DefaultAzureCredential()) + text_analytics_client = TextAnalyticsClient(endpoint=self.endpoint, credential=self.credential_func()) text_analytics_client.recognize_entities(["New York City"], model_version=model_version) - return DefaultAzureCredential() + return self.credential_func() except Exception as e: - logging.info(f"Failed to create the TextAnalyticsClient using DefaultAzureCredential") + logging.info(f"Failed to create the TextAnalyticsClient using provided credential func") logging.info("The error is caused by: {}".format(e)) return None diff --git a/eureka_ml_insights/models/models.py b/eureka_ml_insights/models/models.py index 5681171..21ac1c0 100644 --- a/eureka_ml_insights/models/models.py +++ b/eureka_ml_insights/models/models.py @@ -9,7 +9,7 @@ import anthropic import tiktoken -from azure.identity import DefaultAzureCredential, get_bearer_token_provider +from azure.identity import get_bearer_token_provider from eureka_ml_insights.secret_management import get_secret @@ -222,6 +222,8 @@ class ServerlessAzureRestEndpointModel(EndpointModel, KeyBasedAuthMixIn): stream: bool = False def __post_init__(self): + if self.secret_key_params is None: + raise ValueError("secret_key_params must be provided.") try: super().__post_init__() self.headers = { @@ -235,7 +237,7 @@ def __post_init__(self): } except ValueError: self.bearer_token_provider = get_bearer_token_provider( - DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default" + self.secret_key_params["credential_func"](), "https://cognitiveservices.azure.com/.default" ) self.headers = { "Content-Type": "application/json", @@ -400,11 +402,15 @@ def get_response(self, request): class AzureOpenAIClientMixIn: """This mixin provides some methods to interact with Azure OpenAI models.""" + def __post_init__(self): + if self.secret_key_params is None: + raise ValueError("secret_key_params must be provided.") + def get_client(self): from openai import AzureOpenAI token_provider = get_bearer_token_provider( - DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default" + self.secret_key_params["credential_func"](), "https://cognitiveservices.azure.com/.default" ) return AzureOpenAI( azure_endpoint=self.url, diff --git a/eureka_ml_insights/secret_management/secret_key_utils.py b/eureka_ml_insights/secret_management/secret_key_utils.py index ab4c61c..3a0e903 100644 --- a/eureka_ml_insights/secret_management/secret_key_utils.py +++ b/eureka_ml_insights/secret_management/secret_key_utils.py @@ -3,13 +3,12 @@ import os from typing import Dict, Optional -from azure.identity import DefaultAzureCredential, DeviceCodeCredential from azure.keyvault.secrets import SecretClient logging.basicConfig(level=logging.INFO, format="%(filename)s - %(funcName)s - %(message)s") -def get_secret(key_name: str, local_keys_path:Optional[str]=None, key_vault_url:Optional[str]=None) -> Optional[str]: +def get_secret(key_name: str, local_keys_path:Optional[str]=None, key_vault_url:Optional[str]=None, credential_func=lambda _: None) -> Optional[str]: """This function retrieves a key from key vault or if it is locally cached in a file. args: key_name: str, the name of the key to retrieve. @@ -41,7 +40,7 @@ def get_secret(key_name: str, local_keys_path:Optional[str]=None, key_vault_url: f"Key [{key_name}] not found in local keys file [{local_keys_path}] and key_vault_url is not provided." ) else: - key_value = get_key_from_azure(key_name, key_vault_url) + key_value = get_key_from_azure(key_name, key_vault_url, credential_func=credential_func) # if the key still wasn't found, raise an error if key_value is None: @@ -59,7 +58,7 @@ def get_secret(key_name: str, local_keys_path:Optional[str]=None, key_vault_url: return key_value -def get_key_from_azure(key_name: str, key_vault_url: str) -> Optional[str]: +def get_key_from_azure(key_name: str, key_vault_url: str, credential_func=lambda _: None) -> Optional[str]: """This function retrieves a key from azure key vault. args: key_name: str, the name of the key to retrieve. @@ -69,23 +68,14 @@ def get_key_from_azure(key_name: str, key_vault_url: str) -> Optional[str]: """ logging.getLogger("azure").setLevel(logging.ERROR) try: - logging.info(f"Trying to get the key from Azure Key Vault {key_vault_url} using DefaultAzureCredential") - credential = DefaultAzureCredential(additionally_allowed_tenants=["*"]) + logging.info("Trying to get the key from Azure Key Vault using provided func") + credential = credential_func(additionally_allowed_tenants=["*"]) client = SecretClient(vault_url=key_vault_url, credential=credential) retrieved_key = client.get_secret(key_name) return retrieved_key.value except Exception as e: - logging.info(f"Failed to get the key from Azure Key Vault {key_vault_url} using DefaultAzureCredential") + logging.info("Failed to get the key from Azure Key Vault using provided func") logging.info("The error is caused by: {}".format(e)) - try: - logging.info(f"Trying to get the key from Azure Key Vault {key_vault_url} using DeviceCodeCredential") - credential = DeviceCodeCredential(additionally_allowed_tenants=["*"]) - client = SecretClient(vault_url=key_vault_url, credential=credential) - retrieved_key = client.get_secret(key_name) - return retrieved_key.value - except Exception as e: - logging.error("Failed to get the key from Azure Key Vault using DeviceCodeCredential") - logging.error("The error is caused by: {}".format(e)) return None diff --git a/eureka_ml_insights/user_configs/kitab.py b/eureka_ml_insights/user_configs/kitab.py index 0decdab..a193ed0 100644 --- a/eureka_ml_insights/user_configs/kitab.py +++ b/eureka_ml_insights/user_configs/kitab.py @@ -35,6 +35,8 @@ ) from eureka_ml_insights.configs import ExperimentConfig +from azure.identity import DefaultAzureCredential + # Example template for an Azure Language Service config # required for running entity recognition for evaluating human and city name AZURE_LANG_SERVICE_CONFIG = { @@ -43,6 +45,7 @@ "key_name": "your_azure_lang_service_secret_key_name", "local_keys_path": "keys/keys.json", "key_vault_url": None, + "credential_func": DefaultAzureCredential, }, } diff --git a/tests/test_utils.py b/tests/test_utils.py index 3a71510..afc659f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -9,6 +9,7 @@ ) from eureka_ml_insights.metrics import ClassicMetric, CompositeMetric +from azure.identity import DefaultAzureCredential class TestModel: def __init__(self, model_name="generic_test_model"): @@ -265,5 +266,5 @@ def __init__(self, path, n_iter, image_column_names=None): class TestAzureMMDataLoader(EarlyStoppableIterable, AzureMMDataLoader): def __init__(self, path, n_iter, account_url, blob_container, image_column_names=None): - super().__init__(path, account_url, blob_container, image_column_names=image_column_names) + super().__init__(path, account_url, blob_container, credential_func=DefaultAzureCredential, image_column_names=image_column_names) self.n_iter = n_iter