Spaces:
Sleeping
Sleeping
import json | |
import streamlit as st | |
import os | |
from huggingface_hub import HfApi, login | |
from streamlit_cookies_manager import EncryptedCookieManager | |
import re | |
st.set_page_config( | |
page_title="Holistic AI - LLM Risks", | |
page_icon="👋", | |
layout='wide' | |
) | |
def program(): | |
dataset_name = "holistic-ai/LLM-Risks" | |
token = os.getenv("HF_TOKEN") | |
api = HfApi() | |
login(token) | |
repo_path = api.snapshot_download(repo_id=dataset_name, repo_type="dataset") | |
with open(f'{repo_path}/risk_annotation_consolidated.json') as file: | |
data = json.load(file) | |
task_names = list(set([item['task'] for item in data])) | |
def camel_to_whitespace(camel_str): | |
spaced_str = re.sub(r'([A-Z])', r' \1', camel_str).lower() | |
spaced_str = spaced_str.strip().title() | |
return spaced_str | |
task_2_task_string = {task: camel_to_whitespace(task) for task in task_names} | |
task_string_2_task = {task_string:task for task,task_string in task_2_task_string.items()} | |
task_strings = [task_2_task_string[t] for t in task_names] | |
# Sidebar filters | |
with st.sidebar: | |
st.sidebar.image("hai_logo.png", width=150, use_column_width=True) | |
st.header("Filters") | |
# Extract unique task names and groups | |
selected_task_string = st.selectbox("Select a Task", task_strings) | |
selected_task = task_string_2_task[selected_task_string] | |
# Filter data based on selected task | |
filtered_data_by_task = [item for item in data if item['task'] == selected_task] | |
groups = list(set([item['group'] for item in filtered_data_by_task])) | |
selected_group = st.selectbox("Select a Risk Group", groups) | |
# Filter data based on selected group | |
filtered_data_by_group = [item for item in filtered_data_by_task if item['group'] == selected_group] | |
st.divider() | |
st.sidebar.markdown(f"**Task**: {selected_task_string}") | |
st.sidebar.markdown(f"**Risk Group**: {selected_group}") | |
# CSS for reducing the vertical spacing between <p> tags, justifying text, and ensuring equal height cards | |
st.markdown(""" | |
<style> | |
.card { | |
border: 1px solid #ddd; | |
border-radius: 10px; | |
padding: 10px; | |
margin: 10px; | |
height: 100%; | |
display: flex; | |
flex-direction: column; | |
justify-content: space-between; | |
box-sizing: border-box; | |
background-color: #e4e8f5; | |
} | |
.card h3 { | |
margin-top: 0; | |
background-color: #e4e8f5; | |
} | |
.card p { | |
margin: 2px 0; | |
padding: 0; | |
text-align: justify; | |
background-color: #e4e8f5; | |
} | |
.stApp { | |
max-width: 100%; | |
padding: 1rem; | |
} | |
.grid { | |
display: flex; | |
flex-wrap: wrap; | |
justify-content: space-between; | |
} | |
.grid-item { | |
flex: 1 0 23%; /* 4 items per row */ | |
box-sizing: border-box; | |
margin: 1%; | |
display: flex; | |
} | |
.grid-item .card { | |
flex: 1; | |
display: flex; | |
flex-direction: column; | |
justify-content: space-between; | |
background-color: #e4e8f5; | |
} | |
@media (max-width: 1200px) { | |
.grid-item { | |
flex: 1 0 46%; /* 2 items per row */ | |
} | |
} | |
@media (max-width: 768px) { | |
.grid-item { | |
flex: 1 0 96%; /* 1 item per row */ | |
} | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
sidebar_style = """ | |
<style> | |
[data-testid="stSidebar"] { | |
background-color: white; | |
} | |
</style> | |
""" | |
# Aplica el estilo al sidebar | |
st.markdown(sidebar_style, unsafe_allow_html=True) | |
#st.title("LLM Risks and Mitigators") | |
tabs = st.tabs(["Examples", "Mitigators"]) | |
with tabs[0]: | |
# Display the filtered news as a grid of cards | |
if len(filtered_data_by_group) > 0: | |
for risk in set([item['risk'] for item in filtered_data_by_group]): | |
st.header(risk) | |
# Filter data based on current risk | |
filtered_data_by_risk = [item for item in filtered_data_by_group if item['risk'] == risk] | |
# Define the number of columns | |
num_columns = 3 | |
col_index = 0 | |
# Create an empty container for the grid | |
grid = st.container() | |
# Initialize an empty row | |
row = grid.columns(num_columns) | |
for item in filtered_data_by_risk: | |
for news in item['examples']: | |
with row[col_index]: | |
st.markdown( | |
f""" | |
<div class="grid-item"> | |
<div class="card"> | |
<h3>{news['title']}</h3> | |
<p>{news['incident']}</p> | |
<a href="{news['link']}" target="_blank">Read more</a> | |
</div> | |
</div> | |
""", | |
unsafe_allow_html=True | |
) | |
col_index = (col_index + 1) % num_columns | |
# Start a new row after the last column | |
if col_index == 0: | |
row = grid.columns(num_columns) | |
if len(filtered_data_by_group) == 0: | |
st.write("No news found for the selected task and group.") | |
with tabs[1]: | |
# Display the filtered news as a grid of cards | |
if len(filtered_data_by_group) > 0: | |
for risk in set([item['risk'] for item in filtered_data_by_group]): | |
st.header(risk) | |
# Filter data based on current risk | |
filtered_data_by_risk = [item for item in filtered_data_by_group if item['risk'] == risk] | |
# Define the number of columns | |
num_columns = 3 | |
col_index = 0 | |
# Create an empty container for the grid | |
grid = st.container() | |
# Initialize an empty row | |
row = grid.columns(num_columns) | |
for item in filtered_data_by_risk: | |
for news in item['mitigators']: | |
with row[col_index]: | |
st.markdown( | |
f""" | |
<div class="grid-item"> | |
<div class="card"> | |
<h3>{news['title']}</h3> | |
<p>{news['recommendation']}</p> | |
<p><b>Year:</b> {news['year']}</p> | |
<a href="{news['link']}" target="_blank">Read more</a> | |
</div> | |
</div> | |
""", | |
unsafe_allow_html=True | |
) | |
col_index = (col_index + 1) % num_columns | |
# Start a new row after the last column | |
if col_index == 0: | |
row = grid.columns(num_columns) | |
if len(filtered_data_by_group) == 0: | |
st.write("No news found for the selected task and group.") | |
SECRET_KEY = os.getenv('SECRET_KEY') | |
cookies = EncryptedCookieManager( | |
prefix="login", | |
password=os.getenv('COOKIES_PASSWORD') | |
) | |
if not cookies.ready(): | |
st.stop() | |
def main(): | |
# Título de la aplicación | |
st.title("LLM Mitigation") | |
if not cookies.get("authenticated"): | |
# Entrada de la clave secreta | |
user_key = st.text_input("Password:", type="password") | |
if st.button("Login"): | |
# Verificar si la clave ingresada coincide con la clave secreta | |
if user_key == SECRET_KEY: | |
cookies.__setitem__("authenticated", "True") | |
st.experimental_rerun() | |
else: | |
st.error("Acceso denegado. Clave incorrecta.") | |
else: | |
program() | |
if __name__ == "__main__": | |
main() |