import gradio as gr import os import random import csv from pathlib import Path from datetime import datetime, timedelta import tempfile from huggingface_hub import HfApi, hf_hub_download, login from huggingface_hub.utils import RepositoryNotFoundError, EntryNotFoundError from apscheduler.schedulers.background import BackgroundScheduler import atexit import threading import time import shutil # --- Configuration --- DATASET_REPO_ID = os.getenv("DATASET_REPO_ID", "matsant01/user-study-collected-preferences") HF_TOKEN = os.getenv("HF_TOKEN") RESULTS_FILENAME_IN_REPO = "preferences.csv" TEMP_DIR = tempfile.mkdtemp() LOCAL_RESULTS_FILE = Path(TEMP_DIR) / RESULTS_FILENAME_IN_REPO UPLOAD_INTERVAL_HOURS = 0.1 DATA_DIR = Path("data") IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp"] # --- Global State for Upload Logic --- hf_api = None scheduler = BackgroundScheduler(daemon=True) upload_lock = threading.Lock() new_preferences_recorded_since_last_upload = threading.Event() # --- Hugging Face Hub Login & Initialization --- def initialize_hub_and_results(): global hf_api if HF_TOKEN: print("Logging into Hugging Face Hub...") try: login(token=HF_TOKEN) hf_api = HfApi() print(f"Attempting initial download of {RESULTS_FILENAME_IN_REPO} from {DATASET_REPO_ID}") hf_hub_download( repo_id=DATASET_REPO_ID, filename=RESULTS_FILENAME_IN_REPO, repo_type="dataset", token=HF_TOKEN, local_dir=TEMP_DIR, local_dir_use_symlinks=False ) print(f"Successfully downloaded existing {RESULTS_FILENAME_IN_REPO} to {LOCAL_RESULTS_FILE}") except EntryNotFoundError: print(f"{RESULTS_FILENAME_IN_REPO} not found in repo. Will create locally.") except RepositoryNotFoundError: print(f"Error: Dataset repository {DATASET_REPO_ID} not found or token lacks permissions.") print("Results saving will be disabled.") hf_api = None except Exception as e: print(f"Error during initial download/login: {e}") print("Proceeding without initial download. File will be created locally.") else: print("Warning: HF_TOKEN secret not found. Results will not be saved to the Hub.") hf_api = None # --- Data Loading --- def find_image(folder_path: Path, base_name: str) -> Path | None: for ext in IMAGE_EXTENSIONS: file_path = folder_path / f"{base_name}{ext}" if file_path.exists(): return file_path return None def get_sample_ids() -> list[str]: sample_ids = [] if DATA_DIR.is_dir(): for item in DATA_DIR.iterdir(): if item.is_dir(): prompt_file = item / "prompt.txt" input_bg = find_image(item, "input_bg") input_fg = find_image(item, "input_fg") output_baseline = find_image(item, "baseline") output_tficon = find_image(item, "tf-icon") if prompt_file.exists() and input_bg and input_fg and output_baseline and output_tficon: sample_ids.append(item.name) return sample_ids def load_sample_data(sample_id: str) -> dict | None: sample_path = DATA_DIR / sample_id if not sample_path.is_dir(): return None prompt_file = sample_path / "prompt.txt" input_bg_path = find_image(sample_path, "input_bg") input_fg_path = find_image(sample_path, "input_fg") output_baseline_path = find_image(sample_path, "baseline") output_tficon_path = find_image(sample_path, "tf-icon") if not all([prompt_file.exists(), input_bg_path, input_fg_path, output_baseline_path, output_tficon_path]): print(f"Warning: Missing files in sample {sample_id}") return None try: prompt = prompt_file.read_text().strip() except Exception as e: print(f"Error reading prompt for {sample_id}: {e}") return None return { "id": sample_id, "prompt": prompt, "input_bg": str(input_bg_path), "input_fg": str(input_fg_path), "output_baseline": str(output_baseline_path), "output_tficon": str(output_tficon_path), } # --- State and UI Logic --- INITIAL_SAMPLE_IDS = get_sample_ids() def get_next_sample(available_ids: list[str]) -> tuple[dict | None, list[str]]: if not available_ids: return None, [] chosen_id = random.choice(available_ids) remaining_ids = [id for id in available_ids if id != chosen_id] sample_data = load_sample_data(chosen_id) return sample_data, remaining_ids def display_new_sample(state: dict, available_ids: list[str]): sample_data, remaining_ids = get_next_sample(available_ids) if not sample_data: return { prompt_display: gr.update(value="No more samples available. Thank you!"), input_bg_display: gr.update(value=None, visible=False), input_fg_display: gr.update(value=None, visible=False), output_a_display: gr.update(value=None, visible=False), output_b_display: gr.update(value=None, visible=False), choice_button_a: gr.update(visible=False), choice_button_b: gr.update(visible=False), next_button: gr.update(visible=False), status_display: gr.update(value="Completed!"), app_state: state, available_samples_state: remaining_ids } outputs = [ {"model_name": "baseline", "path": sample_data["output_baseline"]}, {"model_name": "tf-icon", "path": sample_data["output_tficon"]}, ] random.shuffle(outputs) output_a = outputs[0] output_b = outputs[1] state = { "current_sample_id": sample_data["id"], "output_a_model_name": output_a["model_name"], "output_b_model_name": output_b["model_name"], } return { prompt_display: gr.update(value=f"Prompt: {sample_data['prompt']}"), input_bg_display: gr.update(value=sample_data["input_bg"], visible=True), input_fg_display: gr.update(value=sample_data["input_fg"], visible=True), output_a_display: gr.update(value=output_a["path"], visible=True), output_b_display: gr.update(value=output_b["path"], visible=True), choice_button_a: gr.update(visible=True, interactive=True), choice_button_b: gr.update(visible=True, interactive=True), next_button: gr.update(visible=False), status_display: gr.update(value="Please choose the image you prefer."), app_state: state, available_samples_state: remaining_ids } def record_preference(choice: str, state: dict, request: gr.Request): if not request: print("Error: Request object is None. Cannot get session ID.") session_id = "unknown_session" else: try: session_id = request.client.host except AttributeError: print("Error: request.client is None or has no 'host' attribute.") session_id = "unknown_client" if not state or "current_sample_id" not in state: print("Warning: State missing, cannot record preference.") return { choice_button_a: gr.update(interactive=False), choice_button_b: gr.update(interactive=False), next_button: gr.update(visible=True, interactive=True), status_display: gr.update(value="Error: Session state lost. Click Next Sample."), app_state: state } chosen_model_name = state["output_a_model_name"] if choice == "A" else state["output_b_model_name"] baseline_display = "A" if state["output_a_model_name"] == "baseline" else "B" tficon_display = "B" if state["output_a_model_name"] == "baseline" else "A" new_row = { "timestamp": datetime.now().isoformat(), "session_id": session_id, "sample_id": state["current_sample_id"], "baseline_displayed_as": baseline_display, "tficon_displayed_as": tficon_display, "chosen_display": choice, "chosen_model_name": chosen_model_name } header = list(new_row.keys()) try: with upload_lock: file_exists = LOCAL_RESULTS_FILE.exists() mode = 'a' if file_exists else 'w' with open(LOCAL_RESULTS_FILE, mode, newline='', encoding='utf-8') as f: writer = csv.DictWriter(f, fieldnames=header) if not file_exists or os.path.getsize(LOCAL_RESULTS_FILE) == 0: writer.writeheader() print(f"Created or wrote header to {LOCAL_RESULTS_FILE}") writer.writerow(new_row) print(f"Appended preference for {state['current_sample_id']} to local file.") new_preferences_recorded_since_last_upload.set() except Exception as e: print(f"Error writing local results file {LOCAL_RESULTS_FILE}: {e}") return { choice_button_a: gr.update(interactive=False), choice_button_b: gr.update(interactive=False), next_button: gr.update(visible=True, interactive=True), status_display: gr.update(value=f"Error saving preference locally: {e}. Click Next."), app_state: state } return { choice_button_a: gr.update(interactive=False), choice_button_b: gr.update(interactive=False), next_button: gr.update(visible=True, interactive=True), status_display: gr.update(value=f"Preference recorded (Chose {choice}). Click Next Sample."), app_state: state } def upload_preferences_to_hub(): print("Periodic upload check triggered.") if not hf_api: print("Upload check skipped: Hugging Face API not available.") return if not new_preferences_recorded_since_last_upload.is_set(): print("Upload check skipped: No new preferences recorded since last upload.") return with upload_lock: if not new_preferences_recorded_since_last_upload.is_set(): print("Upload check skipped (race condition avoided): No new preferences.") return if not LOCAL_RESULTS_FILE.exists() or os.path.getsize(LOCAL_RESULTS_FILE) == 0: print("Upload check skipped: Local results file is missing or empty.") new_preferences_recorded_since_last_upload.clear() return try: print(f"Attempting to upload {LOCAL_RESULTS_FILE} to {DATASET_REPO_ID}/{RESULTS_FILENAME_IN_REPO}") start_time = time.time() hf_api.upload_file( path_or_fileobj=str(LOCAL_RESULTS_FILE), path_in_repo=RESULTS_FILENAME_IN_REPO, repo_id=DATASET_REPO_ID, repo_type="dataset", commit_message=f"Periodic upload of preferences - {datetime.now().isoformat()}" ) end_time = time.time() print(f"Successfully uploaded preferences. Took {end_time - start_time:.2f} seconds.") new_preferences_recorded_since_last_upload.clear() except Exception as e: print(f"Error uploading results file: {e}") def handle_choice_a(state: dict, request: gr.Request): return record_preference("A", state, request) def handle_choice_b(state: dict, request: gr.Request): return record_preference("B", state, request) with gr.Blocks(title="Image Composition User Study") as demo: gr.Markdown("# Image Composition User Study") gr.Markdown( "Please look at the input images and the prompt below. " "Then, compare the two output images (Output A and Output B) and click the button below the one you prefer." ) app_state = gr.State({}) available_samples_state = gr.State(INITIAL_SAMPLE_IDS) prompt_display = gr.Textbox(label="Prompt", interactive=False) status_display = gr.Textbox(label="Status", value="Loading first sample...", interactive=False) with gr.Row(): input_bg_display = gr.Image(label="Input Background", type="filepath", height=300, width=300, interactive=False) input_fg_display = gr.Image(label="Input Foreground", type="filepath", height=300, width=300, interactive=False) gr.Markdown("---") gr.Markdown("## Choose your preferred output:") with gr.Row(): with gr.Column(): output_a_display = gr.Image(label="Output A", type="filepath", height=400, width=400, interactive=False) choice_button_a = gr.Button("Choose Output A", variant="primary") with gr.Column(): output_b_display = gr.Image(label="Output B", type="filepath", height=400, width=400, interactive=False) choice_button_b = gr.Button("Choose Output B", variant="primary") next_button = gr.Button("Next Sample", visible=False) demo.load( fn=display_new_sample, inputs=[app_state, available_samples_state], outputs=[ prompt_display, input_bg_display, input_fg_display, output_a_display, output_b_display, choice_button_a, choice_button_b, next_button, status_display, app_state, available_samples_state ] ) choice_button_a.click( fn=handle_choice_a, inputs=[app_state], outputs=[choice_button_a, choice_button_b, next_button, status_display, app_state], api_name=False, ) choice_button_b.click( fn=handle_choice_b, inputs=[app_state], outputs=[choice_button_a, choice_button_b, next_button, status_display, app_state], api_name=False, ) next_button.click( fn=display_new_sample, inputs=[app_state, available_samples_state], outputs=[ prompt_display, input_bg_display, input_fg_display, output_a_display, output_b_display, choice_button_a, choice_button_b, next_button, status_display, app_state, available_samples_state ], api_name=False, ) def cleanup_temp_dir(): if Path(TEMP_DIR).exists(): print(f"Cleaning up temporary directory: {TEMP_DIR}") shutil.rmtree(TEMP_DIR, ignore_errors=True) def shutdown_hook(): print("Application shutting down. Performing final upload check...") upload_preferences_to_hub() if scheduler.running: print("Shutting down scheduler...") scheduler.shutdown(wait=False) cleanup_temp_dir() print("Shutdown complete.") atexit.register(shutdown_hook) if __name__ == "__main__": initialize_hub_and_results() if not INITIAL_SAMPLE_IDS: print("Error: No valid samples found in the 'data' directory.") print("Please ensure the 'data' directory exists and contains subdirectories") print("named like 'sample_id', each with 'prompt.txt', 'input_bg.*',") print("'input_fg.*', 'baseline.*', and 'tf-icon.*' files.") elif not DATASET_REPO_ID: print("Error: DATASET_REPO_ID environment variable is not set or is set to the default placeholder.") print("Please set the DATASET_REPO_ID environment variable or update the script.") elif hf_api: print(f"Starting periodic upload scheduler (every {UPLOAD_INTERVAL_HOURS} hours)...") scheduler.add_job(upload_preferences_to_hub, 'interval', hours=UPLOAD_INTERVAL_HOURS) scheduler.start() print(f"Found {len(INITIAL_SAMPLE_IDS)} samples.") print(f"Configured to save results periodically to Hugging Face Dataset: {DATASET_REPO_ID}") print("Starting Gradio app...") demo.launch(server_name="0.0.0.0") else: print("Warning: Running without Hugging Face Hub integration (HF_TOKEN or DATASET_REPO_ID missing/invalid).") print(f"Found {len(INITIAL_SAMPLE_IDS)} samples.") print("Starting Gradio app...") demo.launch(server_name="0.0.0.0")