File size: 4,666 Bytes
717452a
 
 
 
e9bec21
717452a
e9bec21
717452a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e9bec21
717452a
e9bec21
717452a
e9bec21
 
 
 
 
717452a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e9bec21
717452a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import os
from threading import Thread, Event
from typing import Iterator

import gradio as gr

import torch
from transformers import AutoModelForCausalLM, GemmaTokenizerFast, TextIteratorStreamer

DESCRIPTION = """\
# Gemma 2 2B IT
Gemma 2 is Google's latest iteration of open LLMs.
This is a demo of [`google/gemma-2-2b-it`](https://huggingface.co/google/gemma-2-2b-it), fine-tuned for instruction following.
For more details, please check [our post](https://huggingface.co/blog/gemma2).
👉 Looking for a larger and more powerful version? Try the 27B version in [HuggingChat](https://huggingface.co/chat/models/google/gemma-2-27b-it) and the 9B version in [this Space](https://huggingface.co/spaces/huggingface-projects/gemma-2-9b-it).
"""

MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))

# Load the model and tokenizer
tokenizer = GemmaTokenizerFast.from_pretrained("TenzinGayche/example")
model = AutoModelForCausalLM.from_pretrained("TenzinGayche/example", torch_dtype=torch.float16).to("cuda")

model.config.sliding_window = 4096
model.eval()

# Create a shared stop event
stop_event = Event()

def generate(
    message: str,
    chat_history: list[dict],
    max_new_tokens: int = 1024,
    temperature: float = 0.6,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.2,
) -> Iterator[str]:
    # Clear the stop event before starting a new generation
    stop_event.clear()

    conversation = chat_history.copy()
    conversation.append({"role": "user", "content": message})

    input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
    input_ids = input_ids.to(model.device)

    streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        {"input_ids": input_ids},
        streamer=streamer,
        max_new_tokens=max_new_tokens,
    )
    
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        if stop_event.is_set():
            break  # Stop if the stop button is pressed
        outputs.append(text)
        yield "".join(outputs)

# Define a function to stop the generation
def stop_generation():
    stop_event.set()

# Create the chat interface with additional inputs and the stop button
with gr.Blocks(css="style.css", fill_height=True) as demo:
    gr.Markdown(DESCRIPTION)

    # Create the chat interface
    chat_interface = gr.ChatInterface(
        fn=generate,
        additional_inputs=[
            gr.Slider(
                label="Max new tokens",
                minimum=1,
                maximum=MAX_MAX_NEW_TOKENS,
                step=1,
                value=DEFAULT_MAX_NEW_TOKENS,
            ),
            gr.Slider(
                label="Temperature",
                minimum=0.1,
                maximum=4.0,
                step=0.1,
                value=0.6,
            ),
            gr.Slider(
                label="Top-p (nucleus sampling)",
                minimum=0.05,
                maximum=1.0,
                step=0.05,
                value=0.9,
            ),
            gr.Slider(
                label="Top-k",
                minimum=1,
                maximum=1000,
                step=1,
                value=50,
            ),
            gr.Slider(
                label="Repetition penalty",
                minimum=1.0,
                maximum=2.0,
                step=0.05,
                value=1.2,
            ),
        ],
        examples=[
            ["Hello there! How are you doing?"],
            ["Can you explain briefly to me what is the Python programming language?"],
            ["Explain the plot of Cinderella in a sentence."],
            ["How many hours does it take a man to eat a Helicopter?"],
            ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
        ],
        cache_examples=False,
        type="messages",
    )
    
    # Create the stop button inside the Blocks context
    stop_button = gr.Button("Stop", elem_id="stop-btn")
    stop_button.click(fn=stop_generation, inputs=[], outputs=[])

    gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
    chat_interface.render()

if __name__ == "__main__":
    demo.queue(max_size=20).launch(share=True)