gabrielaltay commited on
Commit
80275c5
1 Parent(s): 7632f66
Files changed (2) hide show
  1. app.py +307 -8
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,13 +1,23 @@
 
 
 
 
 
 
 
 
1
  from langchain_community.embeddings import HuggingFaceBgeEmbeddings
2
- from langchain_community.vectorstores import Pinecone
3
  from langchain_community.vectorstores.utils import DistanceStrategy
4
  from langchain_openai import ChatOpenAI
5
- from pinecone import Pinecone as PineconeClient
 
6
  import streamlit as st
7
 
8
- st.set_page_config(layout="wide", page_title="LegisQA")
9
 
 
 
10
 
 
11
  CONGRESS_GOV_TYPE_MAP = {
12
  "hconres": "house-concurrent-resolution",
13
  "hjres": "house-joint-resolution",
@@ -25,6 +35,48 @@ OPENAI_CHAT_MODELS = [
25
  ]
26
 
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  def load_bge_embeddings():
29
  model_name = "BAAI/bge-small-en-v1.5"
30
  model_kwargs = {"device": "cpu"}
@@ -40,9 +92,9 @@ def load_bge_embeddings():
40
 
41
  def load_pinecone_vectorstore():
42
  emb_fn = load_bge_embeddings()
43
- pc = PineconeClient(api_key=st.secrets["pinecone_api_key"])
44
  index = pc.Index(st.secrets["pinecone_index_name"])
45
- vectorstore = Pinecone(
46
  index=index,
47
  embedding=emb_fn,
48
  text_key="text",
@@ -51,8 +103,255 @@ def load_pinecone_vectorstore():
51
  return vectorstore
52
 
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  vectorstore = load_pinecone_vectorstore()
56
- query = st.text_area("Enter query")
57
- docs = vectorstore.similarity_search_with_score(query)
58
- st.write(docs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ import json
3
+
4
+ from langchain_core.documents import Document
5
+ from langchain_core.prompts import PromptTemplate
6
+ from langchain_core.runnables import RunnableParallel
7
+ from langchain_core.runnables import RunnablePassthrough
8
+ from langchain_core.output_parsers import StrOutputParser
9
  from langchain_community.embeddings import HuggingFaceBgeEmbeddings
 
10
  from langchain_community.vectorstores.utils import DistanceStrategy
11
  from langchain_openai import ChatOpenAI
12
+ from langchain_pinecone import PineconeVectorStore
13
+ from pinecone import Pinecone
14
  import streamlit as st
15
 
 
16
 
17
+ st.set_page_config(layout="wide", page_title="LegisQA")
18
+ SS = st.session_state
19
 
20
+ SEED = 292764
21
  CONGRESS_GOV_TYPE_MAP = {
22
  "hconres": "house-concurrent-resolution",
23
  "hjres": "house-joint-resolution",
 
35
  ]
36
 
37
 
38
+ PREAMBLE = "You are an expert analyst. Use the following excerpts from US congressional legislation to respond to the user's query."
39
+ PROMPT_TEMPLATES = {
40
+ "v1": PREAMBLE
41
+ + """ If you don't know how to respond, just tell the user.
42
+
43
+ {context}
44
+
45
+ Question: {question}""",
46
+ "v2": PREAMBLE
47
+ + """ Each snippet starts with a header that includes a unique snippet number (snippet_num), a legis_id, and a title. Your response should reference particular snippets using legis_id and title. If you don't know how to respond, just tell the user.
48
+
49
+ {context}
50
+
51
+ Question: {question}""",
52
+ "v3": PREAMBLE
53
+ + """ Each excerpt starts with a header that includes a legis_id, and a title followed by one or more text snippets. When using text snippets in your response, you should mention the legis_id and title. If you don't know how to respond, just tell the user.
54
+
55
+ {context}
56
+
57
+ Question: {question}""",
58
+ "v4": PREAMBLE
59
+ + """ The excerpts are formatted as a JSON list. Each JSON object has "legis_id", "title", and "snippets" keys. If a snippet is useful in writing part of your response, then mention the "title" and "legis_id" inline as you write. If you don't know how to respond, just tell the user.
60
+
61
+ {context}
62
+
63
+ Query: {question}""",
64
+ }
65
+
66
+
67
+ def get_sponsor_url(bioguide_id: str) -> str:
68
+ return f"https://bioguide.congress.gov/search/bio/{bioguide_id}"
69
+
70
+
71
+ def get_congress_gov_url(congress_num: int, legis_type: str, legis_num: int) -> str:
72
+ lt = CONGRESS_GOV_TYPE_MAP[legis_type]
73
+ return f"https://www.congress.gov/bill/{int(congress_num)}th-congress/{lt}/{int(legis_num)}"
74
+
75
+
76
+ def get_govtrack_url(congress_num: int, legis_type: str, legis_num: int) -> str:
77
+ return f"https://www.govtrack.us/congress/bills/{int(congress_num)}/{legis_type}{int(legis_num)}"
78
+
79
+
80
  def load_bge_embeddings():
81
  model_name = "BAAI/bge-small-en-v1.5"
82
  model_kwargs = {"device": "cpu"}
 
92
 
93
  def load_pinecone_vectorstore():
94
  emb_fn = load_bge_embeddings()
95
+ pc = Pinecone(api_key=st.secrets["pinecone_api_key"])
96
  index = pc.Index(st.secrets["pinecone_index_name"])
97
+ vectorstore = PineconeVectorStore(
98
  index=index,
99
  embedding=emb_fn,
100
  text_key="text",
 
103
  return vectorstore
104
 
105
 
106
+ def write_outreach_links():
107
+ nomic_base_url = "https://atlas.nomic.ai/data/gabrielhyperdemocracy"
108
+ nomic_map_name = "us-congressional-legislation-s1024o256nomic"
109
+ nomic_url = f"{nomic_base_url}/{nomic_map_name}/map"
110
+ hf_url = "https://huggingface.co/hyperdemocracy"
111
+ st.subheader(":brain: Learn about [hyperdemocracy](https://hyperdemocracy.us)")
112
+ st.subheader(f":world_map: Visualize with [nomic atlas]({nomic_url})")
113
+ st.subheader(f":hugging_face: Explore the [huggingface datasets](hf_url)")
114
+
115
+
116
+ def group_docs(docs) -> list[tuple[str, list[Document]]]:
117
+ doc_grps = defaultdict(list)
118
+
119
+ # create legis_id groups
120
+ for doc in docs:
121
+ doc_grps[doc.metadata["legis_id"]].append(doc)
122
+
123
+ # sort docs in each group by start index
124
+ for legis_id in doc_grps.keys():
125
+ doc_grps[legis_id] = sorted(
126
+ doc_grps[legis_id],
127
+ key=lambda x: x.metadata["start_index"],
128
+ )
129
+
130
+ # sort groups by number of docs
131
+ doc_grps = sorted(
132
+ tuple(doc_grps.items()),
133
+ key=lambda x: -len(x[1]),
134
+ )
135
+
136
+ return doc_grps
137
+
138
+
139
+ def format_docs_v1(docs):
140
+ """Simple double new line join"""
141
+ return "\n\n".join([doc.page_content for doc in docs])
142
+
143
+
144
+ def format_docs_v2(docs):
145
+ """Format with snippet_num, legis_id, and title"""
146
+
147
+ def format_doc(idoc, doc):
148
+ return "snippet_num: {}\nlegis_id: {}\ntitle: {}\n... {} ...\n".format(
149
+ idoc,
150
+ doc.metadata["legis_id"],
151
+ doc.metadata["title"],
152
+ doc.page_content,
153
+ )
154
+
155
+ snips = []
156
+ for idoc, doc in enumerate(docs):
157
+ txt = format_doc(idoc, doc)
158
+ snips.append(txt)
159
+
160
+ return "\n===\n".join(snips)
161
+
162
+
163
+ def format_docs_v3(docs):
164
+
165
+ def format_header(doc):
166
+ return "legis_id: {}\ntitle: {}".format(
167
+ doc.metadata["legis_id"],
168
+ doc.metadata["title"],
169
+ )
170
+
171
+ def format_content(doc):
172
+ return "... {} ...\n".format(
173
+ doc.page_content,
174
+ )
175
+
176
+ snips = []
177
+ doc_grps = group_docs(docs)
178
+ for legis_id, doc_grp in doc_grps:
179
+ first_doc = doc_grp[0]
180
+ head = format_header(first_doc)
181
+ contents = []
182
+ for idoc, doc in enumerate(doc_grp):
183
+ txt = format_content(doc)
184
+ contents.append(txt)
185
+ snips.append("{}\n\n{}".format(head, "\n".join(contents)))
186
+
187
+ return "\n===\n".join(snips)
188
+
189
+
190
+ def format_docs_v4(docs):
191
+ """JSON grouped"""
192
+
193
+ doc_grps = group_docs(docs)
194
+ out = []
195
+ for legis_id, doc_grp in doc_grps:
196
+ dd = {
197
+ "legis_id": doc_grp[0].metadata["legis_id"],
198
+ "title": doc_grp[0].metadata["title"],
199
+ "snippets": [doc.page_content for doc in doc_grp],
200
+ }
201
+ out.append(dd)
202
+ return json.dumps(out, indent=4)
203
+
204
+
205
+ DOC_FORMATTERS = {
206
+ "v1": format_docs_v1,
207
+ "v2": format_docs_v2,
208
+ "v3": format_docs_v3,
209
+ "v4": format_docs_v4,
210
+ }
211
+
212
+
213
+ def escape_markdown(text):
214
+ MD_SPECIAL_CHARS = r"\`*_{}[]()#+-.!$"
215
+ for char in MD_SPECIAL_CHARS:
216
+ text = text.replace(char, "\\" + char)
217
+ return text
218
+
219
+
220
+ with st.sidebar:
221
+
222
+ with st.container(border=True):
223
+ write_outreach_links()
224
+
225
+ st.checkbox("escape markdown in answer", key="response_escape_markdown")
226
+
227
+ with st.expander("Generative Config"):
228
+ st.selectbox(label="model name", options=OPENAI_CHAT_MODELS, key="model_name")
229
+ st.slider(
230
+ "temperature", min_value=0.0, max_value=2.0, value=0.0, key="temperature"
231
+ )
232
+ st.slider("top_p", min_value=0.0, max_value=1.0, value=1.0, key="top_p")
233
+
234
+ with st.expander("Retrieval Config"):
235
+ st.slider(
236
+ "Number of chunks to retrieve",
237
+ min_value=1,
238
+ max_value=40,
239
+ value=10,
240
+ key="n_ret_docs",
241
+ )
242
+ st.text_input("Bill ID (e.g. 118-s-2293)", key="filter_legis_id")
243
+ st.text_input("Bioguide ID (e.g. R000595)", key="filter_bioguide_id")
244
+ st.text_input("Congress (e.g. 118)", key="filter_congress_num")
245
+
246
+ with st.expander("Prompt Config"):
247
+ st.selectbox(
248
+ label="prompt version",
249
+ options=PROMPT_TEMPLATES.keys(),
250
+ index=3,
251
+ key="prompt_version",
252
+ )
253
+ st.text_area(
254
+ "prompt template",
255
+ PROMPT_TEMPLATES[SS["prompt_version"]],
256
+ height=300,
257
+ key="prompt_template",
258
+ )
259
+
260
+
261
+ llm = ChatOpenAI(
262
+ model_name=SS["model_name"],
263
+ temperature=SS["temperature"],
264
+ openai_api_key=st.secrets["openai_api_key"],
265
+ model_kwargs={"top_p": SS["top_p"], "seed": SEED},
266
+ )
267
 
268
  vectorstore = load_pinecone_vectorstore()
269
+ format_docs = DOC_FORMATTERS[SS["prompt_version"]]
270
+
271
+ with st.form("my_form"):
272
+ st.text_area("Enter question:", key="query")
273
+ query_submitted = st.form_submit_button("Submit")
274
+
275
+
276
+ def get_vectorstore_filter():
277
+ vs_filter = {}
278
+ if SS["filter_legis_id"] != "":
279
+ vs_filter["legis_id"] = SS["filter_legis_id"]
280
+ if SS["filter_bioguide_id"] != "":
281
+ vs_filter["sponsor_bioguide_id"] = SS["filter_bioguide_id"]
282
+ if SS["filter_congress_num"] != "":
283
+ vs_filter["congress_num"] = int(SS["filter_congress_num"])
284
+ return vs_filter
285
+
286
+
287
+ if query_submitted:
288
+
289
+ vs_filter = get_vectorstore_filter()
290
+ retriever = vectorstore.as_retriever(
291
+ search_kwargs={"k": SS["n_ret_docs"], "filter": vs_filter},
292
+ )
293
+ prompt = PromptTemplate.from_template(SS["prompt_template"])
294
+ rag_chain_from_docs = (
295
+ RunnablePassthrough.assign(context=(lambda x: format_docs(x["context"])))
296
+ | prompt
297
+ | llm
298
+ | StrOutputParser()
299
+ )
300
+ rag_chain_with_source = RunnableParallel(
301
+ {"context": retriever, "question": RunnablePassthrough()}
302
+ ).assign(answer=rag_chain_from_docs)
303
+ out = rag_chain_with_source.invoke(SS["query"])
304
+ SS["out"] = out
305
+
306
+
307
+ def write_doc_grp(legis_id: str, doc_grp: list[Document]):
308
+ first_doc = doc_grp[0]
309
+
310
+ congress_gov_url = get_congress_gov_url(
311
+ first_doc.metadata["congress_num"],
312
+ first_doc.metadata["legis_type"],
313
+ first_doc.metadata["legis_num"],
314
+ )
315
+ congress_gov_link = f"[congress.gov]({congress_gov_url})"
316
+
317
+ gov_track_url = get_govtrack_url(
318
+ first_doc.metadata["congress_num"],
319
+ first_doc.metadata["legis_type"],
320
+ first_doc.metadata["legis_num"],
321
+ )
322
+ gov_track_link = f"[govtrack.us]({gov_track_url})"
323
+
324
+ ref = "{} chunks from {}\n\n{}\n\n{} | {}\n\n[{} ({}) ]({})".format(
325
+ len(doc_grp),
326
+ first_doc.metadata["legis_id"],
327
+ first_doc.metadata["title"],
328
+ congress_gov_link,
329
+ gov_track_link,
330
+ first_doc.metadata["sponsor_full_name"],
331
+ first_doc.metadata["sponsor_bioguide_id"],
332
+ get_sponsor_url(first_doc.metadata["sponsor_bioguide_id"]),
333
+ )
334
+ doc_contents = [
335
+ "[start_index={}] ".format(int(doc.metadata["start_index"])) + doc.page_content
336
+ for doc in doc_grp
337
+ ]
338
+ with st.expander(ref):
339
+ st.write(escape_markdown("\n\n...\n\n".join(doc_contents)))
340
+
341
+
342
+ out = SS.get("out")
343
+ if out:
344
+
345
+ if SS["response_escape_markdown"]:
346
+ st.info(escape_markdown(out["answer"]))
347
+ else:
348
+ st.info(out["answer"])
349
+
350
+ doc_grps = group_docs(out["context"])
351
+ for legis_id, doc_grp in doc_grps:
352
+ write_doc_grp(legis_id, doc_grp)
353
+
354
+ with st.expander("Debug doc format"):
355
+
356
+ st.text_area("formatted docs", value=format_docs(out["context"]), height=600)
357
+ # st.write(json.loads(format_docs(out["context"])))
requirements.txt CHANGED
@@ -33,6 +33,7 @@ jsonschema-specifications==2023.12.1
33
  langchain-community==0.0.24
34
  langchain-core==0.1.26
35
  langchain-openai==0.0.7
 
36
  langsmith==0.1.7
37
  markdown-it-py==3.0.0
38
  MarkupSafe==2.1.5
 
33
  langchain-community==0.0.24
34
  langchain-core==0.1.26
35
  langchain-openai==0.0.7
36
+ langchain-pinecone==0.0.3
37
  langsmith==0.1.7
38
  markdown-it-py==3.0.0
39
  MarkupSafe==2.1.5