Spaces:
Runtime error
Runtime error
import gradio as gr | |
import aiohttp | |
import asyncio | |
import json | |
import os | |
import datetime | |
import time | |
from concurrent.futures import ThreadPoolExecutor | |
API_URL = os.environ.get('API_URL') | |
API_KEY = os.environ.get('API_KEY') | |
headers = { | |
"Authorization": f"Bearer {API_KEY}", | |
"Content-Type": "application/json" | |
} | |
DEFAULT_PARAMS = { | |
"temperature": 0.8, | |
"top_p": 0.95, | |
"top_k": 40, | |
"frequency_penalty": 0, | |
"presence_penalty": 0, | |
"repetition_penalty": 1.1, | |
"max_tokens": 512 | |
} | |
thread_pool = ThreadPoolExecutor(max_workers=10) | |
def get_timestamp(): | |
return datetime.datetime.now().strftime("%H:%M:%S") | |
should_stop = False | |
async def predict(message, history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens): | |
global should_stop | |
history_format = [{"role": "system", "content": system_prompt}] | |
for human, assistant in history: | |
history_format.append({"role": "user", "content": human}) | |
if assistant: | |
history_format.append({"role": "assistant", "content": assistant}) | |
history_format.append({"role": "user", "content": message}) | |
if not message.startswith(('*', '"')): | |
print(f"<|system|> {system_prompt}") | |
print(f"{get_timestamp()} <|user|> {message}") | |
current_params = { | |
"temperature": temperature, | |
"top_p": top_p, | |
"top_k": top_k, | |
"frequency_penalty": frequency_penalty, | |
"presence_penalty": presence_penalty, | |
"repetition_penalty": repetition_penalty, | |
"max_tokens": max_tokens | |
} | |
non_default_params = {k: v for k, v in current_params.items() if v != DEFAULT_PARAMS[k]} | |
if non_default_params and not message.startswith(('*', '"')): | |
for param, value in non_default_params.items(): | |
print(f"{param}={value}") | |
data = { | |
"model": "meta-llama/Meta-Llama-3.1-405B-Instruct", | |
"messages": history_format, | |
"stream": True, | |
"temperature": temperature, | |
"top_p": top_p, | |
"top_k": top_k, | |
"frequency_penalty": frequency_penalty, | |
"presence_penalty": presence_penalty, | |
"repetition_penalty": repetition_penalty, | |
"max_tokens": max_tokens | |
} | |
async with aiohttp.ClientSession() as session: | |
async with session.post(API_URL, headers=headers, json=data) as response: | |
partial_message = "" | |
async for line in response.content: | |
if should_stop: | |
break | |
line = line.decode('utf-8') | |
if line.startswith("data: "): | |
if line.strip() == "data: [DONE]": | |
break | |
try: | |
json_data = json.loads(line[6:]) | |
if 'choices' in json_data and json_data['choices']: | |
content = json_data['choices'][0]['delta'].get('content', '') | |
if content: | |
partial_message += content | |
yield partial_message | |
except json.JSONDecodeError: | |
continue | |
if partial_message: | |
yield partial_message | |
def import_chat(custom_format_string): | |
try: | |
sections = custom_format_string.split('<|') | |
imported_history = [] | |
system_prompt = "" | |
for section in sections: | |
if section.startswith('system|>'): | |
system_prompt = section.replace('system|>', '').strip() | |
elif section.startswith('user|>'): | |
user_message = section.replace('user|>', '').strip() | |
imported_history.append([user_message, None]) | |
elif section.startswith('assistant|>'): | |
assistant_message = section.replace('assistant|>', '').strip() | |
if imported_history: | |
imported_history[-1][1] = assistant_message | |
else: | |
imported_history.append(["", assistant_message]) | |
return imported_history, system_prompt | |
except Exception as e: | |
print(f"Error importing chat: {e}") | |
return None, None | |
def export_chat(history, system_prompt): | |
export_data = f"<|system|> {system_prompt}\n\n" | |
if history is not None: | |
for user_msg, assistant_msg in history: | |
export_data += f"<|user|> {user_msg}\n\n" | |
if assistant_msg: | |
export_data += f"<|assistant|> {assistant_msg}\n\n" | |
return export_data | |
def stop_generation(): | |
global should_stop | |
should_stop = True | |
return gr.update(interactive=True), gr.update(interactive=True) | |
with gr.Blocks(theme=gr.themes.Monochrome()) as demo: | |
with gr.Row(): | |
with gr.Column(scale=2): | |
chatbot = gr.Chatbot(value=[]) | |
msg = gr.Textbox(label="Message") | |
with gr.Row(): | |
clear = gr.Button("Clear") | |
regenerate = gr.Button("Regenerate") | |
stop_btn = gr.Button("Stop") | |
with gr.Row(): | |
with gr.Column(scale=4): | |
import_textbox = gr.Textbox(label="Import textbox", lines=5) | |
with gr.Column(scale=1): | |
export_button = gr.Button("Export Chat") | |
import_button = gr.Button("Import Chat") | |
with gr.Column(scale=1): | |
system_prompt = gr.Textbox("", label="System Prompt", lines=5) | |
temperature = gr.Slider(0, 2, value=0.8, step=0.01, label="Temperature") | |
top_p = gr.Slider(0, 1, value=0.95, step=0.01, label="Top P") | |
top_k = gr.Slider(1, 500, value=40, step=1, label="Top K") | |
frequency_penalty = gr.Slider(-2, 2, value=0, step=0.1, label="Frequency Penalty") | |
presence_penalty = gr.Slider(-2, 2, value=0, step=0.1, label="Presence Penalty") | |
repetition_penalty = gr.Slider(0.01, 5, value=1.1, step=0.01, label="Repetition Penalty") | |
max_tokens = gr.Slider(1, 4096, value=512, step=1, label="Max Output (max_tokens)") | |
def user(user_message, history): | |
history = history or [] | |
return "", history + [[user_message, None]] | |
async def bot(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens): | |
global should_stop | |
should_stop = False | |
history = history or [] | |
if not history: | |
yield history | |
return | |
user_message = history[-1][0] | |
bot_message = predict(user_message, history[:-1], system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens) | |
history[-1][1] = "" | |
try: | |
async for chunk in bot_message: | |
if should_stop: | |
break | |
history[-1][1] = chunk | |
yield history | |
except Exception as e: | |
print(f"Error in bot function: {str(e)}") | |
history[-1][1] = "An error occurred while generating the response." | |
yield history | |
finally: | |
should_stop = False | |
async def regenerate_response(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens): | |
global should_stop | |
should_stop = False | |
if history and len(history) > 0: | |
last_user_message = history[-1][0] | |
history[-1][1] = None | |
async for new_history in bot(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens): | |
if should_stop: | |
break | |
yield new_history | |
else: | |
yield [] | |
should_stop = False | |
def import_chat_wrapper(custom_format_string): | |
imported_history, imported_system_prompt = import_chat(custom_format_string) | |
return imported_history, imported_system_prompt | |
submit_event = msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( | |
bot, [chatbot, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens], chatbot, | |
concurrency_limit=10 | |
) | |
clear.click(lambda: None, None, chatbot, queue=False) | |
regenerate_event = regenerate.click( | |
regenerate_response, | |
[chatbot, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens], | |
chatbot, | |
concurrency_limit=10 | |
) | |
stop_btn.click( | |
stop_generation, | |
inputs=[], | |
outputs=[msg, regenerate], | |
cancels=[submit_event, regenerate_event], | |
queue=False | |
) | |
import_button.click(import_chat_wrapper, inputs=[import_textbox], outputs=[chatbot, system_prompt], queue=False) | |
export_button.click( | |
export_chat, | |
inputs=[chatbot, system_prompt], | |
outputs=[import_textbox], | |
queue=False | |
) | |
if __name__ == "__main__": | |
demo.launch(debug=True, server_name="0.0.0.0", server_port=7860, share=True, max_threads=40) |