Spaces:
Runtime error
Runtime error
Commit
·
0743d21
1
Parent(s):
21a6ab6
created function for llama and gpt strategies
Browse files
app.py
CHANGED
@@ -112,6 +112,31 @@ def llama_respond(tab_name, message, chat_history):
|
|
112 |
time.sleep(2)
|
113 |
return tab_name, "", chat_history
|
114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
def vicuna_strategies_respond(strategy, task_name, task_ling_ent, message, chat_history):
|
116 |
formatted_prompt = ""
|
117 |
if (task_name == "POS Tagging"):
|
@@ -144,6 +169,38 @@ def vicuna_strategies_respond(strategy, task_name, task_ling_ent, message, chat_
|
|
144 |
time.sleep(2)
|
145 |
return task_name, "", chat_history
|
146 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
def interface():
|
148 |
|
149 |
# prompt = template_single.format(tab_name, textbox_prompt)
|
@@ -205,11 +262,13 @@ def interface():
|
|
205 |
llama_S1_chatbot = gr.Chatbot(label="llama-7b")
|
206 |
gpt_S1_chatbot = gr.Chatbot(label="gpt-3.5")
|
207 |
gr.Markdown("Strategy 2 Instruction-Based Prompting")
|
|
|
208 |
with gr.Row():
|
209 |
vicuna_S2_chatbot = gr.Chatbot(label="vicuna-7b")
|
210 |
llama_S2_chatbot = gr.Chatbot(label="llama-7b")
|
211 |
gpt_S2_chatbot = gr.Chatbot(label="gpt-3.5")
|
212 |
gr.Markdown("Strategy 3 Structured Prompting")
|
|
|
213 |
with gr.Row():
|
214 |
vicuna_S3_chatbot = gr.Chatbot(label="vicuna-7b")
|
215 |
llama_S3_chatbot = gr.Chatbot(label="llama-7b")
|
@@ -223,9 +282,9 @@ def interface():
|
|
223 |
# Event Handlers for Vicuna Chatbot POS/Chunk
|
224 |
task_btn.click(vicuna_strategies_respond, inputs=[strategy1, task, task_linguistic_entities, task_prompt, vicuna_S1_chatbot],
|
225 |
outputs=[task, task_prompt, vicuna_S1_chatbot])
|
226 |
-
task_btn.click(vicuna_strategies_respond, inputs=[
|
227 |
outputs=[task, task_prompt, vicuna_S2_chatbot])
|
228 |
-
task_btn.click(vicuna_strategies_respond, inputs=[
|
229 |
outputs=[task, task_prompt, vicuna_S3_chatbot])
|
230 |
|
231 |
# Event Handler for LLaMA Chatbot POS/Chunk
|
|
|
112 |
time.sleep(2)
|
113 |
return tab_name, "", chat_history
|
114 |
|
115 |
+
def gpt_strategies_respond(strategy, task_name, task_ling_ent, message, chat_history, max_convo_length = 10):
|
116 |
+
formatted_system_prompt = ""
|
117 |
+
if (task_name == "POS Tagging"):
|
118 |
+
if (strategy == "S1"):
|
119 |
+
formatted_system_prompt = f'''Generate the output only for the assistant. Please output any {task_ling_ent} in the following sentence one per line without any additional text: {message}'''
|
120 |
+
elif (strategy == "S2"):
|
121 |
+
formatted_system_prompt = f'''Please POS tag the following sentence using Universal POS tag set without generating any additional text: {message}'''
|
122 |
+
elif (strategy == "S3"):
|
123 |
+
formatted_system_prompt = f'''Please POS tag the following sentence using Universal POS tag set without generating any additional text: {message}'''
|
124 |
+
elif (task_name == "Chunking"):
|
125 |
+
if (strategy == "S1"):
|
126 |
+
formatted_system_prompt = f'''Generate the output only for the assistant. Please output any {task_ling_ent} in the following sentence one per line without any additional text: {message}'''
|
127 |
+
elif (strategy == "S2"):
|
128 |
+
formatted_system_prompt = f'''Please chunk the following sentence in CoNLL 2000 format with BIO tags without outputing any additional text: {message}'''
|
129 |
+
elif (strategy == "S3"):
|
130 |
+
formatted_system_prompt = f'''Please chunk the following sentence in CoNLL 2000 format with BIO tags without outputing any additional text: {message}'''
|
131 |
+
|
132 |
+
formatted_prompt = format_chat_prompt(message, chat_history, max_convo_length)
|
133 |
+
print('Prompt + Context:')
|
134 |
+
print(formatted_prompt)
|
135 |
+
bot_message = chat(system_prompt = formatted_system_prompt,
|
136 |
+
user_prompt = formatted_prompt)
|
137 |
+
chat_history.append((message, bot_message))
|
138 |
+
return "", chat_history
|
139 |
+
|
140 |
def vicuna_strategies_respond(strategy, task_name, task_ling_ent, message, chat_history):
|
141 |
formatted_prompt = ""
|
142 |
if (task_name == "POS Tagging"):
|
|
|
169 |
time.sleep(2)
|
170 |
return task_name, "", chat_history
|
171 |
|
172 |
+
def llama_strategies_respond(strategy, task_name, task_ling_ent, message, chat_history):
|
173 |
+
formatted_prompt = ""
|
174 |
+
if (task_name == "POS Tagging"):
|
175 |
+
if (strategy == "S1"):
|
176 |
+
formatted_prompt = f'''Generate the output only for the assistant. Please output any {task_ling_ent} in the following sentence one per line without any additional text: {message}'''
|
177 |
+
elif (strategy == "S2"):
|
178 |
+
formatted_prompt = f'''Please POS tag the following sentence using Universal POS tag set without generating any additional text: {message}'''
|
179 |
+
elif (strategy == "S3"):
|
180 |
+
formatted_prompt = f'''Please POS tag the following sentence using Universal POS tag set without generating any additional text: {message}'''
|
181 |
+
elif (task_name == "Chunking"):
|
182 |
+
if (strategy == "S1"):
|
183 |
+
formatted_prompt = f'''Generate the output only for the assistant. Please output any {task_ling_ent} in the following sentence one per line without any additional text: {message}'''
|
184 |
+
elif (strategy == "S2"):
|
185 |
+
formatted_prompt = f'''Please chunk the following sentence in CoNLL 2000 format with BIO tags without outputing any additional text: {message}'''
|
186 |
+
elif (strategy == "S3"):
|
187 |
+
formatted_prompt = f'''Please chunk the following sentence in CoNLL 2000 format with BIO tags without outputing any additional text: {message}'''
|
188 |
+
|
189 |
+
# print('Llama Strategies - Prompt + Context:')
|
190 |
+
# print(formatted_prompt)
|
191 |
+
input_ids = llama_tokenizer.encode(formatted_prompt, return_tensors="pt")
|
192 |
+
output_ids = llama_model.generate(input_ids, do_sample=True, max_length=1024, num_beams=5, no_repeat_ngram_size=2)
|
193 |
+
bot_message = llama_tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
194 |
+
# print(bot_message)
|
195 |
+
|
196 |
+
# Remove formatted prompt from bot_message
|
197 |
+
bot_message = bot_message.replace(formatted_prompt, '')
|
198 |
+
# print(bot_message)
|
199 |
+
|
200 |
+
chat_history.append((formatted_prompt, bot_message))
|
201 |
+
time.sleep(2)
|
202 |
+
return task_name, "", chat_history
|
203 |
+
|
204 |
def interface():
|
205 |
|
206 |
# prompt = template_single.format(tab_name, textbox_prompt)
|
|
|
262 |
llama_S1_chatbot = gr.Chatbot(label="llama-7b")
|
263 |
gpt_S1_chatbot = gr.Chatbot(label="gpt-3.5")
|
264 |
gr.Markdown("Strategy 2 Instruction-Based Prompting")
|
265 |
+
strategy2 = gr.Markdown("S2", visible=False)
|
266 |
with gr.Row():
|
267 |
vicuna_S2_chatbot = gr.Chatbot(label="vicuna-7b")
|
268 |
llama_S2_chatbot = gr.Chatbot(label="llama-7b")
|
269 |
gpt_S2_chatbot = gr.Chatbot(label="gpt-3.5")
|
270 |
gr.Markdown("Strategy 3 Structured Prompting")
|
271 |
+
strategy3 = gr.Markdown("S3", visible=False)
|
272 |
with gr.Row():
|
273 |
vicuna_S3_chatbot = gr.Chatbot(label="vicuna-7b")
|
274 |
llama_S3_chatbot = gr.Chatbot(label="llama-7b")
|
|
|
282 |
# Event Handlers for Vicuna Chatbot POS/Chunk
|
283 |
task_btn.click(vicuna_strategies_respond, inputs=[strategy1, task, task_linguistic_entities, task_prompt, vicuna_S1_chatbot],
|
284 |
outputs=[task, task_prompt, vicuna_S1_chatbot])
|
285 |
+
task_btn.click(vicuna_strategies_respond, inputs=[strategy2, task, task_linguistic_entities, task_prompt, vicuna_S2_chatbot],
|
286 |
outputs=[task, task_prompt, vicuna_S2_chatbot])
|
287 |
+
task_btn.click(vicuna_strategies_respond, inputs=[strategy3, task, task_linguistic_entities, task_prompt, vicuna_S3_chatbot],
|
288 |
outputs=[task, task_prompt, vicuna_S3_chatbot])
|
289 |
|
290 |
# Event Handler for LLaMA Chatbot POS/Chunk
|