diff --git a/src/smolagents/default_tools.py b/src/smolagents/default_tools.py index a367758..2d0ac58 100644 --- a/src/smolagents/default_tools.py +++ b/src/smolagents/default_tools.py @@ -138,48 +138,62 @@ class GoogleSearchTool(Tool): } output_type = "string" - def __init__(self): + def __init__(self, provider: str = "serpapi"): super().__init__(self) import os - self.serpapi_key = os.getenv("SERPAPI_API_KEY") + self.provider = provider + if provider == "serpapi": + self.organic_key = "organic_results" + api_key_env_name = "SERPAPI_API_KEY" + else: + self.organic_key = "organic" + api_key_env_name = "SERPER_API_KEY" + self.api_key = os.getenv(api_key_env_name) + if self.api_key is None: + raise ValueError(f"Missing API key. Make sure you have '{api_key_env_name}' in your env variables.") def forward(self, query: str, filter_year: Optional[int] = None) -> str: import requests - if self.serpapi_key is None: - raise ValueError("Missing SerpAPI key. Make sure you have 'SERPAPI_API_KEY' in your env variables.") - - params = { - "engine": "google", - "q": query, - "api_key": self.serpapi_key, - "google_domain": "google.com", - } + if self.provider == "serpapi": + params = { + "q": query, + "api_key": self.api_key, + "engine": "google", + "google_domain": "google.com", + } + base_url = "https://serpapi.com/search.json" + else: + params = { + "q": query, + "api_key": self.api_key, + } + base_url = "https://google.serper.dev/search" if filter_year is not None: params["tbs"] = f"cdr:1,cd_min:01/01/{filter_year},cd_max:12/31/{filter_year}" - response = requests.get("https://serpapi.com/search.json", params=params) + response = requests.get(base_url, params=params) if response.status_code == 200: results = response.json() else: raise ValueError(response.json()) - if "organic_results" not in results.keys(): + if self.organic_key not in results.keys(): if filter_year is not None: raise Exception( f"No results found for query: '{query}' with filtering on year={filter_year}. Use a less restrictive query or do not filter on year." ) else: raise Exception(f"No results found for query: '{query}'. Use a less restrictive query.") - if len(results["organic_results"]) == 0: + if len(results[self.organic_key]) == 0: year_filter_message = f" with filter year={filter_year}" if filter_year is not None else "" return f"No results found for '{query}'{year_filter_message}. Try with a more general query, or remove the year filter." web_snippets = [] - if "organic_results" in results: - for idx, page in enumerate(results["organic_results"]): + if self.organic_key in results: + for idx, page in enumerate(results[self.organic_key]): date_published = "" if "date" in page: date_published = "\nDate published: " + page["date"] @@ -193,8 +207,6 @@ class GoogleSearchTool(Tool): snippet = "\n" + page["snippet"] redacted_version = f"{idx}. [{page['title']}]({page['link']}){date_published}{source}\n{snippet}" - - redacted_version = redacted_version.replace("Your browser can't play this video.", "") web_snippets.append(redacted_version) return "## Search Results\n" + "\n\n".join(web_snippets)