Coach / app.py
Medvira's picture
Update app.py
c4b81c5 verified
raw
history blame
3.68 kB
import gradio as gr
import os
from openai import OpenAI
import base64
# Read API key from environment variable
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
if not OPENAI_API_KEY:
raise ValueError("API key not found. Please set the OPENAI_API_KEY environment variable.")
client = OpenAI(api_key=OPENAI_API_KEY)
global_system_prompt = None
global_model = 'gpt-4'
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
def convert_history_to_openai_format(history):
"""
Convert chat history to OpenAI format.
Parameters:
history (list of tuples): The chat history where each tuple consists of (message, sender).
Returns:
list of dict: The formatted history for OpenAI with "role" as either "user" or "assistant".
"""
global global_system_prompt
if global_system_prompt is None:
global_system_prompt = "You are a helpful assistant."
formatted_history = [{"role": "system", "content": global_system_prompt}]
for entry in history:
# Ensure entry is a tuple and has exactly two elements
if isinstance(entry, tuple) and len(entry) == 2:
user_msg, assistant_msg = entry
if isinstance(user_msg, tuple) and ('.png' in user_msg[0] or '.jpg' in user_msg[0]):
encoded_image = encode_image(user_msg[0])
text = 'Help me based on the image'
if user_msg[1] != '':
text = user_msg[1]
content = [{'type':'text', 'text':text}, {'type':'image_url', 'image_url':{'url':f'data:image/jpeg;base64,{encoded_image}'}}]
formatted_history.append({"role": 'user', "content": content})
else:
formatted_history.append({"role": 'user', "content": user_msg})
if isinstance(assistant_msg, str):
formatted_history.append({"role": 'assistant', "content": assistant_msg})
else:
print(f"Unexpected entry format in history: {entry}") # Debugging output
return formatted_history
def bot(history):
global global_model
response = client.chat.completions.create(
model=global_model,
messages=convert_history_to_openai_format(history)
)
print(response)
chatbot_message = response.choices[0].message.content.strip()
history[-1][1] = chatbot_message
return history
def add_message(history, message):
if len(message["files"]) > 0:
for x in message["files"]:
history.append(((x, message["text"]), None))
else:
if message["text"] != '':
history.append((message["text"], None))
return history, gr.MultimodalTextbox(value=None, interactive=False)
# Define the Gradio interface
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Chatbot Playground")
chatbot = gr.Chatbot(label="Chatbot:", bubble_full_width=False, show_copy_button=True, min_width=400,
avatar_images=(os.path.join(os.getcwd(), 'user.png'), os.path.join(os.getcwd(), 'ai.png')))
chat_input = gr.MultimodalTextbox(interactive=True, placeholder="Enter message or upload file...", show_label=False)
chat_msg = chat_input.submit(add_message, [chatbot, chat_input], [chatbot, chat_input])
bot_msg = chat_msg.then(bot, chatbot, chatbot, api_name="bot_response")
bot_msg.then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input])
# Launch the Gradio interface
demo.launch(share=True) # Enable public sharing