sandbox / app.py
justinxzhao's picture
Added per-response plots.
3e0f8f8
raw
history blame
22.1 kB
import os
import streamlit as st
import dotenv
import openai
from openai import OpenAI
import anthropic
from together import Together
import google.generativeai as genai
import time
from typing import List, Optional, Literal, Union, Dict
from constants import (
LLM_COUNCIL_MEMBERS,
PROVIDER_TO_AVATAR_MAP,
AGGREGATORS,
LLM_TO_UI_NAME_MAP,
)
from prompts import *
from judging_dataclasses import DirectAssessmentJudgingResponse
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
dotenv.load_dotenv()
PASSWORD = os.getenv("APP_PASSWORD")
# Load API keys from environment variables
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY")
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
TOGETHER_API_KEY = os.getenv("TOGETHER_API_KEY")
# Initialize API clients
together_client = Together(api_key=TOGETHER_API_KEY)
genai.configure(api_key=GOOGLE_API_KEY)
# Set up API clients for OpenAI and Anthropic
openai.api_key = OPENAI_API_KEY
openai_client = OpenAI(
organization="org-kUoRSK0nOw4W2nQYMVGWOt03",
project="proj_zb6k1DdgnSEbiAEMWxSOVVu4",
)
# anthropic_client = anthropic.Client(api_key=ANTHROPIC_API_KEY)
anthropic_client = anthropic.Anthropic()
client = OpenAI()
def anthropic_streamlit_streamer(stream):
"""
Process the Anthropic streaming response and yield content from the deltas.
:param stream: Streaming object from Anthropic API
:return: Yields content (text) from the streaming response.
"""
for event in stream:
if hasattr(event, "type"):
# Handle content blocks
if event.type == "content_block_delta" and hasattr(event, "delta"):
# Extract text delta from the event
text_delta = getattr(event.delta, "text", None)
if text_delta:
yield text_delta
# Handle message completion events (optional if needed)
elif event.type == "message_stop":
break # End of message, stop streaming
def google_streamlit_streamer(stream):
for chunk in stream:
yield chunk.text
def together_streamlit_streamer(stream):
for chunk in stream:
yield chunk.choices[0].delta.content
def llm_streamlit_streamer(stream, llm):
if llm.startswith("anthropic"):
return anthropic_streamlit_streamer(stream)
elif llm.startswith("vertex"):
return google_streamlit_streamer(stream)
elif llm.startswith("together"):
return together_streamlit_streamer(stream)
# Helper functions for LLM council and aggregator selection
def llm_council_selector():
selected_council = st.radio(
"Choose a council configuration", options=list(LLM_COUNCIL_MEMBERS.keys())
)
return LLM_COUNCIL_MEMBERS[selected_council]
def aggregator_selector():
return st.radio("Choose an aggregator LLM", options=AGGREGATORS)
# API calls for different providers
def get_openai_response(model_name, prompt):
return openai_client.chat.completions.create(
model=model_name,
messages=[{"role": "user", "content": prompt}],
stream=True,
)
# https://docs.anthropic.com/en/api/messages-streaming
def get_anthropic_response(model_name, prompt):
return anthropic_client.messages.create(
max_tokens=1024,
messages=[{"role": "user", "content": prompt}],
model=model_name,
stream=True,
)
def get_together_response(model_name, prompt):
return together_client.chat.completions.create(
model=model_name,
messages=[{"role": "user", "content": prompt}],
stream=True,
)
# https://ai.google.dev/gemini-api/docs/text-generation?lang=python
def get_google_response(model_name, prompt):
model = genai.GenerativeModel(model_name)
return model.generate_content(prompt, stream=True)
def get_llm_response_stream(model_identifier, prompt):
"""Returns a streamlit-friendly stream of response tokens from the LLM."""
provider, model_name = model_identifier.split("://")
if provider == "openai":
return get_openai_response(model_name, prompt)
elif provider == "anthropic":
return anthropic_streamlit_streamer(get_anthropic_response(model_name, prompt))
elif provider == "together":
return together_streamlit_streamer(get_together_response(model_name, prompt))
elif provider == "vertex":
return google_streamlit_streamer(get_google_response(model_name, prompt))
else:
return None
def get_response_key(model):
return model + "__response"
def get_model_from_response_key(response_key):
return response_key.split("__")[0]
def get_direct_assessment_judging_key(judge_model, response_model):
return "direct_assessment_judge__" + judge_model + "__" + response_model
def get_aggregator_response_key(model):
return model + "__aggregator_response"
def create_dataframe_for_direct_assessment_judging_response(
response: DirectAssessmentJudgingResponse,
):
# Initialize empty list to collect data
data = []
# Loop through models
for judging_model in response.judging_models:
model_name = judging_model.model
# Loop through criteria_scores
for criteria_score in judging_model.criteria_scores:
data.append(
{
"llm_judge_model": model_name,
"criteria": criteria_score.criterion,
"score": criteria_score.score,
"explanation": criteria_score.explanation,
}
)
# Create DataFrame
return pd.DataFrame(data)
# Streamlit form UI
def render_criteria_form(criteria_num):
"""Render a criteria input form."""
with st.expander(f"Criteria {criteria_num + 1}"):
name = st.text_input(f"Name for Criteria {criteria_num + 1}")
description = st.text_area(f"Description for Criteria {criteria_num + 1}")
min_score = st.number_input(
f"Min Score for Criteria {criteria_num + 1}", min_value=0, step=1
)
max_score = st.number_input(
f"Max Score for Criteria {criteria_num + 1}", min_value=0, step=1
)
return Criteria(
name=name, description=description, min_score=min_score, max_score=max_score
)
def get_response_mapping():
# Inspect the session state for all the responses.
# This is a dictionary mapping model names to their responses.
# The aggregator response is also included in this mapping under the key "<model>__aggregator_response".
response_mapping = {}
for key in st.session_state.keys():
if "judge" in key:
continue
if key.endswith("__response"):
response_mapping[get_model_from_response_key(key)] = st.session_state[key]
if key.endswith("__aggregator_response"):
response_mapping[key] = st.session_state[key]
return response_mapping
def format_likert_comparison_options(options):
return "\n".join([f"{i + 1}: {option}" for i, option in enumerate(options)])
def format_criteria_list(criteria_list):
return "\n".join(
[f"{criteria.name}: {criteria.description}" for criteria in criteria_list]
)
def get_direct_assessment_prompt(
direct_assessment_prompt, user_prompt, response, criteria_list, options
):
return direct_assessment_prompt.format(
user_prompt=user_prompt,
response=response,
criteria_list=f"{format_criteria_list(DEFAULT_DIRECT_ASSESSMENT_CRITERIA_LIST)}",
options=f"{format_likert_comparison_options(SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS)}",
)
def get_default_direct_assessment_prompt(user_prompt):
return get_direct_assessment_prompt(
direct_assessment_prompt=DEFAULT_DIRECT_ASSESSMENT_PROMPT,
user_prompt=user_prompt,
response="{response}",
criteria_list=DEFAULT_DIRECT_ASSESSMENT_CRITERIA_LIST,
options=SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS,
)
def get_aggregator_prompt(aggregator_prompt, user_prompt, llms):
responses_from_other_llms = "\n\n".join(
[
f"{get_ui_friendly_name(model)} START\n{st.session_state.get(get_response_key(model))}\n\n{get_ui_friendly_name(model)} END\n\n\n"
for model in llms
]
)
return aggregator_prompt.format(
user_prompt=user_prompt,
responses_from_other_llms=responses_from_other_llms,
)
def get_default_aggregator_prompt(user_prompt, llms):
return get_aggregator_prompt(
DEFAULT_AGGREGATOR_PROMPT,
user_prompt=user_prompt,
llms=llms,
)
def get_ui_friendly_name(llm):
return LLM_TO_UI_NAME_MAP.get(llm, llm)
def get_parse_judging_response_for_direct_assessment_prompt(
judging_responses: dict[str, str],
criteria_list,
options,
):
formatted_judging_responses = "\n\n".join(
[
f"{get_ui_friendly_name(model)} START\n{judging_responses[model]}\n\n{get_ui_friendly_name(model)} END\n\n\n"
for model in judging_responses.keys()
]
)
return PARSE_JUDGING_RESPONSE_FOR_DIRECT_ASSESSMENT_PROMPT.format(
judging_responses=formatted_judging_responses,
criteria_list=format_criteria_list(criteria_list),
options=format_likert_comparison_options(options),
)
def get_model_from_direct_assessment_judging_key(judging_key):
return judging_key.split("__")[1]
def get_direct_assessment_judging_responses():
# Get the judging responses from the session state.
judging_responses = {}
for key in st.session_state.keys():
if key.startswith("direct_assessment_judge__"):
judging_responses[get_model_from_direct_assessment_judging_key(key)] = (
st.session_state[key]
)
return judging_responses
def parse_judging_responses(prompt: str) -> DirectAssessmentJudgingResponse:
completion = client.beta.chat.completions.parse(
model="gpt-4o-mini",
messages=[
{
"role": "system",
"content": "Parse the judging responses into structured data.",
},
{"role": "user", "content": prompt},
],
response_format=DirectAssessmentJudgingResponse,
)
return completion.choices[0].message.parsed
def plot_criteria_scores(df):
# Group by criteria and calculate mean and std over all judges.
grouped = df.groupby(["criteria"]).agg({"score": ["mean", "std"]}).reset_index()
# Flatten the MultiIndex columns
grouped.columns = ["criteria", "mean_score", "std_score"]
# Fill NaN std with zeros (in case there's only one score per group)
grouped["std_score"] = grouped["std_score"].fillna(0)
# Set up the plot
plt.figure(figsize=(8, 5))
# Create a horizontal bar plot
ax = sns.barplot(
data=grouped,
x="mean_score",
y="criteria",
hue="criteria",
errorbar=None, # Updated parameter
orient="h",
)
# Add error bars manually
# Iterate over the bars and add error bars
for i, (mean, std) in enumerate(zip(grouped["mean_score"], grouped["std_score"])):
# Get the current bar
bar = ax.patches[i]
# Calculate the center of the bar
center = bar.get_y() + bar.get_height() / 2
# Add the error bar
ax.errorbar(x=mean, y=center, xerr=std, ecolor="black", capsize=3, fmt="none")
# Set labels and title
ax.set_xlabel("")
ax.set_ylabel("")
plt.tight_layout()
# Display the plot in Streamlit
st.pyplot(plt.gcf())
# Main Streamlit App
def main():
st.set_page_config(
page_title="Language Model Council Sandbox", page_icon="🏛️", layout="wide"
)
# Custom CSS for the chat display
center_css = """
<style>
h1, h2, h3, h6 { text-align: center; }
.chat-container {
display: flex;
align-items: flex-start;
margin-bottom: 10px;
}
.avatar {
width: 50px;
margin-right: 10px;
}
.message {
background-color: #f1f1f1;
padding: 10px;
border-radius: 10px;
width: 100%;
}
</style>
"""
st.markdown(center_css, unsafe_allow_html=True)
# App title and description
st.title("Language Model Council Sandbox")
st.markdown("###### Invoke a council of LLMs to generate and judge each other.")
st.markdown("###### [Paper](https://arxiv.org/abs/2406.08598)")
# Authentication system
if "authenticated" not in st.session_state:
st.session_state.authenticated = False
cols = st.columns([2, 1, 2])
if not st.session_state.authenticated:
with cols[1]:
password = st.text_input("Password", type="password")
if st.button("Login", use_container_width=True):
if password == PASSWORD:
st.session_state.authenticated = True
else:
st.error("Invalid credentials")
if st.session_state.authenticated:
st.success("Logged in successfully!")
# Council and aggregator selection
selected_models = llm_council_selector()
st.write("Selected Models:", selected_models)
selected_aggregator = aggregator_selector()
# Prompt input
user_prompt = st.text_area("Enter your prompt:")
if st.button("Submit"):
st.write("Responses:")
response_columns = st.columns(3)
selected_models_to_streamlit_column_map = {
model: response_columns[i] for i, model in enumerate(selected_models)
}
# Fetching and streaming responses from each selected model
for selected_model in selected_models:
with selected_models_to_streamlit_column_map[selected_model]:
st.write(get_ui_friendly_name(selected_model))
with st.chat_message(
selected_model,
avatar=PROVIDER_TO_AVATAR_MAP[selected_model],
):
message_placeholder = st.empty()
stream = get_llm_response_stream(selected_model, user_prompt)
if stream:
st.session_state[get_response_key(selected_model)] = (
message_placeholder.write_stream(stream)
)
# Get the aggregator prompt.
aggregator_prompt = get_default_aggregator_prompt(
user_prompt=user_prompt, llms=selected_models
)
with st.expander("Aggregator Prompt"):
st.code(aggregator_prompt)
# Fetching and streaming response from the aggregator
st.write(
f"Mixture-of-Agents response from {get_ui_friendly_name(selected_aggregator)}"
)
with st.chat_message(
selected_aggregator,
avatar=PROVIDER_TO_AVATAR_MAP[selected_aggregator],
):
message_placeholder = st.empty()
aggregator_stream = get_llm_response_stream(
selected_aggregator, aggregator_prompt
)
if aggregator_stream:
message_placeholder.write_stream(aggregator_stream)
st.session_state[
get_aggregator_response_key(selected_aggregator)
] = message_placeholder.write_stream(aggregator_stream)
# Judging.
st.markdown("#### Judging Configuration Form")
# Choose the type of assessment
assessment_type = st.radio(
"Select the type of assessment",
options=["Direct Assessment", "Pairwise Comparison"],
)
# Depending on the assessment type, render different forms
if assessment_type == "Direct Assessment":
with st.expander("Direct Assessment Prompt"):
direct_assessment_prompt = st.text_area(
"Prompt for the Direct Assessment",
value=get_default_direct_assessment_prompt(user_prompt=user_prompt),
height=500,
)
# TODO: Add option to edit criteria list with a basic text field.
criteria_list = DEFAULT_DIRECT_ASSESSMENT_CRITERIA_LIST
# Create DirectAssessment object when form is submitted
if st.button("Submit Direct Assessment"):
# Submit direct asssessment.
responses_for_judging = get_response_mapping()
response_judging_columns = st.columns(3)
responses_for_judging_to_streamlit_column_map = {
model: response_judging_columns[i % 3]
for i, model in enumerate(responses_for_judging.keys())
}
# Get judging responses.
for response_model, response in responses_for_judging.items():
st_column = responses_for_judging_to_streamlit_column_map[
response_model
]
with st_column:
if "aggregator_response" in response_model:
judging_model_header = "Mixture-of-Agents Response"
else:
judging_model_header = get_ui_friendly_name(response_model)
st.write(f"Judging for {judging_model_header}")
judging_prompt = get_direct_assessment_prompt(
direct_assessment_prompt=direct_assessment_prompt,
user_prompt=user_prompt,
response=response,
criteria_list=criteria_list,
options=SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS,
)
with st.expander("Final Judging Prompt"):
st.code(judging_prompt)
for judging_model in selected_models:
with st.expander(
get_ui_friendly_name(judging_model), expanded=False
):
with st.chat_message(
judging_model,
avatar=PROVIDER_TO_AVATAR_MAP[judging_model],
):
message_placeholder = st.empty()
judging_stream = get_llm_response_stream(
judging_model, judging_prompt
)
if judging_stream:
st.session_state[
get_direct_assessment_judging_key(
judging_model, response_model
)
] = message_placeholder.write_stream(
judging_stream
)
# When all of the judging is finished for the given response, get the actual
# values, parsed (use gpt-4o-mini for now) with json mode.
# TODO.
judging_responses = get_direct_assessment_judging_responses()
parse_judging_response_prompt = (
get_parse_judging_response_for_direct_assessment_prompt(
judging_responses,
criteria_list,
SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS,
)
)
# Issue the prompt to openai mini with structured outputs
parsed_judging_responses = parse_judging_responses(
parse_judging_response_prompt
)
df = create_dataframe_for_direct_assessment_judging_response(
parsed_judging_responses
)
st.write(df)
# Log the output using st.write() under an st.expander
# with st.expander("Parsed Judging Responses", expanded=True):
# st.write(parsed_judging_responses)
plot_criteria_scores(df)
# TODO: Use parsed_judging_responses for further processing or display
elif assessment_type == "Pairwise Comparison":
pairwise_comparison_prompt = st.text_area(
"Prompt for the Pairwise Comparison"
)
granularity = st.selectbox("Granularity", ["coarse", "fine", "super fine"])
ties_allowed = st.checkbox("Are ties allowed?")
position_swapping = st.checkbox("Enable position swapping?")
reference_model = st.text_input("Reference Model")
# Create PairwiseComparison object when form is submitted
if st.button("Submit Pairwise Comparison"):
pairwise_comparison_config = PairwiseComparison(
type="pairwise_comparison",
granularity=granularity,
ties_allowed=ties_allowed,
position_swapping=position_swapping,
reference_model=reference_model,
prompt=prompt,
)
st.success(f"Pairwise Comparison Created: {pairwise_comparison_config}")
# Submit pairwise comparison.
responses_for_judging = get_response_mapping()
else:
with cols[1]:
st.warning("Please log in to access this app.")
if __name__ == "__main__":
main()