File size: 2,821 Bytes
17e8c28
 
 
 
 
 
9fbab2d
dbe7aa2
9fbab2d
dbe7aa2
 
 
 
 
9fbab2d
 
17e8c28
 
 
dbe7aa2
 
17e8c28
dbe7aa2
17e8c28
 
9fbab2d
dbe7aa2
 
 
9fbab2d
 
dbe7aa2
 
 
9fbab2d
dbe7aa2
9fbab2d
 
 
 
dbe7aa2
9fbab2d
dbe7aa2
 
9fbab2d
dbe7aa2
 
 
 
9fbab2d
 
dbe7aa2
9fbab2d
dbe7aa2
9fbab2d
 
 
 
 
dbe7aa2
9fbab2d
 
dbe7aa2
9fbab2d
 
 
 
 
dbe7aa2
9fbab2d
 
dbe7aa2
9fbab2d
 
 
 
 
dbe7aa2
9fbab2d
 
dbe7aa2
9fbab2d
 
 
 
 
dbe7aa2
9fbab2d
 
 
dbe7aa2
17e8c28
 
 
9fbab2d
 
 
 
dbe7aa2
 
 
 
17e8c28
9fbab2d
 
6428f94
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from huggingface_hub import InferenceClient
import gradio as gr

client = InferenceClient(
    "google/gemma-7b-it"
)

def format_prompt(message, history):
    prompt = ""
    if history:
        for user_prompt, bot_response in history:
            prompt += f"<start_of_turn>user{user_prompt}<end_of_turn>"
            prompt += f"<start_of_turn>model{bot_response}"
    prompt += f"<start_of_turn>user{message}<end_of_turn><start_of_turn>model"
    return prompt

def generate(
    prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
):
    if not history:
        history = []
        hist_len=0
    if history:
        hist_len=len(history)
        print(hist_len)

    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)

    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=42,
    )

    formatted_prompt = format_prompt(prompt, history)

    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
    output = ""

    for response in stream:
        output += response.token.text
        yield output
    return output


additional_inputs=[
    gr.Slider(
        label="Temperature",
        value=0.9,
        minimum=0.0,
        maximum=1.0,
        step=0.05,
        interactive=True,
        info="Higher values produce more diverse outputs",
    ),
    gr.Slider(
        label="Max new tokens",
        value=512,
        minimum=0,
        maximum=1048,
        step=64,
        interactive=True,
        info="The maximum numbers of new tokens",
    ),
    gr.Slider(
        label="Top-p (nucleus sampling)",
        value=0.90,
        minimum=0.0,
        maximum=1,
        step=0.05,
        interactive=True,
        info="Higher values sample more low-probability tokens",
    ),
    gr.Slider(
        label="Repetition penalty",
        value=1.2,
        minimum=1.0,
        maximum=2.0,
        step=0.05,
        interactive=True,
        info="Penalize repeated tokens",
    )
]

# Create a Chatbot object with the desired height
chatbot = gr.Chatbot(height=450,
                     layout="bubble",
                     placeholder="Type here to chat...")

with gr.Blocks() as demo:
    gr.HTML("<h1><center>πŸ€– Google-Gemma-7B-Chat πŸ’¬<h1><center>")
    gr.ChatInterface(
        generate,
        chatbot=chatbot,  # Use the created Chatbot object
        additional_inputs=additional_inputs,
        examples=[["What is the meaning of life?"], ["Tell me something about Mt Fuji."]],
        placeholder="Type here to chat..."
    )

demo.queue().launch(debug=True)