File size: 5,807 Bytes
9c317f9
 
 
 
 
ede06bd
 
 
 
459aa64
 
 
 
 
 
 
 
 
9c317f9
ede06bd
 
 
7eeefc1
459aa64
9c317f9
7eeefc1
459aa64
7eeefc1
 
459aa64
7eeefc1
9c317f9
 
ede06bd
9c317f9
 
459aa64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c317f9
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# Define models as None to delay loading
model, model_instruct = None, None
tokenizer, tokenizer_instruct = None, None

def generate_response_base(input_text, max_new_tokens, temperature, top_k, top_p, repetition_penalty, num_beams, length_penalty):
    global model, tokenizer
    if model is None:
        tokenizer = AutoTokenizer.from_pretrained("Zyphra/Zamba2-7B")
        model = AutoModelForCausalLM.from_pretrained(
            "Zyphra/Zamba2-7B", device_map="cuda", torch_dtype=torch.bfloat16
        )
    selected_model = model
    selected_tokenizer = tokenizer

    # Tokenize and generate response
    input_ids = selected_tokenizer(input_text, return_tensors="pt").input_ids.to(selected_model.device)
    outputs = selected_model.generate(
        input_ids=input_ids,
        max_new_tokens=int(max_new_tokens),
        do_sample=True,
        temperature=temperature,
        top_k=int(top_k),
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        num_beams=int(num_beams),
        length_penalty=length_penalty,
        num_return_sequences=1
    )
    response = selected_tokenizer.decode(outputs[0], skip_special_tokens=True)
    return response

def generate_response_instruct(chat_history, max_new_tokens, temperature, top_k, top_p, repetition_penalty, num_beams, length_penalty):
    global model_instruct, tokenizer_instruct
    if model_instruct is None:
        tokenizer_instruct = AutoTokenizer.from_pretrained("Zyphra/Zamba2-7B-instruct")
        model_instruct = AutoModelForCausalLM.from_pretrained(
            "Zyphra/Zamba2-7B-instruct", device_map="cuda", torch_dtype=torch.bfloat16
        )
    selected_model = model_instruct
    selected_tokenizer = tokenizer_instruct

    # Build the sample
    sample = []
    for turn in chat_history:
        if turn[0]:
            sample.append({'role': 'user', 'content': turn[0]})
        if turn[1]:
            sample.append({'role': 'assistant', 'content': turn[1]})
    # Format the chat sample
    chat_sample = selected_tokenizer.apply_chat_template(sample, tokenize=False)
    # Tokenize input and generate output
    input_ids = selected_tokenizer(chat_sample, return_tensors='pt', add_special_tokens=False).input_ids.to(selected_model.device)
    outputs = selected_model.generate(
        input_ids=input_ids,
        max_new_tokens=int(max_new_tokens),
        do_sample=True,
        temperature=temperature,
        top_k=int(top_k),
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        num_beams=int(num_beams),
        length_penalty=length_penalty,
        num_return_sequences=1
    )
    response = selected_tokenizer.decode(outputs[0], skip_special_tokens=True)
    return response

def clear_text():
    return ""

with gr.Blocks() as demo:
    gr.Markdown("# Zamba2-7B Model Selector")
    with gr.Tabs():
        with gr.TabItem("Base Model"):
            gr.Markdown("### Zamba2-7B Base Model")
            input_text = gr.Textbox(lines=2, placeholder="Enter your input text...", label="Input Text")
            output_text = gr.Textbox(label="Generated Response")
            max_new_tokens = gr.Slider(50, 1000, step=50, value=500, label="Max New Tokens")
            temperature = gr.Slider(0.1, 1.5, step=0.1, value=0.7, label="Temperature")
            top_k = gr.Slider(1, 100, step=1, value=50, label="Top K")
            top_p = gr.Slider(0.1, 1.0, step=0.1, value=0.9, label="Top P")
            repetition_penalty = gr.Slider(1.0, 2.0, step=0.1, value=1.2, label="Repetition Penalty")
            num_beams = gr.Slider(1, 10, step=1, value=5, label="Number of Beams")
            length_penalty = gr.Slider(0.0, 2.0, step=0.1, value=1.0, label="Length Penalty")
            submit_button = gr.Button("Generate Response")
            submit_button.click(fn=generate_response_base, inputs=[input_text, max_new_tokens, temperature, top_k, top_p, repetition_penalty, num_beams, length_penalty], outputs=output_text)
            submit_button.click(fn=clear_text, outputs=input_text)
        with gr.TabItem("Instruct Model"):
            gr.Markdown("### Zamba2-7B Instruct Model")
            chat_history = gr.Chatbot()
            message = gr.Textbox(lines=2, placeholder="Enter your message...", label="Your Message")
            max_new_tokens_instruct = gr.Slider(50, 1000, step=50, value=500, label="Max New Tokens")
            temperature_instruct = gr.Slider(0.1, 1.5, step=0.1, value=0.7, label="Temperature")
            top_k_instruct = gr.Slider(1, 100, step=1, value=50, label="Top K")
            top_p_instruct = gr.Slider(0.1, 1.0, step=0.1, value=0.9, label="Top P")
            repetition_penalty_instruct = gr.Slider(1.0, 2.0, step=0.1, value=1.2, label="Repetition Penalty")
            num_beams_instruct = gr.Slider(1, 10, step=1, value=5, label="Number of Beams")
            length_penalty_instruct = gr.Slider(0.0, 2.0, step=0.1, value=1.0, label="Length Penalty")

            def user_message(message, chat_history):
                chat_history = chat_history + [[message, None]]
                return "", chat_history

            def bot_response(chat_history):
                response = generate_response_instruct(chat_history, max_new_tokens_instruct, temperature_instruct, top_k_instruct, top_p_instruct, repetition_penalty_instruct, num_beams_instruct, length_penalty_instruct)
                chat_history[-1][1] = response
                return chat_history

            message.submit(user_message, [message, chat_history], [message, chat_history], queue=False).then(
                bot_response, inputs=[chat_history], outputs=[chat_history]
            )

if __name__ == "__main__":
    demo.launch()