Chris4K commited on
Commit
121ac13
·
verified ·
1 Parent(s): f59b437

Update services/strategy.py

Browse files
Files changed (1) hide show
  1. services/strategy.py +23 -23
services/strategy.py CHANGED
@@ -35,21 +35,21 @@ class DefaultStrategy(GenerationStrategy):
35
  return generator.tokenizer.decode(output[0], skip_special_tokens=True)
36
  #def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str:
37
  #
38
- # tokenizer = generator.tokenizers["llama"]
39
- # model = generator.models["llama"].generate
40
  #
41
- # input_ids = generator.tokenizers["llama"](prompt, return_tensors="pt").input_ids.to(generator.device)
42
- # output = generator.models["llama"].generate(input_ids, **model_kwargs)
43
- # return generator.tokenizers["llama"].decode(output[0], skip_special_tokens=True)
44
 
45
 
46
  class MajorityVotingStrategy(GenerationStrategy):
47
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
48
  outputs = []
49
  for _ in range(num_samples):
50
- input_ids = generator.tokenizers["llama"](prompt, return_tensors="pt").input_ids.to(generator.device)
51
- output = generator.models["llama"].generate(input_ids, **model_kwargs)
52
- outputs.append(generator.tokenizers["llama"].decode(output[0], skip_special_tokens=True))
53
  return max(set(outputs), key=outputs.count)
54
 
55
 
@@ -57,43 +57,43 @@ class BestOfN(GenerationStrategy):
57
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
58
  scored_outputs = []
59
  for _ in range(num_samples):
60
- input_ids = generator.tokenizers["llama"](prompt, return_tensors="pt").input_ids.to(generator.device)
61
- output = generator.models["llama"].generate(input_ids, **model_kwargs)
62
- response =generator.tokenizers["llama"].decode(output[0], skip_special_tokens=True)
63
- score = generator.prm_model(**generator.tokenizers["llama"](response, return_tensors="pt").to(generator.device)).logits.mean().item()
64
  scored_outputs.append((response, score))
65
  return max(scored_outputs, key=lambda x: x[1])[0]
66
 
67
 
68
  class BeamSearch(GenerationStrategy):
69
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
70
- input_ids = generator.tokenizers["llama"](prompt, return_tensors="pt").input_ids.to(generator.device)
71
- outputs = generator.models["llama"].generate(
72
  input_ids,
73
  num_beams=num_samples,
74
  num_return_sequences=num_samples,
75
  **model_kwargs
76
  )
77
- return [generator.tokenizers["llama"].decode(output, skip_special_tokens=True) for output in outputs]
78
 
79
 
80
  class DVT(GenerationStrategy):
81
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
82
  results = []
83
  for _ in range(breadth):
84
- input_ids = generator.tokenizers["llama"](prompt, return_tensors="pt").input_ids.to(generator.device)
85
- output = generator.models["llama"].generate(input_ids, **model_kwargs)
86
- response = generator.tokenizers["llama"].decode(output[0], skip_special_tokens=True)
87
- score = generator.prm_model(**generator.tokenizers["llama"](response, return_tensors="pt").to(generator.device)).logits.mean().item()
88
  results.append((response, score))
89
 
90
  for _ in range(depth - 1):
91
  best_responses = sorted(results, key=lambda x: x[1], reverse=True)[:breadth]
92
  for response, _ in best_responses:
93
- input_ids = generator.tokenizers["llama"](response, return_tensors="pt").input_ids.to(generator.device)
94
- output = generator.models["llama"].generate(input_ids, **model_kwargs)
95
- extended_response = generator.tokenizers["llama"].decode(output[0], skip_special_tokens=True)
96
- score = generator.prm_model(**generator.tokenizers["llama"](extended_response, return_tensors="pt").to(generator.device)).logits.mean().item()
97
  results.append((extended_response, score))
98
  return max(results, key=lambda x: x[1])[0]
99
 
 
35
  return generator.tokenizer.decode(output[0], skip_special_tokens=True)
36
  #def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str:
37
  #
38
+ # tokenizer = generator.tokenizer
39
+ # model = generator.model.generate
40
  #
41
+ # input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device)
42
+ # output = generator.model.generate(input_ids, **model_kwargs)
43
+ # return generator.tokenizer.decode(output[0], skip_special_tokens=True)
44
 
45
 
46
  class MajorityVotingStrategy(GenerationStrategy):
47
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
48
  outputs = []
49
  for _ in range(num_samples):
50
+ input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device)
51
+ output = generator.model.generate(input_ids, **model_kwargs)
52
+ outputs.append(generator.tokenizer.decode(output[0], skip_special_tokens=True))
53
  return max(set(outputs), key=outputs.count)
54
 
55
 
 
57
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
58
  scored_outputs = []
59
  for _ in range(num_samples):
60
+ input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device)
61
+ output = generator.model.generate(input_ids, **model_kwargs)
62
+ response =generator.tokenizer.decode(output[0], skip_special_tokens=True)
63
+ score = generator.prm_model(**generator.tokenizer(response, return_tensors="pt").to(generator.device)).logits.mean().item()
64
  scored_outputs.append((response, score))
65
  return max(scored_outputs, key=lambda x: x[1])[0]
66
 
67
 
68
  class BeamSearch(GenerationStrategy):
69
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
70
+ input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device)
71
+ outputs = generator.model.generate(
72
  input_ids,
73
  num_beams=num_samples,
74
  num_return_sequences=num_samples,
75
  **model_kwargs
76
  )
77
+ return [generator.tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
78
 
79
 
80
  class DVT(GenerationStrategy):
81
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
82
  results = []
83
  for _ in range(breadth):
84
+ input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device)
85
+ output = generator.model.generate(input_ids, **model_kwargs)
86
+ response = generator.tokenizer.decode(output[0], skip_special_tokens=True)
87
+ score = generator.prm_model(**generator.tokenizer(response, return_tensors="pt").to(generator.device)).logits.mean().item()
88
  results.append((response, score))
89
 
90
  for _ in range(depth - 1):
91
  best_responses = sorted(results, key=lambda x: x[1], reverse=True)[:breadth]
92
  for response, _ in best_responses:
93
+ input_ids = generator.tokenizer(response, return_tensors="pt").input_ids.to(generator.device)
94
+ output = generator.model.generate(input_ids, **model_kwargs)
95
+ extended_response = generator.tokenizer.decode(output[0], skip_special_tokens=True)
96
+ score = generator.prm_model(**generator.tokenizer(extended_response, return_tensors="pt").to(generator.device)).logits.mean().item()
97
  results.append((extended_response, score))
98
  return max(results, key=lambda x: x[1])[0]
99