Marcus Posey
Add application file
e68ea89
raw
history blame
12 kB
import random
import json
import os
import httpx
from gradio_client import Client
from urllib.parse import urljoin
import gspread
from oauth2client.service_account import ServiceAccountCredentials
import gradio as gr
from huggingface_hub import login
BOOK_MAPPING = {
"Spaceman on a Spree": "In the context of \"Spaceman on a Spree\", written by Mack Reynolds in 1961,",
"Charity Case": "In the context of \"Charity Case\", written by Jim Harmon in 1972,",
"A Gift from Earth": "In the context of \"A Gift From Earth\", written by Manly Banister in 1950,",
"Pick a Crime": "In the context of \"Pick a Crime\", written by Richard Rein Smith in 1970,",
"Dangerous Quarry": "In the context of \"Dangerous Quarry\", written by Jim Harmon in 1972,"
}
CATEGORY_MAPPING = {
"Character": "character",
"Relationship": "relationship",
"Plot": "plot",
"Numerical": "numerical"
}
MODEL_VARIANTS = ["rephrase", "rephrase_summarize", "entigraph"]
model_responses = {"Model_A": "", "Model_B": "", "Model_C": ""}
class ModelManager:
def __init__(self):
self.book_model_assignments = {}
self.model_A = Client("mep296/llama-3-8b-rephrase-quality")
self.model_B = Client("mep296/llama-3-8b-rephrase-summarize-quality")
self.model_C = Client("mep296/llama-3-8b-entigraph-quality")
self.template_text = self._load_template()
def _load_template(self):
with open("prompt_template.txt", "r", encoding="utf-8") as file:
return file.read()
def get_model_response(self, model_name, prompt):
try:
formatted_prompt = self.template_text.format(prompt)
model_clients = {
"Model_A": self.model_A,
"Model_B": self.model_B,
"Model_C": self.model_C
}
client = model_clients[model_name]
response = client.predict(
prompt=formatted_prompt,
api_name="/predict"
)
return response
except (httpx.ReadTimeout, httpx.ConnectTimeout) as e:
print(f"Timeout while getting response from {model_name}: {str(e)}")
return f"Error: Model {model_name} timed out. Please try again."
except Exception as e:
print(f"Error getting response from {model_name}: {str(e)}")
return f"Error: Could not get response from {model_name}. Please try again."
def get_all_model_responses(self, prompt):
responses = []
for model in ['A', 'B', 'C']:
response = self.get_model_response(f"Model_{model}", prompt)
model_responses[f"Model_{model}"] = response
responses.append(response)
return responses
def get_book_model_mapping(self, book):
if book not in self.book_model_assignments:
shuffled_models = random.sample(MODEL_VARIANTS, len(MODEL_VARIANTS))
self.book_model_assignments[book] = {
"Model A": shuffled_models[0],
"Model B": shuffled_models[1],
"Model C": shuffled_models[2]
}
return self.book_model_assignments[book]
class SheetManager:
def __init__(self):
scope = ["https://spreadsheets.google.com/feeds", "https://www.googleapis.com/auth/drive"]
creds = ServiceAccountCredentials.from_json_keyfile_dict(variables_keys, scope)
client = gspread.authorize(creds)
self.sheet = client.open("Model Blind Comparison Ratings").sheet1
def append_rating(self, rating_data):
self.sheet.append_row([
rating_data["book"],
rating_data["category"],
rating_data["prompt"],
rating_data["rephrase_rating"],
rating_data["rephrase_summarize_rating"],
rating_data["entigraph_rating"],
rating_data["rephrase_response"],
rating_data["rephrase_summarize_response"],
rating_data["entigraph_response"],
])
class ModelComparisonApp:
def __init__(self):
self.model_manager = ModelManager()
self.sheet_manager = SheetManager()
self.votes = []
self.selected_book_string = BOOK_MAPPING["Spaceman on a Spree"]
self.selected_category_string = ""
self.chat_history_A = []
self.chat_history_B = []
self.chat_history_C = []
self.state = gr.State(value="")
def create_interface(self):
text_size = gr.themes.sizes.text_lg
with gr.Blocks(theme=gr.themes.Default(text_size=text_size), fill_width=True) as demo:
gr.Markdown("# Model Blind Comparison")
with gr.Group():
with gr.Row():
chat_interfaces = self._create_chat_interfaces()
with gr.Row():
ratings = self._create_rating_sliders()
with gr.Row(equal_height=True):
submit_button = gr.Button(value="⭐ Submit Ratings", interactive=False)
submission_status = gr.Textbox(label="Submission Status", interactive=False)
with gr.Row(equal_height=True):
input_elements = self._create_input_elements()
self._setup_event_handlers(demo, chat_interfaces, ratings, submit_button, submission_status, input_elements)
return demo
def _create_chat_interfaces(self):
interfaces = {}
for model in ['A', 'B', 'C']:
interfaces[model] = gr.Chatbot(
getattr(self, f'chat_history_{model}'),
type="messages",
label=f"Model {model}",
height=650,
show_copy_button=True
)
return interfaces
def _create_rating_sliders(self):
return {
str(i): gr.Slider(1, 5, step=1, label=f"Rate Response {chr(64+i)}",
interactive=True, value=3)
for i in range(1, 4)
}
def _create_input_elements(self):
return {
'book': gr.Dropdown(choices=list(BOOK_MAPPING.keys()),
label="Select a Book", interactive=True, scale=1),
'category': gr.Dropdown(choices=list(CATEGORY_MAPPING.keys()),
label="Select a Question Category", interactive=True, scale=1),
'question': gr.Textbox(label="Type a Question", max_lines=1,
placeholder="e.g. What is the relationship between characters?",
interactive=True, scale=2),
'send': gr.Button("Send", scale=0, variant="primary", interactive=False)
}
def respond(self, message, chat_A, chat_B, chat_C):
if not message.strip():
raise gr.Error("Message cannot be empty!")
prompt = f"{self.selected_book_string} {message}"
responses = self.model_manager.get_all_model_responses(prompt)
chats = []
for response in responses:
chat = []
chat.append({"role": "user", "content": prompt})
chat.append({"role": "assistant", "content": response})
chats.append(chat)
return chats
def get_votes(self, book, category, question, rating_A, rating_B, rating_C):
model_mapping = self.model_manager.get_book_model_mapping(book)
rating_data = {
"book": book,
"category": category,
"prompt": question,
"rephrase_rating": rating_A if model_mapping["Model A"] == "rephrase" else
rating_B if model_mapping["Model B"] == "rephrase" else rating_C,
"rephrase_summarize_rating": rating_A if model_mapping["Model A"] == "rephrase_summarize" else
rating_B if model_mapping["Model B"] == "rephrase_summarize" else rating_C,
"entigraph_rating": rating_A if model_mapping["Model A"] == "entigraph" else
rating_B if model_mapping["Model B"] == "entigraph" else rating_C,
"rephrase_response": model_responses["Model_A"] if model_mapping["Model A"] == "rephrase" else
model_responses["Model_B"] if model_mapping["Model B"] == "rephrase" else model_responses["Model_C"],
"rephrase_summarize_response": model_responses["Model_A"] if model_mapping["Model A"] == "rephrase_summarize" else
model_responses["Model_B"] if model_mapping["Model B"] == "rephrase_summarize" else model_responses["Model_C"],
"entigraph_response": model_responses["Model_A"] if model_mapping["Model A"] == "entigraph" else
model_responses["Model_B"] if model_mapping["Model B"] == "entigraph" else model_responses["Model_C"]
}
self.votes.append(rating_data)
self.sheet_manager.append_rating(rating_data)
return ("Ratings submitted successfully!", gr.update(interactive=False))
def _setup_event_handlers(self, demo, chat_interfaces, ratings, submit_button, submission_status, input_elements):
def enable_send_btn(book, category, question):
return gr.update(interactive=bool(book and category and question))
def enable_button_group(model_A, model_B, model_C):
return gr.update(interactive=bool(model_A and model_B and model_C))
def update_selected_book(book_selection):
self.selected_book_string = BOOK_MAPPING.get(book_selection, "")
return self.selected_book_string
for input_name in ['book', 'category', 'question']:
input_elements[input_name].change(
enable_send_btn,
inputs=[input_elements['book'], input_elements['category'], input_elements['question']],
outputs=[input_elements['send']]
)
input_elements['book'].change(
update_selected_book,
inputs=[input_elements['book']],
outputs=[self.state]
)
submit_button.click(
self.get_votes,
inputs=[input_elements['book'], input_elements['category'], input_elements['question'],
ratings['1'], ratings['2'], ratings['3']],
outputs=[submission_status, submit_button]
)
input_elements['send'].click(
self.respond,
inputs=[input_elements['question']] + list(chat_interfaces.values()),
outputs=list(chat_interfaces.values())
)
for interface in chat_interfaces.values():
interface.change(
enable_button_group,
inputs=list(chat_interfaces.values()),
outputs=[submit_button]
)
if __name__ == "__main__":
PRIVATE_KEY = os.getenv('PRIVATE_KEY')
PRIVATE_KEY_ID = os.getenv('PRIVATE_KEY_ID')
variables_keys = {
"type": "service_account",
"project_id": "summer-presence-450117-r7",
"private_key_id": "427fe03954113ce7174febe50871c7beba0384cc",
"private_key": PRIVATE_KEY,
"client_email": "model-blind-comparison@summer-presence-450117-r7.iam.gserviceaccount.com",
"client_id": "117681363507032419648",
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://oauth2.googleapis.com/token",
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/model-blind-comparison%40summer-presence-450117-r7.iam.gserviceaccount.com",
"universe_domain": "googleapis.com"
}
app = ModelComparisonApp()
demo = app.create_interface()
demo.launch()