Spaces:
Sleeping
Sleeping
Parse judgments with structured output prompting, one response model, one judge model at a time.
eb4ec23
import os | |
import streamlit as st | |
import dotenv | |
import openai | |
from openai import OpenAI | |
import anthropic | |
from together import Together | |
import google.generativeai as genai | |
import time | |
from collections import defaultdict | |
from typing import List, Optional, Literal, Union, Dict | |
from constants import ( | |
LLM_COUNCIL_MEMBERS, | |
PROVIDER_TO_AVATAR_MAP, | |
AGGREGATORS, | |
LLM_TO_UI_NAME_MAP, | |
) | |
from prompts import * | |
from judging_dataclasses import ( | |
# DirectAssessmentJudgingResponse, | |
DirectAssessmentCriterionScore, | |
DirectAssessmentCriteriaScores, | |
) | |
import pandas as pd | |
import seaborn as sns | |
import matplotlib.pyplot as plt | |
import numpy as np | |
dotenv.load_dotenv() | |
PASSWORD = os.getenv("APP_PASSWORD") | |
# Load API keys from environment variables | |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY") | |
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") | |
TOGETHER_API_KEY = os.getenv("TOGETHER_API_KEY") | |
# Initialize API clients | |
together_client = Together(api_key=TOGETHER_API_KEY) | |
genai.configure(api_key=GOOGLE_API_KEY) | |
# Set up API clients for OpenAI and Anthropic | |
openai.api_key = OPENAI_API_KEY | |
openai_client = OpenAI( | |
organization="org-kUoRSK0nOw4W2nQYMVGWOt03", | |
project="proj_zb6k1DdgnSEbiAEMWxSOVVu4", | |
) | |
# anthropic_client = anthropic.Client(api_key=ANTHROPIC_API_KEY) | |
anthropic_client = anthropic.Anthropic() | |
client = OpenAI() | |
def anthropic_streamlit_streamer(stream, llm): | |
""" | |
Process the Anthropic streaming response and yield content from the deltas. | |
:param stream: Streaming object from Anthropic API | |
:return: Yields content (text) from the streaming response. | |
""" | |
for event in stream: | |
if hasattr(event, "type"): | |
# Count input token usage. | |
if event.type == "message_start": | |
st.session_state["input_token_usage"][ | |
llm | |
] += event.message.usage.input_tokens | |
st.session_state["output_token_usage"][ | |
llm | |
] += event.message.usage.output_tokens | |
# Count output token usage. | |
if event.type == "message_delta": | |
st.session_state["output_token_usage"][llm] += event.usage.output_tokens | |
# Handle content blocks | |
if event.type == "content_block_delta" and hasattr(event, "delta"): | |
# Extract text delta from the event | |
text_delta = getattr(event.delta, "text", None) | |
if text_delta: | |
yield text_delta | |
# Handle message completion events (optional if needed) | |
elif event.type == "message_stop": | |
break # End of message, stop streaming | |
def get_ui_friendly_name(llm): | |
if "agg__" in llm: | |
return ( | |
"MoA (" | |
+ LLM_TO_UI_NAME_MAP.get(llm.split("__")[1], llm.split("__")[1]) | |
+ ")" | |
) | |
return LLM_TO_UI_NAME_MAP.get(llm, llm) | |
def google_streamlit_streamer(stream): | |
# TODO: Count token usage. | |
for chunk in stream: | |
yield chunk.text | |
def openai_streamlit_streamer(stream, llm): | |
# https://platform.openai.com/docs/api-reference/streaming | |
for event in stream: | |
if event.usage: | |
st.session_state["input_token_usage"][llm] += event.usage.prompt_tokens | |
st.session_state["output_token_usage"][llm] += event.usage.completion_tokens | |
if event.choices: | |
if event.choices[0].delta.content: | |
yield event.choices[0].delta.content | |
def together_streamlit_streamer(stream, llm): | |
# https://docs.together.ai/docs/chat-overview#streaming-responses | |
for chunk in stream: | |
if chunk.usage: | |
st.session_state["input_token_usage"][llm] += chunk.usage.prompt_tokens | |
if chunk.usage: | |
st.session_state["output_token_usage"][llm] += chunk.usage.completion_tokens | |
yield chunk.choices[0].delta.content | |
# Helper functions for LLM council and aggregator selection | |
def llm_council_selector(): | |
selected_council = st.radio( | |
"Choose a council configuration", options=list(LLM_COUNCIL_MEMBERS.keys()) | |
) | |
return LLM_COUNCIL_MEMBERS[selected_council] | |
def aggregator_selector(): | |
return st.radio("Choose an aggregator LLM", options=AGGREGATORS) | |
# API calls for different providers | |
def get_openai_response(model_name, prompt): | |
return openai_client.chat.completions.create( | |
model=model_name, | |
messages=[{"role": "user", "content": prompt}], | |
stream=True, | |
stream_options={"include_usage": True}, | |
) | |
# https://docs.anthropic.com/en/api/messages-streaming | |
def get_anthropic_response(model_name, prompt): | |
return anthropic_client.messages.create( | |
max_tokens=1024, | |
messages=[{"role": "user", "content": prompt}], | |
model=model_name, | |
stream=True, | |
) | |
def get_together_response(model_name, prompt): | |
return together_client.chat.completions.create( | |
model=model_name, | |
messages=[{"role": "user", "content": prompt}], | |
stream=True, | |
) | |
# https://ai.google.dev/gemini-api/docs/text-generation?lang=python | |
def get_google_response(model_name, prompt): | |
model = genai.GenerativeModel(model_name) | |
return model.generate_content(prompt, stream=True) | |
def get_llm_response_stream(model_identifier, prompt): | |
"""Returns a streamlit-friendly stream of response tokens from the LLM.""" | |
provider, model_name = model_identifier.split("://") | |
if provider == "openai": | |
return openai_streamlit_streamer( | |
get_openai_response(model_name, prompt), model_identifier | |
) | |
elif provider == "anthropic": | |
return anthropic_streamlit_streamer( | |
get_anthropic_response(model_name, prompt), model_identifier | |
) | |
elif provider == "together": | |
return together_streamlit_streamer( | |
get_together_response(model_name, prompt), model_identifier | |
) | |
elif provider == "vertex": | |
return google_streamlit_streamer(get_google_response(model_name, prompt)) | |
else: | |
return None | |
def create_dataframe_for_direct_assessment_judging_response( | |
response: DirectAssessmentCriteriaScores, judging_model: str | |
) -> pd.DataFrame: | |
# Initialize empty list to collect data | |
data = [] | |
# Loop through models | |
# for judging_model in response.judging_models: | |
# model_name = judging_model.model | |
# Loop through criteria_scores | |
for criteria_score in response.criteria_scores: | |
data.append( | |
{ | |
"judging_model": judging_model, # Gets passed in. | |
"criteria": criteria_score.criterion, | |
"score": criteria_score.score, | |
"explanation": criteria_score.explanation, | |
} | |
) | |
# Create DataFrame | |
return pd.DataFrame(data) | |
# Streamlit form UI | |
def render_criteria_form(criteria_num): | |
"""Render a criteria input form.""" | |
with st.expander(f"Criteria {criteria_num + 1}"): | |
name = st.text_input( | |
f"Name for Criteria {criteria_num + 1}", key=f"criteria_name_{criteria_num}" | |
) | |
description = st.text_area( | |
f"Description for Criteria {criteria_num + 1}", | |
key=f"criteria_desc_{criteria_num}", | |
) | |
min_score = st.number_input( | |
f"Min Score for Criteria {criteria_num + 1}", | |
min_value=0, | |
step=1, | |
key=f"criteria_min_{criteria_num}", | |
) | |
max_score = st.number_input( | |
f"Max Score for Criteria {criteria_num + 1}", | |
min_value=0, | |
step=1, | |
key=f"criteria_max_{criteria_num}", | |
) | |
return Criteria( | |
name=name, description=description, min_score=min_score, max_score=max_score | |
) | |
def format_likert_comparison_options(options): | |
return "\n".join([f"{i + 1}: {option}" for i, option in enumerate(options)]) | |
def format_criteria_list(criteria_list): | |
return "\n".join( | |
[f"{criteria.name}: {criteria.description}" for criteria in criteria_list] | |
) | |
def get_direct_assessment_prompt( | |
direct_assessment_prompt, user_prompt, response, criteria_list, options | |
): | |
return direct_assessment_prompt.format( | |
user_prompt=user_prompt, | |
response=response, | |
criteria_list=f"{format_criteria_list(DEFAULT_DIRECT_ASSESSMENT_CRITERIA_LIST)}", | |
options=f"{format_likert_comparison_options(SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS)}", | |
) | |
def get_default_direct_assessment_prompt(user_prompt): | |
return get_direct_assessment_prompt( | |
direct_assessment_prompt=DEFAULT_DIRECT_ASSESSMENT_PROMPT, | |
user_prompt=user_prompt, | |
response="{response}", | |
criteria_list=DEFAULT_DIRECT_ASSESSMENT_CRITERIA_LIST, | |
options=SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS, | |
) | |
def get_aggregator_prompt(aggregator_prompt, user_prompt, llms): | |
responses_from_other_llms = "\n\n".join( | |
[ | |
f"{get_ui_friendly_name(model)} START\n{st.session_state['responses'][model]}\n\n{get_ui_friendly_name(model)} END\n\n\n" | |
for model in llms | |
] | |
) | |
return aggregator_prompt.format( | |
user_prompt=user_prompt, | |
responses_from_other_llms=responses_from_other_llms, | |
) | |
def get_default_aggregator_prompt(user_prompt, llms): | |
return get_aggregator_prompt( | |
DEFAULT_AGGREGATOR_PROMPT, | |
user_prompt=user_prompt, | |
llms=llms, | |
) | |
def get_parse_judging_response_for_direct_assessment_prompt( | |
judging_response: str, | |
criteria_list, | |
options, | |
) -> str: | |
# formatted_judging_responses = "\n\n\n".join( | |
# [ | |
# f"----- {get_ui_friendly_name(model)} START -----\n\n\n{judging_responses[model]}\n\n\n-----{get_ui_friendly_name(model)} END-----\n\n\n" | |
# for model in judging_responses.keys() | |
# ] | |
# ) | |
formatted_judging_response = ( | |
f"----- START -----\n\n\n{judging_response}\n\n\n----- END -----\n\n\n" | |
) | |
return PARSE_JUDGING_RESPONSE_FOR_DIRECT_ASSESSMENT_PROMPT.format( | |
judging_response=formatted_judging_response, | |
criteria_list=format_criteria_list(criteria_list), | |
options=format_likert_comparison_options(options), | |
) | |
def get_parsed_judging_response_obj_using_llm( | |
prompt: str, | |
) -> DirectAssessmentCriteriaScores: | |
# if os.getenv("DEBUG_MODE") == "True": | |
# return DirectAssessmentJudgingResponse( | |
# judging_models=[ | |
# DirectAssessmentCriteriaScores( | |
# model="together://meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", | |
# criteria_scores=[ | |
# DirectAssessmentCriterionScore( | |
# criterion="helpfulness", score=3, explanation="explanation1" | |
# ), | |
# DirectAssessmentCriterionScore( | |
# criterion="conciseness", score=4, explanation="explanation2" | |
# ), | |
# DirectAssessmentCriterionScore( | |
# criterion="relevance", score=5, explanation="explanation3" | |
# ), | |
# ], | |
# ), | |
# DirectAssessmentCriteriaScores( | |
# model="together://meta-llama/Llama-3.2-3B-Instruct-Turbo", | |
# criteria_scores=[ | |
# DirectAssessmentCriterionScore( | |
# criterion="helpfulness", score=1, explanation="explanation1" | |
# ), | |
# DirectAssessmentCriterionScore( | |
# criterion="conciseness", score=2, explanation="explanation2" | |
# ), | |
# DirectAssessmentCriterionScore( | |
# criterion="relevance", score=3, explanation="explanation3" | |
# ), | |
# ], | |
# ), | |
# ] | |
# ) | |
# else: | |
completion = client.beta.chat.completions.parse( | |
model="gpt-4o-mini", | |
messages=[ | |
{ | |
"role": "system", | |
"content": "Parse the judging responses into structured data.", | |
}, | |
{"role": "user", "content": prompt}, | |
], | |
response_format=DirectAssessmentCriteriaScores, | |
) | |
# Track token usage. | |
st.session_state["input_token_usage"][ | |
"gpt-4o-mini" | |
] += completion.usage.prompt_tokens | |
st.session_state["output_token_usage"][ | |
"gpt-4o-mini" | |
] += completion.usage.completion_tokens | |
return completion.choices[0].message.parsed | |
def get_llm_avatar(model_identifier): | |
if "agg__" in model_identifier: | |
return "img/council_icon.png" | |
else: | |
return PROVIDER_TO_AVATAR_MAP[model_identifier] | |
def plot_criteria_scores(df): | |
# Group by criteria and calculate mean and std over all judges. | |
grouped = df.groupby(["criteria"]).agg({"score": ["mean", "std"]}).reset_index() | |
# Flatten the MultiIndex columns | |
grouped.columns = ["criteria", "mean_score", "std_score"] | |
# Fill NaN std with zeros (in case there's only one score per group) | |
grouped["std_score"] = grouped["std_score"].fillna(0) | |
# Set up the plot | |
plt.figure(figsize=(8, 5)) | |
# Create a horizontal bar plot | |
ax = sns.barplot( | |
data=grouped, | |
x="mean_score", | |
y="criteria", | |
hue="criteria", | |
errorbar=None, # Updated parameter | |
orient="h", | |
) | |
# Add error bars manually | |
# Iterate over the bars and add error bars | |
for i, (mean, std) in enumerate(zip(grouped["mean_score"], grouped["std_score"])): | |
# Get the current bar | |
bar = ax.patches[i] | |
# Calculate the center of the bar | |
center = bar.get_y() + bar.get_height() / 2 | |
# Add the error bar | |
ax.errorbar(x=mean, y=center, xerr=std, ecolor="black", capsize=3, fmt="none") | |
# Set labels and title | |
ax.set_xlabel("") | |
ax.set_ylabel("") | |
plt.tight_layout() | |
# Display the plot in Streamlit | |
st.pyplot(plt.gcf()) | |
def plot_overall_scores(overall_scores_df): | |
# Calculate mean and standard deviation | |
summary = ( | |
overall_scores_df.groupby("response_model") | |
.agg({"score": ["mean", "std"]}) | |
.reset_index() | |
) | |
summary.columns = ["response_model", "mean_score", "std_score"] | |
# Add UI-friendly names | |
summary["ui_friendly_name"] = summary["response_model"].apply(get_ui_friendly_name) | |
# Sort the summary dataframe by mean_score in descending order | |
summary = summary.sort_values("mean_score", ascending=False) | |
# Create the plot | |
plt.figure(figsize=(8, 5)) | |
# Plot bars with rainbow colors | |
ax = sns.barplot( | |
x="ui_friendly_name", | |
y="mean_score", | |
hue="ui_friendly_name", | |
data=summary, | |
palette="rainbow", | |
capsize=0.1, | |
legend=False, | |
) | |
# Add error bars manually | |
x_coords = range(len(summary)) | |
plt.errorbar( | |
x=x_coords, | |
y=summary["mean_score"], | |
yerr=summary["std_score"], | |
fmt="none", | |
c="black", | |
capsize=5, | |
zorder=10, # Ensure error bars are on top | |
) | |
# Add text annotations using the actual positions of the bars | |
for patch, row in zip(ax.patches, summary.itertuples()): | |
# Get the center of each bar (x position) | |
x = patch.get_x() + patch.get_width() / 2 | |
y = patch.get_height() | |
# Add the text annotation | |
ax.text( | |
x, | |
y, | |
f"{row.mean_score:.2f}", | |
ha="center", | |
va="bottom", | |
# fontweight="bold", | |
color="black", | |
bbox=dict(facecolor="white", edgecolor="none", alpha=0.7, pad=0.5), | |
) | |
# Customize the plot | |
plt.xlabel("") | |
plt.ylabel("Overall Score") | |
plt.xticks(rotation=45, ha="right") | |
plt.tight_layout() | |
# Display the plot in Streamlit | |
st.pyplot(plt.gcf()) | |
def plot_per_judge_overall_scores(df): | |
# Find the overall score by finding the overall score for each judge, and then averaging | |
# over all judges. | |
grouped = df.groupby(["judging_model"]).agg({"score": ["mean"]}).reset_index() | |
grouped.columns = ["judging_model", "overall_score"] | |
# Create the horizontal bar plot | |
plt.figure(figsize=(10, 6)) | |
ax = sns.barplot( | |
data=grouped, | |
x="judging_model", | |
y="overall_score", | |
hue="judging_model", | |
orient="v", | |
palette="rainbow", | |
) | |
# Customize the plot | |
plt.title("Overall Score from each LLM Judge") | |
plt.xlabel("Overall Score") | |
plt.ylabel("LLM Judge") | |
# Adjust layout and display the plot | |
plt.tight_layout() | |
st.pyplot(plt) | |
def get_selected_models_to_streamlit_column_map(st_columns, selected_models): | |
selected_models_to_streamlit_column_map = { | |
model: st_columns[i % len(st_columns)] | |
for i, model in enumerate(selected_models) | |
} | |
return selected_models_to_streamlit_column_map | |
def get_aggregator_key(llm_aggregator): | |
return "agg__" + llm_aggregator | |
def st_render_responses(user_prompt): | |
"""Renders the responses from the LLMs. | |
Uses cached responses from the session state, if available. | |
Otherwise, streams the responses anew. | |
Assumes that the session state has already been set up with selected models and selected aggregator. | |
""" | |
st.markdown("#### Responses") | |
response_columns = st.columns(3) | |
selected_models_to_streamlit_column_map = ( | |
get_selected_models_to_streamlit_column_map( | |
response_columns, st.session_state.selected_models | |
) | |
) | |
for response_model in st.session_state.selected_models: | |
st_column = selected_models_to_streamlit_column_map.get( | |
response_model, response_columns[0] | |
) | |
with st_column.chat_message( | |
response_model, | |
avatar=get_llm_avatar(response_model), | |
): | |
st.write(get_ui_friendly_name(response_model)) | |
if response_model in st.session_state.responses: | |
# Use the cached response from session state. | |
st.write(st.session_state.responses[response_model]) | |
else: | |
# Stream the response from the LLM. | |
message_placeholder = st.empty() | |
stream = get_llm_response_stream(response_model, user_prompt) | |
st.session_state.responses[response_model] = ( | |
message_placeholder.write_stream(stream) | |
) | |
# Render the aggregator response. | |
aggregator_prompt = get_default_aggregator_prompt( | |
user_prompt=user_prompt, llms=st.session_state.selected_models | |
) | |
# Streaming response from the aggregator. | |
with st.chat_message( | |
get_aggregator_key(st.session_state.selected_aggregator), | |
avatar="img/council_icon.png", | |
): | |
st.write( | |
f"{get_ui_friendly_name(get_aggregator_key(st.session_state.selected_aggregator))}" | |
) | |
if ( | |
get_aggregator_key(st.session_state.selected_aggregator) | |
in st.session_state.responses | |
): | |
st.write( | |
st.session_state.responses[ | |
get_aggregator_key(st.session_state.selected_aggregator) | |
] | |
) | |
else: | |
message_placeholder = st.empty() | |
aggregator_stream = get_llm_response_stream( | |
st.session_state.selected_aggregator, aggregator_prompt | |
) | |
if aggregator_stream: | |
st.session_state.responses[ | |
get_aggregator_key(st.session_state.selected_aggregator) | |
] = message_placeholder.write_stream(aggregator_stream) | |
st.session_state.responses_collected = True | |
def st_direct_assessment_results(user_prompt, direct_assessment_prompt, criteria_list): | |
"""Renders the direct assessment results block. | |
Uses session state to render results from LLMs. If the session state isn't set, then fetches the | |
responses from the LLMs services from scratch (and sets the session state). | |
Assumes that the session state has already been set up with responses. | |
""" | |
responses_for_judging = st.session_state.responses | |
# Get judging responses. | |
response_judging_columns = st.columns(3) | |
responses_for_judging_to_streamlit_column_map = ( | |
get_selected_models_to_streamlit_column_map( | |
response_judging_columns, responses_for_judging.keys() | |
) | |
) | |
for response_model, response in responses_for_judging.items(): | |
st_column = responses_for_judging_to_streamlit_column_map[response_model] | |
with st_column: | |
st.write(f"Judging for {get_ui_friendly_name(response_model)}") | |
judging_prompt = get_direct_assessment_prompt( | |
direct_assessment_prompt=direct_assessment_prompt, | |
user_prompt=user_prompt, | |
response=response, | |
criteria_list=criteria_list, | |
options=SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS, | |
) | |
with st.expander("Final Judging Prompt"): | |
st.code(judging_prompt) | |
for judging_model in st.session_state.selected_models: | |
with st.expander(get_ui_friendly_name(judging_model), expanded=True): | |
with st.chat_message( | |
judging_model, | |
avatar=PROVIDER_TO_AVATAR_MAP[judging_model], | |
): | |
if ( | |
judging_model | |
in st.session_state.direct_assessment_judging_responses[ | |
response_model | |
] | |
): | |
# Use the session state cached response. | |
st.write( | |
st.session_state.direct_assessment_judging_responses[ | |
response_model | |
][judging_model] | |
) | |
else: | |
message_placeholder = st.empty() | |
# Get the judging response from the LLM. | |
judging_stream = get_llm_response_stream( | |
judging_model, judging_prompt | |
) | |
st.session_state.direct_assessment_judging_responses[ | |
response_model | |
][judging_model] = message_placeholder.write_stream( | |
judging_stream | |
) | |
# Parse the judging response. If parsing results are already cached, then | |
# skip. | |
# Use Structured Output to parse the judging response. | |
parse_judging_response_prompt = get_parse_judging_response_for_direct_assessment_prompt( | |
judging_response=st.session_state.direct_assessment_judging_responses[ | |
response_model | |
][ | |
judging_model | |
], | |
criteria_list=criteria_list, | |
options=SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS, | |
) | |
st.write("Parse judging response prompt:") | |
st.write(parse_judging_response_prompt) | |
if ( | |
response_model | |
not in st.session_state.direct_assessment_judging_by_response_and_judging_model_df | |
or judging_model | |
not in st.session_state.direct_assessment_judging_by_response_and_judging_model_df[ | |
response_model | |
] | |
): | |
parsed_judging_response_obj = ( | |
get_parsed_judging_response_obj_using_llm( | |
parse_judging_response_prompt | |
) | |
) | |
st.session_state.direct_assessment_judging_by_response_and_judging_model_df[ | |
response_model | |
][ | |
judging_model | |
] = create_dataframe_for_direct_assessment_judging_response( | |
parsed_judging_response_obj, judging_model | |
) | |
# with st.expander("Structured output parsing response"): | |
st.write("Structured output parsing response:") | |
st.write( | |
st.session_state.direct_assessment_judging_by_response_and_judging_model_df[ | |
response_model | |
][ | |
judging_model | |
] | |
) | |
# Combined the dataframes for each judging model into a single dataframe for each | |
# response model. | |
if response_model not in st.session_state.direct_assessment_judging_df: | |
# Combine the dataframes for each judging model into a single dataframe. | |
combined_judging_df = pd.DataFrame() | |
for judging_model in st.session_state.selected_models: | |
combined_judging_df = pd.concat( | |
[ | |
combined_judging_df, | |
st.session_state.direct_assessment_judging_by_response_and_judging_model_df[ | |
response_model | |
][ | |
judging_model | |
], | |
] | |
) | |
st.session_state.direct_assessment_judging_df[response_model] = ( | |
combined_judging_df | |
) | |
with st.expander("Judging results from all judges"): | |
st.write(st.session_state.direct_assessment_judging_df[response_model]) | |
# Uses the session state to plot the criteria scores and graphs for a given response | |
# model. | |
plot_criteria_scores( | |
st.session_state.direct_assessment_judging_df[response_model] | |
) | |
plot_per_judge_overall_scores( | |
st.session_state.direct_assessment_judging_df[response_model] | |
) | |
grouped = ( | |
st.session_state.direct_assessment_judging_df[response_model] | |
.groupby(["judging_model"]) | |
.agg({"score": ["mean"]}) | |
.reset_index() | |
) | |
grouped.columns = ["judging_model", "overall_score"] | |
# Save the overall scores to the session state if it's not already there. | |
for record in grouped.to_dict(orient="records"): | |
st.session_state.direct_assessment_overall_scores[ | |
get_ui_friendly_name(response_model) | |
][get_ui_friendly_name(record["judging_model"])] = record[ | |
"overall_score" | |
] | |
overall_score = grouped["overall_score"].mean() | |
controversy = grouped["overall_score"].std() | |
st.write(f"Overall Score: {overall_score:.2f}") | |
st.write(f"Controversy: {controversy:.2f}") | |
# Mark judging as complete. | |
st.session_state.judging_status = "complete" | |
# Main Streamlit App | |
def main(): | |
st.set_page_config( | |
page_title="Language Model Council Sandbox", page_icon="🏛️", layout="wide" | |
) | |
# Custom CSS for the chat display | |
center_css = """ | |
<style> | |
h1, h2, h3, h6 { text-align: center; } | |
.chat-container { | |
display: flex; | |
align-items: flex-start; | |
margin-bottom: 10px; | |
} | |
.avatar { | |
width: 50px; | |
margin-right: 10px; | |
} | |
.message { | |
background-color: #f1f1f1; | |
padding: 10px; | |
border-radius: 10px; | |
width: 100%; | |
} | |
</style> | |
""" | |
st.markdown(center_css, unsafe_allow_html=True) | |
# App title and description | |
st.title("Language Model Council Sandbox") | |
st.markdown("###### Invoke a council of LLMs to judge each other's responses.") | |
st.markdown("###### [Paper](https://arxiv.org/abs/2406.08598)") | |
# Authentication system | |
if "authenticated" not in st.session_state: | |
st.session_state.authenticated = False | |
cols = st.columns([2, 1, 2]) | |
if not st.session_state.authenticated: | |
with cols[1]: | |
with st.form("login_form"): | |
password = st.text_input("Password", type="password") | |
submit_button = st.form_submit_button("Login", use_container_width=True) | |
if submit_button: | |
if password == PASSWORD: | |
st.session_state.authenticated = True | |
st.success("Logged in successfully!") | |
st.rerun() | |
else: | |
st.error("Invalid credentials") | |
if st.session_state.authenticated: | |
if "responses_collected" not in st.session_state: | |
st.session_state["responses_collected"] = False | |
# Initialize session state for collecting responses. | |
if "responses" not in st.session_state: | |
st.session_state.responses = defaultdict(str) | |
# Initialize session state for token usage. | |
if "input_token_usage" not in st.session_state: | |
st.session_state["input_token_usage"] = defaultdict(int) | |
if "output_token_usage" not in st.session_state: | |
st.session_state["output_token_usage"] = defaultdict(int) | |
if "selected_models" not in st.session_state: | |
st.session_state["selected_models"] = [] | |
if "selected_aggregator" not in st.session_state: | |
st.session_state["selected_aggregator"] = None | |
# Initialize session state for direct assessment judging. | |
if "direct_assessment_overall_score" not in st.session_state: | |
st.session_state.direct_assessment_overall_score = {} | |
if "direct_assessment_judging_df" not in st.session_state: | |
st.session_state.direct_assessment_judging_df = {} | |
if ( | |
"direct_assessment_judging_by_response_and_judging_model_df" | |
not in st.session_state | |
): | |
st.session_state.direct_assessment_judging_by_response_and_judging_model_df = defaultdict( | |
dict | |
) | |
if "direct_assessment_judging_responses" not in st.session_state: | |
st.session_state.direct_assessment_judging_responses = defaultdict(dict) | |
if "direct_assessment_overall_scores" not in st.session_state: | |
st.session_state.direct_assessment_overall_scores = defaultdict(dict) | |
if "judging_status" not in st.session_state: | |
st.session_state.judging_status = "incomplete" | |
if "direct_assessment_config" not in st.session_state: | |
st.session_state.direct_assessment_config = {} | |
if "pairwise_comparison_config" not in st.session_state: | |
st.session_state.pairwise_comparison_config = {} | |
if "assessment_type" not in st.session_state: | |
st.session_state.assessment_type = None | |
with st.form(key="prompt_form"): | |
st.markdown("#### LLM Council Member Selection") | |
# Council and aggregator selection | |
selected_models = llm_council_selector() | |
selected_aggregator = aggregator_selector() | |
# Prompt input and submission form | |
st.markdown("#### Enter your prompt") | |
_, center_column, _ = st.columns([3, 5, 3]) | |
with center_column: | |
user_prompt = st.text_area( | |
"Enter your prompt", | |
value="Say 'Hello World'", | |
key="user_prompt", | |
label_visibility="hidden", | |
) | |
submit_button = st.form_submit_button( | |
"Submit", use_container_width=True | |
) | |
if submit_button: | |
# Udpate state. | |
st.session_state.selected_models = selected_models | |
st.session_state.selected_aggregator = selected_aggregator | |
# Render the chats. | |
st_render_responses(user_prompt) | |
# Render chats generally even they are available, if the submit button isn't clicked. | |
elif st.session_state.responses: | |
st_render_responses(user_prompt) | |
# Judging. | |
if st.session_state.responses_collected: | |
with st.form(key="judging_form"): | |
st.markdown("#### Judging Configuration") | |
# Choose the type of assessment | |
assessment_type = st.radio( | |
"Select the type of assessment", | |
options=["Direct Assessment", "Pairwise Comparison"], | |
) | |
_, center_column, _ = st.columns([3, 5, 3]) | |
# Depending on the assessment type, render different forms | |
if assessment_type == "Direct Assessment": | |
# Direct assessment prompt. | |
with center_column.expander("Direct Assessment Prompt"): | |
direct_assessment_prompt = st.text_area( | |
"Prompt for the Direct Assessment", | |
value=get_default_direct_assessment_prompt( | |
user_prompt=user_prompt | |
), | |
height=500, | |
key="direct_assessment_prompt", | |
) | |
# TODO: Add option to edit criteria list with a basic text field. | |
criteria_list = DEFAULT_DIRECT_ASSESSMENT_CRITERIA_LIST | |
with center_column: | |
judging_submit_button = st.form_submit_button( | |
"Submit Judging", use_container_width=True | |
) | |
if judging_submit_button: | |
# Update session state. | |
st.session_state.assessment_type = assessment_type | |
if st.session_state.assessment_type == "Direct Assessment": | |
st.session_state.direct_assessment_config = { | |
"prompt": direct_assessment_prompt, | |
"criteria_list": criteria_list, | |
} | |
st_direct_assessment_results( | |
user_prompt=st.session_state.user_prompt, | |
direct_assessment_prompt=direct_assessment_prompt, | |
criteria_list=criteria_list, | |
) | |
# If judging is complete, but the submit button is cleared, still render the results. | |
elif st.session_state.judging_status == "complete": | |
if st.session_state.assessment_type == "Direct Assessment": | |
st_direct_assessment_results( | |
user_prompt=st.session_state.user_prompt, | |
direct_assessment_prompt=direct_assessment_prompt, | |
criteria_list=criteria_list, | |
) | |
# Judging is complete. | |
# Render stuff that would be rendered that's not stream-specific. | |
# The session state now contains the overall scores for each response from each judge. | |
if st.session_state.judging_status == "complete": | |
st.write("#### Results") | |
overall_scores_df_raw = pd.DataFrame( | |
st.session_state.direct_assessment_overall_scores | |
).reset_index() | |
overall_scores_df = pd.melt( | |
overall_scores_df_raw, | |
id_vars=["index"], | |
var_name="response_model", | |
value_name="score", | |
).rename(columns={"index": "judging_model"}) | |
# Print the overall winner. | |
overall_winner = overall_scores_df.loc[ | |
overall_scores_df["score"].idxmax() | |
] | |
st.write( | |
f"**Overall Winner:** {get_ui_friendly_name(overall_winner['response_model'])}" | |
) | |
# Find how much the standard deviation overlaps with other models | |
# TODO: Calculate separability. | |
st.write(f"**Confidence:** {overall_winner['score']:.2f}") | |
left_column, right_column = st.columns([1, 1]) | |
with left_column: | |
plot_overall_scores(overall_scores_df) | |
with right_column: | |
# All overall scores. | |
overall_scores_df = overall_scores_df[ | |
["response_model", "judging_model", "score"] | |
] | |
overall_scores_df["response_model"] = overall_scores_df[ | |
"response_model" | |
].apply(get_ui_friendly_name) | |
# overall_scores_df["judging_model"] = overall_scores_df[ | |
# "judging_model" | |
# ].apply(get_ui_friendly_name) | |
with st.expander("Overall scores from all judges"): | |
st.write(st.session_state.direct_assessment_overall_scores) | |
st.dataframe(overall_scores_df_raw) | |
st.dataframe(overall_scores_df) | |
# All criteria scores. | |
with right_column: | |
all_scores_df = pd.DataFrame() | |
for ( | |
response_model, | |
score_df, | |
) in st.session_state.direct_assessment_judging_df.items(): | |
score_df["response_model"] = response_model | |
all_scores_df = pd.concat([all_scores_df, score_df]) | |
all_scores_df = all_scores_df.reset_index() | |
all_scores_df = all_scores_df.drop(columns="index") | |
# Reorder the columns | |
all_scores_df = all_scores_df[ | |
[ | |
"response_model", | |
"judging_model", | |
"criteria", | |
"score", | |
"explanation", | |
] | |
] | |
# all_scores_df["response_model"] = all_scores_df[ | |
# "response_model" | |
# ].apply(get_ui_friendly_name) | |
# all_scores_df["judging_model"] = all_scores_df[ | |
# "judging_model" | |
# ].apply(get_ui_friendly_name) | |
with st.expander( | |
"Criteria-specific scores and explanations from all judges" | |
): | |
st.dataframe(all_scores_df) | |
# Token usage. | |
if st.session_state.responses: | |
st.divider() | |
with st.expander("Token Usage"): | |
st.write("Input tokens used.") | |
st.write(st.session_state.input_token_usage) | |
st.write( | |
f"Input Tokens Total: {sum(st.session_state.input_token_usage.values())}" | |
) | |
st.write("Output tokens used.") | |
st.write(st.session_state.output_token_usage) | |
st.write( | |
f"Output Tokens Total: {sum(st.session_state.output_token_usage.values())}" | |
) | |
else: | |
with cols[1]: | |
st.warning("Please log in to access this app.") | |
if __name__ == "__main__": | |
main() | |