Chris4K commited on
Commit
fbe3eee
·
verified ·
1 Parent(s): 16d990c

Update services/strategy.py

Browse files
Files changed (1) hide show
  1. services/strategy.py +1 -1
services/strategy.py CHANGED
@@ -70,7 +70,7 @@ class BestOfN(GenerationStrategy):
70
  response_inputs = generator.tokenizer(response, return_tensors="pt").to(generator.device)
71
 
72
  # Pass the response inputs correctly to the PRM model
73
- prm_output = generator.prm_model(response_inputs, **model_kwargs) # Pass the inputs correctly to the model
74
 
75
  # Check the expected output structure for prm_model and use it accordingly
76
  score = prm_output.logits.mean().item() if hasattr(prm_output, 'logits') else 0.0
 
70
  response_inputs = generator.tokenizer(response, return_tensors="pt").to(generator.device)
71
 
72
  # Pass the response inputs correctly to the PRM model
73
+ prm_output = generator.prm_model(response_inputs) # Pass the inputs correctly to the model
74
 
75
  # Check the expected output structure for prm_model and use it accordingly
76
  score = prm_output.logits.mean().item() if hasattr(prm_output, 'logits') else 0.0