Yeyito commited on
Commit
98282cd
·
1 Parent(s): a0bf640

More aggressive OOM prevention, num_z reduced to 50 for gsm8k and mmlu

Browse files
detect-pretrain-code-contamination/src/run.py CHANGED
@@ -90,7 +90,7 @@ def sample_generation(sentence, model, tokenizer, args,data_name):
90
  if data_name != "cais/mmlu" or data_name != "gsm8k":
91
  output = model.generate(input_ids, max_new_tokens=len(sentence.split())-half_sentence_index, min_new_tokens=1, num_return_sequences=args['num_z'], pad_token_id=tokenizer.eos_token_id, **args['generate_args'])
92
  else:
93
- output = model.generate(input_ids, max_new_tokens=(len(sentence.split())-half_sentence_index)/2, min_new_tokens=1, num_return_sequences=args['num_z'], pad_token_id=tokenizer.eos_token_id, **args['generate_args'])
94
  # print(output)
95
  complete_generated_text = tokenizer.batch_decode(output, skip_special_tokens=True)
96
 
 
90
  if data_name != "cais/mmlu" or data_name != "gsm8k":
91
  output = model.generate(input_ids, max_new_tokens=len(sentence.split())-half_sentence_index, min_new_tokens=1, num_return_sequences=args['num_z'], pad_token_id=tokenizer.eos_token_id, **args['generate_args'])
92
  else:
93
+ output = model.generate(input_ids, max_new_tokens=(len(sentence.split())-half_sentence_index)/2, min_new_tokens=1, num_return_sequences=int(args['num_z']/2), pad_token_id=tokenizer.eos_token_id, **args['generate_args'])
94
  # print(output)
95
  complete_generated_text = tokenizer.batch_decode(output, skip_special_tokens=True)
96