LLM-Risks / app.py
crismunoz's picture
ko
19c49e9
raw
history blame
8.59 kB
import streamlit as st
st.set_page_config(
page_title="Holistic AI - LLM Risks",
page_icon="👋",
layout='wide'
)
import json
import os
from huggingface_hub import HfApi, login
from streamlit_cookies_manager import EncryptedCookieManager
import re
def program():
dataset_name = "holistic-ai/LLM-Risks"
token = os.getenv("HF_TOKEN")
api = HfApi()
login(token)
repo_path = api.snapshot_download(repo_id=dataset_name, repo_type="dataset")
with open(f'{repo_path}/risk_annotation_consolidated.json') as file:
data = json.load(file)
task_names = list(set([item['task'] for item in data]))
def camel_to_whitespace(camel_str):
spaced_str = re.sub(r'([A-Z])', r' \1', camel_str).lower()
spaced_str = spaced_str.strip().title()
return spaced_str
task_2_task_string = {task: camel_to_whitespace(task) for task in task_names}
task_string_2_task = {task_string:task for task,task_string in task_2_task_string.items()}
task_strings = [task_2_task_string[t] for t in task_names]
# Sidebar filters
with st.sidebar:
st.sidebar.image("hai_logo.png", width=150, use_column_width=True)
st.header("Filters")
# Extract unique task names and groups
selected_task_string = st.selectbox("Select a Task", task_strings)
selected_task = task_string_2_task[selected_task_string]
# Filter data based on selected task
filtered_data_by_task = [item for item in data if item['task'] == selected_task]
groups = list(set([item['group'] for item in filtered_data_by_task]))
selected_group = st.selectbox("Select a Risk Group", groups)
# Filter data based on selected group
filtered_data_by_group = [item for item in filtered_data_by_task if item['group'] == selected_group]
st.divider()
st.sidebar.markdown(f"**Task**: {selected_task_string}")
st.sidebar.markdown(f"**Risk Group**: {selected_group}")
# CSS for reducing the vertical spacing between <p> tags, justifying text, and ensuring equal height cards
st.markdown("""
<style>
.card {
border: 1px solid #ddd;
border-radius: 10px;
padding: 10px;
margin: 10px;
height: 100%;
display: flex;
flex-direction: column;
justify-content: space-between;
box-sizing: border-box;
background-color: #e4e8f5;
}
.card h3 {
margin-top: 0;
background-color: #e4e8f5;
}
.card p {
margin: 2px 0;
padding: 0;
text-align: justify;
background-color: #e4e8f5;
}
.stApp {
max-width: 100%;
padding: 1rem;
}
.grid {
display: flex;
flex-wrap: wrap;
justify-content: space-between;
}
.grid-item {
flex: 1 0 23%; /* 4 items per row */
box-sizing: border-box;
margin: 1%;
display: flex;
}
.grid-item .card {
flex: 1;
display: flex;
flex-direction: column;
justify-content: space-between;
background-color: #e4e8f5;
}
@media (max-width: 1200px) {
.grid-item {
flex: 1 0 46%; /* 2 items per row */
}
}
@media (max-width: 768px) {
.grid-item {
flex: 1 0 96%; /* 1 item per row */
}
}
</style>
""", unsafe_allow_html=True)
sidebar_style = """
<style>
[data-testid="stSidebar"] {
background-color: white;
}
</style>
"""
# Aplica el estilo al sidebar
st.markdown(sidebar_style, unsafe_allow_html=True)
#st.title("LLM Risks and Mitigators")
tabs = st.tabs(["Examples", "Mitigators"])
with tabs[0]:
# Display the filtered news as a grid of cards
if len(filtered_data_by_group) > 0:
for risk in set([item['risk'] for item in filtered_data_by_group]):
st.header(risk)
# Filter data based on current risk
filtered_data_by_risk = [item for item in filtered_data_by_group if item['risk'] == risk]
# Define the number of columns
num_columns = 3
col_index = 0
# Create an empty container for the grid
grid = st.container()
# Initialize an empty row
row = grid.columns(num_columns)
for item in filtered_data_by_risk:
for news in item['examples']:
with row[col_index]:
st.markdown(
f"""
<div class="grid-item">
<div class="card">
<h3>{news['title']}</h3>
<p>{news['incident']}</p>
<a href="{news['link']}" target="_blank">Read more</a>
</div>
</div>
""",
unsafe_allow_html=True
)
col_index = (col_index + 1) % num_columns
# Start a new row after the last column
if col_index == 0:
row = grid.columns(num_columns)
if len(filtered_data_by_group) == 0:
st.write("No news found for the selected task and group.")
with tabs[1]:
# Display the filtered news as a grid of cards
if len(filtered_data_by_group) > 0:
for risk in set([item['risk'] for item in filtered_data_by_group]):
st.header(risk)
# Filter data based on current risk
filtered_data_by_risk = [item for item in filtered_data_by_group if item['risk'] == risk]
# Define the number of columns
num_columns = 3
col_index = 0
# Create an empty container for the grid
grid = st.container()
# Initialize an empty row
row = grid.columns(num_columns)
for item in filtered_data_by_risk:
for news in item['mitigators']:
with row[col_index]:
st.markdown(
f"""
<div class="grid-item">
<div class="card">
<h3>{news['title']}</h3>
<p>{news['recommendation']}</p>
<p><b>Year:</b> {news['year']}</p>
<a href="{news['link']}" target="_blank">Read more</a>
</div>
</div>
""",
unsafe_allow_html=True
)
col_index = (col_index + 1) % num_columns
# Start a new row after the last column
if col_index == 0:
row = grid.columns(num_columns)
if len(filtered_data_by_group) == 0:
st.write("No news found for the selected task and group.")
SECRET_KEY = os.getenv('SECRET_KEY')
cookies = EncryptedCookieManager(
prefix="login",
password=os.getenv('COOKIES_PASSWORD')
)
if not cookies.ready():
st.stop()
def main():
# Título de la aplicación
st.title("LLM Mitigation")
if not cookies.get("authenticated"):
# Entrada de la clave secreta
user_key = st.text_input("Password:", type="password")
if st.button("Login"):
# Verificar si la clave ingresada coincide con la clave secreta
if user_key == SECRET_KEY:
cookies.__setitem__("authenticated", "True")
st.experimental_rerun()
else:
st.error("Acceso denegado. Clave incorrecta.")
else:
program()
if __name__ == "__main__":
main()