ofermend commited on
Commit
79067ba
·
1 Parent(s): 4f39567
Files changed (2) hide show
  1. agent.py +65 -50
  2. requirements.txt +1 -1
agent.py CHANGED
@@ -24,29 +24,17 @@ get_headers = {
24
  "Connection": "keep-alive",
25
  }
26
 
27
- def create_assistant_tools(cfg, agent_config):
28
-
29
- class QueryHackerNews(BaseModel):
30
- query: str = Field(..., description="The user query.")
31
-
32
- vec_factory = VectaraToolFactory(vectara_api_key=cfg.api_key,
33
- vectara_corpus_key=cfg.corpus_key)
34
- summarizer = 'vectara-summary-ext-24-05-med-omni'
35
- ask_hackernews = vec_factory.create_rag_tool(
36
- tool_name = "ask_hackernews",
37
- tool_description = """
38
- Provides information on any topic or query, based on relevant hacker news stories.
39
- """,
40
- tool_args_schema = QueryHackerNews,
41
- reranker = "multilingual_reranker_v1", rerank_k = 100,
42
- n_sentences_before = 2, n_sentences_after = 2, lambda_val = 0.0,
43
- summary_num_results = 10,
44
- vectara_summarizer = summarizer,
45
- include_citations = True,
46
- verbose=True
47
- )
48
 
49
  def get_top_stories(
 
50
  n_stories: int = Field(default=10, description="The number of top stories to return.")
51
  ) -> list[str]:
52
  """
@@ -58,6 +46,7 @@ def create_assistant_tools(cfg, agent_config):
58
  return top_stories[:n_stories]
59
 
60
  def get_show_stories(
 
61
  n_stories: int = Field(default=10, description="The number of top SHOW HN stories to return.")
62
  ) -> list[str]:
63
  """
@@ -69,6 +58,7 @@ def create_assistant_tools(cfg, agent_config):
69
  return top_stories[:n_stories]
70
 
71
  def get_ask_stories(
 
72
  n_stories: int = Field(default=10, description="The number of top ASK HN stories to return.")
73
  ) -> list[str]:
74
  """
@@ -80,16 +70,17 @@ def create_assistant_tools(cfg, agent_config):
80
  return top_stories[:n_stories]
81
 
82
  def get_story_details(
 
83
  story_id: str = Field(..., description="The story ID.")
84
  ) -> Tuple[str, str]:
85
  """
86
  Get the title, url and external link of a story from hacker news.
87
  Returns:
88
- - The title of the story (str)
89
- - The main URL of the story (str)
90
- - The external link pointed to in the story (str)
91
- - The author of the story
92
- - The number of descendants (comments + replies) of the story
93
  """
94
  db_url = 'https://hacker-news.firebaseio.com/v0/'
95
  story = requests.get(f"{db_url}item/{story_id}.json").json()
@@ -97,6 +88,7 @@ def create_assistant_tools(cfg, agent_config):
97
  return story['title'], story_url, story['url'], story['by'], story['descendants']
98
 
99
  def get_story_text(
 
100
  story_id: str = Field(..., description="The story ID.")
101
  ) -> str:
102
  """
@@ -112,6 +104,7 @@ def create_assistant_tools(cfg, agent_config):
112
  return text
113
 
114
  def whats_new(
 
115
  n_stories: int = Field(default=10, description="The number of new stories to return.")
116
  ) -> list[str]:
117
  """
@@ -119,31 +112,53 @@ def create_assistant_tools(cfg, agent_config):
119
  by summarizing the content and comments of top stories.
120
  Returns a string with the summary.
121
  """
122
- stories = get_top_stories(n_stories)
123
- texts = [get_story_text(story_id) for story_id in stories[:n_stories]]
124
  all_stories = '---------\n\n'.join(texts)
125
- summarize_text = ToolsCatalog(agent_config).summarize_text
126
  return summarize_text(all_stories)
