Skip to content

Commit

Permalink
Serverless fix (#52)
Browse files Browse the repository at this point in the history
**LlamaServerlessAzureRestEndpointModel**

Updates the type of Llama specific parameters that were previously
declared as str to bool. Previously, the api would ignore the parameters
set as str.

**MistralServerlessAzureRestEndpointModel**
Similar as above for the safe_prompt parameter

**ServerlessAzureRestEndpointModel**
Similar as above for the stream parameter
Also adds "extra-parameters": "pass-through" to keep the class
compatible with Llama 3.2

---------

Co-authored-by: Besmira Nushi <[email protected]>
  • Loading branch information
nushib and Besmira Nushi authored Dec 5, 2024
1 parent 7a27e6c commit c6a2dff
Showing 1 changed file with 18 additions and 8 deletions.
26 changes: 18 additions & 8 deletions eureka_ml_insights/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,22 +219,32 @@ class ServerlessAzureRestEndpointModel(EndpointModel, KeyBasedAuthMixIn):
"""https://learn.microsoft.com/en-us/azure/ai-studio/how-to/deploy-models-serverless?tabs=azure-ai-studio"""
url: str = None
model_name: str = None
stream: str = "false"
stream: bool = False

def __post_init__(self):
try:
super().__post_init__()
self.headers = {
"Content-Type": "application/json",
"Authorization": ("Bearer " + self.api_key),
# The behavior of the API when extra parameters are indicated in the payload.
# Using pass-through makes the API to pass the parameter to the underlying model.
# Use this value when you want to pass parameters that you know the underlying model can support.
# https://learn.microsoft.com/en-us/azure/machine-learning/reference-model-inference-chat-completions?view=azureml-api-2
"extra-parameters": "pass-through"
}
except ValueError:
self.bearer_token_provider = get_bearer_token_provider(
AzureCliCredential(), "https://cognitiveservices.azure.com/.default"
DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
)
headers = {
self.headers = {
"Content-Type": "application/json",
"Authorization": ("Bearer " + self.bearer_token_provider()),
# The behavior of the API when extra parameters are indicated in the payload.
# Using pass-through makes the API to pass the parameter to the underlying model.
# Use this value when you want to pass parameters that you know the underlying model can support.
# https://learn.microsoft.com/en-us/azure/machine-learning/reference-model-inference-chat-completions?view=azureml-api-2
"extra-parameters": "pass-through"
}

@abstractmethod
Expand Down Expand Up @@ -264,7 +274,7 @@ def handle_request_error(self, e):

@dataclass
class LlamaServerlessAzureRestEndpointModel(ServerlessAzureRestEndpointModel):
"""Tested for Llama 3.1 405B Instruct deployments."""
"""Tested for Llama 3.1 405B Instruct deployments and Llama 3.2 90B Vision Instruct."""

"""See https://learn.microsoft.com/en-us/azure/ai-studio/how-to/deploy-models-llama?tabs=llama-three for the api reference."""

Expand All @@ -273,10 +283,10 @@ class LlamaServerlessAzureRestEndpointModel(ServerlessAzureRestEndpointModel):
top_p: float = 0.95
frequency_penalty: float = 0
presence_penalty: float = 0
use_beam_search: str = "false"
use_beam_search: bool = False
best_of: int = 1
skip_special_tokens: str = "false"
ignore_eos: str = "false"
skip_special_tokens: bool = False
ignore_eos: bool = False

def create_request(self, text_prompt, query_images=None, *args, **kwargs):
user_content = {"role": "user", "content": text_prompt}
Expand Down Expand Up @@ -318,7 +328,7 @@ class MistralServerlessAzureRestEndpointModel(ServerlessAzureRestEndpointModel):
temperature: float = 0
max_tokens: int = 2000
top_p: float = 1
safe_prompt: str = "false"
safe_prompt: bool = False

def __post_init__(self):
if self.temperature == 0 and self.top_p != 1:
Expand Down

0 comments on commit c6a2dff

Please sign in to comment.