Skip to content

Commit

Permalink
multi-turn conversations (#63)
Browse files Browse the repository at this point in the history
- Adds support for multi-turn conversations to all the model APIs that
support it. (***Gemini API is the only model that either does not take a
list of messages, or I haven't figured it out yet.)
- Also adds support for Azure authorization scope to be configured via
model config.

---------

Co-authored-by: Safoora Yousefi <[email protected]>
  • Loading branch information
safooray and Safoora Yousefi authored Dec 17, 2024
1 parent 81320d3 commit 4118424
Showing 1 changed file with 71 additions and 37 deletions.
108 changes: 71 additions & 37 deletions eureka_ml_insights/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

from eureka_ml_insights.secret_management import get_secret


@dataclass
class Model(ABC):
"""This class is used to define the structure of a model class.
Expand Down Expand Up @@ -91,15 +90,15 @@ class EndpointModel(Model):
num_retries: int = 3

@abstractmethod
def create_request(self, text_prompt, query_images=None, system_message=None):
def create_request(self, text_prompt, **kwargs):
raise NotImplementedError

@abstractmethod
def get_response(self, request):
# must return the model output and the response time
raise NotImplementedError

def generate(self, query_text, query_images=None, system_message=None):
def generate(self, query_text, **kwargs):
"""
Calls the endpoint to generate the model response.
args:
Expand All @@ -111,7 +110,7 @@ def generate(self, query_text, query_images=None, system_message=None):
and any other relevant information returned by the model.
"""
response_dict = {}
request = self.create_request(query_text, query_images=query_images, system_message=system_message)
request = self.create_request(query_text, **kwargs)
attempts = 0
while attempts < self.num_retries:
try:
Expand Down Expand Up @@ -159,15 +158,17 @@ class RestEndpointModel(EndpointModel, KeyBasedAuthMixIn):
presence_penalty: float = 0
do_sample: bool = True

def create_request(self, text_prompt, query_images=None, system_message=None):
def create_request(self, text_prompt, query_images=None, system_message=None, previous_messages=None):
"""Creates a request for the model."""
messages = []
if system_message:
messages.append({"role": "system", "content": system_message})
if previous_messages:
messages.extend(previous_messages)
messages.append({"role": "user", "content": text_prompt})
data = {
"input_data": {
"input_string": [
{
"role": "user",
"content": text_prompt,
}
],
"input_string": messages,
"parameters": {
"temperature": self.temperature,
"top_p": self.top_p,
Expand All @@ -176,12 +177,8 @@ def create_request(self, text_prompt, query_images=None, system_message=None):
},
}
}
if system_message:
data["input_data"]["input_string"] = [{"role": "system", "content": system_message}] + data["input_data"][
"input_string"
]
if query_images:
raise NotImplementedError("Images are not supported for GCR endpoints yet.")
raise NotImplementedError("Images are not supported for RestEndpointModel endpoints yet.")

body = str.encode(json.dumps(data))
# The azureml-model-deployment header will force the request to go to a specific deployment.
Expand Down Expand Up @@ -220,6 +217,7 @@ class ServerlessAzureRestEndpointModel(EndpointModel, KeyBasedAuthMixIn):
url: str = None
model_name: str = None
stream: bool = False
auth_scope: str = "https://cognitiveservices.azure.com/.default"

def __post_init__(self):
try:
Expand All @@ -235,7 +233,7 @@ def __post_init__(self):
}
except ValueError:
self.bearer_token_provider = get_bearer_token_provider(
DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
DefaultAzureCredential(), self.auth_scope
)
self.headers = {
"Content-Type": "application/json",
Expand All @@ -248,7 +246,7 @@ def __post_init__(self):
}

@abstractmethod
def create_request(self, text_prompt, query_images=None, system_message=None):
def create_request(self, text_prompt, query_images=None, system_message=None, previous_messages=None):
# Exact model parameters are model-specific.
# The method cannot be implemented unless the model being deployed is known.
raise NotImplementedError
Expand Down Expand Up @@ -288,13 +286,18 @@ class LlamaServerlessAzureRestEndpointModel(ServerlessAzureRestEndpointModel):
skip_special_tokens: bool = False
ignore_eos: bool = False

def create_request(self, text_prompt, query_images=None, *args, **kwargs):
user_content = {"role": "user", "content": text_prompt}
def create_request(self, text_prompt, query_images=None, system_message=None, previous_messages=None):
messages = []
if system_message:
messages.append({"role": "system", "content": system_message})
if previous_messages:
messages.extend(previous_messages)
user_content = text_prompt
if query_images:
if len(query_images) > 1:
raise ValueError("Llama vision model does not support more than 1 image.")
encoded_images = self.base64encode(query_images)
user_content["content"] = [
user_content = [
{"type": "text", "text": text_prompt},
{
"type": "image_url",
Expand All @@ -303,9 +306,11 @@ def create_request(self, text_prompt, query_images=None, *args, **kwargs):
},
},
]
messages.append({"role": "user", "content": user_content})


data = {
"messages": [user_content],
"messages": messages,
"max_tokens": self.max_tokens,
"temperature": self.temperature,
"top_p": self.top_p,
Expand Down Expand Up @@ -337,9 +342,17 @@ def __post_init__(self):
self.top_p = 1
super().__post_init__()

def create_request(self, text_prompt, *args, **kwargs):
def create_request(self, text_prompt, query_images=None, system_message=None, previous_messages=None):
messages = []
if system_message:
messages.append({"role": "system", "content": system_message})
if previous_messages:
messages.extend(previous_messages)
if query_images:
raise NotImplementedError("Images are not supported for MistralServerlessAzureRestEndpointModel endpoints.")
messages.append({"role": "user", "content": text_prompt})
data = {
"messages": [{"role": "user", "content": text_prompt}],
"messages": messages,
"max_tokens": self.max_tokens,
"temperature": self.temperature,
"top_p": self.top_p,
Expand All @@ -358,14 +371,16 @@ class OpenAICommonRequestResponseMixIn:
This mixin class defines the request and response handling for most OpenAI models.
"""

def create_request(self, prompt, query_images=None, system_message=None):
def create_request(self, prompt, query_images=None, system_message=None, previous_messages=None):
messages = []
if system_message:
messages.append({"role": "system", "content": system_message})
user_content = {"role": "user", "content": prompt}
if previous_messages:
messages.extend(previous_messages)
user_content = prompt
if query_images:
encoded_images = self.base64encode(query_images)
user_content["content"] = [
user_content = [
{"type": "text", "text": prompt},
{
"type": "image_url",
Expand All @@ -374,7 +389,7 @@ def create_request(self, prompt, query_images=None, system_message=None):
},
},
]
messages.append(user_content)
messages.append({"role": "user", "content": user_content})
return {"messages": messages}

def get_response(self, request):
Expand Down Expand Up @@ -404,7 +419,7 @@ def get_client(self):
from openai import AzureOpenAI

token_provider = get_bearer_token_provider(
DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
DefaultAzureCredential(), self.auth_scope
)
return AzureOpenAI(
azure_endpoint=self.url,
Expand Down Expand Up @@ -449,6 +464,7 @@ class AzureOpenAIModel(OpenAICommonRequestResponseMixIn, AzureOpenAIClientMixIn,
presence_penalty: float = 0
seed: int = 0
api_version: str = "2023-06-01-preview"
auth_scope: str = "https://cognitiveservices.azure.com/.default"

def __post_init__(self):
self.client = self.get_client()
Expand All @@ -473,8 +489,17 @@ def __post_init__(self):


class OpenAIO1RequestResponseMixIn:
def create_request(self, prompt, *args, **kwargs):
messages = [{"role": "user", "content": prompt}]

def create_request(self, prompt, query_images=None, system_message=None, previous_messages=None):
if system_message:
# system messages are not supported for OAI reasoning models
# https://platform.openai.com/docs/guides/reasoning
logging.warning("System messages are not supported for OAI reasoning models.")
messages = []
if previous_messages:
messages.extend(previous_messages)

messages.append({"role": "user", "content": prompt})
return {"messages": messages}

def get_response(self, request):
Expand Down Expand Up @@ -528,6 +553,8 @@ class AzureOpenAIO1Model(OpenAIO1RequestResponseMixIn, AzureOpenAIClientMixIn, E
frequency_penalty: float = 0
presence_penalty: float = 0
api_version: str = "2023-06-01-preview"
auth_scope: str = "https://cognitiveservices.azure.com/.default"


def __post_init__(self):
self.client = self.get_client()
Expand Down Expand Up @@ -563,7 +590,13 @@ def __post_init__(self):
def create_request(self, text_prompt, query_images=None, system_message=None):
import google.generativeai as genai

self.model = genai.GenerativeModel(self.model_name, system_instruction=system_message)
if self.model_name == "gemini-1.0-pro":
if system_message:
logging.warning("System messages are not supported for Gemini 1.0 Pro.")
self.model = genai.GenerativeModel(self.model_name)
else:
self.model = genai.GenerativeModel(self.model_name, system_instruction=system_message)

if query_images:
return [text_prompt] + query_images
else:
Expand Down Expand Up @@ -942,14 +975,15 @@ def __post_init__(self):
timeout=self.timeout,
)

def create_request(self, prompt, query_images=None, system_message=None):
def create_request(self, prompt, query_images=None, system_message=None, previous_messages=None):
messages = []

user_content = {"role": "user", "content": prompt}

user_content = prompt
if previous_messages:
messages.extend(previous_messages)
if query_images:
encoded_images = self.base64encode(query_images)
user_content["content"] = [
user_content = [
{"type": "text", "text": prompt},
{
"type": "image",
Expand All @@ -960,7 +994,7 @@ def create_request(self, prompt, query_images=None, system_message=None):
},
},
]
messages.append(user_content)
messages.append({"role": "user", "content": user_content})

if system_message:
return {"messages": messages, "system": system_message}
Expand Down

0 comments on commit 4118424

Please sign in to comment.