File size: 5,329 Bytes
e8d28ee
312e99a
 
 
 
e8d28ee
f96052b
7e126fe
f96052b
312e99a
 
f96052b
7e126fe
312e99a
 
 
bfc8250
53b2631
312e99a
e8d28ee
312e99a
 
 
 
 
 
53b2631
 
312e99a
 
53b2631
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8d28ee
 
bfc8250
f96052b
e8d28ee
312e99a
d7665a7
 
5a80d05
 
 
 
 
 
312e99a
5a80d05
53b2631
312e99a
e8d28ee
312e99a
e8d28ee
312e99a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8d28ee
312e99a
 
 
 
 
 
 
e8d28ee
312e99a
 
 
 
 
e8d28ee
312e99a
 
e8d28ee
312e99a
 
e8d28ee
312e99a
 
 
 
 
 
31494d4
e8d28ee
312e99a
e8d28ee
312e99a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8d28ee
53b2631
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import gradio as gr
import os
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread

TITLE = ''

DESCRIPTION = ''

LICENSE = """
<p>Built with Llama</p>
"""

PLACEHOLDER = """
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
   <h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.85;">Gameapp</h1>
   <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.75;">Ask me anything...</p>
</div>
"""

css = """
h1 {
  text-align: center;
  display: block;
  display: flex;
  align-items: center;
  justify-content: center;
}

.gradio-container {
  border: 1px solid #ddd;
  border-radius: 10px;
  padding: 20px;
  box-shadow: 0 4px 8px rgba(0,0,0,0.1);
}

.gradio-chatbot .input-container {
  border-top: 1px solid #ddd;
  padding-top: 10px;
}

.gradio-chatbot .input-container textarea {
  border: 1px solid #ddd;
  border-radius: 5px;
  padding: 10px;
  width: 100%;
  box-sizing: border-box;
  resize: none;
  height: 50px;
}

.gradio-chatbot .message {
  border-radius: 10px;
  padding: 10px;
  margin: 10px 0;
  box-shadow: 0 4px 8px rgba(0,0,0,0.1);
}

.gradio-chatbot .message.user {
  background-color: #f5f5f5;
}

.gradio-chatbot .message.assistant {
  background-color: #e6f7ff;
}
"""

model_id = "abhillubillu/gameapp_model"
hf_token = os.getenv("HF_API_TOKEN")

# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
model = AutoModelForCausalLM.from_pretrained(model_id, token=hf_token, device_map="auto")

# Ensure eos_token_id is set
eos_token_id = tokenizer.eos_token_id
if eos_token_id is None:
    eos_token_id = tokenizer.pad_token_id

terminators = [
    eos_token_id,
    tokenizer.convert_tokens_to_ids("")
]

MAX_INPUT_TOKEN_LENGTH = 4096

# Gradio inference function
@spaces.GPU(duration=120)
def chat_llama3_1_8b(message: str, 
              history: list, 
              temperature: float, 
              max_new_tokens: int
             ) -> str:
    """
    Generate a streaming response using the llama3-8b model.
    Args:
        message (str): The input message.
        history (list): The conversation history used by ChatInterface.
        temperature (float): The temperature for generating the response.
        max_new_tokens (int): The maximum number of new tokens to generate.
    Returns:
        str: The generated response.
    """
    conversation = []
    for user, assistant in history:
        conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
    conversation.append({"role": "user", "content": message})

    input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
    input_ids = input_ids.to(model.device)
    
    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)

    generate_kwargs = dict(
        input_ids= input_ids,
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=temperature != 0,  # This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash.
        temperature=temperature,
        eos_token_id=terminators,
    )

    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)

# Gradio block
chatbot = gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='Gradio ChatInterface')

with gr.Blocks(fill_height=True, css=css) as demo:

    gr.Markdown(TITLE)
    gr.Markdown(DESCRIPTION)
    gr.ChatInterface(
        fn=chat_llama3_1_8b,
        chatbot=chatbot,
        fill_height=True,
        examples_per_page=3,
        additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
        additional_inputs=[
            gr.Slider(minimum=0,
                      maximum=1, 
                      step=0.1,
                      value=0.95, 
                      label="Temperature", 
                      render=False),
            gr.Slider(minimum=128, 
                      maximum=4096,
                      step=1,
                      value=512, 
                      label="Max new tokens", 
                      render=False ),
            ],
        examples=[
            ["There's a llama in my garden 😱 What should I do?"],
            ["What is the best way to open a can of worms?"],
            ["The odd numbers in this group add up to an even number: 15, 32, 5, 13, 82, 7, 1. "],
            ['How to setup a human base on Mars? Give short answer.'],
            ['Explain theory of relativity to me like I’m 8 years old.'],
            ['What is 9,000 * 9,000?'],
            ['Write a pun-filled happy birthday message to my friend Alex.'],
            ['Justify why a penguin might make a good king of the jungle.']
            ],
        cache_examples=False,
                     )
    
    gr.Markdown(LICENSE)
    
if __name__ == "__main__":
    demo.launch()