jdev8's picture
Update app.py
9eff11a verified
raw
history blame
6.67 kB
import gradio as gr
import pandas as pd
import os
import uuid
import datetime
import logging
from huggingface_hub import hf_hub_download, upload_file, list_repo_tree
from dotenv import load_dotenv
load_dotenv()
# Configuration
HF_INPUT_DATASET = os.getenv("HF_INPUT_DATASET")
HF_INPUT_DATASET_PATH = os.getenv("HF_INPUT_DATASET_PATH")
HF_INPUT_DATASET_ID_COLUMN = os.getenv("HF_INPUT_DATASET_ID_COLUMN")
HF_INPUT_DATASET_COLUMN_A = os.getenv("HF_INPUT_DATASET_COLUMN_A")
HF_INPUT_DATASET_COLUMN_B = os.getenv("HF_INPUT_DATASET_COLUMN_B")
HF_OUTPUT_DATASET = os.getenv("HF_OUTPUT_DATASET")
HF_OUTPUT_DATASET_DIR = os.getenv("HF_OUTPUT_DATASET_DIR")
INSTRUCTIONS = """
# Pairwise Model Output Labeling
Please compare the two model outputs shown below and select which one you think is better.
- Choose "Left is better" if the left output is superior
- Choose "Right is better" if the right output is superior
- Choose "Tie" if they are equally good or bad
- Choose "Can't choose" if you cannot make a determination
"""
class PairwiseLabeler:
def __init__(self):
self.df = self.read_hf_dataset()
self.results = {}
def __len__(self):
return len(self.df)
def read_hf_dataset(self) -> pd.DataFrame:
try:
local_file = hf_hub_download(repo_id=HF_INPUT_DATASET, repo_type="dataset", filename=HF_INPUT_DATASET_PATH)
if local_file.endswith(".json"):
return pd.read_json(local_file)
elif local_file.endswith(".jsonl"):
return pd.read_json(local_file, orient="records", lines=True)
elif local_file.endswith(".csv"):
return pd.read_csv(local_file)
elif local_file.endswith(".parquet"):
return pd.read_parquet(local_file)
else:
raise ValueError(f"Unsupported file type: {local_file}")
except Exception as e:
logging.error(f"Couldn't read HF dataset from {HF_INPUT_DATASET_PATH}. Using sample data instead.")
sample_data = {
HF_INPUT_DATASET_ID_COLUMN: [f"sample_{i}" for i in range(5)],
HF_INPUT_DATASET_COLUMN_A: [f"This is sample generation A {i}" for i in range(5)],
HF_INPUT_DATASET_COLUMN_B: [f"This is sample generation B {i}" for i in range(5)],
}
return pd.DataFrame(sample_data)
def get_current_pair(self, user_id, user_index):
if user_index >= len(self.df):
return None, None, None
item = self.df.iloc[user_index]
item_id = item.get(HF_INPUT_DATASET_ID_COLUMN, f"item_{user_index}")
left_text = item.get(HF_INPUT_DATASET_COLUMN_A, "")
right_text = item.get(HF_INPUT_DATASET_COLUMN_B, "")
return item_id, left_text, right_text
def submit_judgment(self, user_id, user_index, item_id, left_text, right_text, choice):
if item_id is None:
return None, None, None, user_index
# Store user votes uniquely
if user_id not in self.results:
self.results[user_id] = []
# Check if user already voted for this item
existing_vote = next((r for r in self.results[user_id] if r["item_id"] == item_id), None)
if existing_vote:
existing_vote["judgment"] = choice
existing_vote["timestamp"] = datetime.datetime.now().isoformat()
else:
self.results[user_id].append({
"item_id": item_id,
"generation_a": left_text,
"generation_b": right_text,
"judgment": choice,
"timestamp": datetime.datetime.now().isoformat(),
"labeler_id": user_id
})
# Save immediately
self.save_results(user_id)
# Move to the next item
user_index += 1
next_id, next_left, next_right = self.get_current_pair(user_id, user_index)
return next_id, next_left, next_right, user_index
def save_results(self, user_id):
if user_id not in self.results or not self.results[user_id]:
return
try:
results_df = pd.DataFrame(self.results[user_id])
filename = f"results_{user_id}.jsonl"
results_df.to_json(filename, orient="records", lines=True)
# Push to Hugging Face Hub
upload_file(repo_id=HF_OUTPUT_DATASET, repo_type="dataset",
path_in_repo=os.path.join(HF_OUTPUT_DATASET_DIR, filename),
path_or_fileobj=filename)
os.remove(filename)
except Exception as e:
logging.error(f"Error saving results: {e}")
# Initialize the labeler
labeler = PairwiseLabeler()
# Gradio UI
with gr.Blocks() as app:
gr.Markdown(INSTRUCTIONS)
user_id = gr.Textbox(label="Enter your user ID", interactive=True)
user_index = gr.State(0) # Track each user's progress
with gr.Row():
with gr.Column():
left_output = gr.Textbox(label="Model Output A", lines=10, interactive=False)
with gr.Column():
right_output = gr.Textbox(label="Model Output B", lines=10, interactive=False)
item_id = gr.Textbox(visible=False)
with gr.Row():
left_btn = gr.Button("⬅️ A is better")
right_btn = gr.Button("➡️ B is better")
tie_btn = gr.Button("🤝 Tie")
cant_choose_btn = gr.Button("🤔 Can't choose")
def load_first_pair(user_id):
if not user_id:
return None, None, None, 0
return labeler.get_current_pair(user_id, 0) + (0,)
def judge(choice, user_id, user_index, item_id, left_text, right_text):
return labeler.submit_judgment(user_id, user_index, item_id, left_text, right_text, choice)
user_id.submit(load_first_pair, inputs=[user_id], outputs=[item_id, left_output, right_output, user_index])
left_btn.click(judge, inputs=[gr.State("A is better"), user_id, user_index, item_id, left_output, right_output], outputs=[item_id, left_output, right_output, user_index])
right_btn.click(judge, inputs=[gr.State("B is better"), user_id, user_index, item_id, left_output, right_output], outputs=[item_id, left_output, right_output, user_index])
tie_btn.click(judge, inputs=[gr.State("Tie"), user_id, user_index, item_id, left_output, right_output], outputs=[item_id, left_output, right_output, user_index])
cant_choose_btn.click(judge, inputs=[gr.State("Can't choose"), user_id, user_index, item_id, left_output, right_output], outputs=[item_id, left_output, right_output, user_index])
if __name__ == "__main__":
app.launch()