Mattral commited on
Commit
233f32d
β€’
1 Parent(s): 5c426e5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +296 -92
app.py CHANGED
@@ -1,101 +1,305 @@
 
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
- import random
4
- import textwrap
5
-
6
- # Define the model to be used
7
- model = "mistralai/Mixtral-8x7B-Instruct-v0.1"
8
-
9
- # Load model directly
10
- #model = "GRMenon/mental-health-mistral-7b-instructv0.2-finetuned-V2"
11
- client = InferenceClient(model)
12
-
13
- # Embedded system prompt
14
- system_prompt_text = "Act like a compassionate and helpful Health consultant and professional named PhoenixA. You help and support with any kind of request and provide a detailed answer or suggestion to the question. You are friendly and always willing to help. But if you are asked about something unethical or dangerous, you must provide a safe and respectful way to handle that. Never say that you cannot help themBe sure to ask for specific problem and do your best to give professional advices, remember you are a professional."
15
-
16
- # Read the content of the info.md file
17
- with open("info.md", "r") as file:
18
- info_md_content = file.read()
19
-
20
- # Chunk the info.md content into smaller sections
21
- chunk_size = 2000 # Adjust this size as needed
22
- info_md_chunks = textwrap.wrap(info_md_content, chunk_size)
23
-
24
- def get_all_chunks(chunks):
25
- return "\n\n".join(chunks)
26
-
27
- def format_prompt_mixtral(message, history, info_md_chunks):
28
- prompt = "<s>"
29
- all_chunks = get_all_chunks(info_md_chunks)
30
- prompt += f"{all_chunks}\n\n" # Add all chunks of info.md at the beginning
31
- prompt += f"{system_prompt_text}\n\n" # Add the system prompt
32
-
33
- if history:
34
- for user_prompt, bot_response in history:
35
- prompt += f"[INST] {user_prompt} [/INST]"
36
- prompt += f" {bot_response}</s> "
37
- prompt += f"[INST] {message} [/INST]"
38
- return prompt
39
-
40
- def chat_inf(prompt, history, seed, temp, tokens, top_p, rep_p):
41
- generate_kwargs = dict(
42
- temperature=temp,
43
- max_new_tokens=tokens,
44
- top_p=top_p,
45
- repetition_penalty=rep_p,
46
- do_sample=True,
47
- seed=seed,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  )
 
 
 
 
 
 
 
49
 
50
- formatted_prompt = format_prompt_mixtral(prompt, history, info_md_chunks)
51
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
52
- output = ""
53
- for response in stream:
54
- output += response.token.text
55
- yield [(prompt, output)]
56
- history.append((prompt, output))
57
- yield history
58
 
59
- def clear_fn():
60
- return None, None
 
 
 
 
 
 
 
 
 
61
 
62
- rand_val = random.randint(1, 1111111111111111)
63
 
64
- def check_rand(inp, val):
65
- if inp:
66
- return gr.Slider(label="Seed", minimum=1, maximum=1111111111111111, value=random.randint(1, 1111111111111111))
67
- else:
68
- return gr.Slider(label="Seed", minimum=1, maximum=1111111111111111, value=int(val))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
- with gr.Blocks() as app: # Add auth here
71
- gr.HTML("""<center><h1 style='font-size:xx-large;'>PhoenixAI</h1><br><h3> made with love by Omdena </h3><br><h7>EXPERIMENTAL</center>""")
72
- with gr.Row():
73
- chat = gr.Chatbot(height=500)
74
  with gr.Group():
 
75
  with gr.Row():
