Update services/strategy.py
Browse files- services/strategy.py +1 -1
services/strategy.py
CHANGED
@@ -70,7 +70,7 @@ class BestOfN(GenerationStrategy):
|
|
70 |
response_inputs = generator.tokenizer(response, return_tensors="pt").to(generator.device)
|
71 |
|
72 |
# Pass the response inputs correctly to the PRM model
|
73 |
-
prm_output = generator.prm_model(**
|
74 |
|
75 |
# Check the expected output structure for prm_model and use it accordingly
|
76 |
score = prm_output.logits.mean().item() if hasattr(prm_output, 'logits') else 0.0
|
|
|
70 |
response_inputs = generator.tokenizer(response, return_tensors="pt").to(generator.device)
|
71 |
|
72 |
# Pass the response inputs correctly to the PRM model
|
73 |
+
prm_output = generator.prm_model(response_inputs, **model_kwargs) # Pass the inputs correctly to the model
|
74 |
|
75 |
# Check the expected output structure for prm_model and use it accordingly
|
76 |
score = prm_output.logits.mean().item() if hasattr(prm_output, 'logits') else 0.0
|