Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pandas as pd | |
import os | |
import uuid | |
import datetime | |
import logging | |
from huggingface_hub import hf_hub_download, upload_file, list_repo_tree | |
from dotenv import load_dotenv | |
load_dotenv() | |
# Configuration | |
HF_INPUT_DATASET = os.getenv("HF_INPUT_DATASET") | |
HF_INPUT_DATASET_PATH = os.getenv("HF_INPUT_DATASET_PATH") | |
HF_INPUT_DATASET_ID_COLUMN = os.getenv("HF_INPUT_DATASET_ID_COLUMN") | |
HF_INPUT_DATASET_COLUMN_A = os.getenv("HF_INPUT_DATASET_COLUMN_A") | |
HF_INPUT_DATASET_COLUMN_B = os.getenv("HF_INPUT_DATASET_COLUMN_B") | |
HF_OUTPUT_DATASET = os.getenv("HF_OUTPUT_DATASET") | |
HF_OUTPUT_DATASET_DIR = os.getenv("HF_OUTPUT_DATASET_DIR") | |
INSTRUCTIONS = """ | |
# Pairwise Model Output Labeling | |
Please compare the two model outputs shown below and select which one you think is better. | |
- Choose "Left is better" if the left output is superior | |
- Choose "Right is better" if the right output is superior | |
- Choose "Tie" if they are equally good or bad | |
- Choose "Can't choose" if you cannot make a determination | |
""" | |
class PairwiseLabeler: | |
def __init__(self): | |
self.df = self.read_hf_dataset() | |
self.results = {} | |
def __len__(self): | |
return len(self.df) | |
def read_hf_dataset(self) -> pd.DataFrame: | |
try: | |
local_file = hf_hub_download(repo_id=HF_INPUT_DATASET, repo_type="dataset", filename=HF_INPUT_DATASET_PATH) | |
if local_file.endswith(".json"): | |
return pd.read_json(local_file) | |
elif local_file.endswith(".jsonl"): | |
return pd.read_json(local_file, orient="records", lines=True) | |
elif local_file.endswith(".csv"): | |
return pd.read_csv(local_file) | |
elif local_file.endswith(".parquet"): | |
return pd.read_parquet(local_file) | |
else: | |
raise ValueError(f"Unsupported file type: {local_file}") | |
except Exception as e: | |
logging.error(f"Couldn't read HF dataset from {HF_INPUT_DATASET_PATH}. Using sample data instead.") | |
sample_data = { | |
HF_INPUT_DATASET_ID_COLUMN: [f"sample_{i}" for i in range(5)], | |
HF_INPUT_DATASET_COLUMN_A: [f"This is sample generation A {i}" for i in range(5)], | |
HF_INPUT_DATASET_COLUMN_B: [f"This is sample generation B {i}" for i in range(5)], | |
} | |
return pd.DataFrame(sample_data) | |
def get_current_pair(self, user_id, user_index): | |
if user_index >= len(self.df): | |
return None, None, None | |
item = self.df.iloc[user_index] | |
item_id = item.get(HF_INPUT_DATASET_ID_COLUMN, f"item_{user_index}") | |
left_text = item.get(HF_INPUT_DATASET_COLUMN_A, "") | |
right_text = item.get(HF_INPUT_DATASET_COLUMN_B, "") | |
return item_id, left_text, right_text | |
def submit_judgment(self, user_id, user_index, item_id, left_text, right_text, choice): | |
if item_id is None: | |
return None, None, None, user_index | |
# Store user votes uniquely | |
if user_id not in self.results: | |
self.results[user_id] = [] | |
# Check if user already voted for this item | |
existing_vote = next((r for r in self.results[user_id] if r["item_id"] == item_id), None) | |
if existing_vote: | |
existing_vote["judgment"] = choice | |
existing_vote["timestamp"] = datetime.datetime.now().isoformat() | |
else: | |
self.results[user_id].append({ | |
"item_id": item_id, | |
"generation_a": left_text, | |
"generation_b": right_text, | |
"judgment": choice, | |
"timestamp": datetime.datetime.now().isoformat(), | |
"labeler_id": user_id | |
}) | |
# Save immediately | |
self.save_results(user_id) | |
# Move to the next item | |
user_index += 1 | |
next_id, next_left, next_right = self.get_current_pair(user_id, user_index) | |
return next_id, next_left, next_right, user_index | |
def save_results(self, user_id): | |
if user_id not in self.results or not self.results[user_id]: | |
return | |
try: | |
results_df = pd.DataFrame(self.results[user_id]) | |
filename = f"results_{user_id}.jsonl" | |
results_df.to_json(filename, orient="records", lines=True) | |
# Push to Hugging Face Hub | |
upload_file(repo_id=HF_OUTPUT_DATASET, repo_type="dataset", | |
path_in_repo=os.path.join(HF_OUTPUT_DATASET_DIR, filename), | |
path_or_fileobj=filename) | |
os.remove(filename) | |
except Exception as e: | |
logging.error(f"Error saving results: {e}") | |
# Initialize the labeler | |
labeler = PairwiseLabeler() | |
# Gradio UI | |
with gr.Blocks() as app: | |
gr.Markdown(INSTRUCTIONS) | |
user_id = gr.Textbox(label="Enter your user ID", interactive=True) | |
user_index = gr.State(0) # Track each user's progress | |
with gr.Row(): | |
with gr.Column(): | |
left_output = gr.Textbox(label="Model Output A", lines=10, interactive=False) | |
with gr.Column(): | |
right_output = gr.Textbox(label="Model Output B", lines=10, interactive=False) | |
item_id = gr.Textbox(visible=False) | |
with gr.Row(): | |
left_btn = gr.Button("⬅️ A is better") | |
right_btn = gr.Button("➡️ B is better") | |
tie_btn = gr.Button("🤝 Tie") | |
cant_choose_btn = gr.Button("🤔 Can't choose") | |
def load_first_pair(user_id): | |
if not user_id: | |
return None, None, None, 0 | |
return labeler.get_current_pair(user_id, 0) + (0,) | |
def judge(choice, user_id, user_index, item_id, left_text, right_text): | |
return labeler.submit_judgment(user_id, user_index, item_id, left_text, right_text, choice) | |
user_id.submit(load_first_pair, inputs=[user_id], outputs=[item_id, left_output, right_output, user_index]) | |
left_btn.click(judge, inputs=[gr.State("A is better"), user_id, user_index, item_id, left_output, right_output], outputs=[item_id, left_output, right_output, user_index]) | |
right_btn.click(judge, inputs=[gr.State("B is better"), user_id, user_index, item_id, left_output, right_output], outputs=[item_id, left_output, right_output, user_index]) | |
tie_btn.click(judge, inputs=[gr.State("Tie"), user_id, user_index, item_id, left_output, right_output], outputs=[item_id, left_output, right_output, user_index]) | |
cant_choose_btn.click(judge, inputs=[gr.State("Can't choose"), user_id, user_index, item_id, left_output, right_output], outputs=[item_id, left_output, right_output, user_index]) | |
if __name__ == "__main__": | |
app.launch() | |