File size: 2,692 Bytes
b5d547f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
from smolagents import Tool
from googleapiclient.discovery import build
import os
class GoogleSearchTool(Tool):
name = "web_search"
description = """Performs a google web search for query then returns top search results in markdown format."""
inputs = {
"query": {
"type": "string",
"description": "The query to perform search.",
},
}
output_type = "string"
skip_forward_signature_validation = True
def __init__(
self,
api_key: str | None = None,
search_engine_id: str | None = None,
num_results: int = 10,
**kwargs,
):
api_key = api_key if api_key is not None else os.getenv("GOOGLE_SEARCH_API_KEY")
if not api_key:
raise ValueError(
"Please set the GOOGLE_SEARCH_API_KEY environment variable."
)
search_engine_id = (
search_engine_id
if search_engine_id is not None
else os.getenv("GOOGLE_SEARCH_ENGINE_ID")
)
if not search_engine_id:
raise ValueError(
"Please set the GOOGLE_SEARCH_ENGINE_ID environment variable."
)
self.cse = build("customsearch", "v1", developerKey=api_key).cse()
self.cx = search_engine_id
self.num = num_results
super().__init__(**kwargs)
def _collect_params(self) -> dict:
return {}
def forward(self, query: str, *args, **kwargs) -> str:
params = {
"q": query,
"cx": self.cx,
"fields": "items(title,link,snippet)",
"num": self.num,
}
params = params | self._collect_params(*args, **kwargs)
response = self.cse.list(**params).execute()
if "items" not in response:
return "No results found."
result = "\n\n".join(
[
f"[{item['title']}]({item['link']})\n{item['snippet']}"
for item in response["items"]
]
)
return result
class GoogleSiteSearchTool(GoogleSearchTool):
name = "site_search"
description = """Performs a google search within the website for query then returns top search results in markdown format."""
inputs = {
"query": {
"type": "string",
"description": "The query to perform search.",
},
"site": {
"type": "string",
"description": "The domain of the site on which to search.",
},
}
def _collect_params(self, site: str) -> dict:
return {
"siteSearch": site,
"siteSearchFilter": "i",
}
|