Spaces:
Runtime error
Runtime error
import random | |
import json | |
import os | |
import httpx | |
from gradio_client import Client | |
from urllib.parse import urljoin | |
import gspread | |
from oauth2client.service_account import ServiceAccountCredentials | |
import gradio as gr | |
from huggingface_hub import login | |
import asyncio | |
from concurrent.futures import ThreadPoolExecutor | |
BOOK_MAPPING = { | |
"Spaceman on a Spree": "In the context of \"Spaceman on a Spree\", written by Mack Reynolds in 1961,", | |
"Charity Case": "In the context of \"Charity Case\", written by Jim Harmon in 1972,", | |
"A Gift from Earth": "In the context of \"A Gift From Earth\", written by Manly Banister in 1950,", | |
"Pick a Crime": "In the context of \"Pick a Crime\", written by Richard Rein Smith in 1970,", | |
"Dangerous Quarry": "In the context of \"Dangerous Quarry\", written by Jim Harmon in 1972,", | |
"Lost in Translation": "In the context of \"Lost in Translation\", written by Larry M. Harris in 1972," | |
} | |
CATEGORY_MAPPING = { | |
"Identity": "identity", | |
"Motivation": "motivation", | |
"Relationship": "relationship", | |
"Event": "event", | |
} | |
MODEL_VARIANTS = ["rephrase", "rephrase_summarize", "entigraph"] | |
model_responses = {"Model_A": "", "Model_B": "", "Model_C": ""} | |
class ModelManager: | |
def __init__(self, books): | |
self.book_model_assignments = {} | |
self.assign_models_to_books(books) | |
self.model_A = Client("mep296/llama-3-8b-rephrase-quality") | |
self.model_B = Client("mep296/llama-3-8b-rephrase-summarize-quality") | |
self.model_C = Client("mep296/llama-3-8b-entigraph-quality") | |
self.template_text = self._load_template() | |
def _load_template(self): | |
with open("prompt_template.txt", "r", encoding="utf-8") as file: | |
return file.read() | |
async def get_model_response_async(self, model_name, client, prompt): | |
try: | |
formatted_prompt = self.template_text.format(prompt) | |
loop = asyncio.get_running_loop() | |
with ThreadPoolExecutor() as executor: | |
response = await loop.run_in_executor( | |
executor, | |
client.predict, | |
formatted_prompt, | |
"/predict" | |
) | |
model_responses[model_name] = response | |
return response | |
except (httpx.ReadTimeout, httpx.ConnectTimeout) as e: | |
print(f"Timeout while getting response from {model_name}: {str(e)}") | |
return f"⚠️ Model {model_name} is waking up... This may take a few minutes. Please wait and try again shortly. ⏳" | |
except Exception as e: | |
print(f"Error getting response from {model_name}: {str(e)}") | |
return f"Error: Could not get response from {model_name}. Please try again in a few minutes." | |
async def get_all_model_responses_async(self, prompt): | |
tasks = [ | |
self.get_model_response_async("Model_A", self.model_A, prompt), | |
self.get_model_response_async("Model_B", self.model_B, prompt), | |
self.get_model_response_async("Model_C", self.model_C, prompt) | |
] | |
responses = await asyncio.gather(*tasks) | |
return responses | |
def get_book_model_mapping(self, book): | |
if book not in self.book_model_assignments: | |
shuffled_models = random.sample(MODEL_VARIANTS, len(MODEL_VARIANTS)) | |
self.book_model_assignments[book] = { | |
"Model A": shuffled_models[0], | |
"Model B": shuffled_models[1], | |
"Model C": shuffled_models[2] | |
} | |
return self.book_model_assignments[book] | |
def assign_models_to_books(self, books): | |
for book in books: | |
self.get_book_model_mapping(book) | |
class SheetManager: | |
def __init__(self): | |
scope = ["https://spreadsheets.google.com/feeds", "https://www.googleapis.com/auth/drive"] | |
creds = ServiceAccountCredentials.from_json_keyfile_dict(variables_keys, scope) | |
client = gspread.authorize(creds) | |
self.sheet = client.open("Model Blind Comparison Ratings").sheet1 | |
def append_rating(self, rating_data): | |
self.sheet.append_row([ | |
rating_data["book"], | |
rating_data["category"], | |
rating_data["prompt"], | |
rating_data["rephrase_rating"], | |
rating_data["rephrase_summarize_rating"], | |
rating_data["entigraph_rating"], | |
rating_data["rephrase_response"], | |
rating_data["rephrase_summarize_response"], | |
rating_data["entigraph_response"], | |
]) | |
class ModelComparisonApp: | |
def __init__(self): | |
self.model_manager = ModelManager(BOOK_MAPPING.keys()) | |
self.sheet_manager = SheetManager() | |
self.votes = [] | |
self.selected_book = "Spaceman on a Spree" | |
self.selected_book_string = BOOK_MAPPING["Spaceman on a Spree"] | |
self.selected_category_string = "" | |
self.chat_history_A = [] | |
self.chat_history_B = [] | |
self.chat_history_C = [] | |
self.state = gr.State(value="") | |
self.loop = asyncio.get_event_loop() | |
def create_interface(self): | |
text_size = gr.themes.sizes.text_lg | |
with gr.Blocks(theme=gr.themes.Default(text_size=text_size), fill_width=True) as demo: | |
gr.Markdown("# Model Blind Comparison") | |
with gr.Group(): | |
with gr.Row(): | |
chat_interfaces = self._create_chat_interfaces() | |
with gr.Row(): | |
ratings = self._create_rating_sliders() | |
with gr.Row(equal_height=True): | |
submit_button = gr.Button(value="⭐ Submit Ratings", interactive=False) | |
submission_status = gr.Textbox(label="Submission Status", interactive=False) | |
with gr.Row(equal_height=True): | |
input_elements = self._create_input_elements() | |
self._setup_event_handlers(demo, chat_interfaces, ratings, submit_button, submission_status, input_elements) | |
return demo | |
def _create_chat_interfaces(self): | |
interfaces = {} | |
for model in ['A', 'B', 'C']: | |
interfaces[model] = gr.Chatbot( | |
getattr(self, f'chat_history_{model}'), | |
type="messages", | |
label=f"Model {model}", | |
height=650, | |
show_copy_button=True | |
) | |
return interfaces | |
def _create_rating_sliders(self): | |
return { | |
str(i): gr.Slider(1, 5, step=1, label=f"Rate Response {chr(64+i)}", | |
interactive=True, value=3) | |
for i in range(1, 4) | |
} | |
def _create_input_elements(self): | |
return { | |
'book': gr.Dropdown(choices=list(BOOK_MAPPING.keys()), | |
label="Select a Book", interactive=True, scale=1), | |
'category': gr.Dropdown(choices=list(CATEGORY_MAPPING.keys()), | |
label="Select a Question Category", interactive=True, scale=1), | |
'question': gr.Textbox(label="Type a Question", max_lines=1, | |
placeholder="e.g. What is the relationship between Harry Potter and Sirius Black?", | |
interactive=True, scale=2), | |
'send': gr.Button("Send", scale=0, variant="primary", interactive=False) | |
} | |
def respond(self, message): | |
if not message.strip(): | |
raise gr.Error("Message cannot be empty!") | |
prompt = f"{self.selected_book_string} {message}" | |
mapping_dict = self.model_manager.book_model_assignments[self.selected_book] | |
model_order = ["rephrase", "rephrase_summarize", "entigraph"] | |
model_to_index = {model: i for i, model in enumerate(model_order)} | |
responses = self.loop.run_until_complete( | |
self.model_manager.get_all_model_responses_async(prompt) | |
) | |
chats = [] | |
for response in responses: | |
chat = [] | |
chat.append({"role": "user", "content": prompt}) | |
chat.append({"role": "assistant", "content": response}) | |
chats.append(chat) | |
reordered_chats = [chats[model_to_index[mapping_dict[model]]] for model in ["Model A", "Model B", "Model C"]] | |
return reordered_chats | |
def get_votes(self, book, category, question, rating_A, rating_B, rating_C): | |
model_mapping = self.model_manager.get_book_model_mapping(book) | |
rating_data = { | |
"book": book, | |
"category": category, | |
"prompt": question, | |
"rephrase_rating": rating_A if model_mapping["Model A"] == "rephrase" else | |
rating_B if model_mapping["Model B"] == "rephrase" else rating_C, | |
"rephrase_summarize_rating": rating_A if model_mapping["Model A"] == "rephrase_summarize" else | |
rating_B if model_mapping["Model B"] == "rephrase_summarize" else rating_C, | |
"entigraph_rating": rating_A if model_mapping["Model A"] == "entigraph" else | |
rating_B if model_mapping["Model B"] == "entigraph" else rating_C, | |
"rephrase_response": model_responses["Model_A"], | |
"rephrase_summarize_response": model_responses["Model_B"], | |
"entigraph_response": model_responses["Model_C"] | |
} | |
self.votes.append(rating_data) | |
self.sheet_manager.append_rating(rating_data) | |
return ("Ratings submitted successfully!", gr.update(interactive=False)) | |
def _setup_event_handlers(self, demo, chat_interfaces, ratings, submit_button, submission_status, input_elements): | |
def enable_send_btn(book, category, question): | |
return gr.update(interactive=bool(book and category and question)) | |
def enable_button_group(model_A, model_B, model_C): | |
return gr.update(interactive=bool(model_A and model_B and model_C)) | |
def update_selected_book(book_selection): | |
self.selected_book = book_selection | |
self.selected_book_string = BOOK_MAPPING.get(book_selection, "") | |
return self.selected_book_string | |
for input_name in ['book', 'category', 'question']: | |
input_elements[input_name].change( | |
enable_send_btn, | |
inputs=[input_elements['book'], input_elements['category'], input_elements['question']], | |
outputs=[input_elements['send']] | |
) | |
input_elements['book'].change( | |
update_selected_book, | |
inputs=[input_elements['book']], | |
outputs=[self.state] | |
) | |
submit_button.click( | |
self.get_votes, | |
inputs=[input_elements['book'], input_elements['category'], input_elements['question'], | |
ratings['1'], ratings['2'], ratings['3']], | |
outputs=[submission_status, submit_button] | |
) | |
input_elements['send'].click( | |
self.respond, | |
inputs=[input_elements['question']], | |
outputs=list(chat_interfaces.values()) | |
) | |
for interface in chat_interfaces.values(): | |
interface.change( | |
enable_button_group, | |
inputs=list(chat_interfaces.values()), | |
outputs=[submit_button] | |
) | |
if __name__ == "__main__": | |
PRIVATE_KEY = os.environ.get('PRIVATE_KEY').replace('\\n', '\n') | |
PRIVATE_KEY_ID = os.environ.get('PRIVATE_KEY_ID').replace('\\n', '\n') | |
variables_keys = { | |
"type": "service_account", | |
"project_id": "summer-presence-450117-r7", | |
"private_key_id": PRIVATE_KEY_ID, | |
"private_key": PRIVATE_KEY, | |
"client_email": "model-blind-comparison@summer-presence-450117-r7.iam.gserviceaccount.com", | |
"client_id": "117681363507032419648", | |
"auth_uri": "https://accounts.google.com/o/oauth2/auth", | |
"token_uri": "https://oauth2.googleapis.com/token", | |
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", | |
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/model-blind-comparison%40summer-presence-450117-r7.iam.gserviceaccount.com", | |
"universe_domain": "googleapis.com" | |
} | |
app = ModelComparisonApp() | |
demo = app.create_interface() | |
demo.launch() |