Marcus Posey
Update timeout message
f8c6eb6
import random
import json
import os
import httpx
from gradio_client import Client
from urllib.parse import urljoin
import gspread
from oauth2client.service_account import ServiceAccountCredentials
import gradio as gr
from huggingface_hub import login
import asyncio
from concurrent.futures import ThreadPoolExecutor
BOOK_MAPPING = {
"Spaceman on a Spree": "In the context of \"Spaceman on a Spree\", written by Mack Reynolds in 1961,",
"Charity Case": "In the context of \"Charity Case\", written by Jim Harmon in 1972,",
"A Gift from Earth": "In the context of \"A Gift From Earth\", written by Manly Banister in 1950,",
"Pick a Crime": "In the context of \"Pick a Crime\", written by Richard Rein Smith in 1970,",
"Dangerous Quarry": "In the context of \"Dangerous Quarry\", written by Jim Harmon in 1972,",
"Lost in Translation": "In the context of \"Lost in Translation\", written by Larry M. Harris in 1972,"
}
CATEGORY_MAPPING = {
"Identity": "identity",
"Motivation": "motivation",
"Relationship": "relationship",
"Event": "event",
}
MODEL_VARIANTS = ["rephrase", "rephrase_summarize", "entigraph"]
model_responses = {"Model_A": "", "Model_B": "", "Model_C": ""}
class ModelManager:
def __init__(self, books):
self.book_model_assignments = {}
self.assign_models_to_books(books)
self.model_A = Client("mep296/llama-3-8b-rephrase-quality")
self.model_B = Client("mep296/llama-3-8b-rephrase-summarize-quality")
self.model_C = Client("mep296/llama-3-8b-entigraph-quality")
self.template_text = self._load_template()
def _load_template(self):
with open("prompt_template.txt", "r", encoding="utf-8") as file:
return file.read()
async def get_model_response_async(self, model_name, client, prompt):
try:
formatted_prompt = self.template_text.format(prompt)
loop = asyncio.get_running_loop()
with ThreadPoolExecutor() as executor:
response = await loop.run_in_executor(
executor,
client.predict,
formatted_prompt,
"/predict"
)
model_responses[model_name] = response
return response
except (httpx.ReadTimeout, httpx.ConnectTimeout) as e:
print(f"Timeout while getting response from {model_name}: {str(e)}")
return f"⚠️ Model {model_name} is waking up... This may take a few minutes. Please wait and try again shortly. ⏳"
except Exception as e:
print(f"Error getting response from {model_name}: {str(e)}")
return f"Error: Could not get response from {model_name}. Please try again in a few minutes."
async def get_all_model_responses_async(self, prompt):
tasks = [
self.get_model_response_async("Model_A", self.model_A, prompt),
self.get_model_response_async("Model_B", self.model_B, prompt),
self.get_model_response_async("Model_C", self.model_C, prompt)
]
responses = await asyncio.gather(*tasks)
return responses
def get_book_model_mapping(self, book):
if book not in self.book_model_assignments:
shuffled_models = random.sample(MODEL_VARIANTS, len(MODEL_VARIANTS))
self.book_model_assignments[book] = {
"Model A": shuffled_models[0],
"Model B": shuffled_models[1],
"Model C": shuffled_models[2]
}
return self.book_model_assignments[book]
def assign_models_to_books(self, books):
for book in books:
self.get_book_model_mapping(book)
class SheetManager:
def __init__(self):
scope = ["https://spreadsheets.google.com/feeds", "https://www.googleapis.com/auth/drive"]
creds = ServiceAccountCredentials.from_json_keyfile_dict(variables_keys, scope)
client = gspread.authorize(creds)
self.sheet = client.open("Model Blind Comparison Ratings").sheet1
def append_rating(self, rating_data):
self.sheet.append_row([
rating_data["book"],
rating_data["category"],
rating_data["prompt"],
rating_data["rephrase_rating"],
rating_data["rephrase_summarize_rating"],
rating_data["entigraph_rating"],
rating_data["rephrase_response"],
rating_data["rephrase_summarize_response"],
rating_data["entigraph_response"],
])
class ModelComparisonApp:
def __init__(self):
self.model_manager = ModelManager(BOOK_MAPPING.keys())
self.sheet_manager = SheetManager()
self.votes = []
self.selected_book = "Spaceman on a Spree"
self.selected_book_string = BOOK_MAPPING["Spaceman on a Spree"]
self.selected_category_string = ""
self.chat_history_A = []
self.chat_history_B = []
self.chat_history_C = []
self.state = gr.State(value="")
self.loop = asyncio.get_event_loop()
def create_interface(self):
text_size = gr.themes.sizes.text_lg
with gr.Blocks(theme=gr.themes.Default(text_size=text_size), fill_width=True) as demo:
gr.Markdown("# Model Blind Comparison")
with gr.Group():
with gr.Row():
chat_interfaces = self._create_chat_interfaces()
with gr.Row():
ratings = self._create_rating_sliders()
with gr.Row(equal_height=True):
submit_button = gr.Button(value="⭐ Submit Ratings", interactive=False)
submission_status = gr.Textbox(label="Submission Status", interactive=False)
with gr.Row(equal_height=True):
input_elements = self._create_input_elements()
self._setup_event_handlers(demo, chat_interfaces, ratings, submit_button, submission_status, input_elements)
return demo
def _create_chat_interfaces(self):
interfaces = {}
for model in ['A', 'B', 'C']:
interfaces[model] = gr.Chatbot(
getattr(self, f'chat_history_{model}'),
type="messages",
label=f"Model {model}",
height=650,
show_copy_button=True
)
return interfaces
def _create_rating_sliders(self):
return {
str(i): gr.Slider(1, 5, step=1, label=f"Rate Response {chr(64+i)}",
interactive=True, value=3)
for i in range(1, 4)
}
def _create_input_elements(self):
return {
'book': gr.Dropdown(choices=list(BOOK_MAPPING.keys()),
label="Select a Book", interactive=True, scale=1),
'category': gr.Dropdown(choices=list(CATEGORY_MAPPING.keys()),
label="Select a Question Category", interactive=True, scale=1),
'question': gr.Textbox(label="Type a Question", max_lines=1,
placeholder="e.g. What is the relationship between Harry Potter and Sirius Black?",
interactive=True, scale=2),
'send': gr.Button("Send", scale=0, variant="primary", interactive=False)
}
def respond(self, message):
if not message.strip():
raise gr.Error("Message cannot be empty!")
prompt = f"{self.selected_book_string} {message}"
mapping_dict = self.model_manager.book_model_assignments[self.selected_book]
model_order = ["rephrase", "rephrase_summarize", "entigraph"]
model_to_index = {model: i for i, model in enumerate(model_order)}
responses = self.loop.run_until_complete(
self.model_manager.get_all_model_responses_async(prompt)
)
chats = []
for response in responses:
chat = []
chat.append({"role": "user", "content": prompt})
chat.append({"role": "assistant", "content": response})
chats.append(chat)
reordered_chats = [chats[model_to_index[mapping_dict[model]]] for model in ["Model A", "Model B", "Model C"]]
return reordered_chats
def get_votes(self, book, category, question, rating_A, rating_B, rating_C):
model_mapping = self.model_manager.get_book_model_mapping(book)
rating_data = {
"book": book,
"category": category,
"prompt": question,
"rephrase_rating": rating_A if model_mapping["Model A"] == "rephrase" else
rating_B if model_mapping["Model B"] == "rephrase" else rating_C,
"rephrase_summarize_rating": rating_A if model_mapping["Model A"] == "rephrase_summarize" else
rating_B if model_mapping["Model B"] == "rephrase_summarize" else rating_C,
"entigraph_rating": rating_A if model_mapping["Model A"] == "entigraph" else
rating_B if model_mapping["Model B"] == "entigraph" else rating_C,
"rephrase_response": model_responses["Model_A"],
"rephrase_summarize_response": model_responses["Model_B"],
"entigraph_response": model_responses["Model_C"]
}
self.votes.append(rating_data)
self.sheet_manager.append_rating(rating_data)
return ("Ratings submitted successfully!", gr.update(interactive=False))
def _setup_event_handlers(self, demo, chat_interfaces, ratings, submit_button, submission_status, input_elements):
def enable_send_btn(book, category, question):
return gr.update(interactive=bool(book and category and question))
def enable_button_group(model_A, model_B, model_C):
return gr.update(interactive=bool(model_A and model_B and model_C))
def update_selected_book(book_selection):
self.selected_book = book_selection
self.selected_book_string = BOOK_MAPPING.get(book_selection, "")
return self.selected_book_string
for input_name in ['book', 'category', 'question']:
input_elements[input_name].change(
enable_send_btn,
inputs=[input_elements['book'], input_elements['category'], input_elements['question']],
outputs=[input_elements['send']]
)
input_elements['book'].change(
update_selected_book,
inputs=[input_elements['book']],
outputs=[self.state]
)
submit_button.click(
self.get_votes,
inputs=[input_elements['book'], input_elements['category'], input_elements['question'],
ratings['1'], ratings['2'], ratings['3']],
outputs=[submission_status, submit_button]
)
input_elements['send'].click(
self.respond,
inputs=[input_elements['question']],
outputs=list(chat_interfaces.values())
)
for interface in chat_interfaces.values():
interface.change(
enable_button_group,
inputs=list(chat_interfaces.values()),
outputs=[submit_button]
)
if __name__ == "__main__":
PRIVATE_KEY = os.environ.get('PRIVATE_KEY').replace('\\n', '\n')
PRIVATE_KEY_ID = os.environ.get('PRIVATE_KEY_ID').replace('\\n', '\n')
variables_keys = {
"type": "service_account",
"project_id": "summer-presence-450117-r7",
"private_key_id": PRIVATE_KEY_ID,
"private_key": PRIVATE_KEY,
"client_email": "model-blind-comparison@summer-presence-450117-r7.iam.gserviceaccount.com",
"client_id": "117681363507032419648",
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://oauth2.googleapis.com/token",
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/model-blind-comparison%40summer-presence-450117-r7.iam.gserviceaccount.com",
"universe_domain": "googleapis.com"
}
app = ModelComparisonApp()
demo = app.create_interface()
demo.launch()