Mattral commited on
Commit
4d323ac
·
verified ·
1 Parent(s): 055e4a6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -207
app.py CHANGED
@@ -1,214 +1,101 @@
1
  import gradio as gr
2
- from typing import Iterator, List, Tuple
3
- import torch
4
- from transformers import AutoModelForCausalLM, AutoTokenizer
5
- from peft import PeftConfig, PeftModel
6
- from huggingface_hub import login
7
- import os
8
-
9
- # Authenticate with Hugging Face
10
- HUGGINGFACE_TOKEN = os.getenv('HUGGINGFACE_TOKEN') # Ensure this environment variable is set
11
- login(token=HUGGINGFACE_TOKEN)
12
-
13
- base_model = "mistralai/Mistral-7B-Instruct-v0.2"
14
- adapter = "GRMenon/mental-health-mistral-7b-instructv0.2-finetuned-V2"
15
-
16
- # Load tokenizer
17
- tokenizer = AutoTokenizer.from_pretrained(
18
- base_model,
19
- add_bos_token=True,
20
- trust_remote_code=True,
21
- padding_side='left'
22
- )
23
-
24
- # Create peft model using base_model and finetuned adapter
25
- config = PeftConfig.from_pretrained(adapter)
26
- model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path,
27
- load_in_4bit=True,
28
- device_map='auto',
29
- torch_dtype='auto')
30
- model = PeftModel.from_pretrained(model, adapter)
31
-
32
- device = "cuda" if torch.cuda.is_available() else "cpu"
33
- model.to(device)
34
- model.eval()
35
-
36
- DEFAULT_SYSTEM_PROMPT = "You are Phoenix AI Healthcare. You are professional, you are polite, give only truthful information and are based on the Mistral-7B model from Mistral AI about Healtcare and Wellness. You can communicate in different languages equally well."
37
-
38
- MAX_MAX_NEW_TOKENS = 4096
39
- DEFAULT_MAX_NEW_TOKENS = 256
40
- MAX_INPUT_TOKEN_LENGTH = 4000
41
-
42
- DESCRIPTION = """
43
- # Simple Healthcare Chatbot
44
- ### Powered by Mistral-7B with Healthcare Fine-Tuning
45
- """
46
-
47
- def clear_and_save_textbox(message: str) -> tuple[str, str]:
48
- return "", message
49
-
50
- def display_input(message: str, history: list[tuple[str, str]]) -> list[tuple[str, str]]:
51
- history.append((message, ""))
52
- return history
53
-
54
- def delete_prev_fn(history: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], str]:
55
- try:
56
- message, _ = history.pop()
57
- except IndexError:
58
- message = ""
59
- return history, message or ""
60
-
61
- def generate(
62
- message: str,
63
- history_with_input: list[tuple[str, str]],
64
- system_prompt: str,
65
- max_new_tokens: int,
66
- temperature: float,
67
- top_p: float,
68
- top_k: int,
69
- ) -> Iterator[list[tuple[str, str]]]:
70
- if max_new_tokens > MAX_MAX_NEW_TOKENS:
71
- raise ValueError("Max new tokens exceeded")
72
-
73
- history = history_with_input[:-1]
74
- conversation = [{"role": "system", "content": system_prompt}] + \
75
- [{"role": "user", "content": user_input} for user_input, _ in history] + \
76
- [{"role": "user", "content": message}]
77
- input_ids = tokenizer.apply_chat_template(conversation=conversation,
78
- tokenize=True,
79
- add_generation_prompt=True,
80
- return_tensors='pt').to(device)
81
- output_ids = model.generate(input_ids=input_ids, max_new_tokens=max_new_tokens,
82
- do_sample=True, pad_token_id=tokenizer.pad_token_id)
83
- response = tokenizer.batch_decode(output_ids.detach().cpu().numpy(), skip_special_tokens=True)
84
- response_text = response[0]
85
-
86
- yield history + [(message, response_text)]
87
-
88
- def check_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> None:
89
- input_token_length = len(tokenizer.encode(message)) + sum(len(tokenizer.encode(msg)) for msg, _ in chat_history)
90
- if input_token_length > MAX_INPUT_TOKEN_LENGTH:
91
- raise gr.Error(f"The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again.")
92
-
93
- with gr.Blocks(css="./styles/style.css") as demo: # Link to CSS file
94
- gr.Markdown(DESCRIPTION)
95
- gr.Button("Duplicate Space for private use", elem_id="duplicate-button")
96
-
97
- with gr.Group():
98
- chatbot = gr.Chatbot(label="Chat with Healthcare AI")
99
- with gr.Row():
100
- textbox = gr.Textbox(
101
- container=False,
102
- show_label=False,
103
- placeholder="Ask me anything about Healthcare and Wellness...",
104
- scale=10,
105
- )
106
- submit_button = gr.Button("Submit", variant="primary", scale=1, min_width=0)
107
-
108
- with gr.Row():
109
- retry_button = gr.Button('🔄 Retry', variant='secondary')
110
- undo_button = gr.Button('↩️ Undo', variant='secondary')
111
- clear_button = gr.Button('🗑️ Clear', variant='secondary')
112
-
113
- saved_input = gr.State()
114
-
115
- with gr.Accordion(label="⚙️ Advanced options", open=False):
116
- system_prompt = gr.Textbox(
117
- label="System prompt",
118
- value=DEFAULT_SYSTEM_PROMPT,
119
- lines=5,
120
- interactive=False,
121
- )
122
- max_new_tokens = gr.Slider(
123
- label="Max new tokens",
124
- minimum=1,
125
- maximum=MAX_MAX_NEW_TOKENS,
126
- step=1,
127
- value=DEFAULT_MAX_NEW_TOKENS,
128
- )
129
- temperature = gr.Slider(
130
- label="Temperature",
131
- minimum=0.1,
132
- maximum=4.0,
133
- step=0.1,
134
- value=0.1,
135
- )
136
- top_p = gr.Slider(
137
- label="Top-p (nucleus sampling)",
138
- minimum=0.05,
139
- maximum=1.0,
140
- step=0.05,
141
- value=0.9,
142
- )
143
- top_k = gr.Slider(
144
- label="Top-k",
145
- minimum=1,
146
- maximum=1000,
147
- step=1,
148
- value=10,
149
- )
150
-
151
- textbox.submit(
152
- fn=clear_and_save_textbox,
153
- inputs=textbox,
154
- outputs=[textbox, saved_input],
155
- ).then(
156
- fn=display_input,
157
- inputs=[saved_input, chatbot],
158
- outputs=chatbot,
159
- ).then(
160
- fn=check_input_token_length,
161
- inputs=[saved_input, chatbot, system_prompt],
162
- ).success(
163
- fn=generate,
164
- inputs=[saved_input, chatbot, system_prompt, max_new_tokens, temperature, top_p, top_k],
165
- outputs=chatbot,
166
  )
