research14 commited on
Commit
0743d21
·
1 Parent(s): 21a6ab6

created function for llama and gpt strategies

Browse files
Files changed (1) hide show
  1. app.py +61 -2
app.py CHANGED
@@ -112,6 +112,31 @@ def llama_respond(tab_name, message, chat_history):
112
  time.sleep(2)
113
  return tab_name, "", chat_history
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  def vicuna_strategies_respond(strategy, task_name, task_ling_ent, message, chat_history):
116
  formatted_prompt = ""
117
  if (task_name == "POS Tagging"):
@@ -144,6 +169,38 @@ def vicuna_strategies_respond(strategy, task_name, task_ling_ent, message, chat_
144
  time.sleep(2)
145
  return task_name, "", chat_history
146
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  def interface():
148
 
149
  # prompt = template_single.format(tab_name, textbox_prompt)
@@ -205,11 +262,13 @@ def interface():
205
  llama_S1_chatbot = gr.Chatbot(label="llama-7b")
206
  gpt_S1_chatbot = gr.Chatbot(label="gpt-3.5")
207
  gr.Markdown("Strategy 2 Instruction-Based Prompting")
 
208
  with gr.Row():
209
  vicuna_S2_chatbot = gr.Chatbot(label="vicuna-7b")
210
  llama_S2_chatbot = gr.Chatbot(label="llama-7b")
211
  gpt_S2_chatbot = gr.Chatbot(label="gpt-3.5")
212
  gr.Markdown("Strategy 3 Structured Prompting")
 
213
  with gr.Row():
214
  vicuna_S3_chatbot = gr.Chatbot(label="vicuna-7b")
215
  llama_S3_chatbot = gr.Chatbot(label="llama-7b")
@@ -223,9 +282,9 @@ def interface():
223
  # Event Handlers for Vicuna Chatbot POS/Chunk
224
  task_btn.click(vicuna_strategies_respond, inputs=[strategy1, task, task_linguistic_entities, task_prompt, vicuna_S1_chatbot],
225
  outputs=[task, task_prompt, vicuna_S1_chatbot])
226
- task_btn.click(vicuna_strategies_respond, inputs=[strategy1, task, task_linguistic_entities, task_prompt, vicuna_S2_chatbot],
227
  outputs=[task, task_prompt, vicuna_S2_chatbot])
228
- task_btn.click(vicuna_strategies_respond, inputs=[strategy1, task, task_linguistic_entities, task_prompt, vicuna_S3_chatbot],
229
  outputs=[task, task_prompt, vicuna_S3_chatbot])
230
 
231
  # Event Handler for LLaMA Chatbot POS/Chunk
 
112
  time.sleep(2)
113
  return tab_name, "", chat_history
114
 
115
+ def gpt_strategies_respond(strategy, task_name, task_ling_ent, message, chat_history, max_convo_length = 10):
116
+ formatted_system_prompt = ""
117
+ if (task_name == "POS Tagging"):
118
+ if (strategy == "S1"):
119
+ formatted_system_prompt = f'''Generate the output only for the assistant. Please output any {task_ling_ent} in the following sentence one per line without any additional text: {message}'''
120
+ elif (strategy == "S2"):
121
+ formatted_system_prompt = f'''Please POS tag the following sentence using Universal POS tag set without generating any additional text: {message}'''
122
+ elif (strategy == "S3"):
123
+ formatted_system_prompt = f'''Please POS tag the following sentence using Universal POS tag set without generating any additional text: {message}'''
124
+ elif (task_name == "Chunking"):
125
+ if (strategy == "S1"):
126
+ formatted_system_prompt = f'''Generate the output only for the assistant. Please output any {task_ling_ent} in the following sentence one per line without any additional text: {message}'''
127
+ elif (strategy == "S2"):
128
+ formatted_system_prompt = f'''Please chunk the following sentence in CoNLL 2000 format with BIO tags without outputing any additional text: {message}'''
129
+ elif (strategy == "S3"):
130
+ formatted_system_prompt = f'''Please chunk the following sentence in CoNLL 2000 format with BIO tags without outputing any additional text: {message}'''
131
+
132
+ formatted_prompt = format_chat_prompt(message, chat_history, max_convo_length)
133
+ print('Prompt + Context:')
134
+ print(formatted_prompt)
135
+ bot_message = chat(system_prompt = formatted_system_prompt,
136
+ user_prompt = formatted_prompt)
137
+ chat_history.append((message, bot_message))
138
+ return "", chat_history
139
+
140
  def vicuna_strategies_respond(strategy, task_name, task_ling_ent, message, chat_history):
141
  formatted_prompt = ""
142
  if (task_name == "POS Tagging"):
 
169
  time.sleep(2)
170
  return task_name, "", chat_history
171
 
172
+ def llama_strategies_respond(strategy, task_name, task_ling_ent, message, chat_history):
173
+ formatted_prompt = ""
174
+ if (task_name == "POS Tagging"):
175
+ if (strategy == "S1"):
176
+ formatted_prompt = f'''Generate the output only for the assistant. Please output any {task_ling_ent} in the following sentence one per line without any additional text: {message}'''
177
+ elif (strategy == "S2"):
178
+ formatted_prompt = f'''Please POS tag the following sentence using Universal POS tag set without generating any additional text: {message}'''
179
+ elif (strategy == "S3"):
180
+ formatted_prompt = f'''Please POS tag the following sentence using Universal POS tag set without generating any additional text: {message}'''
181
+ elif (task_name == "Chunking"):
182
+ if (strategy == "S1"):
183
+ formatted_prompt = f'''Generate the output only for the assistant. Please output any {task_ling_ent} in the following sentence one per line without any additional text: {message}'''
184
+ elif (strategy == "S2"):
185
+ formatted_prompt = f'''Please chunk the following sentence in CoNLL 2000 format with BIO tags without outputing any additional text: {message}'''
186
+ elif (strategy == "S3"):
187
+ formatted_prompt = f'''Please chunk the following sentence in CoNLL 2000 format with BIO tags without outputing any additional text: {message}'''
188
+
189
+ # print('Llama Strategies - Prompt + Context:')
190
+ # print(formatted_prompt)
191
+ input_ids = llama_tokenizer.encode(formatted_prompt, return_tensors="pt")
192
+ output_ids = llama_model.generate(input_ids, do_sample=True, max_length=1024, num_beams=5, no_repeat_ngram_size=2)
193
+ bot_message = llama_tokenizer.decode(output_ids[0], skip_special_tokens=True)
194
+ # print(bot_message)
195
+
196
+ # Remove formatted prompt from bot_message
197
+ bot_message = bot_message.replace(formatted_prompt, '')
198
+ # print(bot_message)
199
+
200
+ chat_history.append((formatted_prompt, bot_message))
201
+ time.sleep(2)
202
+ return task_name, "", chat_history
203
+
204
  def interface():
205
 
206
  # prompt = template_single.format(tab_name, textbox_prompt)
 
262
  llama_S1_chatbot = gr.Chatbot(label="llama-7b")
263
  gpt_S1_chatbot = gr.Chatbot(label="gpt-3.5")
264
  gr.Markdown("Strategy 2 Instruction-Based Prompting")
265
+ strategy2 = gr.Markdown("S2", visible=False)
266
  with gr.Row():
267
  vicuna_S2_chatbot = gr.Chatbot(label="vicuna-7b")
268
  llama_S2_chatbot = gr.Chatbot(label="llama-7b")
269
  gpt_S2_chatbot = gr.Chatbot(label="gpt-3.5")
270
  gr.Markdown("Strategy 3 Structured Prompting")
271
+ strategy3 = gr.Markdown("S3", visible=False)
272
  with gr.Row():
273
  vicuna_S3_chatbot = gr.Chatbot(label="vicuna-7b")
274
  llama_S3_chatbot = gr.Chatbot(label="llama-7b")
 
282
  # Event Handlers for Vicuna Chatbot POS/Chunk
283
  task_btn.click(vicuna_strategies_respond, inputs=[strategy1, task, task_linguistic_entities, task_prompt, vicuna_S1_chatbot],
284
  outputs=[task, task_prompt, vicuna_S1_chatbot])
285
+ task_btn.click(vicuna_strategies_respond, inputs=[strategy2, task, task_linguistic_entities, task_prompt, vicuna_S2_chatbot],
286
  outputs=[task, task_prompt, vicuna_S2_chatbot])
287
+ task_btn.click(vicuna_strategies_respond, inputs=[strategy3, task, task_linguistic_entities, task_prompt, vicuna_S3_chatbot],
288
  outputs=[task, task_prompt, vicuna_S3_chatbot])
289
 
290
  # Event Handler for LLaMA Chatbot POS/Chunk