Chris4K commited on
Commit
ae307db
·
verified ·
1 Parent(s): fb159fa

Update services/strategy.py

Browse files
Files changed (1) hide show
  1. services/strategy.py +31 -27
services/strategy.py CHANGED
@@ -30,75 +30,79 @@ class GenerationStrategy(ABC):
30
  class DefaultStrategy(GenerationStrategy):
31
  @observe()
32
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str:
33
- input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device)
34
- output = generator.generate(input_ids, **model_kwargs)
35
- return generator.tokenizer.decode(output[0], skip_special_tokens=True)
 
 
 
 
36
 
37
 
38
  class MajorityVotingStrategy(GenerationStrategy):
39
- def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
40
  outputs = []
41
  for _ in range(num_samples):
42
- input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device)
43
- output = generator.generate(input_ids, **model_kwargs)
44
- outputs.append(generator.tokenizer.decode(output[0], skip_special_tokens=True))
45
  return max(set(outputs), key=outputs.count)
46
 
47
 
48
  class BestOfN(GenerationStrategy):
49
- def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
50
  scored_outputs = []
51
  for _ in range(num_samples):
52
- input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device)
53
- output = generator.generate(input_ids, **model_kwargs)
54
- response =generator.tokenizer.decode(output[0], skip_special_tokens=True)
55
- score = generator.prm_model(**generator.tokenizer(response, return_tensors="pt").to(generator.device)).logits.mean().item()
56
  scored_outputs.append((response, score))
57
  return max(scored_outputs, key=lambda x: x[1])[0]
58
 
59
 
60
  class BeamSearch(GenerationStrategy):
61
- def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
62
- input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device)
63
- outputs = generator.generate(
64
  input_ids,
65
  num_beams=num_samples,
66
  num_return_sequences=num_samples,
67
  **model_kwargs
68
  )
69
- return [generator.tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
70
 
71
 
72
  class DVT(GenerationStrategy):
73
- def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
74
  results = []
75
  for _ in range(breadth):
76
- input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device)
77
- output = generator.generate(input_ids, **model_kwargs)
78
- response = generator.tokenizer.decode(output[0], skip_special_tokens=True)
79
- score = generator.prm_model(**generator.tokenizer(response, return_tensors="pt").to(generator.device)).logits.mean().item()
80
  results.append((response, score))
81
 
82
  for _ in range(depth - 1):
83
  best_responses = sorted(results, key=lambda x: x[1], reverse=True)[:breadth]
84
  for response, _ in best_responses:
85
- input_ids = generator.tokenizer(response, return_tensors="pt").input_ids.to(generator.device)
86
- output = generator.generate(input_ids, **model_kwargs)
87
- extended_response = generator.tokenizer.decode(output[0], skip_special_tokens=True)
88
- score = generator.prm_model(**generator.tokenizer(extended_response, return_tensors="pt").to(generator.device)).logits.mean().item()
89
  results.append((extended_response, score))
90
  return max(results, key=lambda x: x[1])[0]
91
 
92
 
93
  class COT(GenerationStrategy):
94
- def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
95
  #TODO implement the chain of thought strategy
96
 
97
  return "Not implemented yet"
98
 
99
 
100
  class ReAct(GenerationStrategy):
101
- def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
102
  #TODO implement the ReAct framework
103
  return "Not implemented yet"
104
  # Add other strategy implementations...
 
30
  class DefaultStrategy(GenerationStrategy):
31
  @observe()
32
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str:
33
+
34
+ tokenizer = generator.tokenizers["llama"]
35
+ model = generator.models["llama"].generate
36
+
37
+ input_ids = generator.tokenizers["llama"](prompt, return_tensors="pt").input_ids.to(generator.device)
38
+ output = generator.models["llama"].generate(input_ids, **model_kwargs)
39
+ return generator.tokenizers["llama"].decode(output[0], skip_special_tokens=True)
40
 
41
 
42
  class MajorityVotingStrategy(GenerationStrategy):
43
+ def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs, num_samples: int = 5) -> str:
44
  outputs = []
45
  for _ in range(num_samples):
46
+ input_ids = generator.tokenizers["llama"](prompt, return_tensors="pt").input_ids.to(generator.device)
47
+ output = generator.models["llama"].generate(input_ids, **model_kwargs)
48
+ outputs.append(generator.tokenizers["llama"].decode(output[0], skip_special_tokens=True))
49
  return max(set(outputs), key=outputs.count)
50
 
51
 
52
  class BestOfN(GenerationStrategy):
53
+ def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs, num_samples: int = 5) -> str:
54
  scored_outputs = []
55
  for _ in range(num_samples):
56
+ input_ids = generator.tokenizers["llama"](prompt, return_tensors="pt").input_ids.to(generator.device)
57
+ output = generator.models["llama"].generate(input_ids, **model_kwargs)
58
+ response =generator.tokenizers["llama"].decode(output[0], skip_special_tokens=True)
59
+ score = generator.prm_model(**generator.tokenizers["llama"](response, return_tensors="pt").to(generator.device)).logits.mean().item()
60
  scored_outputs.append((response, score))
61
  return max(scored_outputs, key=lambda x: x[1])[0]
62
 
63
 
64
  class BeamSearch(GenerationStrategy):
65
+ def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs, num_samples: int = 5) -> str:
66
+ input_ids = generator.tokenizers["llama"](prompt, return_tensors="pt").input_ids.to(generator.device)
67
+ outputs = generator.models["llama"].generate(
68
  input_ids,
69
  num_beams=num_samples,
70
  num_return_sequences=num_samples,
71
  **model_kwargs
72
  )
73
+ return [generator.tokenizers["llama"].decode(output, skip_special_tokens=True) for output in outputs]
74
 
75
 
76
  class DVT(GenerationStrategy):
77
+ def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs, num_samples: int = 5) -> str:
78
  results = []
79
  for _ in range(breadth):
80
+ input_ids = generator.tokenizers["llama"](prompt, return_tensors="pt").input_ids.to(generator.device)
81
+ output = generator.models["llama"].generate(input_ids, **model_kwargs)
82
+ response = generator.tokenizers["llama"].decode(output[0], skip_special_tokens=True)
83
+ score = generator.prm_model(**generator.tokenizers["llama"](response, return_tensors="pt").to(generator.device)).logits.mean().item()
84
  results.append((response, score))
85
 
86
  for _ in range(depth - 1):
87
  best_responses = sorted(results, key=lambda x: x[1], reverse=True)[:breadth]
88
  for response, _ in best_responses:
89
+ input_ids = generator.tokenizers["llama"](response, return_tensors="pt").input_ids.to(generator.device)
90
+ output = generator.models["llama"].generate(input_ids, **model_kwargs)
91
+ extended_response = generator.tokenizers["llama"].decode(output[0], skip_special_tokens=True)
92
+ score = generator.prm_model(**generator.tokenizers["llama"](extended_response, return_tensors="pt").to(generator.device)).logits.mean().item()
93
  results.append((extended_response, score))
94
  return max(results, key=lambda x: x[1])[0]
95
 
96
 
97
  class COT(GenerationStrategy):
98
+ def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs, num_samples: int = 5) -> str:
99
  #TODO implement the chain of thought strategy
100
 
101
  return "Not implemented yet"
102
 
103
 
104
  class ReAct(GenerationStrategy):
105
+ def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs, num_samples: int = 5) -> str:
106
  #TODO implement the ReAct framework
107
  return "Not implemented yet"
108
  # Add other strategy implementations...