File size: 7,192 Bytes
d39614d
 
 
e05a0c2
d39614d
f8f7ae7
19a7a7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26304eb
d39614d
 
 
 
 
 
 
26304eb
d39614d
26304eb
d39614d
f59b437
 
 
 
 
121ac13
 
f59b437
121ac13
 
 
d39614d
26304eb
d39614d
52416ee
d39614d
 
121ac13
 
 
d39614d
 
26304eb
d39614d
85f410a
52416ee
33239be
 
 
 
 
 
 
 
 
f9deaa6
 
38dc5a0
f9deaa6
 
 
 
 
 
 
 
33239be
f9deaa6
85f410a
ae60821
f9deaa6
ae60821
f9deaa6
ae60821
 
 
 
 
 
 
f9deaa6
 
 
 
ae60821
 
85f410a
33239be
 
 
 
 
 
d39614d
26304eb
85f410a
ae60821
d39614d
52416ee
121ac13
 
d39614d
 
 
 
 
121ac13
d39614d
26304eb
d39614d
52416ee
d39614d
 
121ac13
 
 
 
d39614d
 
 
 
 
121ac13
 
 
 
d39614d
 
 
26304eb
d39614d
52416ee
d39614d
 
 
 
26304eb
d39614d
52416ee
d39614d
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# strategy.py
#TODO UPDATE Paths
from abc import ABC, abstractmethod
from typing import List, Tuple, Dict, Any

from llama_cpp import Llama


from langfuse.decorators import observe, langfuse_context
import os 

# Initialize Langfuse
os.environ["LANGFUSE_PUBLIC_KEY"] = "pk-lf-04d2302a-aa5c-4870-9703-58ab64c3bcae"
os.environ["LANGFUSE_SECRET_KEY"] = "sk-lf-d34ea200-feec-428e-a621-784fce93a5af"
os.environ["LANGFUSE_HOST"] = "https://chris4k-langfuse-template-space.hf.space"  # 🇪🇺 EU region

try:
    langfuse = Langfuse()
except Exception as e:
    print("Langfuse Offline")


class GenerationStrategy(ABC):
    """Base class for generation strategies."""
    
    @abstractmethod
    def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str:
        pass


class DefaultStrategy(GenerationStrategy):
    @observe()
    def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str:
        input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device)
        output = generator.model.generate(input_ids, **model_kwargs)
        return generator.tokenizer.decode(output[0], skip_special_tokens=True)
    #def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str:
    #    
    #    tokenizer = generator.tokenizer
    #    model = generator.model.generate
   # 
   #     input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device)
   #     output = generator.model.generate(input_ids, **model_kwargs)
   #     return generator.tokenizer.decode(output[0], skip_special_tokens=True)


class MajorityVotingStrategy(GenerationStrategy):
    def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
        outputs = []
        for _ in range(num_samples):
            input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device)
            output = generator.model.generate(input_ids, **model_kwargs)
            outputs.append(generator.tokenizer.decode(output[0], skip_special_tokens=True))
        return max(set(outputs), key=outputs.count)


class BestOfN(GenerationStrategy):
    @observe()
    def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
        scored_outputs = []
        for _ in range(num_samples):
            # Tokenize the prompt and move tensors to the appropriate device
            input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device)

            # Generate output from the main model
            output = generator.model.generate(input_ids, **model_kwargs)
            response = generator.tokenizer.decode(output[0], skip_special_tokens=True)


            # Simple inference example
            prm_output = generator.prm_model(
              "<|system|>\n{system_message}</s>\n<|user|>\n{response}</s>\n<|assistant|>", # Prompt
              max_tokens=512,  # Generate up to 512 tokens
              stop=["</s>"],   # Example stop token - not necessarily correct for this specific model! Please check before using.
              echo=True        # Whether to echo the prompt
            )


            
            # Tokenize the response for scoring with the PRM model
            #response_inputs = generator.tokenizer(response, return_tensors="pt").to(generator.device)

            # Pass the response to the PRM model based on its input requirements
            #try:
                # Example 1: If PRM model accepts BatchEncoding
#                prm_output = generator.prm_model(response_inputs)

                # Example 2: If PRM model expects only input_ids
                # prm_output = generator.prm_model(response_inputs["input_ids"])

                # Example 3: If PRM model expects raw text
                # prm_output = generator.prm_model(response)

 #           except Exception as e:
  #              print(f"Error with PRM model: {e}")
   #             score = 0.0
    #            continue

            # Calculate the score based on PRM output structure
            score = prm_output.logits.mean().item() if hasattr(prm_output, 'logits') else 0.0

            # Append the response and its score
            scored_outputs.append((response, score))

        # Return the response with the highest score
        return max(scored_outputs, key=lambda x: x[1])[0]




class BeamSearch(GenerationStrategy):
    def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
            input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device)
            outputs = generator.model.generate(
                input_ids,
                num_beams=num_samples,
                num_return_sequences=num_samples,
                **model_kwargs
            )
            return [generator.tokenizer.decode(output, skip_special_tokens=True) for output in outputs]


class DVT(GenerationStrategy):
    def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
            results = []
            for _ in range(breadth):
                input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device)
                output = generator.model.generate(input_ids, **model_kwargs)
                response = generator.tokenizer.decode(output[0], skip_special_tokens=True)
                score = generator.prm_model(**generator.tokenizer(response, return_tensors="pt").to(generator.device)).logits.mean().item()
                results.append((response, score))
            
            for _ in range(depth - 1):
                best_responses = sorted(results, key=lambda x: x[1], reverse=True)[:breadth]
                for response, _ in best_responses:
                    input_ids = generator.tokenizer(response, return_tensors="pt").input_ids.to(generator.device)
                    output = generator.model.generate(input_ids, **model_kwargs)
                    extended_response = generator.tokenizer.decode(output[0], skip_special_tokens=True)
                    score = generator.prm_model(**generator.tokenizer(extended_response, return_tensors="pt").to(generator.device)).logits.mean().item()
                    results.append((extended_response, score))
            return max(results, key=lambda x: x[1])[0]
        

class COT(GenerationStrategy):
    def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
        #TODO implement the chain of thought strategy
        
        return "Not implemented yet"       


class ReAct(GenerationStrategy):
    def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
        #TODO implement the ReAct framework       
        return "Not implemented yet"
#  Add other strategy implementations...