76
- with gr.Column(scale=3):
77
- inp = gr.Textbox(label="Prompt", lines=5, interactive=True) # Increased lines and interactive
78
- with gr.Row():
79
- with gr.Column(scale=2):
80
- btn = gr.Button("Chat")
81
- with gr.Column(scale=1):
82
- with gr.Group():
83
- stop_btn = gr.Button("Stop")
84
- clear_btn = gr.Button("Clear")
85
- with gr.Column(scale=1):
86
- with gr.Group():
87
- rand = gr.Checkbox(label="Random Seed", value=True)
88
- seed = gr.Slider(label="Seed", minimum=1, maximum=1111111111111111, step=1, value=rand_val)
89
- tokens = gr.Slider(label="Max new tokens", value=3840, minimum=0, maximum=8000, step=64, interactive=True, visible=True, info="The maximum number of tokens")
90
- temp = gr.Slider(label="Temperature", step=0.01, minimum=0.01, maximum=1.0, value=0.9)
91
- top_p = gr.Slider(label="Top-P", step=0.01, minimum=0.01, maximum=1.0, value=0.9)
92
- rep_p = gr.Slider(label="Repetition Penalty", step=0.1, minimum=0.1, maximum=2.0, value=1.0)
93
-
94
- hid1 = gr.Number(value=1, visible=False)
95
-
96
- go = btn.click(check_rand, [rand, seed], seed).then(chat_inf, [inp, chat, seed, temp, tokens, top_p, rep_p], chat)
97
-
98
- stop_btn.click(None, None, None, cancels=[go])
99
- clear_btn.click(clear_fn, None, [inp, chat])
100
-
101
- app.queue(default_concurrency_limit=10).launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Iterator
3
+
4
  import gradio as gr
5
+
6
+ from src.model import run
7
+
8
+ HF_PUBLIC = os.environ.get("HF_PUBLIC", False)
9
+
10
+ DEFAULT_SYSTEM_PROMPT = "You are Phoenix AI Healthcare. You are professional, you are polite, give only truthful information and are based on the Mistral-7B model from Mistral AI about Healtcare and Wellness. You can communicate in different languages equally well."
11
+ MAX_MAX_NEW_TOKENS = 4096
12
+ DEFAULT_MAX_NEW_TOKENS = 256
13
+ MAX_INPUT_TOKEN_LENGTH = 4000
14
+
15
+ DESCRIPTION = """
16
+ # [Mistral-7B](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
17
+ """
18
+
19
+
20
+ def clear_and_save_textbox(message: str) -> tuple[str, str]:
21
+ """
22
+ Clear the textbox and save the input to a state variable.
23
+ :param message: The input message.
24
+ :return: A tuple of the empty string and the input message.
25
+ """
26
+ return "", message
27
+
28
+
29
+ def display_input(
30
+ message: str, history: list[tuple[str, str]]
31
+ ) -> list[tuple[str, str]]:
32
+ """
33
+ Display the input message in the chat history.
34
+ :param message: The input message.
35
+ :param history: The chat history.
36
+ :return: The chat history with the input message appended.
37
+ """
38
+ history.append((message, ""))
39
+ return history
40
+
41
+
42
+ def delete_prev_fn(
43
+ history: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], str]:
44
+ """
45
+ Delete the previous message from the chat history.
46
+ :param history: The chat history.
47
+ :return: The chat history with the last message removed
48
+ and the removed message.
49
+ """
50
+ try:
51
+ message, _ = history.pop()
52
+ except IndexError:
53
+ message = ""
54
+ return history, message or ""
55
+
56
+
57
+ def generate(
58
+ message: str,
59
+ history_with_input: list[tuple[str, str]],
60
+ system_prompt: str,
61
+ max_new_tokens: int,
62
+ temperature: float,
63
+ top_p: float,
64
+ top_k: int,
65
+ ) -> Iterator[list[tuple[str, str]]]:
66
+ """
67
+ Generate a response to the input message.
68
+ :param message: The input message.
69
+ :param history_with_input: The chat history with
70
+ the input message appended.
71
+ :param system_prompt: The system prompt.
72
+ :param max_new_tokens: The maximum number of tokens to generate.
73
+ :param temperature: The temperature.
74
+ :param top_p: The top-p (nucleus sampling) probability.
75
+ :param top_k: The top-k probability.
76
+ :return: An iterator over the chat history with
77
+ the generated response appended.
78
+ """
79
+ if max_new_tokens > MAX_MAX_NEW_TOKENS:
80
+ raise ValueError
81
+
82
+ history = history_with_input[:-1]
83
+ generator = run(
84
+ message, history,
85
+ system_prompt, max_new_tokens, temperature, top_p, top_k
86
  )
