acr_viewer / app.py
Pratyush Maini
strong
dffff5c
raw
history blame
2.16 kB
import os
import gradio as gr
import pandas as pd
MODELS = ['pythia-12B', 'pythia-6.7B', 'pythia-1.4B', 'pythia-410M']
def load_csv(model_name):
csv_path = f"model_csvs/{model_name}.csv"
df = pd.read_csv(csv_path)
return df
def get_result(model_name, target_string):
df = load_csv(model_name)
row = df[df['target_str'] == target_string].iloc[0]
num_free_tokens = int(row['num_free_tokens']) # Convert to regular int
target_length = int(row['target_length']) # Convert to regular int
optimal_prompt = row['optimal_prompt']
ratio = float(row['ratio']) # Convert to regular float
memorized = bool(row['memorized']) # Convert to regular bool
return num_free_tokens, target_length, optimal_prompt, ratio, memorized
def update_csv_dropdown(model_name):
df = load_csv(model_name)
return gr.Dropdown(choices=df['target_str'].tolist(), interactive=True)
with gr.Blocks() as demo:
gr.Markdown("<h1><center>Model Memorization Checker</center></h1>")
with gr.Row():
model_dropdown = gr.Dropdown(choices=MODELS, label="Select Model")
csv_dropdown = gr.Dropdown(choices=[], label="Select Target String", interactive=True)
run_button = gr.Button("Run")
with gr.Row():
target_length_output = gr.Number(label="# Target Tokens")
num_free_tokens_output = gr.Number(label="# Optimal Prompt Tokens")
optimal_prompt_output = gr.Textbox(label="Optimal Prompt")
ratio_output = gr.Number(label="Adversarial Compression Ratio")
memorized_output = gr.Textbox(label="Memorized")
model_dropdown.change(fn=update_csv_dropdown, inputs=model_dropdown, outputs=csv_dropdown)
def run_check(model_name, target_string):
num_free_tokens, target_length, optimal_prompt, ratio, memorized = get_result(model_name, target_string)
return num_free_tokens, target_length, optimal_prompt, ratio, str(memorized)
run_button.click(fn=run_check, inputs=[model_dropdown, csv_dropdown], outputs=[num_free_tokens_output, target_length_output, optimal_prompt_output, ratio_output, memorized_output])
demo.launch(debug=True, show_error=True)