Chris4K commited on
Commit
d39614d
·
verified ·
1 Parent(s): 5fa6a5e

Create strategy.py

Browse files
Files changed (1) hide show
  1. services/strategy.py +89 -0
services/strategy.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # strategy.py
2
+ #TODO UPDATE Paths
3
+ from abc import ABC, abstractmethod
4
+ from typing import List, Tuple
5
+
6
+ @observe()
7
+ class GenerationStrategy(ABC):
8
+ """Base class for generation strategies."""
9
+
10
+ @abstractmethod
11
+ def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str:
12
+ pass
13
+
14
+
15
+ class DefaultStrategy(GenerationStrategy):
16
+
17
+ def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str:
18
+ input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device)
19
+ output = generator.model.generate(input_ids, **model_kwargs)
20
+ return generator.tokenizer.decode(output[0], skip_special_tokens=True)
21
+
22
+ @observe()
23
+ class MajorityVotingStrategy(GenerationStrategy):
24
+ def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
25
+ outputs = []
26
+ for _ in range(num_samples):
27
+ input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device)
28
+ output = generator.model.generate(input_ids, **model_kwargs)
29
+ outputs.append(generator.tokenizer.decode(output[0], skip_special_tokens=True))
30
+ return max(set(outputs), key=outputs.count)
31
+
32
+ @observe()
33
+ class BestOfN(GenerationStrategy):
34
+ def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
35
+ scored_outputs = []
36
+ for _ in range(num_samples):
37
+ input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
38
+ output = self.llama_model.generate(input_ids, **model_kwargs)
39
+ response = self.llama_tokenizer.decode(output[0], skip_special_tokens=True)
40
+ score = self.prm_model(**self.llama_tokenizer(response, return_tensors="pt").to(self.device)).logits.mean().item()
41
+ scored_outputs.append((response, score))
42
+ return max(scored_outputs, key=lambda x: x[1])[0]
43
+
44
+ @observe()
45
+ class BeamSearch(GenerationStrategy):
46
+ def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
47
+ input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
48
+ outputs = self.llama_model.generate(
49
+ input_ids,
50
+ num_beams=num_samples,
51
+ num_return_sequences=num_samples,
52
+ **model_kwargs
53
+ )
54
+ return [self.llama_tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
55
+
56
+ @observe()
57
+ class DVT(GenerationStrategy):
58
+ def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
59
+ results = []
60
+ for _ in range(breadth):
61
+ input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
62
+ output = self.llama_model.generate(input_ids, **model_kwargs)
63
+ response = self.llama_tokenizer.decode(output[0], skip_special_tokens=True)
64
+ score = self.prm_model(**self.llama_tokenizer(response, return_tensors="pt").to(self.device)).logits.mean().item()
65
+ results.append((response, score))
66
+
67
+ for _ in range(depth - 1):
68
+ best_responses = sorted(results, key=lambda x: x[1], reverse=True)[:breadth]
69
+ for response, _ in best_responses:
70
+ input_ids = self.llama_tokenizer(response, return_tensors="pt").input_ids.to(self.device)
71
+ output = self.llama_model.generate(input_ids, **model_kwargs)
72
+ extended_response = self.llama_tokenizer.decode(output[0], skip_special_tokens=True)
73
+ score = self.prm_model(**self.llama_tokenizer(extended_response, return_tensors="pt").to(self.device)).logits.mean().item()
74
+ results.append((extended_response, score))
75
+ return max(results, key=lambda x: x[1])[0]
76
+
77
+ @observe()
78
+ class COT(GenerationStrategy):
79
+ def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
80
+ #TODO implement the chain of thought strategy
81
+
82
+ return "Not implemented yet"
83
+
84
+ @observe()
85
+ class ReAct(GenerationStrategy):
86
+ def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
87
+ #TODO implement the ReAct framework
88
+ return "Not implemented yet"
89
+ # Add other strategy implementations...