Chris4K commited on
Commit
311bab5
·
verified ·
1 Parent(s): 89c3be1

Update services/strategy.py

Browse files
Files changed (1) hide show
  1. services/strategy.py +1 -1
services/strategy.py CHANGED
@@ -60,7 +60,7 @@ class BestOfN(GenerationStrategy):
60
  input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device)
61
  output = generator.model.generate(input_ids, **model_kwargs)
62
  response =generator.tokenizer.decode(output[0], skip_special_tokens=True)
63
- score = generator.prm_model(**generator.tokenizer(response, return_tensors="pt").to(generator.device)).logits.mean().item()
64
  scored_outputs.append((response, score))
65
  return max(scored_outputs, key=lambda x: x[1])[0]
66
 
 
60
  input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device)
61
  output = generator.model.generate(input_ids, **model_kwargs)
62
  response =generator.tokenizer.decode(output[0], skip_special_tokens=True)
63
+ score = generator.prm_model.generate(**generator.tokenizer(response, return_tensors="pt").to(generator.device)).logits.mean().item()
64
  scored_outputs.append((response, score))
65
  return max(scored_outputs, key=lambda x: x[1])[0]
66