Spaces:
Runtime error
Runtime error
More aggressive OOM prevention, num_z reduced to 50 for gsm8k and mmlu
Browse files
detect-pretrain-code-contamination/src/run.py
CHANGED
@@ -90,7 +90,7 @@ def sample_generation(sentence, model, tokenizer, args,data_name):
|
|
90 |
if data_name != "cais/mmlu" or data_name != "gsm8k":
|
91 |
output = model.generate(input_ids, max_new_tokens=len(sentence.split())-half_sentence_index, min_new_tokens=1, num_return_sequences=args['num_z'], pad_token_id=tokenizer.eos_token_id, **args['generate_args'])
|
92 |
else:
|
93 |
-
output = model.generate(input_ids, max_new_tokens=(len(sentence.split())-half_sentence_index)/2, min_new_tokens=1, num_return_sequences=args['num_z'], pad_token_id=tokenizer.eos_token_id, **args['generate_args'])
|
94 |
# print(output)
|
95 |
complete_generated_text = tokenizer.batch_decode(output, skip_special_tokens=True)
|
96 |
|
|
|
90 |
if data_name != "cais/mmlu" or data_name != "gsm8k":
|
91 |
output = model.generate(input_ids, max_new_tokens=len(sentence.split())-half_sentence_index, min_new_tokens=1, num_return_sequences=args['num_z'], pad_token_id=tokenizer.eos_token_id, **args['generate_args'])
|
92 |
else:
|
93 |
+
output = model.generate(input_ids, max_new_tokens=(len(sentence.split())-half_sentence_index)/2, min_new_tokens=1, num_return_sequences=int(args['num_z']/2), pad_token_id=tokenizer.eos_token_id, **args['generate_args'])
|
94 |
# print(output)
|
95 |
complete_generated_text = tokenizer.batch_decode(output, skip_special_tokens=True)
|
96 |
|