Zamba2-7B / app.py
gabrielclark3330's picture
instruct and base chat types
459aa64
raw
history blame
5.81 kB
import os
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# Define models as None to delay loading
model, model_instruct = None, None
tokenizer, tokenizer_instruct = None, None
def generate_response_base(input_text, max_new_tokens, temperature, top_k, top_p, repetition_penalty, num_beams, length_penalty):
global model, tokenizer
if model is None:
tokenizer = AutoTokenizer.from_pretrained("Zyphra/Zamba2-7B")
model = AutoModelForCausalLM.from_pretrained(
"Zyphra/Zamba2-7B", device_map="cuda", torch_dtype=torch.bfloat16
)
selected_model = model
selected_tokenizer = tokenizer
# Tokenize and generate response
input_ids = selected_tokenizer(input_text, return_tensors="pt").input_ids.to(selected_model.device)
outputs = selected_model.generate(
input_ids=input_ids,
max_new_tokens=int(max_new_tokens),
do_sample=True,
temperature=temperature,
top_k=int(top_k),
top_p=top_p,
repetition_penalty=repetition_penalty,
num_beams=int(num_beams),
length_penalty=length_penalty,
num_return_sequences=1
)
response = selected_tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
def generate_response_instruct(chat_history, max_new_tokens, temperature, top_k, top_p, repetition_penalty, num_beams, length_penalty):
global model_instruct, tokenizer_instruct
if model_instruct is None:
tokenizer_instruct = AutoTokenizer.from_pretrained("Zyphra/Zamba2-7B-instruct")
model_instruct = AutoModelForCausalLM.from_pretrained(
"Zyphra/Zamba2-7B-instruct", device_map="cuda", torch_dtype=torch.bfloat16
)
selected_model = model_instruct
selected_tokenizer = tokenizer_instruct
# Build the sample
sample = []
for turn in chat_history:
if turn[0]:
sample.append({'role': 'user', 'content': turn[0]})
if turn[1]:
sample.append({'role': 'assistant', 'content': turn[1]})
# Format the chat sample
chat_sample = selected_tokenizer.apply_chat_template(sample, tokenize=False)
# Tokenize input and generate output
input_ids = selected_tokenizer(chat_sample, return_tensors='pt', add_special_tokens=False).input_ids.to(selected_model.device)
outputs = selected_model.generate(
input_ids=input_ids,
max_new_tokens=int(max_new_tokens),
do_sample=True,
temperature=temperature,
top_k=int(top_k),
top_p=top_p,
repetition_penalty=repetition_penalty,
num_beams=int(num_beams),
length_penalty=length_penalty,
num_return_sequences=1
)
response = selected_tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
def clear_text():
return ""
with gr.Blocks() as demo:
gr.Markdown("# Zamba2-7B Model Selector")
with gr.Tabs():
with gr.TabItem("Base Model"):
gr.Markdown("### Zamba2-7B Base Model")
input_text = gr.Textbox(lines=2, placeholder="Enter your input text...", label="Input Text")
output_text = gr.Textbox(label="Generated Response")
max_new_tokens = gr.Slider(50, 1000, step=50, value=500, label="Max New Tokens")
temperature = gr.Slider(0.1, 1.5, step=0.1, value=0.7, label="Temperature")
top_k = gr.Slider(1, 100, step=1, value=50, label="Top K")
top_p = gr.Slider(0.1, 1.0, step=0.1, value=0.9, label="Top P")
repetition_penalty = gr.Slider(1.0, 2.0, step=0.1, value=1.2, label="Repetition Penalty")
num_beams = gr.Slider(1, 10, step=1, value=5, label="Number of Beams")
length_penalty = gr.Slider(0.0, 2.0, step=0.1, value=1.0, label="Length Penalty")
submit_button = gr.Button("Generate Response")
submit_button.click(fn=generate_response_base, inputs=[input_text, max_new_tokens, temperature, top_k, top_p, repetition_penalty, num_beams, length_penalty], outputs=output_text)
submit_button.click(fn=clear_text, outputs=input_text)
with gr.TabItem("Instruct Model"):
gr.Markdown("### Zamba2-7B Instruct Model")
chat_history = gr.Chatbot()
message = gr.Textbox(lines=2, placeholder="Enter your message...", label="Your Message")
max_new_tokens_instruct = gr.Slider(50, 1000, step=50, value=500, label="Max New Tokens")
temperature_instruct = gr.Slider(0.1, 1.5, step=0.1, value=0.7, label="Temperature")
top_k_instruct = gr.Slider(1, 100, step=1, value=50, label="Top K")
top_p_instruct = gr.Slider(0.1, 1.0, step=0.1, value=0.9, label="Top P")
repetition_penalty_instruct = gr.Slider(1.0, 2.0, step=0.1, value=1.2, label="Repetition Penalty")
num_beams_instruct = gr.Slider(1, 10, step=1, value=5, label="Number of Beams")
length_penalty_instruct = gr.Slider(0.0, 2.0, step=0.1, value=1.0, label="Length Penalty")
def user_message(message, chat_history):
chat_history = chat_history + [[message, None]]
return "", chat_history
def bot_response(chat_history):
response = generate_response_instruct(chat_history, max_new_tokens_instruct, temperature_instruct, top_k_instruct, top_p_instruct, repetition_penalty_instruct, num_beams_instruct, length_penalty_instruct)
chat_history[-1][1] = response
return chat_history
message.submit(user_message, [message, chat_history], [message, chat_history], queue=False).then(
bot_response, inputs=[chat_history], outputs=[chat_history]
)
if __name__ == "__main__":
demo.launch()