87
+ try:
88
+ first_response = next(generator)
89
+ yield history + [(message, first_response)]
90
+ except StopIteration:
91
+ yield history + [(message, "")]
92
+ for response in generator:
93
+ yield history + [(message, response)]
94
 
 
 
 
 
 
 
 
 
95
 
96
+ def process_example(message: str) -> tuple[str, list[tuple[str, str]]]:
97
+ """
98
+ Process an example.
99
+ :param message: The input message.
100
+ :return: A tuple of the empty string and the chat history with the \
101
+ generated response appended.
102
+ """
103
+ generator = generate(message, [], DEFAULT_SYSTEM_PROMPT, 1024, 1, 0.95, 50)
104
+ for x in generator:
105
+ pass
106
+ return "", x
107
 
 
108
 
109
+ def check_input_token_length(
110
+ message: str, chat_history: list[tuple[str, str]], system_prompt: str
111
+ ) -> None:
112
+ """
113
+ Check that the accumulated input is not too long.
114
+ :param message: The input message.
115
+ :param chat_history: The chat history.
116
+ :param system_prompt: The system prompt.
117
+ :return: None.
118
+ """
119
+ input_token_length = len(message) + len(chat_history)
120
+ if input_token_length > MAX_INPUT_TOKEN_LENGTH:
121
+ raise gr.Error(
122
+ f"The accumulated input is too long \
123
+ ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}).\
124
+ Clear your chat history and try again."
125
+ )
126
+
127
+
128
+ with gr.Blocks(css="./styles/style.css") as demo:
129
+ gr.Markdown(DESCRIPTION)
130
+ gr.DuplicateButton(
131
+ value="Duplicate Space for private use", elem_id="duplicate-button"
132
+ )
133
 
 
 
 
 
134
  with gr.Group():
135
+ chatbot = gr.Chatbot(label="Playground")
136
  with gr.Row():