167
 
168
- submit_button.click(
169
- fn=clear_and_save_textbox,
170
- inputs=textbox,
171
- outputs=[textbox, saved_input],
172
- ).then(
173
- fn=display_input,
174
- inputs=[saved_input, chatbot],
175
- outputs=chatbot,
176
- ).then(
177
- fn=check_input_token_length,
178
- inputs=[saved_input, chatbot, system_prompt],
179
- ).success(
180
- fn=generate,
181
- inputs=[saved_input, chatbot, system_prompt, max_new_tokens, temperature, top_p, top_k],
182
- outputs=chatbot,
183
- )
184
 
185
- retry_button.click(
186
- fn=delete_prev_fn,
187
- inputs=chatbot,
188
- outputs=[chatbot, saved_input],
189
- ).then(
190
- fn=display_input,
191
- inputs=[saved_input, chatbot],
192
- outputs=chatbot,
193
- ).then(
194
- fn=generate,
195
- inputs=[saved_input, chatbot, system_prompt, max_new_tokens, temperature, top_p, top_k],
196
- outputs=chatbot,
197
- )
198
 
199
- undo_button.click(
200
- fn=delete_prev_fn,
201
- inputs=chatbot,
202
- outputs=[chatbot, saved_input],
203
- ).then(
204
- fn=lambda x: x,
205
- inputs=[saved_input],
206
- outputs=textbox,
207
- )
208
 
209
- clear_button.click(
210
- fn=lambda: ([], ""),
211
- outputs=[chatbot, saved_input],
212
- )
 
213
 
