Upload run_program.py with huggingface_hub
Browse files- run_program.py +11 -8
run_program.py
CHANGED
@@ -238,9 +238,11 @@ def update_question_with_new_parameters():
|
|
238 |
json.dump(program_data, outfile, indent=4)
|
239 |
|
240 |
|
241 |
-
def call_answer_question(question, model_name='gpt'):
|
242 |
-
|
243 |
-
|
|
|
|
|
244 |
prompt = prompt_template.format_map(
|
245 |
{"question": question}
|
246 |
)
|
@@ -321,7 +323,8 @@ def call_answer_question(question, model_name='gpt'):
|
|
321 |
outputs = llama_pipeline(
|
322 |
messages,
|
323 |
max_new_tokens=300,
|
324 |
-
temperature=0.00001
|
|
|
325 |
)
|
326 |
# print(outputs[0]["generated_text"][-1])
|
327 |
return outputs[0]["generated_text"][-1]['content']
|
@@ -332,19 +335,19 @@ def answer_question(model_name='gpt'):
|
|
332 |
program_data = json.load(infile)
|
333 |
print(len(program_data))
|
334 |
for case in tqdm(program_data):
|
335 |
-
response = call_answer_question(case['question'], model_name=model_name)
|
336 |
case['prediction'] = response
|
337 |
# print(case['prediction'])
|
338 |
case['new_prediction'] = []
|
339 |
for question in case['new_questions']:
|
340 |
-
response = call_answer_question(question, model_name=model_name)
|
341 |
case['new_prediction'].append(response)
|
342 |
# print(case)
|
343 |
# break
|
344 |
# print(case)
|
345 |
# break
|
346 |
-
outfile = open('data/math/test_dump_gsm8k_train_perturbed_with_new_questions_answer_llama8b.json', 'w')
|
347 |
-
|
348 |
json.dump(program_data, outfile, indent=4)
|
349 |
|
350 |
|
|
|
238 |
json.dump(program_data, outfile, indent=4)
|
239 |
|
240 |
|
241 |
+
def call_answer_question(question, model_name='gpt', cot=False):
|
242 |
+
if cot:
|
243 |
+
prompt_template = PROMPT_DICT['prompt_answer_question_few_shot_cot']
|
244 |
+
else:
|
245 |
+
prompt_template = PROMPT_DICT['prompt_answer_question']
|
246 |
prompt = prompt_template.format_map(
|
247 |
{"question": question}
|
248 |
)
|
|
|
323 |
outputs = llama_pipeline(
|
324 |
messages,
|
325 |
max_new_tokens=300,
|
326 |
+
# temperature=0.00001
|
327 |
+
temperature = 0.7
|
328 |
)
|
329 |
# print(outputs[0]["generated_text"][-1])
|
330 |
return outputs[0]["generated_text"][-1]['content']
|
|
|
335 |
program_data = json.load(infile)
|
336 |
print(len(program_data))
|
337 |
for case in tqdm(program_data):
|
338 |
+
response = call_answer_question(case['question'], model_name=model_name, cot=True)
|
339 |
case['prediction'] = response
|
340 |
# print(case['prediction'])
|
341 |
case['new_prediction'] = []
|
342 |
for question in case['new_questions']:
|
343 |
+
response = call_answer_question(question, model_name=model_name, cot=True)
|
344 |
case['new_prediction'].append(response)
|
345 |
# print(case)
|
346 |
# break
|
347 |
# print(case)
|
348 |
# break
|
349 |
+
# outfile = open('data/math/test_dump_gsm8k_train_perturbed_with_new_questions_answer_llama8b.json', 'w')
|
350 |
+
outfile = open('data/math/gsm8k_cot_sc_llama3.1_8b/temp=0.7_iter=5.json', 'w')
|
351 |
json.dump(program_data, outfile, indent=4)
|
352 |
|
353 |
|