carolanderson's picture
update config
7106162
import math
import altair as alt
from haystack import Document
from haystack.document_stores import InMemoryDocumentStore, ElasticsearchDocumentStore, FAISSDocumentStore
from haystack.nodes import BM25Retriever
from haystack.pipelines import DocumentSearchPipeline
import pandas as pd
import streamlit as st
@st.cache_data
def load_data(file):
df = pd.read_csv(file, sep="\t", lineterminator='\n')
# rearrange column order
col_list = ['Agency', 'Name of Inventory Item',
'Primary Type of AI',
'Purpose of AI', 'Length of Usage',
'Does it directly impact the public?',
'Vendor System',
'Description of Inventory Item',
'Other Notes\r']
df = df[col_list]
# remove trailing \r from 'Other Notes' header
df = df.rename(columns = {'Other Notes\r' : 'Other Notes'})
# remove trailing spaces from agency names (caused duplicate instance of "DOC")
df['Agency'] = df['Agency'].apply(lambda x : x.rstrip())
return df
@st.cache_data
def plot_impact_by_agency(df):
df = df.copy()
df = df.rename(columns={'Does it directly impact the public?': 'Impact on public'})
impact_counts = df.groupby('Agency')['Impact on public'].value_counts()
impact_counts = impact_counts.sort_index(level="Agency", ascending=False)
impact_count_df = pd.DataFrame(impact_counts).rename(columns={'Impact on public' : "Count"}).reset_index()
domain = ['Direct impact', 'Indirect impact', 'No impact']
range_ = ['red', 'darkorange', 'steelblue']
chart = (alt.Chart(impact_count_df).mark_bar(align="right").encode(
x=alt.X("count", type="quantitative", title="Number of entries", axis=alt.Axis()),
y=alt.Y("Agency", type="nominal", title="Agency", axis=alt.Axis(labelLimit=300, labelFlushOffset=5)),
color=alt.Color("Impact on public", scale=alt.Scale(domain=domain, range=range_)),
)
)
return chart
@st.cache_data
def plot_count_by_category(column, df):
table = pd.DataFrame(df[column].value_counts().sort_values(ascending=True)).reset_index()
table.columns = [column, "Count"]
chart = (
alt.Chart(table)
.mark_bar(align="right")
.encode(
x=alt.X("Count", type="quantitative", title="Number of entries", axis=alt.Axis()),
y=alt.Y(column, type="nominal", title="", axis=alt.Axis(labelLimit=300, labelFlushOffset=5),
sort="-x"))
)
return chart
@st.cache_data
def filter_table(df, choices):
"""
Function to filter table based on user choices in dropdown menus.
choices: dict with column as key, list of selected values as value
e.g. {"Agency" : ["USDA", "USDOC"]
"""
for column in choices:
desired_values = choices[column]
if "Select all" not in desired_values:
df = df[df[column].isin(desired_values)]
return df
@st.cache_data
def create_search_pipeline(df, col_list):
document_store = InMemoryDocumentStore(use_bm25=True)
docs = []
indices = list(df.index.values)
for col in col_list:
values = df[col].tolist()
assert len(indices) == len(values)
for i, val in zip(indices, values):
dictionary = {'content' : val,
'meta' : {"index": i, "column_header" : col}
}
docs.append(Document.from_dict(dictionary))
document_store.write_documents(docs)
retriever = BM25Retriever(document_store=document_store)
pipeline = DocumentSearchPipeline(retriever)
return pipeline
@st.cache_data
def run_search(text, _pipeline):
if text == "":
return None
res = pipeline.run(query=text, params={"Retriever": {"top_k": 10}})
relevant_results = [r for r in res['documents'] if r.score > 0.5]
result_rows = [doc.meta['index'] for doc in relevant_results]
result_cols = [doc.meta['column_header'] for doc in relevant_results]
return (result_rows, result_cols)
@st.cache_data
def produce_table(df, table_indices):
if not table_indices:
return None
result_df = df.iloc[table_indices[0], :]
result_df = result_df.drop_duplicates()
# highlight the cells found in search
color_df = result_df.copy()
color_df.loc[:,:] = ''
for row, col in zip(table_indices[0], table_indices[1]):
color_df.loc[row, col] = 'background-color: yellow'
return result_df, color_df
@st.cache_data
def convert_df(df):
return df.to_csv(sep="\t", index=False).encode('utf-8')
if __name__ == "__main__":
input_file = "Agency Inventory AI Usage - Sheet1.tsv"
st.markdown("# U.S. Federal Government Use of AI")
main_text = """
The data visualized here come from a report by Anna Blue, a 2023 Social Impact Fellow
at the [Responsible AI Institute](https://www.responsible.ai). The report was released in May 2023. Some agencies have
released updated inventories since then, which are not reflected here.
Anna's report consolidated and annotated data released by individual government agencies in compliance with
Executive Order 13960, which requires federal agencies to produce an annual inventory of their AI usage.
See her [blog post](https://www.responsible.ai/post/federal-government-ai-use-cases) for additional details,
including links to the original data sources.
"""
st.markdown(main_text)
df = load_data(input_file)
# Plot stacked bar chart of impact on the public by agency
st.subheader("Impact of systems on the public, by agency")
stacked_bar_chart = plot_impact_by_agency(df)
st.altair_chart(stacked_bar_chart, use_container_width=True)
# Plot counts by category, allowing user to select category
st.subheader("Number of entries by category")
no_filter_cols = ['Name of Inventory Item', 'Description of Inventory Item', "Other Notes"]
filter_cols = [c for c in df.columns.unique() if c not in no_filter_cols]
column = st.selectbox("Choose what to plot", filter_cols)
count_chart = plot_count_by_category(column, df)
st.altair_chart(count_chart, use_container_width=True)
# Table with filters for user browsing
st.subheader("Explore the entries")
st.write("Use the menus to filter the table. You can download the filtered table below.")
filter_names = ["Agency", "Primary Type of AI", "Purpose of AI", "Length of Usage",
"Does it directly impact the public?", "Vendor System"]
c1, c2 = st.columns((1, 1))
filter_dict = {}
with c1:
for filter_name in filter_names[:3]:
menu = st.multiselect(filter_name, ["Select all"] + list(df[filter_name].unique()), default="Select all")
filter_dict[filter_name] = menu
with c2:
for filter_name in filter_names[3:]:
menu = st.multiselect(filter_name, ["Select all"] + list(df[filter_name].unique()), default="Select all")
filter_dict[filter_name] = menu
filtered_df = filter_table(df, filter_dict)
st.write(filtered_df)
# Download filtered table
st.write("Download current table as TSV (tab-separated values) file")
table_output_file = st.text_input("Enter name for the file to be downloaded", value="table.tsv")
if table_output_file is not None:
csv = convert_df(filtered_df)
st.download_button("Download", csv, file_name=table_output_file)
# Text search
st.subheader("Search the data")
st.markdown("""
This will search text in the following columns:
* Name of Inventory Item
* Primary Type of AI
* Purpose of AI
* Description of Inventory Item
* Other Notes
This is a keyword search based on the BM25 algorithm.
Yellow highlighting indicates text retrieved in the search.
A download button will appear after you run a search.
""")
searchable_cols = ['Name of Inventory Item',
'Primary Type of AI',
'Purpose of AI',
'Description of Inventory Item',
'Other Notes']
pipeline = create_search_pipeline(df, searchable_cols)
input_text = st.text_input("Enter text", "")
if input_text:
result_rows, result_cols = run_search(input_text, pipeline)
result_df, color_df = produce_table(df, (result_rows, result_cols))
st.dataframe(result_df.style.apply(lambda x: color_df, axis=None))
st.write("Download search results as TSV (tab-separated values) file")
search_output_file = st.text_input("Enter name for the file to be downloaded", value="search_results.tsv")
csv = convert_df(result_df) #TODO: change to search results
st.download_button("Download", csv, file_name=search_output_file)