Spaces:
Running
Running
update
Browse files- agent.py +90 -83
- requirements.txt +1 -1
agent.py
CHANGED
@@ -10,94 +10,101 @@ from dotenv import load_dotenv
|
|
10 |
load_dotenv(override=True)
|
11 |
|
12 |
from vectara_agentic.agent import Agent
|
|
|
13 |
from vectara_agentic.tools import ToolsFactory, VectaraToolFactory
|
14 |
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
vec_factory_1 = VectaraToolFactory(vectara_api_key=cfg.api_keys[0],
|
21 |
-
vectara_corpus_key=cfg.corpus_keys[0])
|
22 |
-
|
23 |
-
summarizer = 'vectara-experimental-summary-ext-2023-12-11-med-omni'
|
24 |
-
|
25 |
-
ask_vehicles = vec_factory_1.create_rag_tool(
|
26 |
-
tool_name = "ask_vehicles",
|
27 |
-
tool_description = """
|
28 |
-
Given a user query,
|
29 |
-
returns a response to a user question about electric vehicles.
|
30 |
-
""",
|
31 |
-
tool_args_schema = QueryElectricCars,
|
32 |
-
reranker = "chain", rerank_k = 100,
|
33 |
-
rerank_chain = [
|
34 |
-
{
|
35 |
-
"type": "slingshot",
|
36 |
-
"cutoff": 0.2
|
37 |
-
},
|
38 |
-
{
|
39 |
-
"type": "mmr",
|
40 |
-
"diversity_bias": 0.1
|
41 |
-
}
|
42 |
-
],
|
43 |
-
n_sentences_before = 2, n_sentences_after = 2, lambda_val = 0.005,
|
44 |
-
summary_num_results = 5,
|
45 |
-
vectara_summarizer = summarizer,
|
46 |
-
include_citations = False,
|
47 |
-
)
|
48 |
-
|
49 |
-
vec_factory_2 = VectaraToolFactory(vectara_api_key=cfg.api_keys[1],
|
50 |
-
vectara_corpus_key=cfg.corpus_keys[1])
|
51 |
-
|
52 |
-
|
53 |
-
class QueryEVLaws(BaseModel):
|
54 |
-
query: str = Field(description="The user query")
|
55 |
-
state: Optional[str] = Field(default=None,
|
56 |
-
description="The two digit state code. Optional.",
|
57 |
-
examples=['CA', 'US', 'WA'])
|
58 |
-
policy_type: Optional[str] = Field(default=None,
|
59 |
-
description="The type of policy. Optional",
|
60 |
-
examples = ['Laws and Regulations', 'State Incentives', 'Incentives', 'Utility / Private Incentives', 'Programs'])
|
61 |
-
|
62 |
|
63 |
-
ask_policies = vec_factory_2.create_rag_tool(
|
64 |
-
tool_name = "ask_policies",
|
65 |
-
tool_description = """
|
66 |
-
Given a user query,
|
67 |
-
returns a response to a user question about electric vehicles incentives and regulations, in the United States.
|
68 |
-
You can ask this tool any question about laws passed by states or the federal government related to electric vehicles.
|
69 |
-
""",
|
70 |
-
tool_args_schema = QueryEVLaws,
|
71 |
-
reranker = "chain", rerank_k = 100,
|
72 |
-
rerank_chain = [
|
73 |
-
{
|
74 |
-
"type": "slingshot",
|
75 |
-
"cutoff": 0.2
|
76 |
-
},
|
77 |
-
{
|
78 |
-
"type": "mmr",
|
79 |
-
"diversity_bias": 0.1
|
80 |
-
}
|
81 |
-
],
|
82 |
-
n_sentences_before = 2, n_sentences_after = 2, lambda_val = 0.005,
|
83 |
-
summary_num_results = 10,
|
84 |
-
vectara_summarizer = summarizer,
|
85 |
-
include_citations = False,
|
86 |
-
)
|
87 |
|
88 |
-
|
|
|
|
|
89 |
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
|
102 |
def initialize_agent(_cfg, agent_progress_callback=None):
|
103 |
electric_vehicle_bot_instructions = """
|
@@ -106,7 +113,7 @@ def initialize_agent(_cfg, agent_progress_callback=None):
|
|
106 |
"""
|
107 |
|
108 |
agent = Agent(
|
109 |
-
tools=
|
110 |
topic="Electric vehicles in the United States",
|
111 |
custom_instructions=electric_vehicle_bot_instructions,
|
112 |
agent_progress_callback=agent_progress_callback
|
|
|
10 |
load_dotenv(override=True)
|
11 |
|
12 |
from vectara_agentic.agent import Agent
|
13 |
+
from vectara_agentic.agent_config import AgentConfig
|
14 |
from vectara_agentic.tools import ToolsFactory, VectaraToolFactory
|
15 |
|
16 |
+
class AgentTools:
|
17 |
+
def __init__(self, _cfg, agent_config):
|
18 |
+
self.tools_factory = ToolsFactory()
|
19 |
+
self.agent_config = agent_config
|
20 |
+
self.cfg = _cfg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
+
def get_tools(self):
|
24 |
+
class QueryElectricCars(BaseModel):
|
25 |
+
query: str = Field(description="The user query.")
|
26 |
|
27 |
+
vec_factory_1 = VectaraToolFactory(vectara_api_key=self.cfg.api_keys[0],
|
28 |
+
vectara_corpus_key=self.cfg.corpus_keys[0])
|
29 |
+
|
30 |
+
summarizer = 'vectara-experimental-summary-ext-2023-12-11-med-omni'
|
31 |
+
|
32 |
+
ask_vehicles = vec_factory_1.create_rag_tool(
|
33 |
+
tool_name = "ask_vehicles",
|
34 |
+
tool_description = """
|
35 |
+
Given a user query,
|
36 |
+
returns a response to a user question about electric vehicles.
|
37 |
+
""",
|
38 |
+
tool_args_schema = QueryElectricCars,
|
39 |
+
reranker = "chain", rerank_k = 100,
|
40 |
+
rerank_chain = [
|
41 |
+
{
|
42 |
+
"type": "slingshot",
|
43 |
+
"cutoff": 0.2
|
44 |
+
},
|
45 |
+
{
|
46 |
+
"type": "mmr",
|
47 |
+
"diversity_bias": 0.1
|
48 |
+
}
|
49 |
+
],
|
50 |
+
n_sentences_before = 2, n_sentences_after = 2, lambda_val = 0.005,
|
51 |
+
summary_num_results = 5,
|
52 |
+
vectara_summarizer = summarizer,
|
53 |
+
include_citations = False,
|
54 |
+
)
|
55 |
+
|
56 |
+
vec_factory_2 = VectaraToolFactory(vectara_api_key=self.cfg.api_keys[1],
|
57 |
+
vectara_corpus_key=self.cfg.corpus_keys[1])
|
58 |
+
|
59 |
|
60 |
+
class QueryEVLaws(BaseModel):
|
61 |
+
query: str = Field(description="The user query")
|
62 |
+
state: Optional[str] = Field(default=None,
|
63 |
+
description="The two digit state code. Optional.",
|
64 |
+
examples=['CA', 'US', 'WA'])
|
65 |
+
policy_type: Optional[str] = Field(default=None,
|
66 |
+
description="The type of policy. Optional",
|
67 |
+
examples = ['Laws and Regulations', 'State Incentives', 'Incentives', 'Utility / Private Incentives', 'Programs'])
|
68 |
+
|
69 |
+
|
70 |
+
ask_policies = vec_factory_2.create_rag_tool(
|
71 |
+
tool_name = "ask_policies",
|
72 |
+
tool_description = """
|
73 |
+
Given a user query,
|
74 |
+
returns a response to a user question about electric vehicles incentives and regulations, in the United States.
|
75 |
+
You can ask this tool any question about laws passed by states or the federal government related to electric vehicles.
|
76 |
+
""",
|
77 |
+
tool_args_schema = QueryEVLaws,
|
78 |
+
reranker = "chain", rerank_k = 100,
|
79 |
+
rerank_chain = [
|
80 |
+
{
|
81 |
+
"type": "slingshot",
|
82 |
+
"cutoff": 0.2
|
83 |
+
},
|
84 |
+
{
|
85 |
+
"type": "mmr",
|
86 |
+
"diversity_bias": 0.1
|
87 |
+
}
|
88 |
+
],
|
89 |
+
n_sentences_before = 2, n_sentences_after = 2, lambda_val = 0.005,
|
90 |
+
summary_num_results = 10,
|
91 |
+
vectara_summarizer = summarizer,
|
92 |
+
include_citations = False,
|
93 |
+
)
|
94 |
+
|
95 |
+
tools_factory = ToolsFactory()
|
96 |
+
|
97 |
+
db_tools = tools_factory.database_tools(
|
98 |
+
tool_name_prefix = "ev",
|
99 |
+
content_description = 'Electric Vehicles in the state of Washington and other population information',
|
100 |
+
sql_database = SQLDatabase(create_engine('sqlite:///ev_database.db')),
|
101 |
+
)
|
102 |
+
|
103 |
+
return (tools_factory.standard_tools() +
|
104 |
+
tools_factory.guardrail_tools() +
|
105 |
+
db_tools +
|
106 |
+
[ask_vehicles, ask_policies]
|
107 |
+
)
|
108 |
|
109 |
def initialize_agent(_cfg, agent_progress_callback=None):
|
110 |
electric_vehicle_bot_instructions = """
|
|
|
113 |
"""
|
114 |
|
115 |
agent = Agent(
|
116 |
+
tools=AgentTools(_cfg, AgentConfig()).get_tools(),
|
117 |
topic="Electric vehicles in the United States",
|
118 |
custom_instructions=electric_vehicle_bot_instructions,
|
119 |
agent_progress_callback=agent_progress_callback
|
requirements.txt
CHANGED
@@ -7,4 +7,4 @@ langdetect==1.0.9
|
|
7 |
langcodes==3.4.0
|
8 |
datasets==2.19.2
|
9 |
uuid==1.30
|
10 |
-
vectara-agentic==0.2.
|
|
|
7 |
langcodes==3.4.0
|
8 |
datasets==2.19.2
|
9 |
uuid==1.30
|
10 |
+
vectara-agentic==0.2.1
|