diff --git a/python/packages/autogen-core/src/autogen_core/code_executor/_func_with_reqs.py b/python/packages/autogen-core/src/autogen_core/code_executor/_func_with_reqs.py index 77fc0b831427..50a8c5280935 100644 --- a/python/packages/autogen-core/src/autogen_core/code_executor/_func_with_reqs.py +++ b/python/packages/autogen-core/src/autogen_core/code_executor/_func_with_reqs.py @@ -42,7 +42,7 @@ class ImportFromModule: module: str imports: Tuple[Union[str, Alias], ...] - ## backward compatibility + # backward compatibility def __init__( self, module: str, @@ -214,3 +214,11 @@ def to_stub(func: Union[Callable[..., Any], FunctionWithRequirementsStr]) -> str content += " ..." return content + + +def to_code(func: Union[FunctionWithRequirements[T, P], Callable[P, T], FunctionWithRequirementsStr]) -> str: + return _to_code(func) + + +def import_to_str(im: Import) -> str: + return _import_to_str(im) diff --git a/python/packages/autogen-core/src/autogen_core/tools/_base.py b/python/packages/autogen-core/src/autogen_core/tools/_base.py index 7c4042e9afd6..b484ef84f3e9 100644 --- a/python/packages/autogen-core/src/autogen_core/tools/_base.py +++ b/python/packages/autogen-core/src/autogen_core/tools/_base.py @@ -8,6 +8,7 @@ from typing_extensions import NotRequired from .. import CancellationToken +from .._component_config import ComponentBase from .._function_utils import normalize_annotated_type T = TypeVar("T", bound=BaseModel, contravariant=True) @@ -56,7 +57,9 @@ def load_state_json(self, state: Mapping[str, Any]) -> None: ... StateT = TypeVar("StateT", bound=BaseModel) -class BaseTool(ABC, Tool, Generic[ArgsT, ReturnT]): +class BaseTool(ABC, Tool, Generic[ArgsT, ReturnT], ComponentBase[BaseModel]): + component_type = "tool" + def __init__( self, args_type: Type[ArgsT], @@ -132,7 +135,7 @@ def load_state_json(self, state: Mapping[str, Any]) -> None: pass -class BaseToolWithState(BaseTool[ArgsT, ReturnT], ABC, Generic[ArgsT, ReturnT, StateT]): +class BaseToolWithState(BaseTool[ArgsT, ReturnT], ABC, Generic[ArgsT, ReturnT, StateT], ComponentBase[BaseModel]): def __init__( self, args_type: Type[ArgsT], @@ -144,6 +147,8 @@ def __init__( super().__init__(args_type, return_type, name, description) self._state_type = state_type + component_type = "tool" + @abstractmethod def save_state(self) -> StateT: ... diff --git a/python/packages/autogen-core/src/autogen_core/tools/_function_tool.py b/python/packages/autogen-core/src/autogen_core/tools/_function_tool.py index 026fc845e9c2..6b861292f249 100644 --- a/python/packages/autogen-core/src/autogen_core/tools/_function_tool.py +++ b/python/packages/autogen-core/src/autogen_core/tools/_function_tool.py @@ -1,18 +1,32 @@ import asyncio import functools -from typing import Any, Callable +from textwrap import dedent +from typing import Any, Callable, Sequence from pydantic import BaseModel +from typing_extensions import Self from .. import CancellationToken +from .._component_config import Component from .._function_utils import ( args_base_model_from_signature, get_typed_signature, ) +from ..code_executor._func_with_reqs import Import, import_to_str, to_code from ._base import BaseTool -class FunctionTool(BaseTool[BaseModel, BaseModel]): +class FunctionToolConfig(BaseModel): + """Configuration for a function tool.""" + + source_code: str + name: str + description: str + global_imports: Sequence[Import] + has_cancellation_support: bool + + +class FunctionTool(BaseTool[BaseModel, BaseModel], Component[FunctionToolConfig]): """ Create custom tools by wrapping standard Python functions. @@ -64,8 +78,14 @@ async def example(): asyncio.run(example()) """ - def __init__(self, func: Callable[..., Any], description: str, name: str | None = None) -> None: + component_provider_override = "autogen_core.tools.FunctionTool" + component_config_schema = FunctionToolConfig + + def __init__( + self, func: Callable[..., Any], description: str, name: str | None = None, global_imports: Sequence[Import] = [] + ) -> None: self._func = func + self._global_imports = global_imports signature = get_typed_signature(func) func_name = name or func.__name__ args_model = args_base_model_from_signature(func_name + "args", signature) @@ -98,3 +118,44 @@ async def run(self, args: BaseModel, cancellation_token: CancellationToken) -> A result = await future return result + + def _to_config(self) -> FunctionToolConfig: + return FunctionToolConfig( + source_code=dedent(to_code(self._func)), + global_imports=self._global_imports, + name=self.name, + description=self.description, + has_cancellation_support=self._has_cancellation_support, + ) + + @classmethod + def _from_config(cls, config: FunctionToolConfig) -> Self: + exec_globals: dict[str, Any] = {} + + # Execute imports first + for import_stmt in config.global_imports: + import_code = import_to_str(import_stmt) + try: + exec(import_code, exec_globals) + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + f"Failed to import {import_code}: Module not found. Please ensure the module is installed." + ) from e + except ImportError as e: + raise ImportError(f"Failed to import {import_code}: {str(e)}") from e + except Exception as e: + raise RuntimeError(f"Unexpected error while importing {import_code}: {str(e)}") from e + + # Execute function code + try: + exec(config.source_code, exec_globals) + func_name = config.source_code.split("def ")[1].split("(")[0] + except Exception as e: + raise ValueError(f"Could not compile and load function: {e}") from e + + # Get function and verify it's callable + func: Callable[..., Any] = exec_globals[func_name] + if not callable(func): + raise TypeError(f"Expected function but got {type(func)}") + + return cls(func, "", None) diff --git a/python/packages/autogen-core/tests/test_component_config.py b/python/packages/autogen-core/tests/test_component_config.py index d59fde59c1b6..1f78e907a447 100644 --- a/python/packages/autogen-core/tests/test_component_config.py +++ b/python/packages/autogen-core/tests/test_component_config.py @@ -4,9 +4,11 @@ from typing import Any, Dict import pytest -from autogen_core import Component, ComponentBase, ComponentLoader, ComponentModel +from autogen_core import CancellationToken, Component, ComponentBase, ComponentLoader, ComponentModel from autogen_core._component_config import _type_to_provider_str # type: ignore +from autogen_core.code_executor import ImportFromModule from autogen_core.models import ChatCompletionClient +from autogen_core.tools import FunctionTool from autogen_test_utils import MyInnerComponent, MyOuterComponent from pydantic import BaseModel, ValidationError from typing_extensions import Self @@ -283,3 +285,68 @@ def test_component_version_from_dict() -> None: assert comp.info == "test" assert comp.__class__ == ComponentNonOneVersionWithUpgrade assert comp.dump_component().version == 2 + + +@pytest.mark.asyncio +async def test_function_tool() -> None: + """Test FunctionTool with different function types and features.""" + + # Test sync and async functions + def sync_func(x: int, y: str) -> str: + return y * x + + async def async_func(x: float, y: float, cancellation_token: CancellationToken) -> float: + if cancellation_token.is_cancelled(): + raise Exception("Cancelled") + return x + y + + # Create tools with different configurations + sync_tool = FunctionTool( + func=sync_func, description="Multiply string", global_imports=[ImportFromModule("typing", ("Dict",))] + ) + invalid_import_sync_tool = FunctionTool( + func=sync_func, description="Multiply string", global_imports=[ImportFromModule("invalid_module (", ("Dict",))] + ) + + invalid_import_config = invalid_import_sync_tool.dump_component() + # check that invalid import raises an error + with pytest.raises(RuntimeError): + _ = FunctionTool.load_component(invalid_import_config, FunctionTool) + + async_tool = FunctionTool( + func=async_func, + description="Add numbers", + name="custom_adder", + global_imports=[ImportFromModule("autogen_core", ("CancellationToken",))], + ) + + # Test serialization and config + + sync_config = sync_tool.dump_component() + assert isinstance(sync_config, ComponentModel) + assert sync_config.config["name"] == "sync_func" + assert len(sync_config.config["global_imports"]) == 1 + assert not sync_config.config["has_cancellation_support"] + + async_config = async_tool.dump_component() + assert async_config.config["name"] == "custom_adder" + assert async_config.config["has_cancellation_support"] + + # Test deserialization and execution + loaded_sync = FunctionTool.load_component(sync_config, FunctionTool) + loaded_async = FunctionTool.load_component(async_config, FunctionTool) + + # Test execution and validation + token = CancellationToken() + assert await loaded_sync.run_json({"x": 2, "y": "test"}, token) == "testtest" + assert await loaded_async.run_json({"x": 1.5, "y": 2.5}, token) == 4.0 + + # Test error cases + with pytest.raises(ValueError): + # Type error + await loaded_sync.run_json({"x": "invalid", "y": "test"}, token) + + cancelled_token = CancellationToken() + cancelled_token.cancel() + with pytest.raises(Exception, match="Cancelled"): + await loaded_async.run_json({"x": 1.0, "y": 2.0}, cancelled_token)