Chris4K commited on
Commit
919c57c
·
verified ·
1 Parent(s): 02c1ae0

Create llama_generator.py

Browse files
Files changed (1) hide show
  1. services/llama_generator.py +177 -0
services/llama_generator.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # llama_generator.py
2
+ from config.config import GenerationConfig, ModelConfig
3
+
4
+ @observe()
5
+ class LlamaGenerator(BaseGenerator):
6
+ def __init__(
7
+ self,
8
+ llama_model_name: str,
9
+ prm_model_path: str,
10
+ device: Optional[str] = None,
11
+ default_generation_config: Optional[GenerationConfig] = None,
12
+ model_config: Optional[ModelConfig] = None,
13
+ cache_size: int = 1000,
14
+ max_batch_size: int = 32,
15
+ # self.tokenizer = self.load_tokenizer(llama_model_name)
16
+ # self.tokenizer = self.load_tokenizer(llama_model_name) # Add this line to initialize the tokenizer
17
+
18
+ ):
19
+
20
+ @observe()
21
+ def load_model(self, model_name: str):
22
+ # Code to load your model, e.g., Hugging Face's transformers library
23
+ from transformers import AutoModelForCausalLM
24
+ return AutoModelForCausalLM.from_pretrained(model_name)
25
+
26
+ @observe()
27
+ def load_tokenizer(self, model_name: str):
28
+ # Load the tokenizer associated with the model
29
+ from transformers import AutoTokenizer
30
+ return AutoTokenizer.from_pretrained(model_name)
31
+
32
+ self.tokenizer = load_tokenizer(llama_model_name) # Add this line to initialize the tokenizer
33
+
34
+ super().__init__(
35
+ llama_model_name,
36
+ device,
37
+ default_generation_config,
38
+ model_config,
39
+ cache_size,
40
+ max_batch_size
41
+ )
42
+
43
+ # Initialize models
44
+ self.model_manager.load_model(
45
+ "llama",
46
+ llama_model_name,
47
+ "llama",
48
+ self.model_config
49
+ )
50
+ self.model_manager.load_model(
51
+ "prm",
52
+ prm_model_path,
53
+ "gguf",
54
+ self.model_config
55
+ )
56
+
57
+ self.prompt_builder = LlamaPromptTemplate()
58
+ self._init_strategies()
59
+
60
+ def _init_strategies(self):
61
+ self.strategies = {
62
+ "default": DefaultStrategy(),
63
+ "majority_voting": MajorityVotingStrategy(),
64
+ "best_of_n": BestOfN(),
65
+ "beam_search": BeamSearch(),
66
+ "dvts": DVT(),
67
+ }
68
+
69
+ def _get_generation_kwargs(self, config: GenerationConfig) -> Dict[str, Any]:
70
+ """Get generation kwargs based on config."""
71
+ return {
72
+ key: getattr(config, key)
73
+ for key in [
74
+ "max_new_tokens",
75
+ "temperature",
76
+ "top_p",
77
+ "top_k",
78
+ "repetition_penalty",
79
+ "length_penalty",
80
+ "do_sample"
81
+ ]
82
+ if hasattr(config, key)
83
+ }
84
+
85
+ @observe()
86
+ def generate_stream (self):
87
+ return " NOt implememnted yet "
88
+
89
+ @observe()
90
+ def generate(
91
+ self,
92
+ prompt: str,
93
+ model_kwargs: Dict[str, Any],
94
+ strategy: str = "default",
95
+ **kwargs
96
+ ) -> str:
97
+ """
98
+ Generate text based on a given strategy.
99
+
100
+ Args:
101
+ prompt (str): Input prompt for text generation.
102
+ model_kwargs (Dict[str, Any]): Additional arguments for model generation.
103
+ strategy (str): The generation strategy to use (default: "default").
104
+ **kwargs: Additional arguments passed to the strategy.
105
+
106
+ Returns:
107
+ str: Generated text response.
108
+
109
+ Raises:
110
+ ValueError: If the specified strategy is not available.
111
+ """
112
+ # Validate that the strategy exists
113
+ if strategy not in self.strategies:
114
+ raise ValueError(f"Unknown strategy: {strategy}. Available strategies are: {list(self.strategies.keys())}")
115
+
116
+ # Extract `generator` from kwargs if it exists to prevent duplication
117
+ kwargs.pop("generator", None)
118
+
119
+ # Call the selected strategy with the provided arguments
120
+ return self.strategies[strategy].generate(
121
+ generator=self, # The generator instance
122
+ prompt=prompt, # The input prompt
123
+ model_kwargs=model_kwargs, # Arguments for the model
124
+ **kwargs # Any additional strategy-specific arguments
125
+ )
126
+
127
+ @observe()
128
+ def generate_with_context(
129
+ self,
130
+ context: str,
131
+ user_input: str,
132
+ chat_history: List[Tuple[str, str]],
133
+ model_kwargs: Dict[str, Any],
134
+ max_history_turns: int = 3,
135
+ strategy: str = "default",
136
+ num_samples: int = 5,
137
+ depth: int = 3,
138
+ breadth: int = 2,
139
+
140
+ ) -> str:
141
+ """Generate a response using context and chat history.
142
+
143
+ Args:
144
+ context (str): Context for the conversation
145
+ user_input (str): Current user input
146
+ chat_history (List[Tuple[str, str]]): List of (user, assistant) message pairs
147
+ model_kwargs (dict): Additional arguments for model.generate()
148
+ max_history_turns (int): Maximum number of history turns to include
149
+ strategy (str): Generation strategy
150
+ num_samples (int): Number of samples for applicable strategies
151
+ depth (int): Depth for DVTS strategy
152
+ breadth (int): Breadth for DVTS strategy
153
+
154
+ Returns:
155
+ str: Generated response
156
+ """
157
+ prompt = self.prompt_builder.format(
158
+ context,
159
+ user_input,
160
+ chat_history,
161
+ max_history_turns
162
+ )
163
+ return self.generate(
164
+ generator=self,
165
+ prompt=prompt,
166
+ model_kwargs=model_kwargs,
167
+ strategy=strategy,
168
+ num_samples=num_samples,
169
+ depth=depth,
170
+ breadth=breadth
171
+ )
172
+
173
+
174
+
175
+ def check_health(self) -> HealthStatus:
176
+ """Check the health status of the generator."""
177
+ return self.health_check.check_system_resources() # TODO add model status