import random import json import os import httpx from gradio_client import Client from urllib.parse import urljoin import gspread from oauth2client.service_account import ServiceAccountCredentials import gradio as gr from huggingface_hub import login import asyncio from concurrent.futures import ThreadPoolExecutor BOOK_MAPPING = { "Spaceman on a Spree": "In the context of \"Spaceman on a Spree\", written by Mack Reynolds in 1961,", "Charity Case": "In the context of \"Charity Case\", written by Jim Harmon in 1972,", "A Gift from Earth": "In the context of \"A Gift From Earth\", written by Manly Banister in 1950,", "Pick a Crime": "In the context of \"Pick a Crime\", written by Richard Rein Smith in 1970,", "Dangerous Quarry": "In the context of \"Dangerous Quarry\", written by Jim Harmon in 1972,", "Lost in Translation": "In the context of \"Lost in Translation\", written by Larry M. Harris in 1972," } CATEGORY_MAPPING = { "Identity": "identity", "Motivation": "motivation", "Relationship": "relationship", "Event": "event", } MODEL_VARIANTS = ["rephrase", "rephrase_summarize", "entigraph"] model_responses = {"Model_A": "", "Model_B": "", "Model_C": ""} class ModelManager: def __init__(self, books): self.book_model_assignments = {} self.assign_models_to_books(books) self.model_A = Client("mep296/llama-3-8b-rephrase-quality") self.model_B = Client("mep296/llama-3-8b-rephrase-summarize-quality") self.model_C = Client("mep296/llama-3-8b-entigraph-quality") self.template_text = self._load_template() def _load_template(self): with open("prompt_template.txt", "r", encoding="utf-8") as file: return file.read() async def get_model_response_async(self, model_name, client, prompt): try: formatted_prompt = self.template_text.format(prompt) loop = asyncio.get_running_loop() with ThreadPoolExecutor() as executor: response = await loop.run_in_executor( executor, client.predict, formatted_prompt, "/predict" ) model_responses[model_name] = response return response except (httpx.ReadTimeout, httpx.ConnectTimeout) as e: print(f"Timeout while getting response from {model_name}: {str(e)}") return f"⚠️ Model {model_name} is waking up... This may take a few minutes. Please wait and try again shortly. ⏳" except Exception as e: print(f"Error getting response from {model_name}: {str(e)}") return f"Error: Could not get response from {model_name}. Please try again in a few minutes." async def get_all_model_responses_async(self, prompt): tasks = [ self.get_model_response_async("Model_A", self.model_A, prompt), self.get_model_response_async("Model_B", self.model_B, prompt), self.get_model_response_async("Model_C", self.model_C, prompt) ] responses = await asyncio.gather(*tasks) return responses def get_book_model_mapping(self, book): if book not in self.book_model_assignments: shuffled_models = random.sample(MODEL_VARIANTS, len(MODEL_VARIANTS)) self.book_model_assignments[book] = { "Model A": shuffled_models[0], "Model B": shuffled_models[1], "Model C": shuffled_models[2] } return self.book_model_assignments[book] def assign_models_to_books(self, books): for book in books: self.get_book_model_mapping(book) class SheetManager: def __init__(self): scope = ["https://spreadsheets.google.com/feeds", "https://www.googleapis.com/auth/drive"] creds = ServiceAccountCredentials.from_json_keyfile_dict(variables_keys, scope) client = gspread.authorize(creds) self.sheet = client.open("Model Blind Comparison Ratings").sheet1 def append_rating(self, rating_data): self.sheet.append_row([ rating_data["book"], rating_data["category"], rating_data["prompt"], rating_data["rephrase_rating"], rating_data["rephrase_summarize_rating"], rating_data["entigraph_rating"], rating_data["rephrase_response"], rating_data["rephrase_summarize_response"], rating_data["entigraph_response"], ]) class ModelComparisonApp: def __init__(self): self.model_manager = ModelManager(BOOK_MAPPING.keys()) self.sheet_manager = SheetManager() self.votes = [] self.selected_book = "Spaceman on a Spree" self.selected_book_string = BOOK_MAPPING["Spaceman on a Spree"] self.selected_category_string = "" self.chat_history_A = [] self.chat_history_B = [] self.chat_history_C = [] self.state = gr.State(value="") self.loop = asyncio.get_event_loop() def create_interface(self): text_size = gr.themes.sizes.text_lg with gr.Blocks(theme=gr.themes.Default(text_size=text_size), fill_width=True) as demo: gr.Markdown("# Model Blind Comparison") with gr.Group(): with gr.Row(): chat_interfaces = self._create_chat_interfaces() with gr.Row(): ratings = self._create_rating_sliders() with gr.Row(equal_height=True): submit_button = gr.Button(value="⭐ Submit Ratings", interactive=False) submission_status = gr.Textbox(label="Submission Status", interactive=False) with gr.Row(equal_height=True): input_elements = self._create_input_elements() self._setup_event_handlers(demo, chat_interfaces, ratings, submit_button, submission_status, input_elements) return demo def _create_chat_interfaces(self): interfaces = {} for model in ['A', 'B', 'C']: interfaces[model] = gr.Chatbot( getattr(self, f'chat_history_{model}'), type="messages", label=f"Model {model}", height=650, show_copy_button=True ) return interfaces def _create_rating_sliders(self): return { str(i): gr.Slider(1, 5, step=1, label=f"Rate Response {chr(64+i)}", interactive=True, value=3) for i in range(1, 4) } def _create_input_elements(self): return { 'book': gr.Dropdown(choices=list(BOOK_MAPPING.keys()), label="Select a Book", interactive=True, scale=1), 'category': gr.Dropdown(choices=list(CATEGORY_MAPPING.keys()), label="Select a Question Category", interactive=True, scale=1), 'question': gr.Textbox(label="Type a Question", max_lines=1, placeholder="e.g. What is the relationship between Harry Potter and Sirius Black?", interactive=True, scale=2), 'send': gr.Button("Send", scale=0, variant="primary", interactive=False) } def respond(self, message): if not message.strip(): raise gr.Error("Message cannot be empty!") prompt = f"{self.selected_book_string} {message}" mapping_dict = self.model_manager.book_model_assignments[self.selected_book] model_order = ["rephrase", "rephrase_summarize", "entigraph"] model_to_index = {model: i for i, model in enumerate(model_order)} responses = self.loop.run_until_complete( self.model_manager.get_all_model_responses_async(prompt) ) chats = [] for response in responses: chat = [] chat.append({"role": "user", "content": prompt}) chat.append({"role": "assistant", "content": response}) chats.append(chat) reordered_chats = [chats[model_to_index[mapping_dict[model]]] for model in ["Model A", "Model B", "Model C"]] return reordered_chats def get_votes(self, book, category, question, rating_A, rating_B, rating_C): model_mapping = self.model_manager.get_book_model_mapping(book) rating_data = { "book": book, "category": category, "prompt": question, "rephrase_rating": rating_A if model_mapping["Model A"] == "rephrase" else rating_B if model_mapping["Model B"] == "rephrase" else rating_C, "rephrase_summarize_rating": rating_A if model_mapping["Model A"] == "rephrase_summarize" else rating_B if model_mapping["Model B"] == "rephrase_summarize" else rating_C, "entigraph_rating": rating_A if model_mapping["Model A"] == "entigraph" else rating_B if model_mapping["Model B"] == "entigraph" else rating_C, "rephrase_response": model_responses["Model_A"], "rephrase_summarize_response": model_responses["Model_B"], "entigraph_response": model_responses["Model_C"] } self.votes.append(rating_data) self.sheet_manager.append_rating(rating_data) return ("Ratings submitted successfully!", gr.update(interactive=False)) def _setup_event_handlers(self, demo, chat_interfaces, ratings, submit_button, submission_status, input_elements): def enable_send_btn(book, category, question): return gr.update(interactive=bool(book and category and question)) def enable_button_group(model_A, model_B, model_C): return gr.update(interactive=bool(model_A and model_B and model_C)) def update_selected_book(book_selection): self.selected_book = book_selection self.selected_book_string = BOOK_MAPPING.get(book_selection, "") return self.selected_book_string for input_name in ['book', 'category', 'question']: input_elements[input_name].change( enable_send_btn, inputs=[input_elements['book'], input_elements['category'], input_elements['question']], outputs=[input_elements['send']] ) input_elements['book'].change( update_selected_book, inputs=[input_elements['book']], outputs=[self.state] ) submit_button.click( self.get_votes, inputs=[input_elements['book'], input_elements['category'], input_elements['question'], ratings['1'], ratings['2'], ratings['3']], outputs=[submission_status, submit_button] ) input_elements['send'].click( self.respond, inputs=[input_elements['question']], outputs=list(chat_interfaces.values()) ) for interface in chat_interfaces.values(): interface.change( enable_button_group, inputs=list(chat_interfaces.values()), outputs=[submit_button] ) if __name__ == "__main__": PRIVATE_KEY = os.environ.get('PRIVATE_KEY').replace('\\n', '\n') PRIVATE_KEY_ID = os.environ.get('PRIVATE_KEY_ID').replace('\\n', '\n') variables_keys = { "type": "service_account", "project_id": "summer-presence-450117-r7", "private_key_id": PRIVATE_KEY_ID, "private_key": PRIVATE_KEY, "client_email": "model-blind-comparison@summer-presence-450117-r7.iam.gserviceaccount.com", "client_id": "117681363507032419648", "auth_uri": "https://accounts.google.com/o/oauth2/auth", "token_uri": "https://oauth2.googleapis.com/token", "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/model-blind-comparison%40summer-presence-450117-r7.iam.gserviceaccount.com", "universe_domain": "googleapis.com" } app = ModelComparisonApp() demo = app.create_interface() demo.launch()