127
-
128
- tools_factory = ToolsFactory()
129
- return (
130
- [ask_hackernews] +
131
- [tools_factory.create_tool(tool) for tool in
132
- [
133
- get_top_stories,
134
- get_show_stories,
135
- get_ask_stories,
136
- get_story_details,
137
- get_story_text,
138
- whats_new,
139
- ]
140
- ] +
141
- tools_factory.get_llama_index_tools(
142
- "tavily_research", "TavilyToolSpec",
143
- tool_name_prefix="tavily", api_key=cfg.tavily_api_key
144
- ) +
145
- tools_factory.standard_tools()
146
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
  def initialize_agent(_cfg, agent_progress_callback = None):
149
  bot_instructions = """
@@ -158,7 +173,7 @@ def initialize_agent(_cfg, agent_progress_callback = None):
158
  """
159
  agent_config = AgentConfig()
160
  agent = Agent(
161
- tools=create_assistant_tools(_cfg, agent_config),
162
  topic="hacker news",
163
  custom_instructions=bot_instructions,
164
  agent_progress_callback=agent_progress_callback,
 
24
  "Connection": "keep-alive",
25
  }
26
 
27
+ class AgentTools:
28
+ def __init__(self, _cfg, agent_config):
29
+ self.tools_factory = ToolsFactory()
30
+ self.agent_config = agent_config
31
+ self.cfg = _cfg
32
+ self.vec_factory = VectaraToolFactory(vectara_api_key=_cfg.api_key,
33
+ vectara_corpus_key=_cfg.corpus_key)
34
+
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  def get_top_stories(
37
+ self,
38
  n_stories: int = Field(default=10, description="The number of top stories to return.")
39
  ) -> list[str]:
40
  """
 
46
  return top_stories[:n_stories]
47
 
48
  def get_show_stories(
49
+ self,
50
  n_stories: int = Field(default=10, description="The number of top SHOW HN stories to return.")
51
  ) -> list[str]:
52
  """
 
58
  return top_stories[:n_stories]
59
 
60
  def get_ask_stories(
61
+ self,
62
  n_stories: int = Field(default=10, description="The number of top ASK HN stories to return.")
63
  ) -> list[str]:
64
  """
 
70
  return top_stories[:n_stories]
71
 
72
  def get_story_details(
73
+ self,
74
  story_id: str = Field(..., description="The story ID.")
75
  ) -> Tuple[str, str]:
76
  """
77
  Get the title, url and external link of a story from hacker news.
78
  Returns:
79
+ - The title of the story (str)
80
+ - The main URL of the story (str)
81
+ - The external link pointed to in the story (str)
82
+ - The author of the story
83
+ - The number of descendants (comments + replies) of the story
84
  """
85
  db_url = 'https://hacker-news.firebaseio.com/v0/'
86
  story = requests.get(f"{db_url}item/{story_id}.json").json()
 
88
  return story['title'], story_url, story['url'], story['by'], story['descendants']
89
 
90
  def get_story_text(
91
+ self,
92
  story_id: str = Field(..., description="The story ID.")
93
  ) -> str:
94
  """
 
104
  return text
105
 
106
  def whats_new(
107
+ self,
108
  n_stories: int = Field(default=10, description="The number of new stories to return.")
109
  ) -> list[str]:
110
  """
 
112
  by summarizing the content and comments of top stories.
113
  Returns a string with the summary.
114
  """
115
+ stories = self.get_top_stories(n_stories)
116
+ texts = [self.get_story_text(story_id) for story_id in stories[:n_stories]]
117
  all_stories = '---------\n\n'.join(texts)
118
+ summarize_text = ToolsCatalog(AgentConfig()).summarize_text
119
  return summarize_text(all_stories)
120
+
121
+
122
+ def get_tools(self):
123
+ class QueryHackerNews(BaseModel):
124
+ query: str = Field(..., description="The user query.")
125
+
126
+ vec_factory = VectaraToolFactory(vectara_api_key=cfg.api_key,
127
+ vectara_corpus_key=cfg.corpus_key)
128
+ summarizer = 'vectara-summary-ext-24-05-med-omni'
129
+ ask_hackernews = vec_factory.create_rag_tool(
130
+ tool_name = "ask_hackernews",
131
+ tool_description = """
132
+ Provides information on any topic or query, based on relevant hacker news stories.
133
+ """,
134
+ tool_args_schema = QueryHackerNews,
135
+ reranker = "multilingual_reranker_v1", rerank_k = 100,
136
+ n_sentences_before = 2, n_sentences_after = 2, lambda_val = 0.0,
137
+ summary_num_results = 10,
138
+ vectara_summarizer = summarizer,
139
+ include_citations = True,
140
+ verbose=True
141
+ )
142
+
143
+ tools_factory = ToolsFactory()
144
+ return (
145
+ [ask_hackernews] +
146
+ [tools_factory.create_tool(tool) for tool in
147
+ [
148
+ self.get_top_stories,
149
+ self.get_show_stories,
150
+ self.get_ask_stories,
151
+ self.get_story_details,
152
+ self.get_story_text,
153
+ self.whats_new,
154
+ ]
155
+ ] +
156
+ tools_factory.get_llama_index_tools(
157
+ "tavily_research", "TavilyToolSpec",
158
+ tool_name_prefix="tavily", api_key=self.cfg.tavily_api_key
159
+ ) +
160
+ tools_factory.standard_tools()
161
+ )
162
 
163
  def initialize_agent(_cfg, agent_progress_callback = None):
164
  bot_instructions = """
 
173
  """
174
  agent_config = AgentConfig()
175
  agent = Agent(
176
+ tools=AgentTools(_cfg, agent_config).get_tools(),
177
  topic="hacker news",
178
  custom_instructions=bot_instructions,
179
  agent_progress_callback=agent_progress_callback,
requirements.txt CHANGED
@@ -6,5 +6,5 @@ streamlit_feedback==0.1.3
6
  uuid==1.30
7
  langdetect==1.0.9
8
  langcodes==3.4.0
9
- vectara-agentic==0.2.0
10
 
 
6
  uuid==1.30
7
  langdetect==1.0.9
8
  langcodes==3.4.0
9
+ vectara-agentic==0.2.1
10