research14 commited on
Commit
221eddd
·
1 Parent(s): 9f1cf26

commented out

Browse files
Files changed (1) hide show
  1. run_llm.py +90 -90
run_llm.py CHANGED
@@ -70,24 +70,24 @@ model_mapping = {
70
  # 'koala-13b': 'koala-13b',
71
  }
72
 
73
- # for m in model_mapping.keys():
74
- # for eid, ent in enumerate(ents):
75
- # os.makedirs(f'result/prompt1_qa/{m}/ptb/per_ent/{ent}', exist_ok=True)
76
 
77
- # os.makedirs(f'result/prompt2_instruction/pos_tagging/{m}/ptb', exist_ok=True)
78
- # os.makedirs(f'result/prompt2_instruction/chunking/{m}/ptb', exist_ok=True)
79
- # os.makedirs(f'result/prompt2_instruction/parsing/{m}/ptb', exist_ok=True)
80
 
81
- # os.makedirs(f'result/prompt3_structured_prompt/pos_tagging/{m}/ptb', exist_ok=True)
82
- # os.makedirs(f'result/prompt3_structured_prompt/chunking/{m}/ptb', exist_ok=True)
83
- # os.makedirs(f'result/prompt3_structured_prompt/parsing/{m}/ptb', exist_ok=True)
84
 
85
 
86
- # s = int(sys.argv[1])
87
- # e = int(sys.argv[2])
88
 
89
- # s = 0
90
- # e = 1000
91
  with open('sample_uniform_1k_2.txt', 'r') as f:
92
  selected_idx = f.readlines()
93
  selected_idx = [int(i.strip()) for i in selected_idx]#[s:e]
@@ -219,34 +219,34 @@ def main(args=None):
219
  text = ptb[gid]['text']
220
 
221
  ## POS tagging
222
- # if os.path.exists(f'result/prompt2_instruction/pos_tagging/{args.model_path}/ptb/{gid}.txt'):
223
- # print(gid, 'skip')
224
 
225
- # else:
226
- # msg = prompt2_pos.format(text)
227
 
228
- # if 'gpt3' in args.model_path:
229
- # outputs = gpt3(msg)
230
- # if outputs is None:
231
- # continue
232
- # time.sleep(0.2)
233
 
234
- # else:
235
- # conv = get_conversation_template(args.model_path)
236
- # conv.append_message(conv.roles[0], msg)
237
- # conv.append_message(conv.roles[1], None)
238
- # conv.system = ''
239
- # prompt = conv.get_prompt()
240
 
241
- # outputs = fastchat(prompt, model, tokenizer)
242
 
243
- # with open(f'result/prompt2_instruction/pos_tagging/{args.model_path}/ptb/{gid}.txt', 'w') as f:
244
- # f.write(outputs)
245
 
246
 
247
  ## Sentence chunking
248
- # if os.path.exists(f'result/prompt2_instruction/chunking/{args.model_path}/ptb/{gid}.txt'):
249
- # print(gid, 'skip')
250
  if False:
251
  pass
252
  else:
@@ -273,29 +273,29 @@ def main(args=None):
273
 
274
 
275
  ## Parsing
276
- # if os.path.exists(f'result/prompt2_instruction/parsing/{args.model_path}/ptb/{gid}.txt'):
277
- # print(gid, 'skip')
278
 
279
- # else:
280
- # msg = prompt2_parse.format(text)
281
 
282
- # if 'gpt3' in args.model_path:
283
- # outputs = gpt3(msg)
284
- # if outputs is None:
285
- # continue
286
- # time.sleep(0.2)
287
 
288
- # else:
289
- # conv = get_conversation_template(args.model_path)
290
- # conv.append_message(conv.roles[0], msg)
291
- # conv.append_message(conv.roles[1], None)
292
- # conv.system = ''
293
- # prompt = conv.get_prompt()
294
 
295
- # outputs = fastchat(prompt, model, tokenizer)
296
 
297
- # with open(f'result/prompt2_instruction/parsing/{args.model_path}/ptb/{gid}.txt', 'w') as f:
298
- # f.write(outputs)
299
 
300
 
301
 
@@ -306,29 +306,29 @@ def main(args=None):
306
  poss = ptb[gid]['uni_poss']
307
 
308
  ## POS tagging
309
- # if os.path.exists(f'result/prompt3_structured_prompt/pos_tagging/{args.model_path}/ptb/{gid}.txt'):
310
- # print(gid, 'skip')
311
- # continue
312
 
