|
import gradio as gr |
|
import os |
|
import openai |
|
from dataclasses import dataclass |
|
from supabase import create_client, Client |
|
from uuid import UUID |
|
from dotenv import load_dotenv |
|
import random |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
SUPABASE_URL = os.getenv("SUPABASE_URL") |
|
SUPABASE_KEY = os.getenv("SUPABASE_KEY") |
|
supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY) |
|
|
|
SHOW_CONFIG = True |
|
|
|
@dataclass |
|
class Args: |
|
frequency_penalty: float = 0 |
|
max_tokens: int = 32 |
|
n: int = 1 |
|
presence_penalty: float = 0 |
|
seed: int = 42 |
|
stop: str = None |
|
stream: bool = False |
|
temperature: float = 0.8 |
|
top_p: float = 0.95 |
|
|
|
def get_completion(client, config, messages): |
|
print("GETTING COMPLETION") |
|
completion_args = { |
|
"model": config['model'], |
|
"messages": messages, |
|
"frequency_penalty": config.get('frequency_penalty', 0), |
|
"max_tokens": config.get('max_length', 32), |
|
"n": config.get('n', 1), |
|
"presence_penalty": config.get('presence_penalty', 0), |
|
"seed": config.get('seed', 42), |
|
"stop": config.get('stop', None), |
|
"stream": config.get('stream', False), |
|
"temperature": config.get('temperature', 0.8), |
|
"top_p": config.get('top_p', 0.95), |
|
} |
|
|
|
try: |
|
print("TRYING TO GET COMPLETION") |
|
response = client.chat.completions.create(**completion_args) |
|
print("GOT COMPLETION") |
|
return response |
|
except Exception as e: |
|
print(f"Error during API call: {e}") |
|
return None |
|
|
|
def get_two_random_configs(round_num: int): |
|
print("GETTING TWO RANDOM CONFIGS") |
|
|
|
response = supabase.table("configs")\ |
|
.select("*")\ |
|
.eq("round", round_num)\ |
|
.execute() |
|
|
|
if not response.data or len(response.data) < 2: |
|
return None, None |
|
|
|
|
|
selected_configs = random.sample(response.data, 2) |
|
return selected_configs[0], selected_configs[1] |
|
|
|
def initialize_session(state): |
|
print("INITIALIZING SESSION") |
|
current_round = get_current_round() |
|
if not current_round: |
|
state.value["error"] = "Error: No active round found." |
|
return |
|
|
|
config_a, config_b = get_two_random_configs(round_num=current_round) |
|
if not config_a or not config_b: |
|
state.value["error"] = "Error: Not enough configurations available for voting." |
|
return |
|
|
|
state.value['config_a'] = config_a |
|
state.value['config_b'] = config_b |
|
state.value['conversation_a'] = [] |
|
state.value['conversation_b'] = [] |
|
state.value['round'] = current_round |
|
|
|
def chat_response_a(message, history): |
|
print("CHAT RESPONSE A") |
|
return chat_response(message, history, 'a') |
|
|
|
def chat_response_b(message, history): |
|
print("CHAT RESPONSE B") |
|
return chat_response(message, history, 'b') |
|
|
|
def chat_response(message, history, config_type): |
|
|
|
current_state = demo.blocks['state'].value |
|
print("CHAT RESPONSE") |
|
config_a = current_state.get('config_a') |
|
config_b = current_state.get('config_b') |
|
|
|
|
|
if not config_a or not config_b: |
|
initialize_session(demo.blocks['state']) |
|
config_a = current_state.get('config_a') |
|
config_b = current_state.get('config_b') |
|
if not config_a or not config_b: |
|
return "Error: Configurations not initialized sufficiently." |
|
|
|
|
|
openai_api_key = "super-secret-token" |
|
|
|
os.environ['OPENAI_API_KEY'] = openai_api_key |
|
|
|
openai.api_key = openai_api_key |
|
openai.api_base = "https://turingtest--example-vllm-openai-compatible-serve.modal.run/v1" |
|
client = openai.OpenAI(api_key=openai_api_key, base_url=openai.api_base) |
|
|
|
|
|
if config_type == 'a': |
|
system_message = {"role": "system", "content": f"{config_a['sys_prompt']}"} |
|
messages = [system_message] |
|
for user_msg, assistant_msg in current_state['conversation_a']: |
|
if user_msg: |
|
messages.append({"role": "user", "content": user_msg}) |
|
if assistant_msg: |
|
messages.append({"role": "assistant", "content": assistant_msg}) |
|
else: |
|
system_message = {"role": "system", "content": f"{config_b['sys_prompt']}"} |
|
messages = [system_message] |
|
for user_msg, assistant_msg in current_state['conversation_b']: |
|
if user_msg: |
|
messages.append({"role": "user", "content": user_msg}) |
|
if assistant_msg: |
|
messages.append({"role": "assistant", "content": assistant_msg}) |
|
|
|
messages.append({"role": "user", "content": message}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if config_type == 'a': |
|
response = get_completion(client, config_a, messages) |
|
else: |
|
response = get_completion(client, config_b, messages) |
|
|
|
assistant_reply = ( |
|
response.choices[0].message.content if response and response.choices else |
|
"Error: Please retry or contact support if retried more than twice." |
|
) |
|
|
|
|
|
if config_type == 'a': |
|
current_state['conversation_a'].append((message, assistant_reply)) |
|
else: |
|
current_state['conversation_b'].append((message, assistant_reply)) |
|
|
|
|
|
|
|
demo.blocks['state'].value = current_state |
|
|
|
return assistant_reply |
|
|
|
def create_chat_interface(model_label): |
|
print("CREATE CHAT INTERFACE") |
|
if model_label == 'a': |
|
return gr.ChatInterface( |
|
fn=lambda message, history: (chat_response_a(message, history)), |
|
chatbot=gr.Chatbot(height=400, label=f"Choice {model_label}"), |
|
textbox=gr.Textbox(placeholder="Message", container=False, scale=7), |
|
description="", |
|
theme="dark", |
|
retry_btn=None, |
|
undo_btn=None, |
|
clear_btn=None, |
|
) |
|
else: |
|
return gr.ChatInterface( |
|
fn=lambda message, history: (chat_response_b(message, history)), |
|
chatbot=gr.Chatbot(height=400, label=f"Choice {model_label}"), |
|
textbox=gr.Textbox(placeholder="Message", container=False, scale=7), |
|
description="", |
|
theme="dark", |
|
retry_btn=None, |
|
undo_btn=None, |
|
clear_btn=None, |
|
) |
|
|
|
def submit_vote(vote: str, state): |
|
print("SUBMIT VOTE") |
|
|
|
a_config_id = state.value['config_a']['id'] |
|
b_config_id = state.value['config_b']['id'] |
|
conversation_a = state.value.get('conversation_a', []) |
|
conversation_b = state.value.get('conversation_b', []) |
|
|
|
|
|
supabase.table("conversations").insert([ |
|
{ |
|
"user_id": None, |
|
"configuration_id": a_config_id, |
|
"messages": conversation_a |
|
}, |
|
{ |
|
"user_id": None, |
|
"configuration_id": b_config_id, |
|
"messages": conversation_b |
|
} |
|
]).execute() |
|
|
|
|
|
supabase.table("votes").insert({ |
|
"a_config_id": str(a_config_id), |
|
"b_config_id": str(b_config_id), |
|
"voted_by_uid": None, |
|
"round": get_current_round(), |
|
"is_tie": vote == "tie", |
|
"a_wins": vote == "a", |
|
"created_at": "now()" |
|
}).execute() |
|
|
|
|
|
|
|
|
|
|
|
state.value['conversation_a'] = [] |
|
state.value['conversation_b'] = [] |
|
|
|
return "Vote submitted!" |
|
|
|
def update_elo(a_config_id: UUID, b_config_id: UUID, vote: str): |
|
print("UPDATE ELO") |
|
a_elo_response = supabase.table("elos").select("rating").eq("user_id", a_config_id).single().execute() |
|
b_elo_response = supabase.table("elos").select("rating").eq("user_id", b_config_id).single().execute() |
|
|
|
if not a_elo_response.data or not b_elo_response.data: |
|
return |
|
|
|
a_elo = a_elo_response.data["rating"] |
|
b_elo = b_elo_response.data["rating"] |
|
|
|
if vote == "a": |
|
a_new = a_elo + 10 |
|
b_new = b_elo - 10 |
|
elif vote == "b": |
|
a_new = a_elo - 10 |
|
b_new = b_elo + 10 |
|
else: |
|
|
|
a_new = a_elo |
|
b_new = b_elo |
|
|
|
supabase.table("elos").update({"rating": a_new}).eq("user_id", a_config_id).execute() |
|
supabase.table("elos").update({"rating": b_new}).eq("user_id", b_config_id).execute() |
|
|
|
def get_current_round(): |
|
print("GET CURRENT ROUND") |
|
response = supabase.table("round_status").select("round").eq("is_eval_active", True).single().execute() |
|
if response.data: |
|
return response.data["round"] |
|
return None |
|
|
|
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", neutral_hue="slate"), head= |
|
""" |
|
<style> |
|
body { |
|
font-family: 'Calibri', sans-serif; /* Choose your desired font */ |
|
} |
|
</style> |
|
""") as demo: |
|
gr.Markdown("## Turing Test Prompt Comp") |
|
|
|
|
|
state = gr.State({ |
|
"config_a": None, |
|
"config_b": None, |
|
"conversation_a": [], |
|
"conversation_b": [], |
|
"round": 1, |
|
"error": None |
|
}) |
|
demo.blocks['state'] = state |
|
|
|
initialize_session(state) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
chat_a = create_chat_interface('a') |
|
with gr.Column(): |
|
chat_b = create_chat_interface('b') |
|
|
|
with gr.Row(): |
|
a_better = gr.Button("A is better π", scale=1) |
|
tie = gr.Button("π€ Tie", scale=1) |
|
b_better = gr.Button("π B is better", scale=1) |
|
|
|
|
|
output_message = gr.Textbox(label="Status", interactive=False) |
|
|
|
|
|
def submit_vote_a(): |
|
return submit_vote('a', state) |
|
|
|
def submit_vote_b(): |
|
return submit_vote('b', state) |
|
|
|
def submit_vote_tie(): |
|
return submit_vote('tie', state) |
|
|
|
|
|
a_better.click( |
|
submit_vote_a, |
|
inputs=None, |
|
outputs=output_message |
|
) |
|
b_better.click( |
|
submit_vote_b, |
|
inputs=None, |
|
outputs=output_message |
|
) |
|
tie.click( |
|
submit_vote_tie, |
|
inputs=None, |
|
outputs=output_message |
|
) |
|
|
|
prompt_input = gr.Textbox(placeholder="Message for both...", container=False) |
|
send_btn = gr.Button("Send to Both", variant="primary") |
|
|
|
def send_prompt(prompt): |
|
current_state = state.value |
|
|
|
if prompt: |
|
current_state['conversation_a'].append((prompt, None)) |
|
current_state['conversation_b'].append((prompt, None)) |
|
state.update(current_state) |
|
return "", "" |
|
|
|
send_btn.click( |
|
send_prompt, |
|
inputs=prompt_input, |
|
outputs=[ |
|
prompt_input, |
|
prompt_input |
|
] |
|
) |
|
prompt_input.submit( |
|
send_prompt, |
|
inputs=prompt_input, |
|
outputs=[ |
|
prompt_input, |
|
prompt_input |
|
] |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=True) |