Update app.py
Browse files
app.py
CHANGED
@@ -36,6 +36,7 @@ class ModelManager:
|
|
36 |
self.models: Dict[str, Any] = {}
|
37 |
self.tokenizers: Dict[str, Any] = {}
|
38 |
|
|
|
39 |
def load_model(self, model_id: str, model_path: str, model_type: str, config: ModelConfig) -> None:
|
40 |
"""Load a model with specified configuration."""
|
41 |
try:
|
@@ -73,6 +74,7 @@ class ModelManager:
|
|
73 |
self.logger.error(f"Failed to load model {model_id}: {str(e)}")
|
74 |
raise
|
75 |
|
|
|
76 |
def unload_model(self, model_id: str) -> None:
|
77 |
"""Unload a model and free resources."""
|
78 |
if model_id in self.models:
|
@@ -195,7 +197,8 @@ class BaseGenerator(ABC):
|
|
195 |
@abstractmethod
|
196 |
def _get_generation_kwargs(self, config: GenerationConfig) -> Dict[str, Any]:
|
197 |
pass
|
198 |
-
|
|
|
199 |
@abstractmethod
|
200 |
def generate(
|
201 |
self,
|
@@ -240,6 +243,7 @@ class BaseGenerator(ABC):
|
|
240 |
from abc import ABC, abstractmethod
|
241 |
from typing import List, Tuple
|
242 |
|
|
|
243 |
class GenerationStrategy(ABC):
|
244 |
"""Base class for generation strategies."""
|
245 |
|
@@ -247,12 +251,14 @@ class GenerationStrategy(ABC):
|
|
247 |
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str:
|
248 |
pass
|
249 |
|
|
|
250 |
class DefaultStrategy(GenerationStrategy):
|
251 |
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str:
|
252 |
input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device)
|
253 |
output = generator.model.generate(input_ids, **model_kwargs)
|
254 |
return generator.tokenizer.decode(output[0], skip_special_tokens=True)
|
255 |
|
|
|
256 |
class MajorityVotingStrategy(GenerationStrategy):
|
257 |
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
|
258 |
outputs = []
|
@@ -262,6 +268,7 @@ class MajorityVotingStrategy(GenerationStrategy):
|
|
262 |
outputs.append(generator.tokenizer.decode(output[0], skip_special_tokens=True))
|
263 |
return max(set(outputs), key=outputs.count)
|
264 |
|
|
|
265 |
class BestOfN(GenerationStrategy):
|
266 |
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
|
267 |
scored_outputs = []
|
@@ -272,7 +279,8 @@ class BestOfN(GenerationStrategy):
|
|
272 |
score = self.prm_model(**self.llama_tokenizer(response, return_tensors="pt").to(self.device)).logits.mean().item()
|
273 |
scored_outputs.append((response, score))
|
274 |
return max(scored_outputs, key=lambda x: x[1])[0]
|
275 |
-
|
|
|
276 |
class BeamSearch(GenerationStrategy):
|
277 |
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
|
278 |
input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
|
@@ -284,6 +292,7 @@ class BeamSearch(GenerationStrategy):
|
|
284 |
)
|
285 |
return [self.llama_tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
|
286 |
|
|
|
287 |
class DVT(GenerationStrategy):
|
288 |
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
|
289 |
results = []
|
@@ -304,12 +313,14 @@ class DVT(GenerationStrategy):
|
|
304 |
results.append((extended_response, score))
|
305 |
return max(results, key=lambda x: x[1])[0]
|
306 |
|
|
|
307 |
class COT(GenerationStrategy):
|
308 |
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
|
309 |
#TODO implement the chain of thought strategy
|
310 |
|
311 |
return "Not implemented yet"
|
312 |
-
|
|
|
313 |
class ReAct(GenerationStrategy):
|
314 |
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
|
315 |
#TODO implement the ReAct framework
|
@@ -320,11 +331,13 @@ class ReAct(GenerationStrategy):
|
|
320 |
from typing import Protocol, List, Tuple
|
321 |
from transformers import AutoTokenizer
|
322 |
|
|
|
323 |
class PromptTemplate(Protocol):
|
324 |
"""Protocol for prompt templates."""
|
325 |
def format(self, context: str, user_input: str, chat_history: List[Tuple[str, str]], **kwargs) -> str:
|
326 |
pass
|
327 |
|
|
|
328 |
class LlamaPromptTemplate:
|
329 |
def format(self, context: str, user_input: str, chat_history: List[Tuple[str, str]], max_history_turns: int = 1) -> str:
|
330 |
system_message = f"Please assist based on the following context: {context}"
|
@@ -337,7 +350,8 @@ class LlamaPromptTemplate:
|
|
337 |
prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{user_input}<|eot_id|>"
|
338 |
prompt += "<|start_header_id|>assistant<|end_header_id|>\n\n"
|
339 |
return prompt
|
340 |
-
|
|
|
341 |
class TransformersPromptTemplate:
|
342 |
def __init__(self, model_path: str):
|
343 |
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
@@ -403,6 +417,7 @@ class HealthCheck:
|
|
403 |
# llama_generator.py
|
404 |
from config.config import GenerationConfig, ModelConfig
|
405 |
|
|
|
406 |
class LlamaGenerator(BaseGenerator):
|
407 |
def __init__(
|
408 |
self,
|
@@ -418,19 +433,22 @@ class LlamaGenerator(BaseGenerator):
|
|
418 |
|
419 |
):
|
420 |
|
421 |
-
|
422 |
-
|
423 |
def load_model(self, model_name: str):
|
424 |
# Code to load your model, e.g., Hugging Face's transformers library
|
425 |
from transformers import AutoModelForCausalLM
|
426 |
return AutoModelForCausalLM.from_pretrained(model_name)
|
427 |
-
|
|
|
428 |
def load_tokenizer(self, model_name: str):
|
429 |
# Load the tokenizer associated with the model
|
430 |
from transformers import AutoTokenizer
|
431 |
return AutoTokenizer.from_pretrained(model_name)
|
432 |
|
433 |
-
|
|
|
|
|
|
|
434 |
super().__init__(
|
435 |
llama_model_name,
|
436 |
device,
|
@@ -482,9 +500,11 @@ class LlamaGenerator(BaseGenerator):
|
|
482 |
if hasattr(config, key)
|
483 |
}
|
484 |
|
|
|
485 |
def generate_stream (self):
|
486 |
return " NOt implememnted yet "
|
487 |
-
|
|
|
488 |
def generate(
|
489 |
self,
|
490 |
prompt: str,
|
@@ -522,6 +542,7 @@ class LlamaGenerator(BaseGenerator):
|
|
522 |
**kwargs # Any additional strategy-specific arguments
|
523 |
)
|
524 |
|
|
|
525 |
def generate_with_context(
|
526 |
self,
|
527 |
context: str,
|
|
|
36 |
self.models: Dict[str, Any] = {}
|
37 |
self.tokenizers: Dict[str, Any] = {}
|
38 |
|
39 |
+
@observe()
|
40 |
def load_model(self, model_id: str, model_path: str, model_type: str, config: ModelConfig) -> None:
|
41 |
"""Load a model with specified configuration."""
|
42 |
try:
|
|
|
74 |
self.logger.error(f"Failed to load model {model_id}: {str(e)}")
|
75 |
raise
|
76 |
|
77 |
+
@observe()
|
78 |
def unload_model(self, model_id: str) -> None:
|
79 |
"""Unload a model and free resources."""
|
80 |
if model_id in self.models:
|
|
|
197 |
@abstractmethod
|
198 |
def _get_generation_kwargs(self, config: GenerationConfig) -> Dict[str, Any]:
|
199 |
pass
|
200 |
+
|
201 |
+
@observe()
|
202 |
@abstractmethod
|
203 |
def generate(
|
204 |
self,
|
|
|
243 |
from abc import ABC, abstractmethod
|
244 |
from typing import List, Tuple
|
245 |
|
246 |
+
@observe()
|
247 |
class GenerationStrategy(ABC):
|
248 |
"""Base class for generation strategies."""
|
249 |
|
|
|
251 |
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str:
|
252 |
pass
|
253 |
|
254 |
+
@observe()
|
255 |
class DefaultStrategy(GenerationStrategy):
|
256 |
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str:
|
257 |
input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device)
|
258 |
output = generator.model.generate(input_ids, **model_kwargs)
|
259 |
return generator.tokenizer.decode(output[0], skip_special_tokens=True)
|
260 |
|
261 |
+
@observe()
|
262 |
class MajorityVotingStrategy(GenerationStrategy):
|
263 |
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
|
264 |
outputs = []
|
|
|
268 |
outputs.append(generator.tokenizer.decode(output[0], skip_special_tokens=True))
|
269 |
return max(set(outputs), key=outputs.count)
|
270 |
|
271 |
+
@observe()
|
272 |
class BestOfN(GenerationStrategy):
|
273 |
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
|
274 |
scored_outputs = []
|
|
|
279 |
score = self.prm_model(**self.llama_tokenizer(response, return_tensors="pt").to(self.device)).logits.mean().item()
|
280 |
scored_outputs.append((response, score))
|
281 |
return max(scored_outputs, key=lambda x: x[1])[0]
|
282 |
+
|
283 |
+
@observe()
|
284 |
class BeamSearch(GenerationStrategy):
|
285 |
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
|
286 |
input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
|
|
|
292 |
)
|
293 |
return [self.llama_tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
|
294 |
|
295 |
+
@observe()
|
296 |
class DVT(GenerationStrategy):
|
297 |
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
|
298 |
results = []
|
|
|
313 |
results.append((extended_response, score))
|
314 |
return max(results, key=lambda x: x[1])[0]
|
315 |
|
316 |
+
@observe()
|
317 |
class COT(GenerationStrategy):
|
318 |
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
|
319 |
#TODO implement the chain of thought strategy
|
320 |
|
321 |
return "Not implemented yet"
|
322 |
+
|
323 |
+
@observe()
|
324 |
class ReAct(GenerationStrategy):
|
325 |
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
|
326 |
#TODO implement the ReAct framework
|
|
|
331 |
from typing import Protocol, List, Tuple
|
332 |
from transformers import AutoTokenizer
|
333 |
|
334 |
+
@observe()
|
335 |
class PromptTemplate(Protocol):
|
336 |
"""Protocol for prompt templates."""
|
337 |
def format(self, context: str, user_input: str, chat_history: List[Tuple[str, str]], **kwargs) -> str:
|
338 |
pass
|
339 |
|
340 |
+
@observe()
|
341 |
class LlamaPromptTemplate:
|
342 |
def format(self, context: str, user_input: str, chat_history: List[Tuple[str, str]], max_history_turns: int = 1) -> str:
|
343 |
system_message = f"Please assist based on the following context: {context}"
|
|
|
350 |
prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{user_input}<|eot_id|>"
|
351 |
prompt += "<|start_header_id|>assistant<|end_header_id|>\n\n"
|
352 |
return prompt
|
353 |
+
|
354 |
+
@observe()
|
355 |
class TransformersPromptTemplate:
|
356 |
def __init__(self, model_path: str):
|
357 |
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
|
|
417 |
# llama_generator.py
|
418 |
from config.config import GenerationConfig, ModelConfig
|
419 |
|
420 |
+
@observe()
|
421 |
class LlamaGenerator(BaseGenerator):
|
422 |
def __init__(
|
423 |
self,
|
|
|
433 |
|
434 |
):
|
435 |
|
436 |
+
@observe()
|
|
|
437 |
def load_model(self, model_name: str):
|
438 |
# Code to load your model, e.g., Hugging Face's transformers library
|
439 |
from transformers import AutoModelForCausalLM
|
440 |
return AutoModelForCausalLM.from_pretrained(model_name)
|
441 |
+
|
442 |
+
@observe()
|
443 |
def load_tokenizer(self, model_name: str):
|
444 |
# Load the tokenizer associated with the model
|
445 |
from transformers import AutoTokenizer
|
446 |
return AutoTokenizer.from_pretrained(model_name)
|
447 |
|
448 |
+
self.tokenizer = load_tokenizer(llama_model_name) # Add this line to initialize the tokenizer
|
449 |
+
|
450 |
+
|
451 |
+
@observe()
|
452 |
super().__init__(
|
453 |
llama_model_name,
|
454 |
device,
|
|
|
500 |
if hasattr(config, key)
|
501 |
}
|
502 |
|
503 |
+
@observe()
|
504 |
def generate_stream (self):
|
505 |
return " NOt implememnted yet "
|
506 |
+
|
507 |
+
@observe()
|
508 |
def generate(
|
509 |
self,
|
510 |
prompt: str,
|
|
|
542 |
**kwargs # Any additional strategy-specific arguments
|
543 |
)
|
544 |
|
545 |
+
@observe()
|
546 |
def generate_with_context(
|
547 |
self,
|
548 |
context: str,
|