hf-legisqa / doc_format_mod.py
gabrielaltay's picture
deterministic sort order
b2c7bf9
from collections import defaultdict
import json
from langchain.schema import Document
import streamlit as st
import utils_mod
def group_docs(docs) -> list[tuple[str, list[Document]]]:
"""Group and sort docs.
docs are grouped by legis_id
inside a legis_id group, the docs are sorted by start_index
overall the legis_id groups are sorted by number of docs (desc)
doc_grps = [
(legis_id, start_index sorted docs), # group with the most docs
(legis_id, start_index sorted docs),
...
(legis_id, start_index sorted docs), # group with the least docs
]
"""
doc_grps = defaultdict(list)
# create legis_id groups
for doc in docs:
doc_grps[doc.metadata["legis_id"]].append(doc)
# sort docs in each group by start index
for legis_id in doc_grps.keys():
doc_grps[legis_id] = sorted(
doc_grps[legis_id],
key=lambda x: x.metadata["start_index"],
)
# sort groups by number of docs
doc_grps = sorted(
tuple(doc_grps.items()),
key=lambda x: (
-len(x[1]), # length of x[1] = number of chunks
x[0], # legis_id for deterministic sort
),
)
return doc_grps
def format_docs(docs: list[Document]) -> str:
"""JSON grouped"""
doc_grps = group_docs(docs)
out = []
for legis_id, doc_grp in doc_grps:
dd = {
"legis_id": doc_grp[0].metadata["legis_id"],
"title": doc_grp[0].metadata["title"],
"introduced_date": doc_grp[0].metadata["introduced_date"],
"sponsor": doc_grp[0].metadata["sponsor_full_name"],
"snippets": [doc.page_content for doc in doc_grp],
}
out.append(dd)
return json.dumps(out, indent=4)
def render_doc_grp(legis_id: str, doc_grp: list[Document]):
first_doc = doc_grp[0]
congress_gov_url = utils_mod.get_congress_gov_url(
first_doc.metadata["congress_num"],
first_doc.metadata["legis_type"],
first_doc.metadata["legis_num"],
)
congress_gov_link = f"[congress.gov]({congress_gov_url})"
ref = "{} chunks from {}\n\n{}\n\n{}\n\n[{} ({}) ]({})".format(
len(doc_grp),
first_doc.metadata["legis_id"],
first_doc.metadata["title"],
congress_gov_link,
first_doc.metadata["sponsor_full_name"],
first_doc.metadata["sponsor_bioguide_id"],
utils_mod.get_sponsor_url(first_doc.metadata["sponsor_bioguide_id"]),
)
doc_contents = [
"[start_index={}] ".format(int(doc.metadata["start_index"])) + doc.page_content
for doc in doc_grp
]
with st.expander(ref):
st.write(utils_mod.escape_markdown("\n\n...\n\n".join(doc_contents)))
def render_retrieved_chunks(docs: list[Document], tag: str | None = None):
with st.container(border=True):
doc_grps = group_docs(docs)
if tag is None:
st.write(
"Retrieved Chunks\n\nleft click to expand, right click to follow links"
)
else:
st.write(
f"Retrieved Chunks ({tag})\n\nleft click to expand, right click to follow links"
)
for legis_id, doc_grp in doc_grps:
render_doc_grp(legis_id, doc_grp)