xdyu commited on
Commit
7b2dd63
·
verified ·
1 Parent(s): 21af911

Upload run_program.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. run_program.py +11 -8
run_program.py CHANGED
@@ -238,9 +238,11 @@ def update_question_with_new_parameters():
238
  json.dump(program_data, outfile, indent=4)
239
 
240
 
241
- def call_answer_question(question, model_name='gpt'):
242
- prompt_template = PROMPT_DICT['prompt_answer_question']
243
- # prompt_template = PROMPT_DICT['prompt_answer_question_few_shot_cot']
 
 
244
  prompt = prompt_template.format_map(
245
  {"question": question}
246
  )
@@ -321,7 +323,8 @@ def call_answer_question(question, model_name='gpt'):
321
  outputs = llama_pipeline(
322
  messages,
323
  max_new_tokens=300,
324
- temperature=0.00001
 
325
  )
326
  # print(outputs[0]["generated_text"][-1])
327
  return outputs[0]["generated_text"][-1]['content']
@@ -332,19 +335,19 @@ def answer_question(model_name='gpt'):
332
  program_data = json.load(infile)
333
  print(len(program_data))
334
  for case in tqdm(program_data):
335
- response = call_answer_question(case['question'], model_name=model_name)
336
  case['prediction'] = response
337
  # print(case['prediction'])
338
  case['new_prediction'] = []
339
  for question in case['new_questions']:
340
- response = call_answer_question(question, model_name=model_name)
341
  case['new_prediction'].append(response)
342
  # print(case)
343
  # break
344
  # print(case)
345
  # break
346
- outfile = open('data/math/test_dump_gsm8k_train_perturbed_with_new_questions_answer_llama8b.json', 'w')
347
- # outfile = open('data/math/gsm8k_cot_sc_qwen/temp=0.7_iter=4.json', 'w')
348
  json.dump(program_data, outfile, indent=4)
349
 
350
 
 
238
  json.dump(program_data, outfile, indent=4)
239
 
240
 
241
+ def call_answer_question(question, model_name='gpt', cot=False):
242
+ if cot:
243
+ prompt_template = PROMPT_DICT['prompt_answer_question_few_shot_cot']
244
+ else:
245
+ prompt_template = PROMPT_DICT['prompt_answer_question']
246
  prompt = prompt_template.format_map(
247
  {"question": question}
248
  )
 
323
  outputs = llama_pipeline(
324
  messages,
325
  max_new_tokens=300,
326
+ # temperature=0.00001
327
+ temperature = 0.7
328
  )
329
  # print(outputs[0]["generated_text"][-1])
330
  return outputs[0]["generated_text"][-1]['content']
 
335
  program_data = json.load(infile)
336
  print(len(program_data))
337
  for case in tqdm(program_data):
338
+ response = call_answer_question(case['question'], model_name=model_name, cot=True)
339
  case['prediction'] = response
340
  # print(case['prediction'])
341
  case['new_prediction'] = []
342
  for question in case['new_questions']:
343
+ response = call_answer_question(question, model_name=model_name, cot=True)
344
  case['new_prediction'].append(response)
345
  # print(case)
346
  # break
347
  # print(case)
348
  # break
349
+ # outfile = open('data/math/test_dump_gsm8k_train_perturbed_with_new_questions_answer_llama8b.json', 'w')
350
+ outfile = open('data/math/gsm8k_cot_sc_llama3.1_8b/temp=0.7_iter=5.json', 'w')
351
  json.dump(program_data, outfile, indent=4)
352
 
353