EditArena / app.py
taesiri's picture
backup
212e953
import gradio as gr
import random
from datasets import load_dataset
import csv
from datetime import datetime
import os
import pandas as pd
import json
from huggingface_hub import CommitScheduler, HfApi, snapshot_download
import shutil
import uuid
import git
from pathlib import Path
from io import BytesIO
import PIL
import time # Add this import at the top
import re
api = HfApi(token=os.environ["HF_TOKEN"])
RESULTS_BACKUP_REPO = "taesiri/PhotoEditBattleResults"
MAIN_DATASET_REPO = "taesiri/IERv2-BattlePairs"
# Load the experimental dataset
dataset = load_dataset(MAIN_DATASET_REPO, split="train")
dataset_post_ids = list(
set(
load_dataset(MAIN_DATASET_REPO, columns=["post_id"], split="train")
.to_pandas()
.post_id.tolist()
)
)
# Download existing data from hub
def sync_with_hub():
"""
Synchronize local data with the hub by cloning the dataset repo
"""
print("Starting sync with hub...")
data_dir = Path("./data")
local_csv_path = data_dir / "evaluation_results_exp.csv"
# Read existing local data if it exists
local_data = None
if local_csv_path.exists():
local_data = pd.read_csv(local_csv_path)
print(f"Found local data with {len(local_data)} entries")
# Clone/pull latest data from hub
token = os.environ["HF_TOKEN"]
username = "taesiri"
repo_url = (
f"https://{username}:{token}@huggingface.co/datasets/{RESULTS_BACKUP_REPO}"
)
hub_data_dir = Path("hub_data")
if hub_data_dir.exists():
print("Pulling latest changes...")
repo = git.Repo(hub_data_dir)
origin = repo.remotes.origin
if "https://" in origin.url:
origin.set_url(repo_url)
origin.pull()
else:
print("Cloning repository...")
git.Repo.clone_from(repo_url, hub_data_dir)
# Merge hub data with local data
hub_data_source = hub_data_dir / "data"
if hub_data_source.exists():
data_dir.mkdir(exist_ok=True)
hub_csv_path = hub_data_source / "evaluation_results_exp.csv"
if hub_csv_path.exists():
hub_data = pd.read_csv(hub_csv_path)
print(f"Found hub data with {len(hub_data)} entries")
if local_data is not None:
# Merge data, keeping all entries and removing exact duplicates
merged_data = pd.concat([local_data, hub_data]).drop_duplicates()
print(f"Merged data has {len(merged_data)} entries")
# Save merged data
merged_data.to_csv(local_csv_path, index=False)
else:
# If no local data exists, just copy hub data
shutil.copy2(hub_csv_path, local_csv_path)
# Copy any other files from hub
for item in hub_data_source.glob("*"):
if item.is_file() and item.name != "evaluation_results_exp.csv":
shutil.copy2(item, data_dir / item.name)
elif item.is_dir():
dest = data_dir / item.name
if not dest.exists():
shutil.copytree(item, dest)
# Clean up cloned repo
if hub_data_dir.exists():
shutil.rmtree(hub_data_dir)
print("Finished syncing with hub!")
scheduler = CommitScheduler(
repo_id=RESULTS_BACKUP_REPO,
repo_type="dataset",
folder_path="./data",
path_in_repo="data",
every=1,
)
def save_evaluation(
post_id, model_a, model_b, verdict, username, start_time, end_time, dataset_idx
):
"""Save evaluation results to CSV including timing, username and dataset index information."""
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
duration = end_time - start_time
os.makedirs("data", exist_ok=True)
filename = "data/evaluation_results_exp.csv"
# Create file with headers if it doesn't exist
if not os.path.exists(filename):
with open(filename, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(
[
"timestamp",
"post_id",
"model_a",
"model_b",
"verdict",
"username",
"start_time",
"end_time",
"duration_seconds",
"dataset_idx",
]
)
# Append the new evaluation
with open(filename, "a", newline="") as f:
writer = csv.writer(f)
writer.writerow(
[
timestamp,
post_id,
model_a,
model_b,
verdict,
username,
start_time,
end_time,
duration,
dataset_idx,
]
)
print(
f"Saved evaluation: {post_id} - Model A: {model_a} - Model B: {model_b} - Verdict: {verdict} - Duration: {duration:.2f}s"
)
def get_annotated_indices(username):
"""Get list of dataset indices already annotated by this user"""
filename = "data/evaluation_results_exp.csv"
if not os.path.exists(filename):
print(f"No annotations found for user {username} (file doesn't exist)")
return set()
try:
df = pd.read_csv(filename)
if "dataset_idx" not in df.columns or "username" not in df.columns:
print(f"No annotations found for user {username} (missing columns)")
return set()
user_annotations = df[df["username"] == username]["dataset_idx"].tolist()
print(f"User {username} has already processed {len(user_annotations)} posts")
return set(user_annotations)
except:
print(f"Error reading annotations for user {username}")
return set()
def get_annotated_post_ids(username):
"""Get list of post_ids already annotated by this user"""
filename = "data/evaluation_results_exp.csv"
if not os.path.exists(filename):
print(f"No annotations found for user {username} (file doesn't exist)")
return set()
try:
df = pd.read_csv(filename)
if "post_id" not in df.columns or "username" not in df.columns:
print(f"No annotations found for user {username} (missing columns)")
return set()
user_annotations = df[df["username"] == username]["post_id"].tolist()
print(f"User {username} has seen {len(set(user_annotations))} unique posts")
return set(user_annotations)
except:
print(f"Error reading annotations for user {username}")
return set()
def get_random_sample(username):
"""Get a random sample trying to avoid previously seen post_ids"""
# Get indices and post_ids already annotated by this user
annotated_indices = get_annotated_indices(username)
annotated_post_ids = get_annotated_post_ids(username)
# Get all valid indices that haven't been annotated
all_indices = set(range(len(dataset)))
available_indices = list(all_indices - annotated_indices)
if not available_indices:
# If user has annotated all items, allow repeats
available_indices = list(all_indices)
# Try up to 5 times to get a sample with unseen post_id
max_attempts = 5
for _ in range(max_attempts):
idx = random.choice(available_indices)
sample = dataset[idx]
if sample["post_id"] not in annotated_post_ids:
break
# Remove this index from available indices for next attempt
available_indices.remove(idx)
if not available_indices:
# If no more indices available, use the last sampled one
break
# Randomly decide which image goes to position A and B
if random.choice([True, False]):
# AI edit is A, human edit is B
image_a = sample["ai_edited_image"]
image_b = sample["human_edited_image"]
model_a = sample["model"]
model_b = "HUMAN"
else:
# Human edit is A, AI edit is B
image_a = sample["human_edited_image"]
image_b = sample["ai_edited_image"]
model_a = "HUMAN"
model_b = sample["model"]
return {
"post_id": sample["post_id"],
"instruction": '<div style="font-size: 1.8em; font-weight: bold; padding: 20px; background-color: white; border-radius: 10px; margin: 10px;"><span style="color: #888888;">Request:</span> <span style="color: black;">'
+ sample["instruction"]
+ "</span></div>",
"simplified_instruction": '<div style="font-size: 1.8em; font-weight: bold; padding: 20px; background-color: white; border-radius: 10px; margin: 10px;"><span style="color: #888888;">Request:</span> <span style="color: black;">'
+ sample["simplified_instruction"]
+ "</span></div>",
"source_image": sample["source_image"],
"image_a": image_a,
"image_b": image_b,
"model_a": model_a,
"model_b": model_b,
"dataset_idx": idx,
}
def evaluate(verdict, state):
"""Handle evaluation button clicks with timing"""
if state is None:
return (
None,
None,
None,
None,
None,
None,
None,
False,
False,
False,
False,
None,
gr.update(variant="secondary"),
gr.update(variant="secondary"),
gr.update(variant="secondary"),
gr.update(variant="secondary"),
None,
None,
"",
)
# Record end time and save the evaluation
end_time = time.time()
save_evaluation(
state["post_id"],
state["model_a"],
state["model_b"],
verdict,
state["username"],
state["start_time"],
end_time,
state["dataset_idx"],
)
# Get next sample using username to avoid repeats
next_sample = get_random_sample(state["username"])
# Preserve username in state and set new start time
next_state = next_sample.copy()
next_state["username"] = state["username"]
next_state["start_time"] = time.time() # Set start time for next evaluation
# Reset button styles
a_better_reset = gr.update(variant="secondary")
b_better_reset = gr.update(variant="secondary")
neither_reset = gr.update(variant="secondary")
tie_reset = gr.update(variant="secondary")
return (
next_sample["source_image"],
next_sample["image_a"],
next_sample["image_b"],
next_sample["instruction"],
next_sample["simplified_instruction"],
f"Model A: {next_sample['model_a']} | Model B: {next_sample['model_b']}",
next_state, # Now includes username and start_time
None, # selected_verdict
False, # a_better_selected
False, # b_better_selected
False, # neither_selected
False, # tie_selected
a_better_reset, # reset A is better button style
b_better_reset, # reset B is better button style
neither_reset, # reset neither is good button style
tie_reset, # reset tie button style
next_sample["post_id"],
next_sample["simplified_instruction"],
state["username"], # Use username from state
)
def select_verdict(verdict, state):
"""Handle first step selection"""
if state is None:
return None, False, False, False, False # Ensure it returns 5 values
return (
verdict,
verdict == "A is better",
verdict == "B is better",
verdict == "Neither is good",
verdict == "Tie",
)
def is_valid_email(email):
"""
Validate email format and content more strictly:
- Check basic email format
- Prevent common injection attempts
- Limit length
- Restrict to printable ASCII characters
"""
if not email or not isinstance(email, str):
return False
# Check length limits
if len(email) > 254: # Maximum length per RFC 5321
return False
# Remove any whitespace
email = email.strip()
# Check for common injection characters
dangerous_chars = [";", '"', "'", ",", "\\", "\n", "\r", "\t"]
if any(char in email for char in dangerous_chars):
return False
# Ensure all characters are printable ASCII
if not all(32 <= ord(char) <= 126 for char in email):
return False
# Validate email format using comprehensive regex
pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"
if not re.match(pattern, email):
return False
# Additional checks for common patterns
if ".." in email: # No consecutive dots
return False
if email.count("@") != 1: # Exactly one @ symbol
return False
# Validate lengths of local and domain parts
local, domain = email.split("@")
if len(local) > 64 or len(domain) > 255: # RFC 5321 limits
return False
return True
def handle_username_submit(email, current_page):
"""Handle email submission with enhanced validation"""
try:
if not email:
gr.Warning("Please enter an email address")
return current_page, gr.update(value=email), gr.update(value=""), None
# Clean the input
email = str(email).strip()
if not is_valid_email(email):
gr.Warning("Please enter a valid email address (e.g., [email protected])")
return current_page, gr.update(value=email), gr.update(value=""), None
# Sanitize email for CSV storage
safe_email = email.replace('"', "").replace("'", "")
return (
2, # next page
gr.update(value=""), # clear input
gr.update(value=safe_email), # update debug
safe_email, # update state
)
except Exception as e:
print(f"Error in handle_username_submit: {str(e)}")
gr.Warning("An error occurred. Please try again.")
return current_page, gr.update(value=""), gr.update(value=""), None
def initialize(username):
"""Initialize the interface with first sample"""
sample = get_random_sample(username)
# Create state with username and start time included
state = sample.copy()
state["username"] = username
state["start_time"] = time.time() # Record start time
return (
sample["source_image"],
sample["image_a"],
sample["image_b"],
sample["instruction"],
sample["simplified_instruction"],
f"Model A: {sample['model_a']} | Model B: {sample['model_b']}",
state, # Now includes username and start_time
None, # selected_verdict
False, # a_better_selected
False, # b_better_selected
False, # neither_selected
False, # tie_selected
sample["post_id"],
sample["simplified_instruction"],
username or "",
)
def update_button_styles(verdict):
"""Update button styles based on selection"""
# Update button labels to use emojis
a_better_style = gr.update(
value="☝️ A is better" if verdict == "A is better" else "☝️ A is better"
)
b_better_style = gr.update(
value="☝️ B is better" if verdict == "B is better" else "☝️ B is better"
)
neither_style = gr.update(
value="👎 Both are bad" if verdict == "Neither is good" else "👎 Both are bad"
)
tie_style = gr.update(value="🤝 Tie" if verdict == "Tie" else "🤝 Tie")
return a_better_style, b_better_style, neither_style, tie_style
# Add at the top after imports
def create_instruction_page(html_content, image_path=None):
"""Helper function to create consistent instruction pages"""
with gr.Column():
gr.HTML(html_content)
if image_path:
gr.Image(image_path, container=False)
def advance_page(current_page):
"""Handle next button clicks to advance pages"""
return current_page + 1
# Modify the main interface
with gr.Blocks() as demo:
# Add states for page management and user info
current_page = gr.State(1) # Start at page 1
username_state = gr.State(None) # We'll actually use this now
# Create container for all pages
with gr.Column() as page_container:
# Page 1 - Username Collection
with gr.Column(visible=True) as page1:
create_instruction_page(
"""
<div style="text-align: center; padding: 20px;">
<h1>Welcome to the Image Edit Evaluation</h1>
<p>Help us evaluate different image edits for a given instruction.</p>
</div>
""",
image_path="./instructions/home.jpg",
)
username_input = gr.Textbox(
label="Please enter your email address (if you don't want to share your email, please enter a fake email)",
placeholder="[email protected]",
)
start_btn = gr.Button("Start", variant="primary")
# Page 2 - First instruction page
with gr.Column(visible=False) as page2:
create_instruction_page(
"""
<div style="text-align: center; padding: 20px;">
<h1>How to Evaluate Edits</h1>
</div>
""",
image_path="./instructions/page2.jpg", # Replace with actual image path
)
next_btn1 = gr.Button(
"Start Evaluation", variant="primary"
) # Changed button text
# Main Evaluation UI (existing code)
with gr.Column(visible=False) as main_ui:
# Add instruction panel at the top
gr.HTML(
"""
<div style="padding: 0.8rem; margin-bottom: 0.8rem; border-radius: 0.5rem; color: white; text-align: center;">
<div style="font-size: 1.2rem; margin-bottom: 0.5rem;">Read the user instruction, look at the source image, then evaluate which edit (A or B) best satisfies the request better.</div>
<div style="font-size: 1rem;">
<strong>🤝 Tie</strong> &nbsp;&nbsp;|&nbsp;&nbsp;
<strong> A is better</strong> &nbsp;&nbsp;|&nbsp;&nbsp;
<strong> B is better</strong>
</div>
<div style="color: #ff4444; font-size: 0.9rem; margin-top: 0.5rem;">
Please ignore any watermark on the image. Your rating should not be affected by any watermark on the image.
</div>
</div>
"""
)
with gr.Row():
simplified_instruction = gr.Textbox(
label="Simplified Instruction", show_label=True, visible=False
)
instruction = gr.HTML(label="Original Instruction", show_label=True)
with gr.Row():
with gr.Column():
source_image = gr.Image(
label="Source Image", show_label=True, height=500
)
gr.HTML("<h2 style='text-align: center;'>Source Image</h2>")
tie_btn = gr.Button("🤝 Tie", variant="secondary")
with gr.Column():
image_a = gr.Image(label="Image A", show_label=True, height=500)
gr.HTML("<h2 style='text-align: center;'>Image A</h2>")
a_better_btn = gr.Button("☝️ A is better", variant="secondary")
with gr.Column():
image_b = gr.Image(label="Image B", show_label=True, height=500)
gr.HTML("<h2 style='text-align: center;'>Image B</h2>")
b_better_btn = gr.Button("☝️ B is better", variant="secondary")
# Add confirmation button in new row
with gr.Row():
confirm_btn = gr.Button(
"Confirm Selection", variant="primary", visible=False
)
with gr.Row():
neither_btn = gr.Button(
"👎 Both are bad", variant="secondary", visible=False
)
with gr.Accordion("DEBUG", open=False, visible=False):
with gr.Column():
post_id_display = gr.Textbox(
label="Post ID", show_label=True, interactive=False
)
model_info = gr.Textbox(label="Model Information", show_label=True)
simplified_instruction_debug = gr.Textbox(
label="Simplified Instruction",
show_label=True,
interactive=False,
)
username_debug = gr.Textbox(
label="Username", show_label=True, interactive=False
)
state = gr.State()
selected_verdict = gr.State()
# Add states for button selection
a_better_selected = gr.Checkbox(visible=False)
b_better_selected = gr.Checkbox(visible=False)
neither_selected = gr.Checkbox(visible=False)
tie_selected = gr.Checkbox(visible=False)
def update_confirm_visibility(a_better, b_better, neither, tie):
# Update button text based on selection
if a_better:
return gr.update(visible=True, value="Confirm A is better")
elif b_better:
return gr.update(visible=True, value="Confirm B is better")
elif neither:
return gr.update(visible=True, value="Confirm Neither is good")
elif tie:
return gr.update(visible=True, value="Confirm Tie")
return gr.update(visible=False)
# Initialize the interface
demo.load(
lambda: initialize(None), # Pass None on initial load
outputs=[
source_image,
image_a,
image_b,
instruction,
simplified_instruction,
model_info,
state,
selected_verdict,
a_better_selected,
b_better_selected,
neither_selected,
tie_selected,
post_id_display,
simplified_instruction_debug,
username_debug,
],
)
# Handle first step button clicks
a_better_btn.click(
lambda state: select_verdict("A is better", state),
inputs=[state],
outputs=[
selected_verdict,
a_better_selected,
b_better_selected,
neither_selected,
tie_selected,
],
).then(
update_button_styles,
inputs=[selected_verdict],
outputs=[a_better_btn, b_better_btn, neither_btn, tie_btn],
)
b_better_btn.click(
lambda state: select_verdict("B is better", state),
inputs=[state],
outputs=[
selected_verdict,
a_better_selected,
b_better_selected,
neither_selected,
tie_selected,
],
).then(
update_button_styles,
inputs=[selected_verdict],
outputs=[a_better_btn, b_better_btn, neither_btn, tie_btn],
)
neither_btn.click(
lambda state: select_verdict("Neither is good", state),
inputs=[state],
outputs=[
selected_verdict,
a_better_selected,
b_better_selected,
neither_selected,
tie_selected,
],
).then(
update_button_styles,
inputs=[selected_verdict],
outputs=[a_better_btn, b_better_btn, neither_btn, tie_btn],
)
tie_btn.click(
lambda state: select_verdict("Tie", state),
inputs=[state],
outputs=[
selected_verdict,
a_better_selected,
b_better_selected,
neither_selected,
tie_selected,
],
).then(
update_button_styles,
inputs=[selected_verdict],
outputs=[a_better_btn, b_better_btn, neither_btn, tie_btn],
)
# Update confirm button visibility when selection changes
for checkbox in [
a_better_selected,
b_better_selected,
neither_selected,
tie_selected,
]:
checkbox.change(
update_confirm_visibility,
inputs=[
a_better_selected,
b_better_selected,
neither_selected,
tie_selected,
],
outputs=[confirm_btn],
)
# Handle confirmation button click
confirm_btn.click(
lambda verdict, state: evaluate(verdict, state),
inputs=[selected_verdict, state],
outputs=[
source_image,
image_a,
image_b,
instruction,
simplified_instruction,
model_info,
state,
selected_verdict,
a_better_selected,
b_better_selected,
neither_selected,
tie_selected,
a_better_btn,
b_better_btn,
neither_btn,
tie_btn,
post_id_display,
simplified_instruction_debug,
username_debug,
],
)
# Handle page visibility
def update_page_visibility(page_num):
"""Return visibility updates for each page column"""
return [
gr.update(visible=(page_num == 1)), # page1
gr.update(visible=(page_num == 2)), # page2
gr.update(visible=(page_num == 3)), # main_ui - changed from 4 to 3
]
# Connect button clicks to page navigation
start_btn.click(
handle_username_submit,
inputs=[username_input, current_page],
outputs=[
current_page,
username_input,
username_debug,
username_state,
],
).then(
update_page_visibility,
inputs=[current_page],
outputs=[page1, page2, main_ui],
).then(
initialize,
inputs=[username_state],
outputs=[
source_image,
image_a,
image_b,
instruction,
simplified_instruction,
model_info,
state,
selected_verdict,
a_better_selected,
b_better_selected,
neither_selected,
tie_selected,
post_id_display,
simplified_instruction_debug,
username_debug,
],
)
next_btn1.click(
lambda x: 3, # Force page 3 instead of using advance_page
inputs=[current_page],
outputs=current_page,
).then(
update_page_visibility,
inputs=[current_page],
outputs=[page1, page2, main_ui],
).then(
initialize,
inputs=[username_state],
outputs=[
source_image,
image_a,
image_b,
instruction,
simplified_instruction,
model_info,
state,
selected_verdict,
a_better_selected,
b_better_selected,
neither_selected,
tie_selected,
post_id_display,
simplified_instruction_debug,
username_debug,
],
)
if __name__ == "__main__":
# Sync with hub before launching
sync_with_hub()
demo.launch()