Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python: Upgrading Mistral AI Connector to Version 1.0 #9542

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ milvus = [
"milvus >= 2.3,<2.3.8; platform_system != 'Windows'"
]
mistralai = [
"mistralai >= 0.4,< 2.0"
"mistralai >= 1.0,< 2.0"
]
ollama = [
"ollama ~= 0.2"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ class MistralAIChatPromptExecutionSettings(MistralAIPromptExecutionSettings):

response_format: dict[Literal["type"], Literal["text", "json_object"]] | None = None
messages: list[dict[str, Any]] | None = None
safe_mode: bool = False
safe_prompt: bool = False
max_tokens: int | None = Field(None, gt=0)
seed: int | None = None
Expand All @@ -43,5 +42,3 @@ class MistralAIChatPromptExecutionSettings(MistralAIPromptExecutionSettings):
None,
description="Do not set this manually. It is set by the service based on the function choice configuration.",
)


Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from abc import ABC
from typing import ClassVar

from mistralai.async_client import MistralAsyncClient
from mistralai import Mistral

from semantic_kernel.kernel_pydantic import KernelBaseModel

Expand All @@ -13,4 +13,4 @@ class MistralAIBase(KernelBaseModel, ABC):

MODEL_PROVIDER_NAME: ClassVar[str] = "mistralai"

async_client: MistralAsyncClient
async_client: Mistral
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
else:
from typing_extensions import override # pragma: no cover