214
- demo.queue(max_size=32).launch(share=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from huggingface_hub import InferenceClient
3
+ import random
4
+ import textwrap
5
+
6
+ # Define the model to be used
7
+ model = "mistralai/Mixtral-8x7B-Instruct-v0.1"
8
+
9
+ # Load model directly
10
+ #model = "GRMenon/mental-health-mistral-7b-instructv0.2-finetuned-V2"
11
+ client = InferenceClient(model)
12
+
13
+ # Embedded system prompt
14
+ system_prompt_text ="You are Phoenix AI Healthcare. You are professional, you are polite, give only truthful information and are based on the Mistral-7B model from Mistral AI about Healtcare and Wellness. You can communicate in different languages equally well."
15
+
16
+ # Read the content of the info.md file
17
+ with open("info.md", "r") as file:
18
+ info_md_content = file.read()
19
+
20
+ # Chunk the info.md content into smaller sections
21
+ chunk_size = 2000 # Adjust this size as needed
22
+ info_md_chunks = textwrap.wrap(info_md_content, chunk_size)
23
+
24
+ def get_all_chunks(chunks):
25
+ return "\n\n".join(chunks)
26
+
27
+ def format_prompt_mixtral(message, history, info_md_chunks):
28
+ prompt = "<s>"
29
+ all_chunks = get_all_chunks(info_md_chunks)
30
+ prompt += f"{all_chunks}\n\n" # Add all chunks of info.md at the beginning
31
+ prompt += f"{system_prompt_text}\n\n" # Add the system prompt
32
+
33
+ if history:
34
+ for user_prompt, bot_response in history:
35
+ prompt += f"[INST] {user_prompt} [/INST]"
36
+ prompt += f" {bot_response}</s> "
37
+ prompt += f"[INST] {message} [/INST]"
38
+ return prompt
39
+
40
+ def chat_inf(prompt, history, seed, temp, tokens, top_p, rep_p):
41
+ generate_kwargs = dict(
42
+ temperature=temp,
43
+ max_new_tokens=tokens,
44
+ top_p=top_p,
45
+ repetition_penalty=rep_p,
46
+ do_sample=True,
47
+ seed=seed,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  )
49
 
50
+ formatted_prompt = format_prompt_mixtral(prompt, history, info_md_chunks)
51
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
52
+ output = ""
53
+ for response in stream:
54
+ output += response.token.text
55
+ yield [(prompt, output)]
56
+ history.append((prompt, output))
57
+ yield history
 
 
 
 
 
 
 
 
58
 
59
+ def clear_fn():
60
+ return None, None
 
 
 
 
 
 
 
 
 
 
 
61
 
62
+ rand_val = random.randint(1, 1111111111111111)
 
 
 
 
 
 
 
 
63
 
64
+ def check_rand(inp, val):
65
+ if inp:
66
+ return gr.Slider(label="Seed", minimum=1, maximum=1111111111111111, value=random.randint(1, 1111111111111111))
67
+ else:
68
+ return gr.Slider(label="Seed", minimum=1, maximum=1111111111111111, value=int(val))
69
 
70
+ with gr.Blocks() as app: # Add auth here
71
+ gr.HTML("""<center><h1 style='font-size:xx-large;'>PhoenixAI</h1><br><h3> made with love by Omdena </h3><br><h7>EXPERIMENTAL</center>""")
72
+ with gr.Row():
73
+ chat = gr.Chatbot(height=500)
74
+ with gr.Group():
75
+ with gr.Row():
76
+ with gr.Column(scale=3):
77
+ inp = gr.Textbox(label="Prompt", lines=5, interactive=True) # Increased lines and interactive
78
+ with gr.Row():
79
+ with gr.Column(scale=2):
80
+ btn = gr.Button("Chat")
81
+ with gr.Column(scale=1):
82
+ with gr.Group():
83
+ stop_btn = gr.Button("Stop")
84
+ clear_btn = gr.Button("Clear")
85
+ with gr.Column(scale=1):
86
+ with gr.Group():
87
+ rand = gr.Checkbox(label="Random Seed", value=True)
88
+ seed = gr.Slider(label="Seed", minimum=1, maximum=1111111111111111, step=1, value=rand_val)
89
+ tokens = gr.Slider(label="Max new tokens", value=3840, minimum=0, maximum=8000, step=64, interactive=True, visible=True, info="The maximum number of tokens")
90
+ temp = gr.Slider(label="Temperature", step=0.01, minimum=0.01, maximum=1.0, value=0.9)
91
+ top_p = gr.Slider(label="Top-P", step=0.01, minimum=0.01, maximum=1.0, value=0.9)
92
+ rep_p = gr.Slider(label="Repetition Penalty", step=0.1, minimum=0.1, maximum=2.0, value=1.0)
93
+
94
+ hid1 = gr.Number(value=1, visible=False)
95
+
96
+ go = btn.click(check_rand, [rand, seed], seed).then(chat_inf, [inp, chat, seed, temp, tokens, top_p, rep_p], chat)
97
+
98
+ stop_btn.click(None, None, None, cancels=[go])
99
+ clear_btn.click(clear_fn, None, [inp, chat])
100
+
101
+ app.queue(default_concurrency_limit=10).launch(share=True)