Chris4K commited on
Commit
33239be
·
verified ·
1 Parent(s): f9528ef

Update services/strategy.py

Browse files
Files changed (1) hide show
  1. services/strategy.py +19 -8
services/strategy.py CHANGED
@@ -55,14 +55,25 @@ class MajorityVotingStrategy(GenerationStrategy):
55
 
56
  class BestOfN(GenerationStrategy):
57
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
58
- scored_outputs = []
59
- for _ in range(num_samples):
60
- input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device)
61
- output = generator.model.generate(input_ids, **model_kwargs)
62
- response =generator.tokenizer.decode(output[0], skip_special_tokens=True)
63
- score = generator.prm_model.generate(**generator.tokenizer(response, return_tensors="pt").to(generator.device)).logits.mean().item()
64
- scored_outputs.append((response, score))
65
- return max(scored_outputs, key=lambda x: x[1])[0]
 
 
 
 
 
 
 
 
 
 
 
66
 
67
 
68
  class BeamSearch(GenerationStrategy):
 
55
 
56
  class BestOfN(GenerationStrategy):
57
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
58
+ scored_outputs = []
59
+ for _ in range(num_samples):
60
+ # Tokenize the prompt and move tensors to the appropriate device
61
+ input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device)
62
+
63
+ # Generate output from the main model
64
+ output = generator.model.generate(input_ids, **model_kwargs)
65
+ response = generator.tokenizer.decode(output[0], skip_special_tokens=True)
66
+
67
+ # Tokenize the response for scoring with the PRM model
68
+ response_inputs = generator.tokenizer(response, return_tensors="pt").to(generator.device)
69
+ prm_output = generator.prm_model(**response_inputs) # Pass the inputs correctly to the model
70
+ score = prm_output.logits.mean().item()
71
+
72
+ # Append the response and its score
73
+ scored_outputs.append((response, score))
74
+
75
+ # Return the response with the highest score
76
+ return max(scored_outputs, key=lambda x: x[1])[0]
77
 
78
 
79
  class BeamSearch(GenerationStrategy):