Spaces:
Sleeping
Sleeping
import math | |
import altair as alt | |
from haystack import Document | |
from haystack.document_stores import InMemoryDocumentStore, ElasticsearchDocumentStore, FAISSDocumentStore | |
from haystack.nodes import BM25Retriever | |
from haystack.pipelines import DocumentSearchPipeline | |
import pandas as pd | |
import streamlit as st | |
def load_data(file): | |
df = pd.read_csv(file, sep="\t", lineterminator='\n') | |
# rearrange column order | |
col_list = ['Agency', 'Name of Inventory Item', | |
'Primary Type of AI', | |
'Purpose of AI', 'Length of Usage', | |
'Does it directly impact the public?', | |
'Vendor System', | |
'Description of Inventory Item', | |
'Other Notes\r'] | |
df = df[col_list] | |
# remove trailing \r from 'Other Notes' header | |
df = df.rename(columns = {'Other Notes\r' : 'Other Notes'}) | |
# remove trailing spaces from agency names (caused duplicate instance of "DOC") | |
df['Agency'] = df['Agency'].apply(lambda x : x.rstrip()) | |
return df | |
def plot_impact_by_agency(df): | |
df = df.copy() | |
df = df.rename(columns={'Does it directly impact the public?': 'Impact on public'}) | |
impact_counts = df.groupby('Agency')['Impact on public'].value_counts() | |
impact_counts = impact_counts.sort_index(level="Agency", ascending=False) | |
impact_count_df = pd.DataFrame(impact_counts).rename(columns={'Impact on public' : "Count"}).reset_index() | |
domain = ['Direct impact', 'Indirect impact', 'No impact'] | |
range_ = ['red', 'darkorange', 'steelblue'] | |
chart = (alt.Chart(impact_count_df).mark_bar(align="right").encode( | |
x=alt.X("count", type="quantitative", title="Number of entries", axis=alt.Axis()), | |
y=alt.Y("Agency", type="nominal", title="Agency", axis=alt.Axis(labelLimit=300, labelFlushOffset=5)), | |
color=alt.Color("Impact on public", scale=alt.Scale(domain=domain, range=range_)), | |
) | |
) | |
return chart | |
def plot_count_by_category(column, df): | |
table = pd.DataFrame(df[column].value_counts().sort_values(ascending=True)).reset_index() | |
table.columns = [column, "Count"] | |
chart = ( | |
alt.Chart(table) | |
.mark_bar(align="right") | |
.encode( | |
x=alt.X("Count", type="quantitative", title="Number of entries", axis=alt.Axis()), | |
y=alt.Y(column, type="nominal", title="", axis=alt.Axis(labelLimit=300, labelFlushOffset=5), | |
sort="-x")) | |
) | |
return chart | |
def filter_table(df, choices): | |
""" | |
Function to filter table based on user choices in dropdown menus. | |
choices: dict with column as key, list of selected values as value | |
e.g. {"Agency" : ["USDA", "USDOC"] | |
""" | |
for column in choices: | |
desired_values = choices[column] | |
if "Select all" not in desired_values: | |
df = df[df[column].isin(desired_values)] | |
return df | |
def create_search_pipeline(df, col_list): | |
document_store = InMemoryDocumentStore(use_bm25=True) | |
docs = [] | |
indices = list(df.index.values) | |
for col in col_list: | |
values = df[col].tolist() | |
assert len(indices) == len(values) | |
for i, val in zip(indices, values): | |
dictionary = {'content' : val, | |
'meta' : {"index": i, "column_header" : col} | |
} | |
docs.append(Document.from_dict(dictionary)) | |
document_store.write_documents(docs) | |
retriever = BM25Retriever(document_store=document_store) | |
pipeline = DocumentSearchPipeline(retriever) | |
return pipeline | |
def run_search(text, _pipeline): | |
if text == "": | |
return None | |
res = pipeline.run(query=text, params={"Retriever": {"top_k": 10}}) | |
relevant_results = [r for r in res['documents'] if r.score > 0.5] | |
result_rows = [doc.meta['index'] for doc in relevant_results] | |
result_cols = [doc.meta['column_header'] for doc in relevant_results] | |
return (result_rows, result_cols) | |
def produce_table(df, table_indices): | |
if not table_indices: | |
return None | |
result_df = df.iloc[table_indices[0], :] | |
result_df = result_df.drop_duplicates() | |
# highlight the cells found in search | |
color_df = result_df.copy() | |
color_df.loc[:,:] = '' | |
for row, col in zip(table_indices[0], table_indices[1]): | |
color_df.loc[row, col] = 'background-color: yellow' | |
return result_df, color_df | |
def convert_df(df): | |
return df.to_csv(sep="\t", index=False).encode('utf-8') | |
if __name__ == "__main__": | |
input_file = "Agency Inventory AI Usage - Sheet1.tsv" | |
st.markdown("# U.S. Federal Government Use of AI") | |
main_text = """ | |
The data visualized here come from a report by Anna Blue, a 2023 Social Impact Fellow | |
at the [Responsible AI Institute](https://www.responsible.ai). The report was released in May 2023. Some agencies have | |
released updated inventories since then, which are not reflected here. | |
Anna's report consolidated and annotated data released by individual government agencies in compliance with | |
Executive Order 13960, which requires federal agencies to produce an annual inventory of their AI usage. | |
See her [blog post](https://www.responsible.ai/post/federal-government-ai-use-cases) for additional details, | |
including links to the original data sources. | |
""" | |
st.markdown(main_text) | |
df = load_data(input_file) | |
# Plot stacked bar chart of impact on the public by agency | |
st.subheader("Impact of systems on the public, by agency") | |
stacked_bar_chart = plot_impact_by_agency(df) | |
st.altair_chart(stacked_bar_chart, use_container_width=True) | |
# Plot counts by category, allowing user to select category | |
st.subheader("Number of entries by category") | |
no_filter_cols = ['Name of Inventory Item', 'Description of Inventory Item', "Other Notes"] | |
filter_cols = [c for c in df.columns.unique() if c not in no_filter_cols] | |
column = st.selectbox("Choose what to plot", filter_cols) | |
count_chart = plot_count_by_category(column, df) | |
st.altair_chart(count_chart, use_container_width=True) | |
# Table with filters for user browsing | |
st.subheader("Explore the entries") | |
st.write("Use the menus to filter the table. You can download the filtered table below.") | |
filter_names = ["Agency", "Primary Type of AI", "Purpose of AI", "Length of Usage", | |
"Does it directly impact the public?", "Vendor System"] | |
c1, c2 = st.columns((1, 1)) | |
filter_dict = {} | |
with c1: | |
for filter_name in filter_names[:3]: | |
menu = st.multiselect(filter_name, ["Select all"] + list(df[filter_name].unique()), default="Select all") | |
filter_dict[filter_name] = menu | |
with c2: | |
for filter_name in filter_names[3:]: | |
menu = st.multiselect(filter_name, ["Select all"] + list(df[filter_name].unique()), default="Select all") | |
filter_dict[filter_name] = menu | |
filtered_df = filter_table(df, filter_dict) | |
st.write(filtered_df) | |
# Download filtered table | |
st.write("Download current table as TSV (tab-separated values) file") | |
table_output_file = st.text_input("Enter name for the file to be downloaded", value="table.tsv") | |
if table_output_file is not None: | |
csv = convert_df(filtered_df) | |
st.download_button("Download", csv, file_name=table_output_file) | |
# Text search | |
st.subheader("Search the data") | |
st.markdown(""" | |
This will search text in the following columns: | |
* Name of Inventory Item | |
* Primary Type of AI | |
* Purpose of AI | |
* Description of Inventory Item | |
* Other Notes | |
This is a keyword search based on the BM25 algorithm. | |
Yellow highlighting indicates text retrieved in the search. | |
A download button will appear after you run a search. | |
""") | |
searchable_cols = ['Name of Inventory Item', | |
'Primary Type of AI', | |
'Purpose of AI', | |
'Description of Inventory Item', | |
'Other Notes'] | |
pipeline = create_search_pipeline(df, searchable_cols) | |
input_text = st.text_input("Enter text", "") | |
if input_text: | |
result_rows, result_cols = run_search(input_text, pipeline) | |
result_df, color_df = produce_table(df, (result_rows, result_cols)) | |
st.dataframe(result_df.style.apply(lambda x: color_df, axis=None)) | |
st.write("Download search results as TSV (tab-separated values) file") | |
search_output_file = st.text_input("Enter name for the file to be downloaded", value="search_results.tsv") | |
csv = convert_df(result_df) #TODO: change to search results | |
st.download_button("Download", csv, file_name=search_output_file) | |