Ravi21 commited on
Commit
e99b7ca
·
1 Parent(s): 9f34020

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +238 -30
app.py CHANGED
@@ -1,16 +1,56 @@
1
- import pandas as pd
2
- import numpy as np
3
  import gradio as gr
4
- pip install transformers
5
- from transformers import AutoModelForMultipleChoice, AutoTokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
 
8
- # Load the model and tokenizer
9
- model_path = "/kaggle/input/deberta-v3-large-hf-weights"
10
- model = AutoModelForMultipleChoice.from_pretrained(model_path)
11
- tokenizer = AutoTokenizer.from_pretrained(model_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- # Define the preprocessing function
14
  def preprocess(sample):
15
  first_sentences = [sample["prompt"]] * 5
16
  second_sentences = [sample[option] for option in "ABCDE"]
@@ -28,25 +68,193 @@ def predict(data):
28
  predictions_as_ids = torch.argsort(-logits, dim=1)
29
  answers = np.array(list("ABCDE"))[predictions_as_ids.tolist()]
30
  return ["".join(i) for i in answers[:, :3]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- # Create the Gradio interface
33
- iface = gr.Interface(
34
- fn=predict,
35
- inputs=gr.Interface.DataType.json,
36
- outputs=gr.outputs.Label(num_top_classes=3),
37
- live=True,
38
- examples=[
39
- {"prompt": "This is the prompt", "A": "Option A text", "B": "Option B text", "C": "Option C text", "D": "Option D text", "E": "Option E text"}
40
- ],
41
- title="LLM Science Exam Demo",
42
- description="Enter the prompt and options (A to E) below and get predictions.",
43
- )
44
-
45
- # Run the interface locally
46
- iface.launch(share=True)
47
-
48
- # Once you have verified that the interface works as expected, proceed to create the Hugging Face space:
49
- """
50
- repo_url = hf_hub_url("your-username/your-repo-name")
51
- repo = Repository.from_hf_hub(repo_url)
52
- repo.push(path="./my_model", model=model, tokenizer=tokenizer, config=model.config)"""
 
1
+ from typing import Iterator
2
+
3
  import gradio as gr
4
+ import torch
5
+
6
+ from model import get_input_token_length, run
7
+ MAX_MAX_NEW_TOKENS = 512
8
+ DEFAULT_MAX_NEW_TOKENS = 1024
9
+ MAX_INPUT_TOKEN_LENGTH = 512
10
+
11
+ if not torch.cuda.is_available():
12
+ DESCRIPTION += '\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>'
13
+
14
+ def clear_and_save_textbox(message: str) -> tuple[str, str]:
15
+ return '', message
16
+
17
+
18
+ def display_input(message: str,
19
+ history: list[tuple[str, str]]) -> list[tuple[str, str]]:
20
+ history.append((message, ''))
21
+ return history
22
 
23
 
24
+ def delete_prev_fn(
25
+ history: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], str]:
26
+ try:
27
+ message, _ = history.pop()
28
+ except IndexError:
29
+ message = ''
30
+ return history, message or ''
31
+
32
+ def generate(
33
+ message: str,
34
+ history_with_input: list[tuple[str, str]],
35
+ system_prompt: str,
36
+ max_new_tokens: int,
37
+ temperature: float,
38
+ top_p: float,
39
+ top_k: int,
40
+ ) -> Iterator[list[tuple[str, str]]]:
41
+ if max_new_tokens > MAX_MAX_NEW_TOKENS:
42
+ raise ValueError
43
+
44
+ history = history_with_input[:-1]
45
+ generator = run(message, history, system_prompt, max_new_tokens, temperature, top_p, top_k)
46
+ try:
47
+ first_response = next(generator)
48
+ yield history + [(message, first_response)]
49
+ except StopIteration:
50
+ yield history + [(message, '')]
51
+ for response in generator:
52
+ yield history + [(message, response)]
53
 
 
54
  def preprocess(sample):
55
  first_sentences = [sample["prompt"]] * 5
56
  second_sentences = [sample[option] for option in "ABCDE"]
 
68
  predictions_as_ids = torch.argsort(-logits, dim=1)
69
  answers = np.array(list("ABCDE"))[predictions_as_ids.tolist()]
70
  return ["".join(i) for i in answers[:, :3]]
71
+ def check_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> None:
72
+ input_token_length = get_input_token_length(message, chat_history, system_prompt)
73
+ if input_token_length > MAX_INPUT_TOKEN_LENGTH:
74
+ raise gr.Error(f'The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again.')
75
+
76
+ with gr.Blocks(css='style.css') as demo:
77
+ gr.Markdown(DESCRIPTION)
78
+ gr.DuplicateButton(value='Duplicate Space for private use',
79
+ elem_id='duplicate-button')
80
+
81
+ with gr.Group():
82
+ chatbot = gr.Chatbot(label='Chatbot')
83
+ with gr.Row():
84
+ textbox = gr.Textbox(
85
+ container=False,
86
+ show_label=False,
87
+ placeholder='Type a message...',
88
+ scale=10,
89
+ )
90
+ submit_button = gr.Button('Submit',
91
+ variant='primary',
92
+ scale=1,
93
+ min_width=0)
94
+ with gr.Row():
95
+ retry_button = gr.Button('🔄 Retry', variant='secondary')
96
+ undo_button = gr.Button('↩️ Undo', variant='secondary')
97
+ clear_button = gr.Button('🗑️ Clear', variant='secondary')
98
+
99
+ saved_input = gr.State()
100
+
101
+ with gr.Accordion(label='Advanced options', open=False):
102
+ system_prompt = gr.Textbox(label='System prompt',
103
+ value=DEFAULT_SYSTEM_PROMPT,
104
+ lines=6)
105
+ max_new_tokens = gr.Slider(
106
+ label='Max new tokens',
107
+ minimum=1,
108
+ maximum=MAX_MAX_NEW_TOKENS,
109
+ step=1,
110
+ value=DEFAULT_MAX_NEW_TOKENS,
111
+ )
112
+ temperature = gr.Slider(
113
+ label='Temperature',
114
+ minimum=0.1,
115
+ maximum=4.0,
116
+ step=0.1,
117
+ value=1.0,
118
+ )
119
+ top_p = gr.Slider(
120
+ label='Top-p (nucleus sampling)',
121
+ minimum=0.05,
122
+ maximum=1.0,
123
+ step=0.05,
124
+ value=0.95,
125
+ )
126
+ top_k = gr.Slider(
127
+ label='Top-k',
128
+ minimum=1,
129
+ maximum=1000,
130
+ step=1,
131
+ value=50,
132
+ )
133
+
134
+ gr.Interface(
135
+
136
+ fn=predict,
137
+ inputs=gr.Interface.DataType.json,
138
+ outputs=gr.outputs.Label(num_top_classes=3),
139
+ live=True,
140
+ examples=[
141
+ {"prompt": "This is the prompt", "A": "Option A text", "B": "Option B text", "C": "Option C text", "D": "Option D text", "E": "Option E text"}
142
+ ],
143
+ title="LLM Science Exam Demo",
144
+ description="Enter the prompt and options (A to E) below and get predictions.",
145
+ )
146
+ gr.Markdown(LICENSE)
147
+
148
+ textbox.submit(
149
+ fn=clear_and_save_textbox,
150
+ inputs=textbox,
151
+ outputs=[textbox, saved_input],
152
+ api_name=False,
153
+ queue=False,
154
+ ).then(
155
+ fn=display_input,
156
+ inputs=[saved_input, chatbot],
157
+ outputs=chatbot,
158
+ api_name=False,
159
+ queue=False,
160
+ ).then(
161
+ fn=check_input_token_length,
162
+ inputs=[saved_input, chatbot, system_prompt],
163
+ api_name=False,
164
+ queue=False,
165
+ ).success(
166
+ fn=generate,
167
+ inputs=[
168
+ saved_input,
169
+ chatbot,
170
+ system_prompt,
171
+ max_new_tokens,
172
+ temperature,
173
+ top_p,
174
+ top_k,
175
+ ],
176
+ outputs=chatbot,
177
+ api_name=False,
178
+ )
179
+
180
+ button_event_preprocess = submit_button.click(
181
+ fn=clear_and_save_textbox,
182
+ inputs=textbox,
183
+ outputs=[textbox, saved_input],
184
+ api_name=False,
185
+ queue=False,
186
+ ).then(
187
+ fn=display_input,
188
+ inputs=[saved_input, chatbot],
189
+ outputs=chatbot,
190
+ api_name=False,
191
+ queue=False,
192
+ ).then(
193
+ fn=check_input_token_length,
194
+ inputs=[saved_input, chatbot, system_prompt],
195
+ api_name=False,
196
+ queue=False,
197
+ ).success(
198
+ fn=generate,
199
+ inputs=[
200
+ saved_input,
201
+ chatbot,
202
+ system_prompt,
203
+ max_new_tokens,
204
+ temperature,
205
+ top_p,
206
+ top_k,
207
+ ],
208
+ outputs=chatbot,
209
+ api_name=False,
210
+ )
211
+
212
+ retry_button.click(
213
+ fn=delete_prev_fn,
214
+ inputs=chatbot,
215
+ outputs=[chatbot, saved_input],
216
+ api_name=False,
217
+ queue=False,
218
+ ).then(
219
+ fn=display_input,
220
+ inputs=[saved_input, chatbot],
221
+ outputs=chatbot,
222
+ api_name=False,
223
+ queue=False,
224
+ ).then(
225
+ fn=generate,
226
+ inputs=[
227
+ saved_input,
228
+ chatbot,
229
+ system_prompt,
230
+ max_new_tokens,
231
+ temperature,
232
+ top_p,
233
+ top_k,
234
+ ],
235
+ outputs=chatbot,
236
+ api_name=False,
237
+ )
238
+
239
+ undo_button.click(
240
+ fn=delete_prev_fn,
241
+ inputs=chatbot,
242
+ outputs=[chatbot, saved_input],
243
+ api_name=False,
244
+ queue=False,
245
+ ).then(
246
+ fn=lambda x: x,
247
+ inputs=[saved_input],
248
+ outputs=textbox,
249
+ api_name=False,
250
+ queue=False,
251
+ )
252
+
253
+ clear_button.click(
254
+ fn=lambda: ([], ''),
255
+ outputs=[chatbot, saved_input],
256
+ queue=False,
257
+ api_name=False,
258
+ )
259
 
260
+ demo.queue(max_size=20).launch()