File size: 5,654 Bytes
d39614d
 
 
e05a0c2
d39614d
19a7a7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26304eb
d39614d
 
 
 
 
 
 
26304eb
d39614d
26304eb
d39614d
f59b437
 
 
 
 
121ac13
 
f59b437
121ac13
 
 
d39614d
26304eb
d39614d
52416ee
d39614d
 
121ac13
 
 
d39614d
 
26304eb
d39614d
52416ee
d39614d
 
121ac13
 
 
311bab5
d39614d
 
 
26304eb
d39614d
52416ee
121ac13
 
d39614d
 
 
 
 
121ac13
d39614d
26304eb
d39614d
52416ee
d39614d
 
121ac13
 
 
 
d39614d
 
 
 
 
121ac13
 
 
 
d39614d
 
 
26304eb
d39614d
52416ee
d39614d
 
 
 
26304eb
d39614d
52416ee
d39614d
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# strategy.py
#TODO UPDATE Paths
from abc import ABC, abstractmethod
from typing import List, Tuple, Dict, Any



from langfuse.decorators import observe, langfuse_context
import os 

# Initialize Langfuse
os.environ["LANGFUSE_PUBLIC_KEY"] = "pk-lf-04d2302a-aa5c-4870-9703-58ab64c3bcae"
os.environ["LANGFUSE_SECRET_KEY"] = "sk-lf-d34ea200-feec-428e-a621-784fce93a5af"
os.environ["LANGFUSE_HOST"] = "https://chris4k-langfuse-template-space.hf.space"  # 🇪🇺 EU region

try:
    langfuse = Langfuse()
except Exception as e:
    print("Langfuse Offline")


class GenerationStrategy(ABC):
    """Base class for generation strategies."""
    
    @abstractmethod
    def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str:
        pass


class DefaultStrategy(GenerationStrategy):
    @observe()
    def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str:
        input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device)
        output = generator.model.generate(input_ids, **model_kwargs)
        return generator.tokenizer.decode(output[0], skip_special_tokens=True)
    #def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str:
    #    
    #    tokenizer = generator.tokenizer
    #    model = generator.model.generate
   # 
   #     input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device)
   #     output = generator.model.generate(input_ids, **model_kwargs)
   #     return generator.tokenizer.decode(output[0], skip_special_tokens=True)


class MajorityVotingStrategy(GenerationStrategy):
    def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
        outputs = []
        for _ in range(num_samples):
            input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device)
            output = generator.model.generate(input_ids, **model_kwargs)
            outputs.append(generator.tokenizer.decode(output[0], skip_special_tokens=True))
        return max(set(outputs), key=outputs.count)


class BestOfN(GenerationStrategy):
    def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
            scored_outputs = []
            for _ in range(num_samples):
                input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device)
                output = generator.model.generate(input_ids, **model_kwargs)
                response =generator.tokenizer.decode(output[0], skip_special_tokens=True)
                score = generator.prm_model.generate(**generator.tokenizer(response, return_tensors="pt").to(generator.device)).logits.mean().item()
                scored_outputs.append((response, score))
            return max(scored_outputs, key=lambda x: x[1])[0]


class BeamSearch(GenerationStrategy):
    def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
            input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device)
            outputs = generator.model.generate(
                input_ids,
                num_beams=num_samples,
                num_return_sequences=num_samples,
                **model_kwargs
            )
            return [generator.tokenizer.decode(output, skip_special_tokens=True) for output in outputs]


class DVT(GenerationStrategy):
    def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
            results = []
            for _ in range(breadth):
                input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device)
                output = generator.model.generate(input_ids, **model_kwargs)
                response = generator.tokenizer.decode(output[0], skip_special_tokens=True)
                score = generator.prm_model(**generator.tokenizer(response, return_tensors="pt").to(generator.device)).logits.mean().item()
                results.append((response, score))
            
            for _ in range(depth - 1):
                best_responses = sorted(results, key=lambda x: x[1], reverse=True)[:breadth]
                for response, _ in best_responses:
                    input_ids = generator.tokenizer(response, return_tensors="pt").input_ids.to(generator.device)
                    output = generator.model.generate(input_ids, **model_kwargs)
                    extended_response = generator.tokenizer.decode(output[0], skip_special_tokens=True)
                    score = generator.prm_model(**generator.tokenizer(extended_response, return_tensors="pt").to(generator.device)).logits.mean().item()
                    results.append((extended_response, score))
            return max(results, key=lambda x: x[1])[0]
        

class COT(GenerationStrategy):
    def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
        #TODO implement the chain of thought strategy
        
        return "Not implemented yet"       


class ReAct(GenerationStrategy):
    def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
        #TODO implement the ReAct framework       
        return "Not implemented yet"
#  Add other strategy implementations...