Chris4K commited on
Commit
3a7547b
·
verified ·
1 Parent(s): fbe3eee

Update services/strategy.py

Browse files
Files changed (1) hide show
  1. services/strategy.py +6 -3
services/strategy.py CHANGED
@@ -66,12 +66,15 @@ class BestOfN(GenerationStrategy):
66
  response = generator.tokenizer.decode(output[0], skip_special_tokens=True)
67
 
68
  # Tokenize the response for scoring with the PRM model
69
- #TODO use the real tokenizer from generator
70
  response_inputs = generator.tokenizer(response, return_tensors="pt").to(generator.device)
71
 
72
- # Pass the response inputs correctly to the PRM model
73
- prm_output = generator.prm_model(response_inputs) # Pass the inputs correctly to the model
 
74
 
 
 
 
75
  # Check the expected output structure for prm_model and use it accordingly
76
  score = prm_output.logits.mean().item() if hasattr(prm_output, 'logits') else 0.0
77
 
 
66
  response = generator.tokenizer.decode(output[0], skip_special_tokens=True)
67
 
68
  # Tokenize the response for scoring with the PRM model
 
69
  response_inputs = generator.tokenizer(response, return_tensors="pt").to(generator.device)
70
 
71
+ # Extract the necessary inputs for prm_model
72
+ prm_input_ids = response_inputs["input_ids"] # Always present
73
+ attention_mask = response_inputs["attention_mask"] # Optional, depending on your model
74
 
75
+ # Pass only the required tensors to prm_model
76
+ prm_output = generator.prm_model(input_ids=prm_input_ids, attention_mask=attention_mask)
77
+
78
  # Check the expected output structure for prm_model and use it accordingly
79
  score = prm_output.logits.mean().item() if hasattr(prm_output, 'logits') else 0.0
80