Final_Assignment_Template / tools /google_search_tools.py
onkar127's picture
base upload
b5d547f verified
raw
history blame
2.69 kB
from smolagents import Tool
from googleapiclient.discovery import build
import os
class GoogleSearchTool(Tool):
name = "web_search"
description = """Performs a google web search for query then returns top search results in markdown format."""
inputs = {
"query": {
"type": "string",
"description": "The query to perform search.",
},
}
output_type = "string"
skip_forward_signature_validation = True
def __init__(
self,
api_key: str | None = None,
search_engine_id: str | None = None,
num_results: int = 10,
**kwargs,
):
api_key = api_key if api_key is not None else os.getenv("GOOGLE_SEARCH_API_KEY")
if not api_key:
raise ValueError(
"Please set the GOOGLE_SEARCH_API_KEY environment variable."
)
search_engine_id = (
search_engine_id
if search_engine_id is not None
else os.getenv("GOOGLE_SEARCH_ENGINE_ID")
)
if not search_engine_id:
raise ValueError(
"Please set the GOOGLE_SEARCH_ENGINE_ID environment variable."
)
self.cse = build("customsearch", "v1", developerKey=api_key).cse()
self.cx = search_engine_id
self.num = num_results
super().__init__(**kwargs)
def _collect_params(self) -> dict:
return {}
def forward(self, query: str, *args, **kwargs) -> str:
params = {
"q": query,
"cx": self.cx,
"fields": "items(title,link,snippet)",
"num": self.num,
}
params = params | self._collect_params(*args, **kwargs)
response = self.cse.list(**params).execute()
if "items" not in response:
return "No results found."
result = "\n\n".join(
[
f"[{item['title']}]({item['link']})\n{item['snippet']}"
for item in response["items"]
]
)
return result
class GoogleSiteSearchTool(GoogleSearchTool):
name = "site_search"
description = """Performs a google search within the website for query then returns top search results in markdown format."""
inputs = {
"query": {
"type": "string",
"description": "The query to perform search.",
},
"site": {
"type": "string",
"description": "The domain of the site on which to search.",
},
}
def _collect_params(self, site: str) -> dict:
return {
"siteSearch": site,
"siteSearchFilter": "i",
}