137
+ textbox = gr.Textbox(
138
+ container=False,
139
+ show_label=False,
140
+ placeholder="Greetings, with what Healthcare/Wellness topic can I help you with today?",
141
+ scale=10,
142
+ )
143
+ submit_button = gr.Button("Submit", variant="primary",
144
+ scale=1, min_width=0)
145
+ with gr.Row():
146
+ retry_button = gr.Button('πŸ”„ Retry', variant='secondary')
147
+ undo_button = gr.Button('↩️ Undo', variant='secondary')
148
+ clear_button = gr.Button('πŸ—‘οΈ Clear', variant='secondary')
149
+
150
+ saved_input = gr.State()
151
+
152
+ with gr.Accordion(label="βš™οΈ Advanced options", open=False):
153
+ system_prompt = gr.Textbox(
154
+ label="System prompt",
155
+ value=DEFAULT_SYSTEM_PROMPT,
156
+ lines=5,
157
+ interactive=False,
158
+ )
159
+ max_new_tokens = gr.Slider(
160
+ label="Max new tokens",
161
+ minimum=1,
162
+ maximum=MAX_MAX_NEW_TOKENS,
163
+ step=1,
164
+ value=DEFAULT_MAX_NEW_TOKENS,
165
+ )
166
+ temperature = gr.Slider(
167
+ label="Temperature",
168
+ minimum=0.1,
169
+ maximum=4.0,
170
+ step=0.1,
171
+ value=0.1,
172
+ )
173
+ top_p = gr.Slider(
174
+ label="Top-p (nucleus sampling)",
175
+ minimum=0.05,
176
+ maximum=1.0,
177
+ step=0.05,
178
+ value=0.9,
179
+ )
180
+ top_k = gr.Slider(
181
+ label="Top-k",
182
+ minimum=1,
183
+ maximum=1000,
184
+ step=1,
185
+ value=10,
186
+ )
187
+
188
+ textbox.submit(
189
+ fn=clear_and_save_textbox,
190
+ inputs=textbox,
191
+ outputs=[textbox, saved_input],
192
+ api_name=False,
193
+ queue=False,
194
+ ).then(
195
+ fn=display_input,
196
+ inputs=[saved_input, chatbot],
197
+ outputs=chatbot,
198
+ api_name=False,
199
+ queue=False,
200
+ ).then(
201
+ fn=check_input_token_length,
202
+ inputs=[saved_input, chatbot, system_prompt],
203
+ api_name=False,
204
+ queue=False,
205
+ ).success(
206
+ fn=generate,
207
+ inputs=[
208
+ saved_input,
209
+ chatbot,
210
+ system_prompt,
211
+ max_new_tokens,
212
+ temperature,
213
+ top_p,
214
+ top_k,
215
+ ],
216
+ outputs=chatbot,
217
+ api_name=False,
218
+ )
219
+
220
+ button_event_preprocess = (
221
+ submit_button.click(
222
+ fn=clear_and_save_textbox,
223
+ inputs=textbox,
224
+ outputs=[textbox, saved_input],
225
+ api_name=False,
226
+ queue=False,
227
+ )
228
+ .then(
229
+ fn=display_input,
230
+ inputs=[saved_input, chatbot],
231
+ outputs=chatbot,
232
+ api_name=False,
233
+ queue=False,
234
+ )
235
+ .then(
236
+ fn=check_input_token_length,
237
+ inputs=[saved_input, chatbot, system_prompt],
238
+ api_name=False,
239
+ queue=False,
240
+ )
241
+ .success(
242
+ fn=generate,
243
+ inputs=[
244
+ saved_input,
245
+ chatbot,
246
+ system_prompt,
247
+ max_new_tokens,
248
+ temperature,
249
+ top_p,
250
+ top_k,
251
+ ],
252
+ outputs=chatbot,
253
+ api_name=False,
254
+ )
255
+ )
256
+
257
+ retry_button.click(
258
+ fn=delete_prev_fn,
259
+ inputs=chatbot,
260
+ outputs=[chatbot, saved_input],
261
+ api_name=False,
262
+ queue=False,
263
+ ).then(
264
+ fn=display_input,
265
+ inputs=[saved_input, chatbot],
266
+ outputs=chatbot,
267
+ api_name=False,
268
+ queue=False,
269
+ ).then(
270
+ fn=generate,
271
+ inputs=[
272
+ saved_input,
273
+ chatbot,
274
+ system_prompt,
275
+ max_new_tokens,
276
+ temperature,
277
+ top_p,
278
+ top_k,
279
+ ],
280
+ outputs=chatbot,
281
+ api_name=False,
282
+ )
283
+
284
+ undo_button.click(
285
+ fn=delete_prev_fn,
286
+ inputs=chatbot,
287
+ outputs=[chatbot, saved_input],
288
+ api_name=False,
289
+ queue=False,
290
+ ).then(
291
+ fn=lambda x: x,
292
+ inputs=[saved_input],
293
+ outputs=textbox,
294
+ api_name=False,
295
+ queue=False,
296
+ )
297
+
298
+ clear_button.click(
299
+ fn=lambda: ([], ""),
300
+ outputs=[chatbot, saved_input],
301
+ queue=False,
302
+ api_name=False,
303
+ )
304
+
305
+ demo.queue(max_size=32).launch(share=HF_PUBLIC, show_api=False)