313
- # prompt = demon_pos + '\n' + 'C: ' + text + '\n' + 'T: '
314
 
315
- # if 'gpt3' in args.model_path:
316
- # outputs = gpt3(prompt)
317
- # if outputs is None:
318
- # continue
319
- # time.sleep(0.2)
320
 
321
- # else:
322
- # pred_poss = []
323
- # for _tok, _pos in zip(tokens, poss):
324
- # prompt = prompt + ' ' + _tok + '_'
325
- # outputs = structured_prompt(prompt, model, tokenizer, bad_words_ids_pos)
326
- # prompt = prompt + outputs
327
- # pred_poss.append(outputs)
328
 
329
- # outputs = ' '.join(pred_poss)
330
- # with open(f'result/prompt3_structured_prompt/pos_tagging/{args.model_path}/ptb/{gid}.txt', 'w') as f:
331
- # f.write(outputs)
332
 
333
 
334
  ## Chunking
@@ -366,27 +366,27 @@ def main(args=None):
366
  f.write(outputs)
367
 
368
  ## Parsing
369
- # if os.path.exists(f'result/prompt3_structured_prompt/parsing/{args.model_path}/ptb/{gid}.txt'):
370
- # print(gid, 'skip')
371
- # continue
372
 
373
- # prompt = demon_parse + '\n' + 'C: ' + text + '\n' + 'T: '
374
 
375
- # if 'gpt3' in args.model_path:
376
- # outputs = gpt3(prompt)
377
- # if outputs is None:
378
- # continue
379
- # time.sleep(0.2)
380
 
381
- # else:
382
- # pred_syn = []
383
- # for _tok, _pos in zip(tokens, poss):
384
- # prompt = prompt + _tok + '_'
385
- # outputs = structured_prompt(prompt, model, tokenizer, bad_words_ids_parse)
386
- # pred_syn.append(outputs)
387
 
388
- # with open(f'result/prompt3_structured_prompt/parsing/{args.model_path}/ptb/{gid}.txt', 'w') as f:
389
- # f.write(' '.join(pred_syn))
390
 
391
 
392
  def structured_prompt(prompt, model, tokenizer, bad_words_ids):
@@ -426,9 +426,9 @@ def fastchat(prompt, model, tokenizer):
426
  output_ids, skip_special_tokens=True, spaces_between_special_tokens=False
427
  )
428
 
429
- # print('Empty system message')
430
- # print(f"{conv.roles[0]}: {msg}")
431
- # print(f"{conv.roles[1]}: {outputs}")
432
 
433
  return outputs
434
 
 
70
  # 'koala-13b': 'koala-13b',
71
  }
72
 
73
+ for m in model_mapping.keys():
74
+ for eid, ent in enumerate(ents):
75
+ os.makedirs(f'result/prompt1_qa/{m}/ptb/per_ent/{ent}', exist_ok=True)
76
 
77
+ os.makedirs(f'result/prompt2_instruction/pos_tagging/{m}/ptb', exist_ok=True)
78
+ os.makedirs(f'result/prompt2_instruction/chunking/{m}/ptb', exist_ok=True)
79
+ os.makedirs(f'result/prompt2_instruction/parsing/{m}/ptb', exist_ok=True)
80
 
81
+ os.makedirs(f'result/prompt3_structured_prompt/pos_tagging/{m}/ptb', exist_ok=True)
82
+ os.makedirs(f'result/prompt3_structured_prompt/chunking/{m}/ptb', exist_ok=True)
83
+ os.makedirs(f'result/prompt3_structured_prompt/parsing/{m}/ptb', exist_ok=True)
84
 
85
 
86
+ #s = int(sys.argv[1])
87
+ #e = int(sys.argv[2])
88
 
89
+ #s = 0
90
+ #e = 1000
91
  with open('sample_uniform_1k_2.txt', 'r') as f:
92
  selected_idx = f.readlines()
93
  selected_idx = [int(i.strip()) for i in selected_idx]#[s:e]
 
219
  text = ptb[gid]['text']
220
 
221
  ## POS tagging
222
+ if os.path.exists(f'result/prompt2_instruction/pos_tagging/{args.model_path}/ptb/{gid}.txt'):
223
+ print(gid, 'skip')
224
 
225
+ else:
226
+ msg = prompt2_pos.format(text)
227
 
228
+ if 'gpt3' in args.model_path:
229
+ outputs = gpt3(msg)
230
+ if outputs is None:
231
+ continue
232
+ time.sleep(0.2)
233
 