from mistralai import Mistral
from mistralai.async_client import MistralAsyncClient
from mistralai.models.chat_completion import (
from mistralai.models import (
AssistantMessage,
ChatCompletionChoice,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
ChatMessage,
CompletionChunk,
CompletionResponseStreamChoice,
DeltaMessage,
)
from pydantic import ValidationError
Expand Down Expand Up @@ -63,7 +64,7 @@ def __init__(
ai_model_id: str | None = None,
service_id: str | None = None,
api_key: str | None = None,
async_client: MistralAsyncClient | None = None,
async_client: MistralAsyncClient | Mistral | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
) -> None:
Expand Down Expand Up @@ -93,8 +94,13 @@ def __init__(
if not mistralai_settings.chat_model_id:
raise ServiceInitializationError("The MistralAI chat model ID is required.")

if not async_client:
async_client = MistralAsyncClient(
# ensure backwards compatibility with MistralAsyncClient
if not async_client or isinstance(async_client, MistralAsyncClient):
if isinstance(async_client, MistralAsyncClient):
logger.warning(
"MistralAIChatCompletion: The MistralAsyncClient is deprecated, please use Mistral instead."
)
async_client = Mistral(
api_key=mistralai_settings.api_key.get_secret_value(),
)

Expand Down Expand Up @@ -135,7 +141,7 @@ async def _inner_get_chat_message_contents(
settings.messages = self._prepare_chat_history_for_request(chat_history)

try:
response = await self.async_client.chat(**settings.prepare_settings_dict())
response = await self.async_client.chat.complete_async(**settings.prepare_settings_dict())
except Exception as ex:
raise ServiceResponseException(
f"{type(self)} service failed to complete the prompt",
Expand All @@ -160,26 +166,27 @@ async def _inner_get_streaming_chat_message_contents(
settings.messages = self._prepare_chat_history_for_request(chat_history)

try:
response = self.async_client.chat_stream(**settings.prepare_settings_dict())
response = await self.async_client.chat.stream_async(**settings.prepare_settings_dict())
except Exception as ex:
raise ServiceResponseException(
f"{type(self)} service failed to complete the prompt",
ex,
) from ex
async for chunk in response:
if len(chunk.choices) == 0:
if len(chunk.data.choices) == 0:
continue
chunk_metadata = self._get_metadata_from_response(chunk)
chunk_metadata = self._get_metadata_from_response(chunk.data)
yield [
self._create_streaming_chat_message_content(chunk, choice, chunk_metadata) for choice in chunk.choices
self._create_streaming_chat_message_content(chunk.data, choice, chunk_metadata)
for choice in chunk.data.choices
]

# endregion

# region content conversion to SK

def _create_chat_message_content(
self, response: ChatCompletionResponse, choice: ChatCompletionResponseChoice, response_metadata: dict[str, Any]
self, response: ChatCompletionResponse, choice: ChatCompletionChoice, response_metadata: dict[str, Any]
) -> "ChatMessageContent":
"""Create a chat message content object from a choice."""
metadata = self._get_metadata_from_chat_choice(choice)
Expand All @@ -201,8 +208,8 @@ def _create_chat_message_content(

def _create_streaming_chat_message_content(
self,
chunk: ChatCompletionStreamResponse,
choice: ChatCompletionResponseStreamChoice,
chunk: CompletionChunk,
choice: CompletionResponseStreamChoice,
chunk_metadata: dict[str, Any],
) -> StreamingChatMessageContent:
"""Create a streaming chat message content object from a choice."""
Expand All @@ -224,9 +231,7 @@ def _create_streaming_chat_message_content(
items=items,
)

def _get_metadata_from_response(
self, response: ChatCompletionResponse | ChatCompletionStreamResponse
) -> dict[str, Any]:
def _get_metadata_from_response(self, response: ChatCompletionResponse | CompletionChunk) -> dict[str, Any]:
"""Get metadata from a chat response."""
metadata: dict[str, Any] = {
"id": response.id,
Expand All @@ -244,19 +249,19 @@ def _get_metadata_from_response(
return metadata

def _get_metadata_from_chat_choice(
self, choice: ChatCompletionResponseChoice | ChatCompletionResponseStreamChoice
self, choice: ChatCompletionChoice | CompletionResponseStreamChoice
) -> dict[str, Any]:
"""Get metadata from a chat choice."""
return {
"logprobs": getattr(choice, "logprobs", None),
}

def _get_tool_calls_from_chat_choice(
self, choice: ChatCompletionResponseChoice | ChatCompletionResponseStreamChoice
self, choice: ChatCompletionChoice | CompletionResponseStreamChoice
) -> list[FunctionCallContent]:
"""Get tool calls from a chat choice."""
content: ChatMessage | DeltaMessage
content = choice.message if isinstance(choice, ChatCompletionResponseChoice) else choice.delta
content: AssistantMessage | DeltaMessage
content = choice.message if isinstance(choice, ChatCompletionChoice) else choice.delta
if content.tool_calls is None:
return []

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@

import logging

from mistralai import Mistral
from mistralai.async_client import MistralAsyncClient
from mistralai.models.embeddings import EmbeddingResponse
from mistralai.models import EmbeddingResponse
from numpy import array, ndarray
from pydantic import ValidationError

Expand All @@ -33,7 +34,7 @@ def __init__(
ai_model_id: str | None = None,
api_key: str | None = None,
service_id: str | None = None,
async_client: MistralAsyncClient | None = None,
async_client: Mistral | MistralAsyncClient | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
) -> None:
Expand Down Expand Up @@ -68,9 +69,15 @@ def __init__(
if not mistralai_settings.embedding_model_id:
raise ServiceInitializationError("The MistralAI embedding model ID is required.")

if not async_client:
async_client = MistralAsyncClient(api_key=mistralai_settings.api_key.get_secret_value())

# ensure backwards compatibility with MistralAsyncClient
if not async_client or isinstance(async_client, MistralAsyncClient):
if isinstance(async_client, MistralAsyncClient):
logger.warning(
"MistralAIChatCompletion: The MistralAsyncClient is deprecated, please use Mistral instead."
)
async_client = Mistral(
api_key=mistralai_settings.api_key.get_secret_value(),
)
super().__init__(
service_id=service_id or mistralai_settings.embedding_model_id,
ai_model_id=ai_model_id or mistralai_settings.embedding_model_id,
Expand All @@ -96,8 +103,8 @@ async def generate_raw_embeddings(
) -> Any:
"""Generate embeddings from the Mistral AI service."""
try:
embedding_response: EmbeddingResponse = await self.async_client.embeddings(
model=self.ai_model_id, input=texts
embedding_response: EmbeddingResponse = await self.async_client.embeddings.create_async(
model=self.ai_model_id, inputs=texts
)
except Exception as ex:
raise ServiceResponseException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from mistralai.async_client import MistralAsyncClient
from mistralai import Mistral

from semantic_kernel.connectors.ai.chat_completion_client_base import ChatCompletionClientBase
from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceBehavior
Expand Down Expand Up @@ -37,37 +37,40 @@ def mock_settings() -> MistralAIChatPromptExecutionSettings:


@pytest.fixture
def mock_mistral_ai_client_completion() -> MistralAsyncClient:
client = MagicMock(spec=MistralAsyncClient)
def mock_mistral_ai_client_completion() -> Mistral:
client = MagicMock(spec=Mistral)
client.chat = AsyncMock()

chat_completion_response = AsyncMock()
choices = [MagicMock(finish_reason="stop", message=MagicMock(role="assistant", content="Test"))]
chat_completion_response.choices = choices
client.chat.return_value = chat_completion_response
client.chat.complete_async.return_value = chat_completion_response
return client


@pytest.fixture
def mock_mistral_ai_client_completion_stream() -> MistralAsyncClient:
client = MagicMock(spec=MistralAsyncClient)
def mock_mistral_ai_client_completion_stream() -> Mistral:
client = MagicMock(spec=Mistral)
client.chat = AsyncMock()
chat_completion_response = MagicMock()
choices = [
MagicMock(finish_reason="stop", delta=MagicMock(role="assistant", content="Test")),
MagicMock(finish_reason="stop", delta=MagicMock(role="assistant", content="Test", tool_calls=None)),
]
chat_completion_response.choices = choices
chat_completion_response.data.choices = choices
chat_completion_response_empty = MagicMock()
chat_completion_response_empty.choices = []
generator_mock = MagicMock()
generator_mock.__aiter__.return_value = [chat_completion_response_empty, chat_completion_response]
client.chat_stream.return_value = generator_mock
client.chat.stream_async.return_value = generator_mock
return client


@pytest.mark.asyncio
async def test_complete_chat_contents(
kernel: Kernel,
mock_settings: MistralAIChatPromptExecutionSettings,
mock_mistral_ai_client_completion: MistralAsyncClient,
mock_mistral_ai_client_completion: Mistral,
):
chat_history = MagicMock()
arguments = KernelArguments()
Expand Down Expand Up @@ -145,7 +148,7 @@ async def test_complete_chat_contents_function_call_behavior_tool_call(
@pytest.mark.asyncio
async def test_complete_chat_contents_function_call_behavior_without_kernel(
mock_settings: MistralAIChatPromptExecutionSettings,
mock_mistral_ai_client_completion: MistralAsyncClient,
mock_mistral_ai_client_completion: Mistral,
):
chat_history = MagicMock()
chat_completion_base = MistralAIChatCompletion(
Expand All @@ -162,7 +165,7 @@ async def test_complete_chat_contents_function_call_behavior_without_kernel(
async def test_complete_chat_stream_contents(
kernel: Kernel,
mock_settings: MistralAIChatPromptExecutionSettings,
mock_mistral_ai_client_completion_stream: MistralAsyncClient,
mock_mistral_ai_client_completion_stream: Mistral,
):
chat_history = MagicMock()
arguments = KernelArguments()
Expand Down Expand Up @@ -248,8 +251,9 @@ async def test_complete_chat_contents_streaming_function_call_behavior_tool_call
async def test_mistral_ai_sdk_exception(kernel: Kernel, mock_settings: MistralAIChatPromptExecutionSettings):
chat_history = MagicMock()
arguments = KernelArguments()
client = MagicMock(spec=MistralAsyncClient)
client.chat.side_effect = Exception("Test Exception")
client = MagicMock(spec=Mistral)
client.chat = MagicMock()
client.chat.complete_async.side_effect = Exception("Test Exception")

chat_completion_base = MistralAIChatCompletion(
ai_model_id="test_model_id", service_id="test", api_key="", async_client=client
Expand All @@ -265,8 +269,9 @@ async def test_mistral_ai_sdk_exception(kernel: Kernel, mock_settings: MistralAI
async def test_mistral_ai_sdk_exception_streaming(kernel: Kernel, mock_settings: MistralAIChatPromptExecutionSettings):
chat_history = MagicMock()
arguments = KernelArguments()
client = MagicMock(spec=MistralAsyncClient)
client.chat_stream.side_effect = Exception("Test Exception")
client = MagicMock(spec=Mistral)
client.chat = MagicMock()
client.chat.chat_stream.side_effect = Exception("Test Exception")

chat_completion_base = MistralAIChatCompletion(
ai_model_id="test_model_id", service_id="test", api_key="", async_client=client
Expand All @@ -284,7 +289,7 @@ def test_mistral_ai_chat_completion_init(mistralai_unit_test_env) -> None:
mistral_ai_chat_completion = MistralAIChatCompletion()

assert mistral_ai_chat_completion.ai_model_id == mistralai_unit_test_env["MISTRALAI_CHAT_MODEL_ID"]
assert mistral_ai_chat_completion.async_client._api_key == mistralai_unit_test_env["MISTRALAI_API_KEY"]
assert mistral_ai_chat_completion.async_client.sdk_configuration.security.api_key == mistralai_unit_test_env["MISTRALAI_API_KEY"]
assert isinstance(mistral_ai_chat_completion, ChatCompletionClientBase)


Expand All @@ -298,7 +303,7 @@ def test_mistral_ai_chat_completion_init_constructor(mistralai_unit_test_env) ->
)

assert mistral_ai_chat_completion.ai_model_id == "overwrite_model_id"
assert mistral_ai_chat_completion.async_client._api_key == "overwrite_api_key"
assert mistral_ai_chat_completion.async_client.sdk_configuration.security.api_key == "overwrite_api_key"
assert isinstance(mistral_ai_chat_completion, ChatCompletionClientBase)


Expand All @@ -322,7 +327,7 @@ def test_mistral_ai_chat_completion_init_hybrid(mistralai_unit_test_env) -> None
env_file_path="test.env",
)
assert mistral_ai_chat_completion.ai_model_id == "overwrite_model_id"
assert mistral_ai_chat_completion.async_client._api_key == "test_api_key"
assert mistral_ai_chat_completion.async_client.sdk_configuration.security.api_key == "test_api_key"


@pytest.mark.parametrize("exclude_list", [["MISTRALAI_CHAT_MODEL_ID"]], indirect=True)
Expand Down Expand Up @@ -351,8 +356,8 @@ async def test_with_different_execution_settings(kernel: Kernel, mock_mistral_ai
await chat_completion_base.get_chat_message_contents(
chat_history=chat_history, settings=settings, kernel=kernel, arguments=arguments
)
assert mock_mistral_ai_client_completion.chat.call_args.kwargs["temperature"] == 0.2
assert mock_mistral_ai_client_completion.chat.call_args.kwargs["seed"] == 2
assert mock_mistral_ai_client_completion.chat.complete_async.call_args.kwargs["temperature"] == 0.2
assert mock_mistral_ai_client_completion.chat.complete_async.call_args.kwargs["seed"] == 2


@pytest.mark.asyncio
Expand All @@ -373,5 +378,5 @@ async def test_with_different_execution_settings_stream(
chat_history, settings, kernel=kernel, arguments=arguments
):
continue
assert mock_mistral_ai_client_completion_stream.chat_stream.call_args.kwargs["temperature"] == 0.2
assert mock_mistral_ai_client_completion_stream.chat_stream.call_args.kwargs["seed"] == 2
assert mock_mistral_ai_client_completion_stream.chat.stream_async.call_args.kwargs["temperature"] == 0.2
assert mock_mistral_ai_client_completion_stream.chat.stream_async.call_args.kwargs["seed"] == 2
Loading