diff --git a/src/smolagents/tools.py b/src/smolagents/tools.py index 48c9c19..92aeaa4 100644 --- a/src/smolagents/tools.py +++ b/src/smolagents/tools.py @@ -206,10 +206,10 @@ class Tool: assert getattr(self, "output_type", None) in AUTHORIZED_TYPES - # Validate forward function signature, except for PipelineTool + # Validate forward function signature, except for Tools that use a "generic" signature (PipelineTool, SpaceToolWrapper) if not ( - hasattr(self, "is_pipeline_tool") - and getattr(self, "is_pipeline_tool") is True + hasattr(self, "skip_forward_signature_validation") + and getattr(self, "skip_forward_signature_validation") is True ): signature = inspect.signature(self.forward) @@ -575,6 +575,9 @@ class Tool: from gradio_client import Client, handle_file class SpaceToolWrapper(Tool): + + skip_forward_signature_validation = True + def __init__( self, space_id: str, @@ -1098,7 +1101,7 @@ class PipelineTool(Tool): name = "pipeline" inputs = {"prompt": str} output_type = str - is_pipeline_tool = True + skip_forward_signature_validation = True def __init__( self,