diff --git a/python/packages/autogen-ext/src/autogen_ext/models/semantic_kernel/_sk_chat_completion_adapter.py b/python/packages/autogen-ext/src/autogen_ext/models/semantic_kernel/_sk_chat_completion_adapter.py index fe21c134f621..278e448e9885 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/semantic_kernel/_sk_chat_completion_adapter.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/semantic_kernel/_sk_chat_completion_adapter.py @@ -12,7 +12,10 @@ from semantic_kernel.kernel import Kernel from semantic_kernel.functions.kernel_plugin import KernelPlugin from typing_extensions import AsyncGenerator, Union -from ._kernel_function_from_tool import KernelFunctionFromTool +from autogen_ext.tools.semantic_kernel import KernelFunctionFromTool +from semantic_kernel.contents.function_call_content import FunctionCallContent +from autogen_core import FunctionCall +from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent class SKChatCompletionAdapter(ChatCompletionClient): @@ -76,7 +79,7 @@ def _build_execution_settings(self, extra_create_args: Mapping[str, Any], tools: # If tools are available, configure function choice behavior with auto_invoke disabled function_choice_behavior = None if tools: - function_choice_behavior = FunctionChoiceBehavior.NoneInvoke() + function_choice_behavior = FunctionChoiceBehavior.Auto(auto_invoke=extra_create_args.get("auto_invoke", False)) # Create settings with remaining args as extension_data settings = PromptExecutionSettings( @@ -111,6 +114,28 @@ def _sync_tools_with_kernel(self, kernel: Kernel, tools: Sequence[Tool | ToolSch kernel_function = KernelFunctionFromTool(tool, plugin_name="autogen_tools") self._tools_plugin.functions[tool.name] = kernel_function + def _process_tool_calls(self, result: ChatMessageContent) -> list[FunctionCall]: + """Process tool calls from SK ChatMessageContent""" + function_calls = [] + for item in result.items: + if isinstance(item, FunctionCallContent): + # Extract plugin name and function name + plugin_name = item.plugin_name or "" + function_name = item.function_name or item.name + if plugin_name: + full_name = f"{plugin_name}-{function_name}" + else: + full_name = function_name + + function_calls.append( + FunctionCall( + id=item.id, + name=full_name, + arguments=item.arguments or "{}" + ) + ) + return function_calls + async def create( self, messages: Sequence[LLMMessage], @@ -150,10 +175,19 @@ async def create( self._total_prompt_tokens += prompt_tokens self._total_completion_tokens += completion_tokens + + # Process content based on whether there are tool calls + content: Union[str, list[FunctionCall]] + if any(isinstance(item, FunctionCallContent) for item in result[0].items): + content = self._process_tool_calls(result[0]) + finish_reason = "function_calls" + else: + content = result[0].content + finish_reason = "stop" return CreateResult( - content=result[0].content, - finish_reason="stop", + content=content, + finish_reason=finish_reason, usage=RequestUsage( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens @@ -169,12 +203,68 @@ async def create_stream( extra_create_args: Mapping[str, Any] = {}, cancellation_token: Optional[CancellationToken] = None, ) -> AsyncGenerator[Union[str, CreateResult], None]: - # Very similar to create(), but orchestrates streaming. - # 1. Convert messages -> ChatHistory - # 2. Possibly set function-calling if needed - # 3. Build generator that yields str segments or a final CreateResult - # from SK's get_streaming_chat_message_contents(...) - raise NotImplementedError("create_stream is not implemented") + if "kernel" not in extra_create_args: + raise ValueError("kernel is required in extra_create_args") + + kernel = extra_create_args["kernel"] + if not isinstance(kernel, Kernel): + raise ValueError("kernel must be an instance of semantic_kernel.kernel.Kernel") + + chat_history = self._convert_to_chat_history(messages) + settings = self._build_execution_settings(extra_create_args, tools) + self._sync_tools_with_kernel(kernel, tools) + + prompt_tokens = 0 + completion_tokens = 0 + accumulated_content = "" + + async for streaming_messages in self._sk_client.get_streaming_chat_message_contents( + chat_history, + settings=settings, + kernel=kernel + ): + for msg in streaming_messages: + if not isinstance(msg, StreamingChatMessageContent): + continue + + # Track token usage + if msg.metadata and 'usage' in msg.metadata: + usage = msg.metadata['usage'] + prompt_tokens = getattr(usage, 'prompt_tokens', 0) + completion_tokens = getattr(usage, 'completion_tokens', 0) + + # Check for function calls + if any(isinstance(item, FunctionCallContent) for item in msg.items): + function_calls = self._process_tool_calls(msg) + yield CreateResult( + content=function_calls, + finish_reason="function_calls", + usage=RequestUsage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens + ), + cached=False + ) + return + + # Handle text content + if msg.content: + accumulated_content += msg.content + yield msg.content + + # Final yield if there was text content + if accumulated_content: + self._total_prompt_tokens += prompt_tokens + self._total_completion_tokens += completion_tokens + yield CreateResult( + content=accumulated_content, + finish_reason="stop", + usage=RequestUsage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens + ), + cached=False + ) def actual_usage(self) -> RequestUsage: return RequestUsage( diff --git a/python/packages/autogen-ext/tests/models/test_sk_chat_completion_adapter.py b/python/packages/autogen-ext/tests/models/test_sk_chat_completion_adapter.py index 80944df6b778..c79568f4bc12 100644 --- a/python/packages/autogen-ext/tests/models/test_sk_chat_completion_adapter.py +++ b/python/packages/autogen-ext/tests/models/test_sk_chat_completion_adapter.py @@ -28,20 +28,20 @@ def __init__(self): async def run(self, args: CalculatorArgs, cancellation_token: CancellationToken) -> CalculatorResult: return CalculatorResult(result=args.a + args.b) -@pytest.mark.asyncio -async def test_sk_chat_completion_with_tools(): - # Set up Azure OpenAI client with token auth - deployment_name = "gpt-4o-mini" - endpoint = "https://.openai.azure.com/" - api_version = "2024-07-18" - - # Create SK client - sk_client = AzureChatCompletion( +@pytest.fixture +def sk_client(): + deployment_name = os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME") + endpoint = os.getenv("AZURE_OPENAI_ENDPOINT") + api_key = os.getenv("AZURE_OPENAI_API_KEY") + + return AzureChatCompletion( deployment_name=deployment_name, endpoint=endpoint, - api_key=os.getenv("AZURE_OPENAI_API_KEY"), + api_key=api_key, ) - + +@pytest.mark.asyncio +async def test_sk_chat_completion_with_tools(sk_client): # Create adapter adapter = SKChatCompletionAdapter(sk_client) @@ -63,7 +63,31 @@ async def test_sk_chat_completion_with_tools(): tools=[tool], extra_create_args={"kernel": kernel} ) + + # Verify response + assert isinstance(result.content, list) + assert result.finish_reason == "function_calls" + assert result.usage.prompt_tokens >= 0 + assert result.usage.completion_tokens >= 0 + assert not result.cached +@pytest.mark.asyncio +async def test_sk_chat_completion_without_tools(sk_client): + # Create adapter and kernel + adapter = SKChatCompletionAdapter(sk_client) + kernel = Kernel(memory=NullMemory()) + + # Test messages + messages = [ + SystemMessage(content="You are a helpful assistant."), + UserMessage(content="Say hello!", source="user"), + ] + + # Call create without tools + result = await adapter.create( + messages=messages, + extra_create_args={"kernel": kernel} + ) # Verify response assert isinstance(result.content, str) @@ -71,3 +95,70 @@ async def test_sk_chat_completion_with_tools(): assert result.usage.prompt_tokens >= 0 assert result.usage.completion_tokens >= 0 assert not result.cached + +@pytest.mark.asyncio +async def test_sk_chat_completion_stream_with_tools(sk_client): + # Create adapter and kernel + adapter = SKChatCompletionAdapter(sk_client) + kernel = Kernel(memory=NullMemory()) + + # Create calculator tool + tool = CalculatorTool() + + # Test messages + messages = [ + SystemMessage(content="You are a helpful assistant."), + UserMessage(content="What is 2 + 2?", source="user"), + ] + + # Call create_stream with tool + response_chunks = [] + async for chunk in adapter.create_stream( + messages=messages, + tools=[tool], + extra_create_args={"kernel": kernel} + ): + response_chunks.append(chunk) + + # Verify response + assert len(response_chunks) > 0 + final_chunk = response_chunks[-1] + assert isinstance(final_chunk.content, list) # Function calls + assert final_chunk.finish_reason == "function_calls" + assert final_chunk.usage.prompt_tokens >= 0 + assert final_chunk.usage.completion_tokens >= 0 + assert not final_chunk.cached + +@pytest.mark.asyncio +async def test_sk_chat_completion_stream_without_tools(sk_client): + # Create adapter and kernel + adapter = SKChatCompletionAdapter(sk_client) + kernel = Kernel(memory=NullMemory()) + + # Test messages + messages = [ + SystemMessage(content="You are a helpful assistant."), + UserMessage(content="Say hello!", source="user"), + ] + + # Call create_stream without tools + response_chunks = [] + async for chunk in adapter.create_stream( + messages=messages, + extra_create_args={"kernel": kernel} + ): + response_chunks.append(chunk) + + # Verify response + assert len(response_chunks) > 0 + # All chunks except last should be strings + for chunk in response_chunks[:-1]: + assert isinstance(chunk, str) + + # Final chunk should be CreateResult + final_chunk = response_chunks[-1] + assert isinstance(final_chunk.content, str) + assert final_chunk.finish_reason == "stop" + assert final_chunk.usage.prompt_tokens >= 0 + assert final_chunk.usage.completion_tokens >= 0 + assert not final_chunk.cached