david-oplatka commited on
Commit
8dcd782
·
1 Parent(s): 9af859a

Add Metadata Filters

Browse files
Files changed (1) hide show
  1. agent.py +34 -0
agent.py CHANGED
@@ -16,6 +16,16 @@ def create_assistant_tools(cfg):
16
 
17
  class QueryCFPBComplaints(BaseModel):
18
  query: str = Field(description="The user query.")
 
 
 
 
 
 
 
 
 
 
19
 
20
  vec_factory = VectaraToolFactory(
21
  vectara_api_key=cfg.api_keys,
@@ -39,6 +49,30 @@ def create_assistant_tools(cfg):
39
  include_citations = True,
40
  )
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  tools_factory = ToolsFactory()
43
 
44
  db_tools = tools_factory.database_tools(
 
16
 
17
  class QueryCFPBComplaints(BaseModel):
18
  query: str = Field(description="The user query.")
19
+ Company: Optional[str] = Field(
20
+ default=None,
21
+ description="The company that the complaint is about.",
22
+ examples=['CAPITAL ONE FINANCIAL CORPORATION', 'BANK OF AMERICA, NATIONAL ASSOCIATION', 'CITIBANK, N.A.', 'WELLS FARGO & COMPANY', 'JPMORGAN CHASE & CO.']
23
+ )
24
+ State: Optional[str] = Field(
25
+ default=None,
26
+ descripition="The two-character state code where the consumer lives.",
27
+ examples=['CA', 'FL', 'NY', 'TX', 'GA']
28
+ )
29
 
30
  vec_factory = VectaraToolFactory(
31
  vectara_api_key=cfg.api_keys,
 
49
  include_citations = True,
50
  )
51
 
52
+ # ask_complaints = vec_factory.create_rag_tool(
53
+ # tool_name = "ask_complaints",
54
+ # tool_description = """
55
+ # Given a user query,
56
+ # returns a response to a user question about customer complaints for bank services.
57
+ # """,
58
+ # tool_args_schema = QueryCFPBComplaints,
59
+ # reranker = "chain", rerank_k = 100,
60
+ # rerank_chain = [
61
+ # {
62
+ # "type": "slingshot",
63
+ # "cutoff": 0.2
64
+ # },
65
+ # {
66
+ # "type": "mmr",
67
+ # "diversity_bias": 0.4,
68
+ # "limit": 30
69
+ # }
70
+ # ],
71
+ # n_sentences_before = 2, n_sentences_after = 2, lambda_val = 0.005,
72
+ # vectara_summarizer = summarizer,
73
+ # include_citations = True,
74
+ # )
75
+
76
  tools_factory = ToolsFactory()
77
 
78
  db_tools = tools_factory.database_tools(