|
import os |
|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import torch |
|
|
|
|
|
model, model_instruct = None, None |
|
tokenizer, tokenizer_instruct = None, None |
|
|
|
def generate_response_base(input_text, max_new_tokens, temperature, top_k, top_p, repetition_penalty, num_beams, length_penalty): |
|
global model, tokenizer |
|
if model is None: |
|
tokenizer = AutoTokenizer.from_pretrained("Zyphra/Zamba2-7B") |
|
model = AutoModelForCausalLM.from_pretrained( |
|
"Zyphra/Zamba2-7B", device_map="cuda", torch_dtype=torch.bfloat16 |
|
) |
|
selected_model = model |
|
selected_tokenizer = tokenizer |
|
|
|
|
|
input_ids = selected_tokenizer(input_text, return_tensors="pt").input_ids.to(selected_model.device) |
|
outputs = selected_model.generate( |
|
input_ids=input_ids, |
|
max_new_tokens=int(max_new_tokens), |
|
do_sample=True, |
|
temperature=temperature, |
|
top_k=int(top_k), |
|
top_p=top_p, |
|
repetition_penalty=repetition_penalty, |
|
num_beams=int(num_beams), |
|
length_penalty=length_penalty, |
|
num_return_sequences=1 |
|
) |
|
response = selected_tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
return response |
|
|
|
def generate_response_instruct(chat_history, max_new_tokens, temperature, top_k, top_p, repetition_penalty, num_beams, length_penalty): |
|
global model_instruct, tokenizer_instruct |
|
if model_instruct is None: |
|
tokenizer_instruct = AutoTokenizer.from_pretrained("Zyphra/Zamba2-7B-instruct") |
|
model_instruct = AutoModelForCausalLM.from_pretrained( |
|
"Zyphra/Zamba2-7B-instruct", device_map="cuda", torch_dtype=torch.bfloat16 |
|
) |
|
selected_model = model_instruct |
|
selected_tokenizer = tokenizer_instruct |
|
|
|
|
|
sample = [] |
|
for turn in chat_history: |
|
if turn[0]: |
|
sample.append({'role': 'user', 'content': turn[0]}) |
|
if turn[1]: |
|
sample.append({'role': 'assistant', 'content': turn[1]}) |
|
|
|
chat_sample = selected_tokenizer.apply_chat_template(sample, tokenize=False) |
|
|
|
input_ids = selected_tokenizer(chat_sample, return_tensors='pt', add_special_tokens=False).input_ids.to(selected_model.device) |
|
outputs = selected_model.generate( |
|
input_ids=input_ids, |
|
max_new_tokens=int(max_new_tokens), |
|
do_sample=True, |
|
temperature=temperature, |
|
top_k=int(top_k), |
|
top_p=top_p, |
|
repetition_penalty=repetition_penalty, |
|
num_beams=int(num_beams), |
|
length_penalty=length_penalty, |
|
num_return_sequences=1 |
|
) |
|
response = selected_tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
return response |
|
|
|
def clear_text(): |
|
return "" |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Zamba2-7B Model Selector") |
|
with gr.Tabs(): |
|
with gr.TabItem("Base Model"): |
|
gr.Markdown("### Zamba2-7B Base Model") |
|
input_text = gr.Textbox(lines=2, placeholder="Enter your input text...", label="Input Text") |
|
output_text = gr.Textbox(label="Generated Response") |
|
max_new_tokens = gr.Slider(50, 1000, step=50, value=500, label="Max New Tokens") |
|
temperature = gr.Slider(0.1, 1.5, step=0.1, value=0.7, label="Temperature") |
|
top_k = gr.Slider(1, 100, step=1, value=50, label="Top K") |
|
top_p = gr.Slider(0.1, 1.0, step=0.1, value=0.9, label="Top P") |
|
repetition_penalty = gr.Slider(1.0, 2.0, step=0.1, value=1.2, label="Repetition Penalty") |
|
num_beams = gr.Slider(1, 10, step=1, value=5, label="Number of Beams") |
|
length_penalty = gr.Slider(0.0, 2.0, step=0.1, value=1.0, label="Length Penalty") |
|
submit_button = gr.Button("Generate Response") |
|
submit_button.click(fn=generate_response_base, inputs=[input_text, max_new_tokens, temperature, top_k, top_p, repetition_penalty, num_beams, length_penalty], outputs=output_text) |
|
submit_button.click(fn=clear_text, outputs=input_text) |
|
with gr.TabItem("Instruct Model"): |
|
gr.Markdown("### Zamba2-7B Instruct Model") |
|
chat_history = gr.Chatbot() |
|
message = gr.Textbox(lines=2, placeholder="Enter your message...", label="Your Message") |
|
max_new_tokens_instruct = gr.Slider(50, 1000, step=50, value=500, label="Max New Tokens") |
|
temperature_instruct = gr.Slider(0.1, 1.5, step=0.1, value=0.7, label="Temperature") |
|
top_k_instruct = gr.Slider(1, 100, step=1, value=50, label="Top K") |
|
top_p_instruct = gr.Slider(0.1, 1.0, step=0.1, value=0.9, label="Top P") |
|
repetition_penalty_instruct = gr.Slider(1.0, 2.0, step=0.1, value=1.2, label="Repetition Penalty") |
|
num_beams_instruct = gr.Slider(1, 10, step=1, value=5, label="Number of Beams") |
|
length_penalty_instruct = gr.Slider(0.0, 2.0, step=0.1, value=1.0, label="Length Penalty") |
|
|
|
def user_message(message, chat_history): |
|
chat_history = chat_history + [[message, None]] |
|
return "", chat_history |
|
|
|
def bot_response(chat_history): |
|
response = generate_response_instruct(chat_history, max_new_tokens_instruct, temperature_instruct, top_k_instruct, top_p_instruct, repetition_penalty_instruct, num_beams_instruct, length_penalty_instruct) |
|
chat_history[-1][1] = response |
|
return chat_history |
|
|
|
message.submit(user_message, [message, chat_history], [message, chat_history], queue=False).then( |
|
bot_response, inputs=[chat_history], outputs=[chat_history] |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|