Spaces:
Runtime error
Runtime error
File size: 6,870 Bytes
1ce6f28 408fb2e 3a8e738 7901b62 b9b48c5 f991d7d b9b48c5 f991d7d 880e945 408fb2e 880e945 dcf6e59 880e945 dcf6e59 880e945 e61245d b213ce2 e61245d b213ce2 e61245d 64f86b5 e61245d b213ce2 64f86b5 b213ce2 e61245d b213ce2 6ac567d 7e4cf08 c314120 dcf6e59 e298b33 dcf6e59 408fb2e 71a6f99 408fb2e 70a5709 e298b33 70a5709 71a6f99 70a5709 8244168 408fb2e 70a5709 408fb2e 70a5709 9ab3033 70a5709 e298b33 70a5709 e298b33 8312b78 4c1f576 ac31486 25819b2 ae4438b dcf6e59 463c3f1 7901b62 463c3f1 7901b62 463c3f1 f7c578d dcf6e59 f7c578d dcf6e59 a4cd409 71a6f99 dcf6e59 880e945 dcf6e59 17f0ed4 421426b f7c578d 880e945 dcf6e59 f7c578d 71a6f99 f7c578d 71a6f99 f7c578d 6ac567d e61245d 6ac567d 71a6f99 bce3dcd 90981c9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import gradio as gr
import re
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import bitsandbytes
import accelerate
model_name_or_path = "teknium/OpenHermes-2.5-Mistral-7B"
dtype = torch.bfloat16
model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
device_map="auto",
torch_dtype=dtype,
trust_remote_code=False,
load_in_4bit=True,
revision="main")
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
BASE_SYSTEM_MESSAGE = "I carefully provide accurate, factual, thoughtful, nuanced answers and am brilliant at reasoning."
def clear_chat(chat_history_state, chat_message):
chat_history_state = []
chat_message = ''
return chat_history_state, chat_message
def user(message, history):
history = history or []
history.append([message, ""])
return "", history
def regenerate(chatbot, chat_history_state, system_msg, max_tokens, temperature, top_p, top_k, repetition_penalty):
print("Regenerate function called") # Debug print
if not chat_history_state:
print("Chat history is empty") # Debug print
return chatbot, chat_history_state, ""
# Remove only the last assistant's message from the chat history
if len(chat_history_state) > 0:
print(f"Before: {chat_history_state[-1]}") # Debug print
chat_history_state[-1][1] = ""
print(f"After: {chat_history_state[-1]}") # Debug print
# Re-run the chat function
new_history, _, _ = chat(chat_history_state, system_msg, max_tokens, temperature, top_p, top_k, repetition_penalty)
print(f"New history: {new_history}") # Debug print
return new_history, new_history, ""
def chat(history, system_message, max_tokens, temperature, top_p, top_k, repetition_penalty):
print(f"Chat function called with history: {history}")
history = history or []
# Use BASE_SYSTEM_MESSAGE if system_message is empty
system_message_to_use = system_message if system_message.strip() else BASE_SYSTEM_MESSAGE
# A última mensagem do usuário
user_prompt = history[-1][0] if history else ""
print(f"User prompt used for generation: {user_prompt}") # Debug print
# Preparar a entrada para o modelo
prompt_template = f'''system
{system_message_to_use.strip()}
user
{user_prompt}
assistant
'''
input_ids = tokenizer(prompt_template, return_tensors='pt').input_ids.cuda()
# Gerar a saída
output = model.generate(
input_ids=input_ids,
max_length=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty
)
# Decodificar a saída
decoded_output = tokenizer.decode(output[0], skip_special_tokens=True)
assistant_response = decoded_output.split('assistant')[-1].strip() # Pegar apenas a última resposta do assistente
print(f"Generated assistant response: {assistant_response}") # Debug print
# Atualizar o histórico
if history:
history[-1][1] += assistant_response
else:
history.append(["", assistant_response])
print(f"Updated history: {history}")
return history, history, ""
start_message = ""
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
gr.Markdown("""
## OpenHermes-V2.5 Finetuned on Mistral 7B
**Space created by [@artificialguybr](https://twitter.com/artificialguybr). Model by [@Teknium1](https://twitter.com/Teknium1). Thanks HF for GPU!**
**OpenHermes-V2.5 is currently SOTA in some benchmarks for 7B models.**
**Hermes 2 model was trained on 900,000 instructions, and surpasses all previous versions of Hermes 13B and below, and matches 70B on some benchmarks! Hermes 2 changes the game with strong multiturn chat skills, system prompt capabilities, and uses ChatML format. It's quality, diversity and scale is unmatched in the current OS LM landscape. Not only does it do well in benchmarks, but also in unmeasured capabilities, like Roleplaying, Tasks, and more.**
""")
with gr.Row():
#chatbot = gr.Chatbot().style(height=500)
chatbot = gr.Chatbot(elem_id="chatbot")
with gr.Row():
message = gr.Textbox(
label="What do you want to chat about?",
placeholder="Ask me anything.",
lines=3,
)
with gr.Row():
submit = gr.Button(value="Send message", variant="secondary", scale=1)
clear = gr.Button(value="New topic", variant="secondary", scale=0)
stop = gr.Button(value="Stop", variant="secondary", scale=0)
regen_btn = gr.Button(value="Regenerate", variant="secondary", scale=0)
with gr.Accordion("Show Model Parameters", open=False):
with gr.Row():
with gr.Column():
max_tokens = gr.Slider(20, 512, label="Max Tokens", step=20, value=500)
temperature = gr.Slider(0.0, 2.0, label="Temperature", step=0.1, value=0.7)
top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.95)
top_k = gr.Slider(1, 100, label="Top K", step=1, value=40)
repetition_penalty = gr.Slider(1.0, 2.0, label="Repetition Penalty", step=0.1, value=1.1)
system_msg = gr.Textbox(
start_message, label="System Message", interactive=True, visible=True, placeholder="System prompt. Provide instructions which you want the model to remember.", lines=5)
chat_history_state = gr.State()
clear.click(clear_chat, inputs=[chat_history_state, message], outputs=[chat_history_state, message], queue=False)
clear.click(lambda: None, None, chatbot, queue=False)
submit_click_event = submit.click(
fn=user, inputs=[message, chat_history_state], outputs=[message, chat_history_state], queue=True
).then(
fn=chat, inputs=[chat_history_state, system_msg, max_tokens, temperature, top_p, top_k, repetition_penalty], outputs=[chatbot, chat_history_state, message], queue=True
)
# Corrected the clear button click event
clear.click(
fn=clear_chat, inputs=[chat_history_state, message], outputs=[chat_history_state, message], queue=False
)
# Stop button remains the same
stop.click(fn=None, inputs=None, outputs=None, cancels=[submit_click_event], queue=False)
regen_click_event = regen_btn.click(
fn=regenerate,
inputs=[chatbot, chat_history_state, system_msg, max_tokens, temperature, top_p, top_k, repetition_penalty],
outputs=[chatbot, chat_history_state, message],
queue=True
)
demo.queue(max_size=128, concurrency_count=2)
demo.launch() |