Spaces:
Running
Running
import traceback | |
from time import sleep | |
import anthropic | |
import streamlit as st | |
import streamlit.components.v1 as components | |
from mistralai.client import MistralClient | |
from mistralai.models.chat_completion import ChatMessage | |
from openai import OpenAI | |
def create_client(client_class, api_key_name): | |
if api_key_name not in st.session_state: | |
return None | |
return client_class(api_key=st.session_state[api_key_name]) | |
# Create clients | |
openai_client = create_client(OpenAI, "openai_key") | |
claude_client = create_client(anthropic.Anthropic, "claude_key") | |
mistral_client = create_client(MistralClient, "mistral_key") | |
# Initialize counter | |
if "counter" not in st.session_state: | |
st.session_state["counter"] = 0 | |
# Increment counter to trigger javascript so that focus always will be on the input field. | |
def increment_counter(): | |
st.session_state.counter += 1 | |
# Create debate text from chat history | |
# First system prompt is added to the debate text. | |
# Then all the remaining messages are concatenated with preceding role titles. The last role title is the role of the next message. | |
def create_debate_text(role): | |
debate_text = "" | |
debate_text += common_system_prompt + "\n" | |
turn_titles = { | |
"openai": "\nChatGPT:", | |
"mistral": "\nMistral:", | |
"claude": "\nClaude:", | |
"user": "\nUser:", | |
"end": "End of debate", | |
} | |
if len(st.session_state.messages) == 0: | |
debate_text += "\n" + turn_titles[role] | |
return debate_text | |
for message in st.session_state.messages: | |
debate_text += "\n".join( | |
[turn_titles[message["role"]], message["content"], "\n"] | |
) | |
debate_text += "\n\n" + turn_titles[role] | |
return debate_text | |
# Initialize chat history | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
# Create sidebar | |
with st.sidebar: | |
st.image("assets/openai.svg", width=20) | |
openai_api_key = st.text_input("OpenAI API Key", type="password") | |
openai_system_prompt = st.text_area( | |
"OpenAI System Prompt", | |
value="You are ChatGPT. You are agreeing with the debate topic.", | |
) | |
if ( | |
"openai_key" not in st.session_state | |
or openai_api_key != st.session_state.openai_key | |
) and openai_api_key != "": | |
st.session_state.openai_key = openai_api_key | |
openai_client = OpenAI(api_key=openai_api_key) | |
st.toast("OpenAI API Key is set to test", icon="✔️") | |
st.divider() | |
st.image("assets/Mistral.svg", width=20) | |
mistral_api_key = st.text_input("mistral API Key", type="password") | |
mistral_system_prompt = st.text_area( | |
"mistral System Prompt", | |
value="You are mistral. You are disagreeing with the debate topic.", | |
) | |
if ( | |
"mistral_key" not in st.session_state | |
or mistral_api_key != st.session_state.mistral_key | |
) and mistral_api_key != "": | |
st.session_state.mistral_key = mistral_api_key | |
mistral_client = MistralClient(api_key=mistral_api_key) | |
st.toast("mistral API Key is set to test", icon="✔️") | |
st.divider() | |
st.image("assets/claude-ai-icon.svg", width=20) | |
claude_api_key = st.text_input("Claude API Key", type="password") | |
claude_system_prompt = st.text_area( | |
"Claude System Prompt", | |
value="You are Claude. You are neutral to the debate topic.", | |
) | |
if ( | |
"claude_key" not in st.session_state | |
or claude_api_key != st.session_state.claude_key | |
) and claude_api_key != "": | |
st.session_state.claude_key = claude_api_key | |
claude_client = anthropic.Anthropic(api_key=claude_api_key) | |
st.toast("Claude API Key is set to test", icon="✔️") | |
st.divider() | |
common_system_prompt = st.text_area( | |
"Common System Prompt", | |
value="Following is a conversation from a debate group. You will state your opinion when its your turn. User will transfer participants responses to you and your response to the participants so that you can communicate.", | |
height=300, | |
) | |
# Display chat messages from history on app rerun | |
with st.container(border=True): | |
for message in st.session_state.messages: | |
with st.chat_message(message["name"], avatar=message["avatar"]): | |
st.markdown(message["content"]) | |
with st.container(border=True): | |
def is_last_message_role(role): | |
if len(st.session_state.messages) == 0: | |
return False | |
return st.session_state.messages[-1]["role"] == role | |
def get_chatgpt_response(): | |
try: | |
debate_text = create_debate_text("openai") | |
completion = openai_client.chat.completions.create( | |
model="gpt-3.5-turbo", | |
messages=[ | |
{"role": "system", "content": openai_system_prompt}, | |
{ | |
"role": "user", | |
"content": debate_text, | |
}, | |
], | |
) | |
st.session_state.messages.append( | |
{ | |
"name": "user", | |
"role": "openai", | |
"content": completion.choices[0].message.content, | |
"avatar": "assets/openai.svg", | |
} | |
) | |
except Exception as e: | |
print(e, traceback.format_exc()) | |
def get_mistral_response(): | |
try: | |
debate_text = create_debate_text("mistral") | |
message = mistral_client.chat( | |
model="mistral-large-latest", | |
messages=[ | |
ChatMessage( | |
role="system", | |
content=mistral_system_prompt, | |
), | |
ChatMessage( | |
role="user", | |
content=debate_text, | |
), | |
], | |
) | |
st.session_state.messages.append( | |
{ | |
"name": "user", | |
"role": "mistral", | |
"content": message.choices[0].message.content, | |
"avatar": "assets/Mistral.svg", | |
} | |
) | |
except Exception as e: | |
print(e, traceback.format_exc()) | |
def get_claude_response(): | |
try: | |
debate_text = create_debate_text("claude") | |
message = claude_client.messages.create( | |
model="claude-3-sonnet-20240229", | |
max_tokens=1000, | |
temperature=0, | |
system=claude_system_prompt, | |
messages=[ | |
{ | |
"role": "user", | |
"content": [ | |
{ | |
"type": "text", | |
"text": debate_text, | |
} | |
], | |
}, | |
], | |
) | |
st.session_state.messages.append( | |
{ | |
"name": "user", | |
"role": "claude", | |
"content": message.content[0].text, | |
"avatar": "assets/claude-ai-icon.svg", | |
} | |
) | |
except Exception as e: | |
print(e, traceback.format_exc()) | |
# React to user input | |
with st.container(): | |
if prompt := st.chat_input( | |
"Start the conversation.", on_submit=increment_counter | |
): | |
st.session_state.messages.append( | |
{"name": "user", "role": "user", "content": prompt, "avatar": "❔"} | |
) | |
st.rerun() | |
with st.container(border=False): | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
st.button( | |
"ChatGPT", | |
on_click=get_chatgpt_response, | |
disabled=is_last_message_role("openai"), | |
type="primary", | |
) | |
with col2: | |
st.button( | |
"Mistral", | |
on_click=get_mistral_response, | |
disabled=is_last_message_role("mistral"), | |
type="primary", | |
) | |
with col3: | |
st.button( | |
"Claude", | |
on_click=get_claude_response, | |
disabled=is_last_message_role("claude"), | |
type="primary", | |
) | |
col1, col2 = st.columns(2) | |
with col1: | |
st.button( | |
"Clear chat", on_click=lambda: st.session_state.pop("messages", None) | |
) | |
with col2: | |
# save chat history to file | |
st.download_button( | |
"Save chat history", | |
data=create_debate_text("end"), | |
file_name="chat_history.txt", | |
mime="text/plain", | |
) | |
custom_css = """ | |
<style> | |
.stButton{ | |
display: flex; | |
justify-content: center; | |
align-items: center; | |
} | |
</style> | |
""" | |
st.markdown(custom_css, unsafe_allow_html=True) | |
components.html( | |
f""" | |
<div>some hidden container</div> | |
<p>{st.session_state.counter}</p> | |
<script> | |
console.log("Hello from the other side"); | |
var input = window.parent.document.querySelectorAll("textarea[type=textarea]"); | |
console.log(input); | |
for (var i = 0; i < input.length; ++i) {{ | |
console.log(input[i]); | |
input[i].focus(); | |
}} | |
</script> | |
""", | |
height=0, | |
) | |