234
+ else:
235
+ conv = get_conversation_template(args.model_path)
236
+ conv.append_message(conv.roles[0], msg)
237
+ conv.append_message(conv.roles[1], None)
238
+ conv.system = ''
239
+ prompt = conv.get_prompt()
240
 
241
+ outputs = fastchat(prompt, model, tokenizer)
242
 
243
+ with open(f'result/prompt2_instruction/pos_tagging/{args.model_path}/ptb/{gid}.txt', 'w') as f:
244
+ f.write(outputs)
245
 
246
 
247
  ## Sentence chunking
248
+ if os.path.exists(f'result/prompt2_instruction/chunking/{args.model_path}/ptb/{gid}.txt'):
249
+ print(gid, 'skip')
250
  if False:
251
  pass
252
  else:
 
273
 
274
 
275
  ## Parsing
276
+ if os.path.exists(f'result/prompt2_instruction/parsing/{args.model_path}/ptb/{gid}.txt'):
277
+ print(gid, 'skip')
278
 
279
+ else:
280
+ msg = prompt2_parse.format(text)
281
 
282
+ if 'gpt3' in args.model_path:
283
+ outputs = gpt3(msg)
284
+ if outputs is None:
285
+ continue
286
+ time.sleep(0.2)
287
 
288
+ else:
289
+ conv = get_conversation_template(args.model_path)
290
+ conv.append_message(conv.roles[0], msg)
291
+ conv.append_message(conv.roles[1], None)
292
+ conv.system = ''
293
+ prompt = conv.get_prompt()
294
 
295
+ outputs = fastchat(prompt, model, tokenizer)
296
 
297
+ with open(f'result/prompt2_instruction/parsing/{args.model_path}/ptb/{gid}.txt', 'w') as f:
298
+ f.write(outputs)
299
 
300
 
301
 
 
306
  poss = ptb[gid]['uni_poss']
307
 
308
  ## POS tagging
309
+ if os.path.exists(f'result/prompt3_structured_prompt/pos_tagging/{args.model_path}/ptb/{gid}.txt'):
310
+ print(gid, 'skip')
311
+ continue
312
 
313
+ prompt = demon_pos + '\n' + 'C: ' + text + '\n' + 'T: '
314
 
315
+ if 'gpt3' in args.model_path:
316
+ outputs = gpt3(prompt)
317
+ if outputs is None:
318
+ continue
319
+ time.sleep(0.2)
320
 
321
+ else:
322
+ pred_poss = []
323
+ for _tok, _pos in zip(tokens, poss):
324
+ prompt = prompt + ' ' + _tok + '_'
325
+ outputs = structured_prompt(prompt, model, tokenizer, bad_words_ids_pos)
326
+ prompt = prompt + outputs
327
+ pred_poss.append(outputs)
328
 
329
+ outputs = ' '.join(pred_poss)
330
+ with open(f'result/prompt3_structured_prompt/pos_tagging/{args.model_path}/ptb/{gid}.txt', 'w') as f:
331
+ f.write(outputs)
332
 
333
 
334
  ## Chunking
 
366
  f.write(outputs)
367
 
368
  ## Parsing
369
+ if os.path.exists(f'result/prompt3_structured_prompt/parsing/{args.model_path}/ptb/{gid}.txt'):
370
+ print(gid, 'skip')
371
+ continue
372
 
373
+ prompt = demon_parse + '\n' + 'C: ' + text + '\n' + 'T: '
374
 
375
+ if 'gpt3' in args.model_path:
376
+ outputs = gpt3(prompt)
377
+ if outputs is None:
378
+ continue
379
+ time.sleep(0.2)
380
 
381
+ else:
382
+ pred_syn = []
383
+ for _tok, _pos in zip(tokens, poss):
384
+ prompt = prompt + _tok + '_'
385
+ outputs = structured_prompt(prompt, model, tokenizer, bad_words_ids_parse)
386
+ pred_syn.append(outputs)
387
 
388
+ with open(f'result/prompt3_structured_prompt/parsing/{args.model_path}/ptb/{gid}.txt', 'w') as f:
389
+ f.write(' '.join(pred_syn))
390
 
391
 
392
  def structured_prompt(prompt, model, tokenizer, bad_words_ids):
 
426
  output_ids, skip_special_tokens=True, spaces_between_special_tokens=False
427
  )
428
 
429
+ print('Empty system message')
430
+ print(f"{conv.roles[0]}: {msg}")
431
+ print(f"{conv.roles[1]}: {outputs}")
432
 
433
  return outputs
434