from collections import defaultdict import json from langchain.schema import Document import streamlit as st import utils_mod def group_docs(docs) -> list[tuple[str, list[Document]]]: """Group and sort docs. docs are grouped by legis_id inside a legis_id group, the docs are sorted by start_index overall the legis_id groups are sorted by number of docs (desc) doc_grps = [ (legis_id, start_index sorted docs), # group with the most docs (legis_id, start_index sorted docs), ... (legis_id, start_index sorted docs), # group with the least docs ] """ doc_grps = defaultdict(list) # create legis_id groups for doc in docs: doc_grps[doc.metadata["legis_id"]].append(doc) # sort docs in each group by start index for legis_id in doc_grps.keys(): doc_grps[legis_id] = sorted( doc_grps[legis_id], key=lambda x: x.metadata["start_index"], ) # sort groups by number of docs doc_grps = sorted( tuple(doc_grps.items()), key=lambda x: ( -len(x[1]), # length of x[1] = number of chunks x[0], # legis_id for deterministic sort ), ) return doc_grps def format_docs(docs: list[Document]) -> str: """JSON grouped""" doc_grps = group_docs(docs) out = [] for legis_id, doc_grp in doc_grps: dd = { "legis_id": doc_grp[0].metadata["legis_id"], "title": doc_grp[0].metadata["title"], "introduced_date": doc_grp[0].metadata["introduced_date"], "sponsor": doc_grp[0].metadata["sponsor_full_name"], "snippets": [doc.page_content for doc in doc_grp], } out.append(dd) return json.dumps(out, indent=4) def render_doc_grp(legis_id: str, doc_grp: list[Document]): first_doc = doc_grp[0] congress_gov_url = utils_mod.get_congress_gov_url( first_doc.metadata["congress_num"], first_doc.metadata["legis_type"], first_doc.metadata["legis_num"], ) congress_gov_link = f"[congress.gov]({congress_gov_url})" ref = "{} chunks from {}\n\n{}\n\n{}\n\n[{} ({}) ]({})".format( len(doc_grp), first_doc.metadata["legis_id"], first_doc.metadata["title"], congress_gov_link, first_doc.metadata["sponsor_full_name"], first_doc.metadata["sponsor_bioguide_id"], utils_mod.get_sponsor_url(first_doc.metadata["sponsor_bioguide_id"]), ) doc_contents = [ "[start_index={}] ".format(int(doc.metadata["start_index"])) + doc.page_content for doc in doc_grp ] with st.expander(ref): st.write(utils_mod.escape_markdown("\n\n...\n\n".join(doc_contents))) def render_retrieved_chunks(docs: list[Document], tag: str | None = None): with st.container(border=True): doc_grps = group_docs(docs) if tag is None: st.write( "Retrieved Chunks\n\nleft click to expand, right click to follow links" ) else: st.write( f"Retrieved Chunks ({tag})\n\nleft click to expand, right click to follow links" ) for legis_id, doc_grp in doc_grps: render_doc_grp(legis_id, doc_grp)