File size: 4,753 Bytes
d39614d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# strategy.py
#TODO UPDATE Paths
from abc import ABC, abstractmethod
from typing import List, Tuple

@observe()
class GenerationStrategy(ABC):
    """Base class for generation strategies."""
    
    @abstractmethod
    def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str:
        pass

 
class DefaultStrategy(GenerationStrategy):
   
    def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str:
        input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device)
        output = generator.model.generate(input_ids, **model_kwargs)
        return generator.tokenizer.decode(output[0], skip_special_tokens=True)

@observe()
class MajorityVotingStrategy(GenerationStrategy):
    def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
        outputs = []
        for _ in range(num_samples):
            input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device)
            output = generator.model.generate(input_ids, **model_kwargs)
            outputs.append(generator.tokenizer.decode(output[0], skip_special_tokens=True))
        return max(set(outputs), key=outputs.count)

@observe()
class BestOfN(GenerationStrategy):
    def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
            scored_outputs = []
            for _ in range(num_samples):
                input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
                output = self.llama_model.generate(input_ids, **model_kwargs)
                response = self.llama_tokenizer.decode(output[0], skip_special_tokens=True)
                score = self.prm_model(**self.llama_tokenizer(response, return_tensors="pt").to(self.device)).logits.mean().item()
                scored_outputs.append((response, score))
            return max(scored_outputs, key=lambda x: x[1])[0]

@observe() 
class BeamSearch(GenerationStrategy):
    def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
            input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
            outputs = self.llama_model.generate(
                input_ids,
                num_beams=num_samples,
                num_return_sequences=num_samples,
                **model_kwargs
            )
            return [self.llama_tokenizer.decode(output, skip_special_tokens=True) for output in outputs]

@observe()
class DVT(GenerationStrategy):
    def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
            results = []
            for _ in range(breadth):
                input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
                output = self.llama_model.generate(input_ids, **model_kwargs)
                response = self.llama_tokenizer.decode(output[0], skip_special_tokens=True)
                score = self.prm_model(**self.llama_tokenizer(response, return_tensors="pt").to(self.device)).logits.mean().item()
                results.append((response, score))
            
            for _ in range(depth - 1):
                best_responses = sorted(results, key=lambda x: x[1], reverse=True)[:breadth]
                for response, _ in best_responses:
                    input_ids = self.llama_tokenizer(response, return_tensors="pt").input_ids.to(self.device)
                    output = self.llama_model.generate(input_ids, **model_kwargs)
                    extended_response = self.llama_tokenizer.decode(output[0], skip_special_tokens=True)
                    score = self.prm_model(**self.llama_tokenizer(extended_response, return_tensors="pt").to(self.device)).logits.mean().item()
                    results.append((extended_response, score))
            return max(results, key=lambda x: x[1])[0]
        
@observe()
class COT(GenerationStrategy):
    def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
        #TODO implement the chain of thought strategy
        
        return "Not implemented yet"       

@observe()
class ReAct(GenerationStrategy):
    def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
        #TODO implement the ReAct framework       
        return "Not implemented yet"
#  Add other strategy implementations...