File size: 2,692 Bytes
b5d547f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from smolagents import Tool
from googleapiclient.discovery import build
import os


class GoogleSearchTool(Tool):
    name = "web_search"
    description = """Performs a google web search for query then returns top search results in markdown format."""
    inputs = {
        "query": {
            "type": "string",
            "description": "The query to perform search.",
        },
    }
    output_type = "string"

    skip_forward_signature_validation = True

    def __init__(
            self,
            api_key: str | None = None,
            search_engine_id: str | None = None,
            num_results: int = 10,
            **kwargs,
    ):
        api_key = api_key if api_key is not None else os.getenv("GOOGLE_SEARCH_API_KEY")
        if not api_key:
            raise ValueError(
                "Please set the GOOGLE_SEARCH_API_KEY environment variable."
            )
        search_engine_id = (
            search_engine_id
            if search_engine_id is not None
            else os.getenv("GOOGLE_SEARCH_ENGINE_ID")
        )
        if not search_engine_id:
            raise ValueError(
                "Please set the GOOGLE_SEARCH_ENGINE_ID environment variable."
            )

        self.cse = build("customsearch", "v1", developerKey=api_key).cse()
        self.cx = search_engine_id
        self.num = num_results
        super().__init__(**kwargs)

    def _collect_params(self) -> dict:
        return {}

    def forward(self, query: str, *args, **kwargs) -> str:
        params = {
            "q": query,
            "cx": self.cx,
            "fields": "items(title,link,snippet)",
            "num": self.num,
        }

        params = params | self._collect_params(*args, **kwargs)

        response = self.cse.list(**params).execute()
        if "items" not in response:
            return "No results found."

        result = "\n\n".join(
            [
                f"[{item['title']}]({item['link']})\n{item['snippet']}"
                for item in response["items"]
            ]
        )
        return result


class GoogleSiteSearchTool(GoogleSearchTool):
    name = "site_search"
    description = """Performs a google search within the website for query then returns top search results in markdown format."""
    inputs = {
        "query": {
            "type": "string",
            "description": "The query to perform search.",
        },
        "site": {
            "type": "string",
            "description": "The domain of the site on which to search.",
        },
    }

    def _collect_params(self, site: str) -> dict:
        return {
            "siteSearch": site,
            "siteSearchFilter": "i",
        }