abhillubillu commited on
Commit
b26b8bd
·
verified ·
1 Parent(s): 5a80d05

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -162
app.py CHANGED
@@ -1,178 +1,63 @@
1
  import gradio as gr
2
- import os
3
- import spaces
4
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
5
- from threading import Thread
6
 
7
- TITLE = ''
8
-
9
- DESCRIPTION = ''
10
-
11
- LICENSE = """
12
- <p>Built with Llama</p>
13
- """
14
-
15
- PLACEHOLDER = """
16
- <div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
17
- <h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.85;">Gameapp</h1>
18
- <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.75;">Ask me anything...</p>
19
- </div>
20
  """
21
-
22
- css = """
23
- h1 {
24
- text-align: center;
25
- display: block;
26
- display: flex;
27
- align-items: center;
28
- justify-content: center;
29
- }
30
-
31
- .gradio-container {
32
- border: 1px solid #ddd;
33
- border-radius: 10px;
34
- padding: 20px;
35
- box-shadow: 0 4px 8px rgba(0,0,0,0.1);
36
- }
37
-
38
- .gradio-chatbot .input-container {
39
- border-top: 1px solid #ddd;
40
- padding-top: 10px;
41
- }
42
-
43
- .gradio-chatbot .input-container textarea {
44
- border: 1px solid #ddd;
45
- border-radius: 5px;
46
- padding: 10px;
47
- width: 100%;
48
- box-sizing: border-box;
49
- resize: none;
50
- height: 50px;
51
- }
52
-
53
- .gradio-chatbot .message {
54
- border-radius: 10px;
55
- padding: 10px;
56
- margin: 10px 0;
57
- box-shadow: 0 4px 8px rgba(0,0,0,0.1);
58
- }
59
-
60
- .gradio-chatbot .message.user {
61
- background-color: #f5f5f5;
62
- }
63
-
64
- .gradio-chatbot .message.assistant {
65
- background-color: #e6f7ff;
66
- }
67
  """
 
68
 
69
- model_id = "abhillubillu/gameapp_model"
70
- hf_token = os.getenv("HF_API_TOKEN")
71
-
72
- # Load the tokenizer and model
73
- tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
74
- model = AutoModelForCausalLM.from_pretrained(model_id, token=hf_token, device_map="auto")
75
 
76
- # Ensure eos_token_id is set
77
- eos_token_id = tokenizer.eos_token_id
78
- if eos_token_id is None:
79
- eos_token_id = tokenizer.pad_token_id
 
 
 
 
 
80
 
81
- terminators = [
82
- eos_token_id,
83
- tokenizer.convert_tokens_to_ids("")
84
- ]
 
85
 
86
- MAX_INPUT_TOKEN_LENGTH = 4096
87
 
88
- # Gradio inference function
89
- @spaces.GPU(duration=120)
90
- def chat_llama3_1_8b(message: str,
91
- history: list,
92
- temperature: float,
93
- max_new_tokens: int
94
- ) -> str:
95
- """
96
- Generate a streaming response using the llama3-8b model.
97
- Args:
98
- message (str): The input message.
99
- history (list): The conversation history used by ChatInterface.
100
- temperature (float): The temperature for generating the response.
101
- max_new_tokens (int): The maximum number of new tokens to generate.
102
- Returns:
103
- str: The generated response.
104
- """
105
- conversation = []
106
- for user, assistant in history:
107
- conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
108
- conversation.append({"role": "user", "content": message})
109
 
110
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
111
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
112
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
113
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
114
- input_ids = input_ids.to(model.device)
115
-
116
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
117
-
118
- generate_kwargs = dict(
119
- input_ids= input_ids,
120
- streamer=streamer,
121
- max_new_tokens=max_new_tokens,
122
- do_sample=temperature != 0, # This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash.
123
  temperature=temperature,
124
- eos_token_id=terminators,
125
- )
126
-
127
- t = Thread(target=model.generate, kwargs=generate_kwargs)
128
- t.start()
129
 
130
- outputs = []
131
- for text in streamer:
132
- outputs.append(text)
133
- yield "".join(outputs)
134
 
135
- # Gradio block
136
- chatbot = gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='Gradio ChatInterface')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
- with gr.Blocks(fill_height=True, css=css) as demo:
139
 
140
- gr.Markdown(TITLE)
141
- gr.Markdown(DESCRIPTION)
142
- gr.ChatInterface(
143
- fn=chat_llama3_1_8b,
144
- chatbot=chatbot,
145
- fill_height=True,
146
- examples_per_page=3,
147
- additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
148
- additional_inputs=[
149
- gr.Slider(minimum=0,
150
- maximum=1,
151
- step=0.1,
152
- value=0.95,
153
- label="Temperature",
154
- render=False),
155
- gr.Slider(minimum=128,
156
- maximum=4096,
157
- step=1,
158
- value=512,
159
- label="Max new tokens",
160
- render=False ),
161
- ],
162
- examples=[
163
- ["There's a llama in my garden 😱 What should I do?"],
164
- ["What is the best way to open a can of worms?"],
165
- ["The odd numbers in this group add up to an even number: 15, 32, 5, 13, 82, 7, 1. "],
166
- ['How to setup a human base on Mars? Give short answer.'],
167
- ['Explain theory of relativity to me like I’m 8 years old.'],
168
- ['What is 9,000 * 9,000?'],
169
- ['Write a pun-filled happy birthday message to my friend Alex.'],
170
- ['Justify why a penguin might make a good king of the jungle.']
171
- ],
172
- cache_examples=False,
173
- )
174
-
175
- gr.Markdown(LICENSE)
176
-
177
  if __name__ == "__main__":
178
- demo.launch()
 
1
  import gradio as gr
2
+ from huggingface_hub import InferenceClient
 
 
 
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  """
5
+ For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  """
7
+ client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
 
 
 
 
 
 
 
9
 
10
+ def respond(
11
+ message,
12
+ history: list[tuple[str, str]],
13
+ system_message,
14
+ max_tokens,
15
+ temperature,
16
+ top_p,
17
+ ):
18
+ messages = [{"role": "system", "content": system_message}]
19
 
20
+ for val in history:
21
+ if val[0]:
22
+ messages.append({"role": "user", "content": val[0]})
23
+ if val[1]:
24
+ messages.append({"role": "assistant", "content": val[1]})
25
 
26
+ messages.append({"role": "user", "content": message})
27
 
28
+ response = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
+ for message in client.chat_completion(
31
+ messages,
32
+ max_tokens=max_tokens,
33
+ stream=True,
 
 
 
 
 
 
 
 
 
34
  temperature=temperature,
35
+ top_p=top_p,
36
+ ):
37
+ token = message.choices[0].delta.content
 
 
38
 
39
+ response += token
40
+ yield response
 
 
41
 
42
+ """
43
+ For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
44
+ """
45
+ demo = gr.ChatInterface(
46
+ respond,
47
+ additional_inputs=[
48
+ gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
+ gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
+ gr.Slider(
52
+ minimum=0.1,
53
+ maximum=1.0,
54
+ value=0.95,
55
+ step=0.05,
56
+ label="Top-p (nucleus sampling)",
57
+ ),
58
+ ],
59
+ )
60
 
 
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  if __name__ == "__main__":
63
+ demo.launch()