Spaces:
Sleeping
Sleeping
import os | |
import streamlit as st | |
import dotenv | |
import openai | |
from openai import OpenAI | |
import anthropic | |
from together import Together | |
import google.generativeai as genai | |
import time | |
from typing import List, Optional, Literal, Union, Dict | |
from constants import ( | |
LLM_COUNCIL_MEMBERS, | |
PROVIDER_TO_AVATAR_MAP, | |
AGGREGATORS, | |
LLM_TO_UI_NAME_MAP, | |
) | |
from prompts import * | |
from judging_dataclasses import DirectAssessmentJudgingResponse | |
import pandas as pd | |
import seaborn as sns | |
import matplotlib.pyplot as plt | |
dotenv.load_dotenv() | |
PASSWORD = os.getenv("APP_PASSWORD") | |
# Load API keys from environment variables | |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY") | |
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") | |
TOGETHER_API_KEY = os.getenv("TOGETHER_API_KEY") | |
# Initialize API clients | |
together_client = Together(api_key=TOGETHER_API_KEY) | |
genai.configure(api_key=GOOGLE_API_KEY) | |
# Set up API clients for OpenAI and Anthropic | |
openai.api_key = OPENAI_API_KEY | |
openai_client = OpenAI( | |
organization="org-kUoRSK0nOw4W2nQYMVGWOt03", | |
project="proj_zb6k1DdgnSEbiAEMWxSOVVu4", | |
) | |
# anthropic_client = anthropic.Client(api_key=ANTHROPIC_API_KEY) | |
anthropic_client = anthropic.Anthropic() | |
client = OpenAI() | |
def anthropic_streamlit_streamer(stream): | |
""" | |
Process the Anthropic streaming response and yield content from the deltas. | |
:param stream: Streaming object from Anthropic API | |
:return: Yields content (text) from the streaming response. | |
""" | |
for event in stream: | |
if hasattr(event, "type"): | |
# Handle content blocks | |
if event.type == "content_block_delta" and hasattr(event, "delta"): | |
# Extract text delta from the event | |
text_delta = getattr(event.delta, "text", None) | |
if text_delta: | |
yield text_delta | |
# Handle message completion events (optional if needed) | |
elif event.type == "message_stop": | |
break # End of message, stop streaming | |
def google_streamlit_streamer(stream): | |
for chunk in stream: | |
yield chunk.text | |
def together_streamlit_streamer(stream): | |
for chunk in stream: | |
yield chunk.choices[0].delta.content | |
def llm_streamlit_streamer(stream, llm): | |
if llm.startswith("anthropic"): | |
return anthropic_streamlit_streamer(stream) | |
elif llm.startswith("vertex"): | |
return google_streamlit_streamer(stream) | |
elif llm.startswith("together"): | |
return together_streamlit_streamer(stream) | |
# Helper functions for LLM council and aggregator selection | |
def llm_council_selector(): | |
selected_council = st.radio( | |
"Choose a council configuration", options=list(LLM_COUNCIL_MEMBERS.keys()) | |
) | |
return LLM_COUNCIL_MEMBERS[selected_council] | |
def aggregator_selector(): | |
return st.radio("Choose an aggregator LLM", options=AGGREGATORS) | |
# API calls for different providers | |
def get_openai_response(model_name, prompt): | |
return openai_client.chat.completions.create( | |
model=model_name, | |
messages=[{"role": "user", "content": prompt}], | |
stream=True, | |
) | |
# https://docs.anthropic.com/en/api/messages-streaming | |
def get_anthropic_response(model_name, prompt): | |
return anthropic_client.messages.create( | |
max_tokens=1024, | |
messages=[{"role": "user", "content": prompt}], | |
model=model_name, | |
stream=True, | |
) | |
def get_together_response(model_name, prompt): | |
return together_client.chat.completions.create( | |
model=model_name, | |
messages=[{"role": "user", "content": prompt}], | |
stream=True, | |
) | |
# https://ai.google.dev/gemini-api/docs/text-generation?lang=python | |
def get_google_response(model_name, prompt): | |
model = genai.GenerativeModel(model_name) | |
return model.generate_content(prompt, stream=True) | |
def get_llm_response_stream(model_identifier, prompt): | |
"""Returns a streamlit-friendly stream of response tokens from the LLM.""" | |
provider, model_name = model_identifier.split("://") | |
if provider == "openai": | |
return get_openai_response(model_name, prompt) | |
elif provider == "anthropic": | |
return anthropic_streamlit_streamer(get_anthropic_response(model_name, prompt)) | |
elif provider == "together": | |
return together_streamlit_streamer(get_together_response(model_name, prompt)) | |
elif provider == "vertex": | |
return google_streamlit_streamer(get_google_response(model_name, prompt)) | |
else: | |
return None | |
def get_response_key(model): | |
return model + "__response" | |
def get_model_from_response_key(response_key): | |
return response_key.split("__")[0] | |
def get_direct_assessment_judging_key(judge_model, response_model): | |
return "direct_assessment_judge__" + judge_model + "__" + response_model | |
def get_aggregator_response_key(model): | |
return model + "__aggregator_response" | |
def create_dataframe_for_direct_assessment_judging_response( | |
response: DirectAssessmentJudgingResponse, | |
): | |
# Initialize empty list to collect data | |
data = [] | |
# Loop through models | |
for judging_model in response.judging_models: | |
model_name = judging_model.model | |
# Loop through criteria_scores | |
for criteria_score in judging_model.criteria_scores: | |
data.append( | |
{ | |
"llm_judge_model": model_name, | |
"criteria": criteria_score.criterion, | |
"score": criteria_score.score, | |
"explanation": criteria_score.explanation, | |
} | |
) | |
# Create DataFrame | |
return pd.DataFrame(data) | |
# Streamlit form UI | |
def render_criteria_form(criteria_num): | |
"""Render a criteria input form.""" | |
with st.expander(f"Criteria {criteria_num + 1}"): | |
name = st.text_input(f"Name for Criteria {criteria_num + 1}") | |
description = st.text_area(f"Description for Criteria {criteria_num + 1}") | |
min_score = st.number_input( | |
f"Min Score for Criteria {criteria_num + 1}", min_value=0, step=1 | |
) | |
max_score = st.number_input( | |
f"Max Score for Criteria {criteria_num + 1}", min_value=0, step=1 | |
) | |
return Criteria( | |
name=name, description=description, min_score=min_score, max_score=max_score | |
) | |
def get_response_mapping(): | |
# Inspect the session state for all the responses. | |
# This is a dictionary mapping model names to their responses. | |
# The aggregator response is also included in this mapping under the key "<model>__aggregator_response". | |
response_mapping = {} | |
for key in st.session_state.keys(): | |
if "judge" in key: | |
continue | |
if key.endswith("__response"): | |
response_mapping[get_model_from_response_key(key)] = st.session_state[key] | |
if key.endswith("__aggregator_response"): | |
response_mapping[key] = st.session_state[key] | |
return response_mapping | |
def format_likert_comparison_options(options): | |
return "\n".join([f"{i + 1}: {option}" for i, option in enumerate(options)]) | |
def format_criteria_list(criteria_list): | |
return "\n".join( | |
[f"{criteria.name}: {criteria.description}" for criteria in criteria_list] | |
) | |
def get_direct_assessment_prompt( | |
direct_assessment_prompt, user_prompt, response, criteria_list, options | |
): | |
return direct_assessment_prompt.format( | |
user_prompt=user_prompt, | |
response=response, | |
criteria_list=f"{format_criteria_list(DEFAULT_DIRECT_ASSESSMENT_CRITERIA_LIST)}", | |
options=f"{format_likert_comparison_options(SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS)}", | |
) | |
def get_default_direct_assessment_prompt(user_prompt): | |
return get_direct_assessment_prompt( | |
direct_assessment_prompt=DEFAULT_DIRECT_ASSESSMENT_PROMPT, | |
user_prompt=user_prompt, | |
response="{response}", | |
criteria_list=DEFAULT_DIRECT_ASSESSMENT_CRITERIA_LIST, | |
options=SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS, | |
) | |
def get_aggregator_prompt(aggregator_prompt, user_prompt, llms): | |
responses_from_other_llms = "\n\n".join( | |
[ | |
f"{get_ui_friendly_name(model)} START\n{st.session_state.get(get_response_key(model))}\n\n{get_ui_friendly_name(model)} END\n\n\n" | |
for model in llms | |
] | |
) | |
return aggregator_prompt.format( | |
user_prompt=user_prompt, | |
responses_from_other_llms=responses_from_other_llms, | |
) | |
def get_default_aggregator_prompt(user_prompt, llms): | |
return get_aggregator_prompt( | |
DEFAULT_AGGREGATOR_PROMPT, | |
user_prompt=user_prompt, | |
llms=llms, | |
) | |
def get_ui_friendly_name(llm): | |
return LLM_TO_UI_NAME_MAP.get(llm, llm) | |
def get_parse_judging_response_for_direct_assessment_prompt( | |
judging_responses: dict[str, str], | |
criteria_list, | |
options, | |
): | |
formatted_judging_responses = "\n\n".join( | |
[ | |
f"{get_ui_friendly_name(model)} START\n{judging_responses[model]}\n\n{get_ui_friendly_name(model)} END\n\n\n" | |
for model in judging_responses.keys() | |
] | |
) | |
return PARSE_JUDGING_RESPONSE_FOR_DIRECT_ASSESSMENT_PROMPT.format( | |
judging_responses=formatted_judging_responses, | |
criteria_list=format_criteria_list(criteria_list), | |
options=format_likert_comparison_options(options), | |
) | |
def get_model_from_direct_assessment_judging_key(judging_key): | |
return judging_key.split("__")[1] | |
def get_direct_assessment_judging_responses(): | |
# Get the judging responses from the session state. | |
judging_responses = {} | |
for key in st.session_state.keys(): | |
if key.startswith("direct_assessment_judge__"): | |
judging_responses[get_model_from_direct_assessment_judging_key(key)] = ( | |
st.session_state[key] | |
) | |
return judging_responses | |
def parse_judging_responses(prompt: str) -> DirectAssessmentJudgingResponse: | |
completion = client.beta.chat.completions.parse( | |
model="gpt-4o-mini", | |
messages=[ | |
{ | |
"role": "system", | |
"content": "Parse the judging responses into structured data.", | |
}, | |
{"role": "user", "content": prompt}, | |
], | |
response_format=DirectAssessmentJudgingResponse, | |
) | |
return completion.choices[0].message.parsed | |
def plot_criteria_scores(df): | |
# Group by criteria and calculate mean and std over all judges. | |
grouped = df.groupby(["criteria"]).agg({"score": ["mean", "std"]}).reset_index() | |
# Flatten the MultiIndex columns | |
grouped.columns = ["criteria", "mean_score", "std_score"] | |
# Fill NaN std with zeros (in case there's only one score per group) | |
grouped["std_score"] = grouped["std_score"].fillna(0) | |
# Set up the plot | |
plt.figure(figsize=(8, 5)) | |
# Create a horizontal bar plot | |
ax = sns.barplot( | |
data=grouped, | |
x="mean_score", | |
y="criteria", | |
hue="criteria", | |
errorbar=None, # Updated parameter | |
orient="h", | |
) | |
# Add error bars manually | |
# Iterate over the bars and add error bars | |
for i, (mean, std) in enumerate(zip(grouped["mean_score"], grouped["std_score"])): | |
# Get the current bar | |
bar = ax.patches[i] | |
# Calculate the center of the bar | |
center = bar.get_y() + bar.get_height() / 2 | |
# Add the error bar | |
ax.errorbar(x=mean, y=center, xerr=std, ecolor="black", capsize=3, fmt="none") | |
# Set labels and title | |
ax.set_xlabel("") | |
ax.set_ylabel("") | |
plt.tight_layout() | |
# Display the plot in Streamlit | |
st.pyplot(plt.gcf()) | |
# Main Streamlit App | |
def main(): | |
st.set_page_config( | |
page_title="Language Model Council Sandbox", page_icon="🏛️", layout="wide" | |
) | |
# Custom CSS for the chat display | |
center_css = """ | |
<style> | |
h1, h2, h3, h6 { text-align: center; } | |
.chat-container { | |
display: flex; | |
align-items: flex-start; | |
margin-bottom: 10px; | |
} | |
.avatar { | |
width: 50px; | |
margin-right: 10px; | |
} | |
.message { | |
background-color: #f1f1f1; | |
padding: 10px; | |
border-radius: 10px; | |
width: 100%; | |
} | |
</style> | |
""" | |
st.markdown(center_css, unsafe_allow_html=True) | |
# App title and description | |
st.title("Language Model Council Sandbox") | |
st.markdown("###### Invoke a council of LLMs to generate and judge each other.") | |
st.markdown("###### [Paper](https://arxiv.org/abs/2406.08598)") | |
# Authentication system | |
if "authenticated" not in st.session_state: | |
st.session_state.authenticated = False | |
cols = st.columns([2, 1, 2]) | |
if not st.session_state.authenticated: | |
with cols[1]: | |
password = st.text_input("Password", type="password") | |
if st.button("Login", use_container_width=True): | |
if password == PASSWORD: | |
st.session_state.authenticated = True | |
else: | |
st.error("Invalid credentials") | |
if st.session_state.authenticated: | |
st.success("Logged in successfully!") | |
# Council and aggregator selection | |
selected_models = llm_council_selector() | |
st.write("Selected Models:", selected_models) | |
selected_aggregator = aggregator_selector() | |
# Prompt input | |
user_prompt = st.text_area("Enter your prompt:") | |
if st.button("Submit"): | |
st.write("Responses:") | |
response_columns = st.columns(3) | |
selected_models_to_streamlit_column_map = { | |
model: response_columns[i] for i, model in enumerate(selected_models) | |
} | |
# Fetching and streaming responses from each selected model | |
for selected_model in selected_models: | |
with selected_models_to_streamlit_column_map[selected_model]: | |
st.write(get_ui_friendly_name(selected_model)) | |
with st.chat_message( | |
selected_model, | |
avatar=PROVIDER_TO_AVATAR_MAP[selected_model], | |
): | |
message_placeholder = st.empty() | |
stream = get_llm_response_stream(selected_model, user_prompt) | |
if stream: | |
st.session_state[get_response_key(selected_model)] = ( | |
message_placeholder.write_stream(stream) | |
) | |
# Get the aggregator prompt. | |
aggregator_prompt = get_default_aggregator_prompt( | |
user_prompt=user_prompt, llms=selected_models | |
) | |
with st.expander("Aggregator Prompt"): | |
st.code(aggregator_prompt) | |
# Fetching and streaming response from the aggregator | |
st.write( | |
f"Mixture-of-Agents response from {get_ui_friendly_name(selected_aggregator)}" | |
) | |
with st.chat_message( | |
selected_aggregator, | |
avatar=PROVIDER_TO_AVATAR_MAP[selected_aggregator], | |
): | |
message_placeholder = st.empty() | |
aggregator_stream = get_llm_response_stream( | |
selected_aggregator, aggregator_prompt | |
) | |
if aggregator_stream: | |
message_placeholder.write_stream(aggregator_stream) | |
st.session_state[ | |
get_aggregator_response_key(selected_aggregator) | |
] = message_placeholder.write_stream(aggregator_stream) | |
# Judging. | |
st.markdown("#### Judging Configuration Form") | |
# Choose the type of assessment | |
assessment_type = st.radio( | |
"Select the type of assessment", | |
options=["Direct Assessment", "Pairwise Comparison"], | |
) | |
# Depending on the assessment type, render different forms | |
if assessment_type == "Direct Assessment": | |
with st.expander("Direct Assessment Prompt"): | |
direct_assessment_prompt = st.text_area( | |
"Prompt for the Direct Assessment", | |
value=get_default_direct_assessment_prompt(user_prompt=user_prompt), | |
height=500, | |
) | |
# TODO: Add option to edit criteria list with a basic text field. | |
criteria_list = DEFAULT_DIRECT_ASSESSMENT_CRITERIA_LIST | |
# Create DirectAssessment object when form is submitted | |
if st.button("Submit Direct Assessment"): | |
# Submit direct asssessment. | |
responses_for_judging = get_response_mapping() | |
response_judging_columns = st.columns(3) | |
responses_for_judging_to_streamlit_column_map = { | |
model: response_judging_columns[i % 3] | |
for i, model in enumerate(responses_for_judging.keys()) | |
} | |
# Get judging responses. | |
for response_model, response in responses_for_judging.items(): | |
st_column = responses_for_judging_to_streamlit_column_map[ | |
response_model | |
] | |
with st_column: | |
if "aggregator_response" in response_model: | |
judging_model_header = "Mixture-of-Agents Response" | |
else: | |
judging_model_header = get_ui_friendly_name(response_model) | |
st.write(f"Judging for {judging_model_header}") | |
judging_prompt = get_direct_assessment_prompt( | |
direct_assessment_prompt=direct_assessment_prompt, | |
user_prompt=user_prompt, | |
response=response, | |
criteria_list=criteria_list, | |
options=SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS, | |
) | |
with st.expander("Final Judging Prompt"): | |
st.code(judging_prompt) | |
for judging_model in selected_models: | |
with st.expander( | |
get_ui_friendly_name(judging_model), expanded=False | |
): | |
with st.chat_message( | |
judging_model, | |
avatar=PROVIDER_TO_AVATAR_MAP[judging_model], | |
): | |
message_placeholder = st.empty() | |
judging_stream = get_llm_response_stream( | |
judging_model, judging_prompt | |
) | |
if judging_stream: | |
st.session_state[ | |
get_direct_assessment_judging_key( | |
judging_model, response_model | |
) | |
] = message_placeholder.write_stream( | |
judging_stream | |
) | |
# When all of the judging is finished for the given response, get the actual | |
# values, parsed (use gpt-4o-mini for now) with json mode. | |
# TODO. | |
judging_responses = get_direct_assessment_judging_responses() | |
parse_judging_response_prompt = ( | |
get_parse_judging_response_for_direct_assessment_prompt( | |
judging_responses, | |
criteria_list, | |
SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS, | |
) | |
) | |
# Issue the prompt to openai mini with structured outputs | |
parsed_judging_responses = parse_judging_responses( | |
parse_judging_response_prompt | |
) | |
df = create_dataframe_for_direct_assessment_judging_response( | |
parsed_judging_responses | |
) | |
st.write(df) | |
# Log the output using st.write() under an st.expander | |
# with st.expander("Parsed Judging Responses", expanded=True): | |
# st.write(parsed_judging_responses) | |
plot_criteria_scores(df) | |
# TODO: Use parsed_judging_responses for further processing or display | |
elif assessment_type == "Pairwise Comparison": | |
pairwise_comparison_prompt = st.text_area( | |
"Prompt for the Pairwise Comparison" | |
) | |
granularity = st.selectbox("Granularity", ["coarse", "fine", "super fine"]) | |
ties_allowed = st.checkbox("Are ties allowed?") | |
position_swapping = st.checkbox("Enable position swapping?") | |
reference_model = st.text_input("Reference Model") | |
# Create PairwiseComparison object when form is submitted | |
if st.button("Submit Pairwise Comparison"): | |
pairwise_comparison_config = PairwiseComparison( | |
type="pairwise_comparison", | |
granularity=granularity, | |
ties_allowed=ties_allowed, | |
position_swapping=position_swapping, | |
reference_model=reference_model, | |
prompt=prompt, | |
) | |
st.success(f"Pairwise Comparison Created: {pairwise_comparison_config}") | |
# Submit pairwise comparison. | |
responses_for_judging = get_response_mapping() | |
else: | |
with cols[1]: | |
st.warning("Please log in to access this app.") | |
if __name__ == "__main__": | |
main() | |