Spaces:
Sleeping
Sleeping
import os | |
import streamlit as st | |
import dotenv | |
import openai | |
from openai import OpenAI | |
import anthropic | |
from together import Together | |
import google.generativeai as genai | |
import time | |
from typing import List, Optional, Literal, Union, Dict | |
from constants import ( | |
LLM_COUNCIL_MEMBERS, | |
PROVIDER_TO_AVATAR_MAP, | |
AGGREGATORS, | |
LLM_TO_UI_NAME_MAP, | |
) | |
from prompts import * | |
from judging_dataclasses import ( | |
DirectAssessmentJudgingResponse, | |
DirectAssessmentCriterionScore, | |
DirectAssessmentCriteriaScores, | |
) | |
import pandas as pd | |
import seaborn as sns | |
import matplotlib.pyplot as plt | |
import numpy as np | |
dotenv.load_dotenv() | |
PASSWORD = os.getenv("APP_PASSWORD") | |
# Load API keys from environment variables | |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY") | |
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") | |
TOGETHER_API_KEY = os.getenv("TOGETHER_API_KEY") | |
# Initialize API clients | |
together_client = Together(api_key=TOGETHER_API_KEY) | |
genai.configure(api_key=GOOGLE_API_KEY) | |
# Set up API clients for OpenAI and Anthropic | |
openai.api_key = OPENAI_API_KEY | |
openai_client = OpenAI( | |
organization="org-kUoRSK0nOw4W2nQYMVGWOt03", | |
project="proj_zb6k1DdgnSEbiAEMWxSOVVu4", | |
) | |
# anthropic_client = anthropic.Client(api_key=ANTHROPIC_API_KEY) | |
anthropic_client = anthropic.Anthropic() | |
client = OpenAI() | |
def anthropic_streamlit_streamer(stream): | |
""" | |
Process the Anthropic streaming response and yield content from the deltas. | |
:param stream: Streaming object from Anthropic API | |
:return: Yields content (text) from the streaming response. | |
""" | |
for event in stream: | |
if hasattr(event, "type"): | |
# Handle content blocks | |
if event.type == "content_block_delta" and hasattr(event, "delta"): | |
# Extract text delta from the event | |
text_delta = getattr(event.delta, "text", None) | |
if text_delta: | |
yield text_delta | |
# Handle message completion events (optional if needed) | |
elif event.type == "message_stop": | |
break # End of message, stop streaming | |
def get_ui_friendly_name(llm): | |
if "agg__" in llm: | |
return ( | |
"MoA (" | |
+ LLM_TO_UI_NAME_MAP.get(llm.split("__")[1], llm.split("__")[1]) | |
+ ")" | |
) | |
return LLM_TO_UI_NAME_MAP.get(llm, llm) | |
def google_streamlit_streamer(stream): | |
for chunk in stream: | |
yield chunk.text | |
def together_streamlit_streamer(stream): | |
for chunk in stream: | |
yield chunk.choices[0].delta.content | |
def llm_streamlit_streamer(stream, llm): | |
if llm.startswith("anthropic"): | |
return anthropic_streamlit_streamer(stream) | |
elif llm.startswith("vertex"): | |
return google_streamlit_streamer(stream) | |
elif llm.startswith("together"): | |
return together_streamlit_streamer(stream) | |
# Helper functions for LLM council and aggregator selection | |
def llm_council_selector(): | |
selected_council = st.radio( | |
"Choose a council configuration", options=list(LLM_COUNCIL_MEMBERS.keys()) | |
) | |
return LLM_COUNCIL_MEMBERS[selected_council] | |
def aggregator_selector(): | |
return st.radio("Choose an aggregator LLM", options=AGGREGATORS) | |
# API calls for different providers | |
def get_openai_response(model_name, prompt): | |
return openai_client.chat.completions.create( | |
model=model_name, | |
messages=[{"role": "user", "content": prompt}], | |
stream=True, | |
) | |
# https://docs.anthropic.com/en/api/messages-streaming | |
def get_anthropic_response(model_name, prompt): | |
return anthropic_client.messages.create( | |
max_tokens=1024, | |
messages=[{"role": "user", "content": prompt}], | |
model=model_name, | |
stream=True, | |
) | |
def get_together_response(model_name, prompt): | |
return together_client.chat.completions.create( | |
model=model_name, | |
messages=[{"role": "user", "content": prompt}], | |
stream=True, | |
) | |
# https://ai.google.dev/gemini-api/docs/text-generation?lang=python | |
def get_google_response(model_name, prompt): | |
model = genai.GenerativeModel(model_name) | |
return model.generate_content(prompt, stream=True) | |
def get_llm_response_stream(model_identifier, prompt): | |
"""Returns a streamlit-friendly stream of response tokens from the LLM.""" | |
provider, model_name = model_identifier.split("://") | |
if provider == "openai": | |
return get_openai_response(model_name, prompt) | |
elif provider == "anthropic": | |
return anthropic_streamlit_streamer(get_anthropic_response(model_name, prompt)) | |
elif provider == "together": | |
return together_streamlit_streamer(get_together_response(model_name, prompt)) | |
elif provider == "vertex": | |
return google_streamlit_streamer(get_google_response(model_name, prompt)) | |
else: | |
return None | |
def create_dataframe_for_direct_assessment_judging_response( | |
response: DirectAssessmentJudgingResponse, | |
): | |
# Initialize empty list to collect data | |
data = [] | |
# Loop through models | |
for judging_model in response.judging_models: | |
model_name = judging_model.model | |
# Loop through criteria_scores | |
for criteria_score in judging_model.criteria_scores: | |
data.append( | |
{ | |
"llm_judge_model": model_name, | |
"criteria": criteria_score.criterion, | |
"score": criteria_score.score, | |
"explanation": criteria_score.explanation, | |
} | |
) | |
# Create DataFrame | |
return pd.DataFrame(data) | |
# Streamlit form UI | |
def render_criteria_form(criteria_num): | |
"""Render a criteria input form.""" | |
with st.expander(f"Criteria {criteria_num + 1}"): | |
name = st.text_input(f"Name for Criteria {criteria_num + 1}") | |
description = st.text_area(f"Description for Criteria {criteria_num + 1}") | |
min_score = st.number_input( | |
f"Min Score for Criteria {criteria_num + 1}", min_value=0, step=1 | |
) | |
max_score = st.number_input( | |
f"Max Score for Criteria {criteria_num + 1}", min_value=0, step=1 | |
) | |
return Criteria( | |
name=name, description=description, min_score=min_score, max_score=max_score | |
) | |
def format_likert_comparison_options(options): | |
return "\n".join([f"{i + 1}: {option}" for i, option in enumerate(options)]) | |
def format_criteria_list(criteria_list): | |
return "\n".join( | |
[f"{criteria.name}: {criteria.description}" for criteria in criteria_list] | |
) | |
def get_direct_assessment_prompt( | |
direct_assessment_prompt, user_prompt, response, criteria_list, options | |
): | |
return direct_assessment_prompt.format( | |
user_prompt=user_prompt, | |
response=response, | |
criteria_list=f"{format_criteria_list(DEFAULT_DIRECT_ASSESSMENT_CRITERIA_LIST)}", | |
options=f"{format_likert_comparison_options(SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS)}", | |
) | |
def get_default_direct_assessment_prompt(user_prompt): | |
return get_direct_assessment_prompt( | |
direct_assessment_prompt=DEFAULT_DIRECT_ASSESSMENT_PROMPT, | |
user_prompt=user_prompt, | |
response="{response}", | |
criteria_list=DEFAULT_DIRECT_ASSESSMENT_CRITERIA_LIST, | |
options=SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS, | |
) | |
def get_aggregator_prompt(aggregator_prompt, user_prompt, llms): | |
responses_from_other_llms = "\n\n".join( | |
[ | |
f"{get_ui_friendly_name(model)} START\n{st.session_state['responses'][model]}\n\n{get_ui_friendly_name(model)} END\n\n\n" | |
for model in llms | |
] | |
) | |
return aggregator_prompt.format( | |
user_prompt=user_prompt, | |
responses_from_other_llms=responses_from_other_llms, | |
) | |
def get_default_aggregator_prompt(user_prompt, llms): | |
return get_aggregator_prompt( | |
DEFAULT_AGGREGATOR_PROMPT, | |
user_prompt=user_prompt, | |
llms=llms, | |
) | |
def get_parse_judging_response_for_direct_assessment_prompt( | |
judging_responses: dict[str, str], | |
criteria_list, | |
options, | |
): | |
formatted_judging_responses = "\n\n".join( | |
[ | |
f"{get_ui_friendly_name(model)} START\n{judging_responses[model]}\n\n{get_ui_friendly_name(model)} END\n\n\n" | |
for model in judging_responses.keys() | |
] | |
) | |
return PARSE_JUDGING_RESPONSE_FOR_DIRECT_ASSESSMENT_PROMPT.format( | |
judging_responses=formatted_judging_responses, | |
criteria_list=format_criteria_list(criteria_list), | |
options=format_likert_comparison_options(options), | |
) | |
DEBUG_MODE = True | |
def parse_judging_responses( | |
prompt: str, judging_responses: dict[str, str] | |
) -> DirectAssessmentJudgingResponse: | |
if DEBUG_MODE: | |
return DirectAssessmentJudgingResponse( | |
judging_models=[ | |
DirectAssessmentCriteriaScores( | |
model="together://meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", | |
criteria_scores=[ | |
DirectAssessmentCriterionScore( | |
criterion="helpfulness", score=3, explanation="explanation1" | |
), | |
DirectAssessmentCriterionScore( | |
criterion="conciseness", score=4, explanation="explanation2" | |
), | |
DirectAssessmentCriterionScore( | |
criterion="relevance", score=5, explanation="explanation3" | |
), | |
], | |
), | |
DirectAssessmentCriteriaScores( | |
model="together://meta-llama/Llama-3.2-3B-Instruct-Turbo", | |
criteria_scores=[ | |
DirectAssessmentCriterionScore( | |
criterion="helpfulness", score=1, explanation="explanation1" | |
), | |
DirectAssessmentCriterionScore( | |
criterion="conciseness", score=2, explanation="explanation2" | |
), | |
DirectAssessmentCriterionScore( | |
criterion="relevance", score=3, explanation="explanation3" | |
), | |
], | |
), | |
] | |
) | |
else: | |
completion = client.beta.chat.completions.parse( | |
model="gpt-4o-mini", | |
messages=[ | |
{ | |
"role": "system", | |
"content": "Parse the judging responses into structured data.", | |
}, | |
{"role": "user", "content": prompt}, | |
], | |
response_format=DirectAssessmentJudgingResponse, | |
) | |
return completion.choices[0].message.parsed | |
def plot_criteria_scores(df): | |
# Group by criteria and calculate mean and std over all judges. | |
grouped = df.groupby(["criteria"]).agg({"score": ["mean", "std"]}).reset_index() | |
# Flatten the MultiIndex columns | |
grouped.columns = ["criteria", "mean_score", "std_score"] | |
# Fill NaN std with zeros (in case there's only one score per group) | |
grouped["std_score"] = grouped["std_score"].fillna(0) | |
# Set up the plot | |
plt.figure(figsize=(8, 5)) | |
# Create a horizontal bar plot | |
ax = sns.barplot( | |
data=grouped, | |
x="mean_score", | |
y="criteria", | |
hue="criteria", | |
errorbar=None, # Updated parameter | |
orient="h", | |
) | |
# Add error bars manually | |
# Iterate over the bars and add error bars | |
for i, (mean, std) in enumerate(zip(grouped["mean_score"], grouped["std_score"])): | |
# Get the current bar | |
bar = ax.patches[i] | |
# Calculate the center of the bar | |
center = bar.get_y() + bar.get_height() / 2 | |
# Add the error bar | |
ax.errorbar(x=mean, y=center, xerr=std, ecolor="black", capsize=3, fmt="none") | |
# Set labels and title | |
ax.set_xlabel("") | |
ax.set_ylabel("") | |
plt.tight_layout() | |
# Display the plot in Streamlit | |
st.pyplot(plt.gcf()) | |
def plot_overall_scores(overall_scores_df): | |
# Calculate mean and standard deviation | |
summary = ( | |
overall_scores_df.groupby("response_model") | |
.agg({"score": ["mean", "std"]}) | |
.reset_index() | |
) | |
summary.columns = ["response_model", "mean_score", "std_score"] | |
# Add UI-friendly names | |
summary["ui_friendly_name"] = summary["response_model"].apply(get_ui_friendly_name) | |
# Sort the summary dataframe by mean_score in descending order | |
summary = summary.sort_values("mean_score", ascending=False) | |
# Create the plot | |
plt.figure(figsize=(8, 5)) | |
# Plot bars with rainbow colors | |
ax = sns.barplot( | |
x="ui_friendly_name", | |
y="mean_score", | |
data=summary, | |
palette="prism", | |
capsize=0.1, | |
) | |
# Add error bars manually | |
x_coords = range(len(summary)) | |
plt.errorbar( | |
x=x_coords, | |
y=summary["mean_score"], | |
yerr=summary["std_score"], | |
fmt="none", | |
c="black", | |
capsize=5, | |
zorder=10, # Ensure error bars are on top | |
) | |
# Add text annotations | |
for i, row in summary.iterrows(): | |
ax.text( | |
i, | |
row["mean_score"], | |
f"{row['mean_score']:.2f}", | |
ha="center", | |
va="bottom", | |
fontweight="bold", | |
color="black", | |
bbox=dict(facecolor="white", edgecolor="none", alpha=0.7, pad=0.5), | |
) | |
# Customize the plot | |
plt.xlabel("") | |
plt.ylabel("Overall Score") | |
plt.xticks(rotation=45, ha="right") | |
plt.tight_layout() | |
# Display the plot in Streamlit | |
st.pyplot(plt.gcf()) | |
def plot_per_judge_overall_scores(df): | |
# Find the overall score by finding the overall score for each judge, and then averaging | |
# over all judges. | |
grouped = df.groupby(["llm_judge_model"]).agg({"score": ["mean"]}).reset_index() | |
grouped.columns = ["llm_judge_model", "overall_score"] | |
# Create the horizontal bar plot | |
plt.figure(figsize=(10, 6)) | |
ax = sns.barplot( | |
data=grouped, | |
y="llm_judge_model", | |
x="overall_score", | |
hue="llm_judge_model", | |
orient="h", | |
) | |
# Customize the plot | |
plt.title("Overall Scores by LLM Judge Model") | |
plt.xlabel("Overall Score") | |
plt.ylabel("LLM Judge Model") | |
# Adjust layout and display the plot | |
plt.tight_layout() | |
st.pyplot(plt) | |
# Main Streamlit App | |
def main(): | |
st.set_page_config( | |
page_title="Language Model Council Sandbox", page_icon="🏛️", layout="wide" | |
) | |
# Custom CSS for the chat display | |
center_css = """ | |
<style> | |
h1, h2, h3, h6 { text-align: center; } | |
.chat-container { | |
display: flex; | |
align-items: flex-start; | |
margin-bottom: 10px; | |
} | |
.avatar { | |
width: 50px; | |
margin-right: 10px; | |
} | |
.message { | |
background-color: #f1f1f1; | |
padding: 10px; | |
border-radius: 10px; | |
width: 100%; | |
} | |
</style> | |
""" | |
st.markdown(center_css, unsafe_allow_html=True) | |
# App title and description | |
st.title("Language Model Council Sandbox") | |
st.markdown("###### Invoke a council of LLMs to judge each other's responses.") | |
st.markdown("###### [Paper](https://arxiv.org/abs/2406.08598)") | |
# Authentication system | |
if "authenticated" not in st.session_state: | |
st.session_state.authenticated = False | |
cols = st.columns([2, 1, 2]) | |
if not st.session_state.authenticated: | |
with cols[1]: | |
password = st.text_input("Password", type="password") | |
if st.button("Login", use_container_width=True): | |
if password == PASSWORD: | |
st.session_state.authenticated = True | |
else: | |
st.error("Invalid credentials") | |
if st.session_state.authenticated: | |
# cols[1].success("Logged in successfully!") | |
st.markdown("#### LLM Council Member Selection") | |
# Council and aggregator selection | |
selected_models = llm_council_selector() | |
# st.write("Selected Models:", selected_models) | |
selected_aggregator = aggregator_selector() | |
# Initialize session state for collecting responses. | |
if "responses" not in st.session_state: | |
st.session_state.responses = {} | |
# if "aggregator_response" not in st.session_state: | |
# st.session_state.aggregator_response = {} | |
# Prompt input | |
st.markdown("#### Enter your prompt") | |
_, center_column, _ = st.columns([3, 5, 3]) | |
with center_column: | |
user_prompt = st.text_area(value="Say 'Hello World'", label="") | |
if center_column.button("Submit", use_container_width=True): | |
st.markdown("#### Responses") | |
response_columns = st.columns(3) | |
selected_models_to_streamlit_column_map = { | |
model: response_columns[i] for i, model in enumerate(selected_models) | |
} | |
# Fetching and streaming responses from each selected model | |
for selected_model in selected_models: | |
with selected_models_to_streamlit_column_map[selected_model]: | |
st.write(get_ui_friendly_name(selected_model)) | |
with st.chat_message( | |
selected_model, | |
avatar=PROVIDER_TO_AVATAR_MAP[selected_model], | |
): | |
message_placeholder = st.empty() | |
stream = get_llm_response_stream(selected_model, user_prompt) | |
if stream: | |
st.session_state["responses"][selected_model] = ( | |
message_placeholder.write_stream(stream) | |
) | |
# Get the aggregator prompt. | |
aggregator_prompt = get_default_aggregator_prompt( | |
user_prompt=user_prompt, llms=selected_models | |
) | |
with st.expander("Aggregator Prompt"): | |
st.code(aggregator_prompt) | |
# Fetching and streaming response from the aggregator | |
st.write(f"Mixture-of-Agents ({get_ui_friendly_name(selected_aggregator)})") | |
with st.chat_message( | |
selected_aggregator, | |
avatar="img/council_icon.png", | |
): | |
message_placeholder = st.empty() | |
aggregator_stream = get_llm_response_stream( | |
selected_aggregator, aggregator_prompt | |
) | |
if aggregator_stream: | |
st.session_state["responses"]["agg__" + selected_aggregator] = ( | |
message_placeholder.write_stream(aggregator_stream) | |
) | |
# st.write("Responses (in session state):") | |
# st.write(st.session_state["responses"]) | |
# Judging. | |
st.markdown("#### Judging Configuration") | |
# Choose the type of assessment | |
assessment_type = st.radio( | |
"Select the type of assessment", | |
options=["Direct Assessment", "Pairwise Comparison"], | |
) | |
_, center_column, _ = st.columns([3, 5, 3]) | |
# Depending on the assessment type, render different forms | |
if assessment_type == "Direct Assessment": | |
# Initialize session state for direct assessment. | |
if "direct_assessment_overall_score" not in st.session_state: | |
st.session_state["direct_assessment_overall_score"] = {} | |
if "direct_assessment_judging_df" not in st.session_state: | |
st.session_state["direct_assessment_judging_df"] = {} | |
for response_model in selected_models: | |
st.session_state["direct_assessment_judging_df"][ | |
response_model | |
] = {} | |
# aggregator model | |
st.session_state["direct_assessment_judging_df"][ | |
"agg__" + selected_aggregator | |
] = {} | |
if "direct_assessment_judging_responses" not in st.session_state: | |
st.session_state["direct_assessment_judging_responses"] = {} | |
for response_model in selected_models: | |
st.session_state["direct_assessment_judging_responses"][ | |
response_model | |
] = {} | |
# aggregator model | |
st.session_state["direct_assessment_judging_responses"][ | |
"agg__" + selected_aggregator | |
] = {} | |
if "direct_assessment_overall_scores" not in st.session_state: | |
st.session_state["direct_assessment_overall_scores"] = {} | |
for response_model in selected_models: | |
st.session_state["direct_assessment_overall_scores"][ | |
response_model | |
] = {} | |
st.session_state["direct_assessment_overall_scores"][ | |
"agg__" + selected_aggregator | |
] = {} | |
if "judging_status" not in st.session_state: | |
st.session_state["judging_status"] = "incomplete" | |
# Direct assessment prompt. | |
with center_column.expander("Direct Assessment Prompt"): | |
direct_assessment_prompt = st.text_area( | |
"Prompt for the Direct Assessment", | |
value=get_default_direct_assessment_prompt(user_prompt=user_prompt), | |
height=500, | |
) | |
# TODO: Add option to edit criteria list with a basic text field. | |
criteria_list = DEFAULT_DIRECT_ASSESSMENT_CRITERIA_LIST | |
# Create DirectAssessment object when form is submitted | |
if center_column.button( | |
"Submit Direct Assessment", use_container_width=True | |
): | |
# Submit direct asssessment. | |
responses_for_judging = st.session_state["responses"] | |
# st.write("Responses for judging (in session state):") | |
# st.write(responses_for_judging) | |
response_judging_columns = st.columns(3) | |
responses_for_judging_to_streamlit_column_map = { | |
model: response_judging_columns[i % 3] | |
for i, model in enumerate(responses_for_judging.keys()) | |
} | |
# Get judging responses. | |
for response_model, response in responses_for_judging.items(): | |
st_column = responses_for_judging_to_streamlit_column_map[ | |
response_model | |
] | |
with st_column: | |
if "agg__" in response_model: | |
judging_model_header = "Mixture-of-Agents Response" | |
else: | |
judging_model_header = get_ui_friendly_name(response_model) | |
st.write(f"Judging for {judging_model_header}") | |
# st.write("Response being judged: ") | |
# st.write(response) | |
judging_prompt = get_direct_assessment_prompt( | |
direct_assessment_prompt=direct_assessment_prompt, | |
user_prompt=user_prompt, | |
response=response, | |
criteria_list=criteria_list, | |
options=SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS, | |
) | |
with st.expander("Final Judging Prompt"): | |
st.code(judging_prompt) | |
for judging_model in selected_models: | |
with st.expander( | |
get_ui_friendly_name(judging_model), expanded=False | |
): | |
with st.chat_message( | |
judging_model, | |
avatar=PROVIDER_TO_AVATAR_MAP[judging_model], | |
): | |
message_placeholder = st.empty() | |
judging_stream = get_llm_response_stream( | |
judging_model, judging_prompt | |
) | |
# if judging_stream: | |
st.session_state[ | |
"direct_assessment_judging_responses" | |
][response_model][ | |
judging_model | |
] = message_placeholder.write_stream( | |
judging_stream | |
) | |
# When all of the judging is finished for the given response, get the actual | |
# values, parsed (use gpt-4o-mini for now) with json mode. | |
# TODO. | |
judging_responses = st.session_state[ | |
"direct_assessment_judging_responses" | |
][response_model] | |
# st.write("Judging responses (in session state):") | |
# st.write(judging_responses) | |
if not judging_responses: | |
st.error(f"No judging responses for {response_model}") | |
quit() | |
parse_judging_response_prompt = ( | |
get_parse_judging_response_for_direct_assessment_prompt( | |
judging_responses, | |
criteria_list, | |
SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS, | |
) | |
) | |
with st.expander("Parse Judging Response Prompt"): | |
st.code(parse_judging_response_prompt) | |
# Issue the prompt to openai mini with structured outputs | |
parsed_judging_responses = parse_judging_responses( | |
parse_judging_response_prompt, judging_responses | |
) | |
st.session_state["direct_assessment_judging_df"][ | |
response_model | |
] = create_dataframe_for_direct_assessment_judging_response( | |
parsed_judging_responses | |
) | |
st.write( | |
st.session_state["direct_assessment_judging_df"][ | |
response_model | |
] | |
) | |
plot_criteria_scores( | |
st.session_state["direct_assessment_judging_df"][ | |
response_model | |
] | |
) | |
# Find the overall score by finding the overall score for each judge, and then averaging | |
# over all judges. | |
plot_per_judge_overall_scores( | |
st.session_state["direct_assessment_judging_df"][ | |
response_model | |
] | |
) | |
grouped = ( | |
st.session_state["direct_assessment_judging_df"][ | |
response_model | |
] | |
.groupby(["llm_judge_model"]) | |
.agg({"score": ["mean"]}) | |
.reset_index() | |
) | |
grouped.columns = ["llm_judge_model", "overall_score"] | |
# st.write( | |
# "Extracting overall scores from this grouped dataframe:" | |
# ) | |
# st.write(grouped) | |
# Save the overall scores to the session state. | |
for record in grouped.to_dict(orient="records"): | |
st.session_state["direct_assessment_overall_scores"][ | |
response_model | |
][record["llm_judge_model"]] = record["overall_score"] | |
overall_score = grouped["overall_score"].mean() | |
controversy = grouped["overall_score"].std() | |
st.write(f"Overall Score: {overall_score:.2f}") | |
st.write(f"Controversy: {controversy:.2f}") | |
st.session_state["judging_status"] = "complete" | |
# Judging is complete. | |
st.write("#### Results") | |
# The session state now contains the overall scores for each response from each judge. | |
if st.session_state["judging_status"] == "complete": | |
overall_scores_df_raw = pd.DataFrame( | |
st.session_state["direct_assessment_overall_scores"] | |
).reset_index() | |
overall_scores_df = pd.melt( | |
overall_scores_df_raw, | |
id_vars=["index"], | |
var_name="response_model", | |
value_name="score", | |
).rename(columns={"index": "judging_model"}) | |
# Print the overall winner. | |
overall_winner = overall_scores_df.loc[ | |
overall_scores_df["score"].idxmax() | |
] | |
st.write( | |
f"**Overall Winner:** {get_ui_friendly_name(overall_winner['response_model'])}" | |
) | |
# Find how much the standard deviation overlaps with other models. | |
# Calculate separability. | |
# TODO. | |
st.write(f"**Confidence:** {overall_winner['score']:.2f}") | |
left_column, right_column = st.columns([1, 1]) | |
with left_column: | |
plot_overall_scores(overall_scores_df) | |
with right_column: | |
st.dataframe(overall_scores_df) | |
elif assessment_type == "Pairwise Comparison": | |
pass | |
# pairwise_comparison_prompt = st.text_area( | |
# "Prompt for the Pairwise Comparison" | |
# ) | |
# granularity = st.selectbox("Granularity", ["coarse", "fine", "super fine"]) | |
# ties_allowed = st.checkbox("Are ties allowed?") | |
# position_swapping = st.checkbox("Enable position swapping?") | |
# reference_model = st.text_input("Reference Model") | |
# # Create PairwiseComparison object when form is submitted | |
# if st.button("Submit Pairwise Comparison"): | |
# pairwise_comparison_config = PairwiseComparison( | |
# type="pairwise_comparison", | |
# granularity=granularity, | |
# ties_allowed=ties_allowed, | |
# position_swapping=position_swapping, | |
# reference_model=reference_model, | |
# prompt=prompt, | |
# ) | |
# st.success(f"Pairwise Comparison Created: {pairwise_comparison_config}") | |
# # Submit pairwise comparison. | |
# responses_for_judging = st.session_state["responses"] | |
else: | |
with cols[1]: | |
st.warning("Please log in to access this app.") | |
if __name__ == "__main__": | |
main() | |