Chris4K commited on
Commit
4267e89
·
verified ·
1 Parent(s): c643a72

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -9
app.py CHANGED
@@ -36,6 +36,7 @@ class ModelManager:
36
  self.models: Dict[str, Any] = {}
37
  self.tokenizers: Dict[str, Any] = {}
38
 
 
39
  def load_model(self, model_id: str, model_path: str, model_type: str, config: ModelConfig) -> None:
40
  """Load a model with specified configuration."""
41
  try:
@@ -73,6 +74,7 @@ class ModelManager:
73
  self.logger.error(f"Failed to load model {model_id}: {str(e)}")
74
  raise
75
 
 
76
  def unload_model(self, model_id: str) -> None:
77
  """Unload a model and free resources."""
78
  if model_id in self.models:
@@ -195,7 +197,8 @@ class BaseGenerator(ABC):
195
  @abstractmethod
196
  def _get_generation_kwargs(self, config: GenerationConfig) -> Dict[str, Any]:
197
  pass
198
-
 
199
  @abstractmethod
200
  def generate(
201
  self,
@@ -240,6 +243,7 @@ class BaseGenerator(ABC):
240
  from abc import ABC, abstractmethod
241
  from typing import List, Tuple
242
 
 
243
  class GenerationStrategy(ABC):
244
  """Base class for generation strategies."""
245
 
@@ -247,12 +251,14 @@ class GenerationStrategy(ABC):
247
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str:
248
  pass
249
 
 
250
  class DefaultStrategy(GenerationStrategy):
251
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str:
252
  input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device)
253
  output = generator.model.generate(input_ids, **model_kwargs)
254
  return generator.tokenizer.decode(output[0], skip_special_tokens=True)
255
 
 
256
  class MajorityVotingStrategy(GenerationStrategy):
257
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
258
  outputs = []
@@ -262,6 +268,7 @@ class MajorityVotingStrategy(GenerationStrategy):
262
  outputs.append(generator.tokenizer.decode(output[0], skip_special_tokens=True))
263
  return max(set(outputs), key=outputs.count)
264
 
 
265
  class BestOfN(GenerationStrategy):
266
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
267
  scored_outputs = []
@@ -272,7 +279,8 @@ class BestOfN(GenerationStrategy):
272
  score = self.prm_model(**self.llama_tokenizer(response, return_tensors="pt").to(self.device)).logits.mean().item()
273
  scored_outputs.append((response, score))
274
  return max(scored_outputs, key=lambda x: x[1])[0]
275
-
 
276
  class BeamSearch(GenerationStrategy):
277
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
278
  input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
@@ -284,6 +292,7 @@ class BeamSearch(GenerationStrategy):
284
  )
285
  return [self.llama_tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
286
 
 
287
  class DVT(GenerationStrategy):
288
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
289
  results = []
@@ -304,12 +313,14 @@ class DVT(GenerationStrategy):
304
  results.append((extended_response, score))
305
  return max(results, key=lambda x: x[1])[0]
306
 
 
307
  class COT(GenerationStrategy):
308
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
309
  #TODO implement the chain of thought strategy
310
 
311
  return "Not implemented yet"
312
-
 
313
  class ReAct(GenerationStrategy):
314
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
315
  #TODO implement the ReAct framework
@@ -320,11 +331,13 @@ class ReAct(GenerationStrategy):
320
  from typing import Protocol, List, Tuple
321
  from transformers import AutoTokenizer
322
 
 
323
  class PromptTemplate(Protocol):
324
  """Protocol for prompt templates."""
325
  def format(self, context: str, user_input: str, chat_history: List[Tuple[str, str]], **kwargs) -> str:
326
  pass
327
 
 
328
  class LlamaPromptTemplate:
329
  def format(self, context: str, user_input: str, chat_history: List[Tuple[str, str]], max_history_turns: int = 1) -> str:
330
  system_message = f"Please assist based on the following context: {context}"
@@ -337,7 +350,8 @@ class LlamaPromptTemplate:
337
  prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{user_input}<|eot_id|>"
338
  prompt += "<|start_header_id|>assistant<|end_header_id|>\n\n"
339
  return prompt
340
-
 
341
  class TransformersPromptTemplate:
342
  def __init__(self, model_path: str):
343
  self.tokenizer = AutoTokenizer.from_pretrained(model_path)
@@ -403,6 +417,7 @@ class HealthCheck:
403
  # llama_generator.py
404
  from config.config import GenerationConfig, ModelConfig
405
 
 
406
  class LlamaGenerator(BaseGenerator):
407
  def __init__(
408
  self,
@@ -418,19 +433,22 @@ class LlamaGenerator(BaseGenerator):
418
 
419
  ):
420
 
421
- self.tokenizer = self.load_tokenizer(llama_model_name) # Add this line to initialize the tokenizer
422
-
423
  def load_model(self, model_name: str):
424
  # Code to load your model, e.g., Hugging Face's transformers library
425
  from transformers import AutoModelForCausalLM
426
  return AutoModelForCausalLM.from_pretrained(model_name)
427
-
 
428
  def load_tokenizer(self, model_name: str):
429
  # Load the tokenizer associated with the model
430
  from transformers import AutoTokenizer
431
  return AutoTokenizer.from_pretrained(model_name)
432
 
433
-
 
 
 
434
  super().__init__(
435
  llama_model_name,
436
  device,
@@ -482,9 +500,11 @@ class LlamaGenerator(BaseGenerator):
482
  if hasattr(config, key)
483
  }
484
 
 
485
  def generate_stream (self):
486
  return " NOt implememnted yet "
487
-
 
488
  def generate(
489
  self,
490
  prompt: str,
@@ -522,6 +542,7 @@ class LlamaGenerator(BaseGenerator):
522
  **kwargs # Any additional strategy-specific arguments
523
  )
524
 
 
525
  def generate_with_context(
526
  self,
527
  context: str,
 
36
  self.models: Dict[str, Any] = {}
37
  self.tokenizers: Dict[str, Any] = {}
38
 
39
+ @observe()
40
  def load_model(self, model_id: str, model_path: str, model_type: str, config: ModelConfig) -> None:
41
  """Load a model with specified configuration."""
42
  try:
 
74
  self.logger.error(f"Failed to load model {model_id}: {str(e)}")
75
  raise
76
 
77
+ @observe()
78
  def unload_model(self, model_id: str) -> None:
79
  """Unload a model and free resources."""
80
  if model_id in self.models:
 
197
  @abstractmethod
198
  def _get_generation_kwargs(self, config: GenerationConfig) -> Dict[str, Any]:
199
  pass
200
+
201
+ @observe()
202
  @abstractmethod
203
  def generate(
204
  self,
 
243
  from abc import ABC, abstractmethod
244
  from typing import List, Tuple
245
 
246
+ @observe()
247
  class GenerationStrategy(ABC):
248
  """Base class for generation strategies."""
249
 
 
251
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str:
252
  pass
253
 
254
+ @observe()
255
  class DefaultStrategy(GenerationStrategy):
256
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str:
257
  input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device)
258
  output = generator.model.generate(input_ids, **model_kwargs)
259
  return generator.tokenizer.decode(output[0], skip_special_tokens=True)
260
 
261
+ @observe()
262
  class MajorityVotingStrategy(GenerationStrategy):
263
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
264
  outputs = []
 
268
  outputs.append(generator.tokenizer.decode(output[0], skip_special_tokens=True))
269
  return max(set(outputs), key=outputs.count)
270
 
271
+ @observe()
272
  class BestOfN(GenerationStrategy):
273
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
274
  scored_outputs = []
 
279
  score = self.prm_model(**self.llama_tokenizer(response, return_tensors="pt").to(self.device)).logits.mean().item()
280
  scored_outputs.append((response, score))
281
  return max(scored_outputs, key=lambda x: x[1])[0]
282
+
283
+ @observe()
284
  class BeamSearch(GenerationStrategy):
285
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
286
  input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
 
292
  )
293
  return [self.llama_tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
294
 
295
+ @observe()
296
  class DVT(GenerationStrategy):
297
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
298
  results = []
 
313
  results.append((extended_response, score))
314
  return max(results, key=lambda x: x[1])[0]
315
 
316
+ @observe()
317
  class COT(GenerationStrategy):
318
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
319
  #TODO implement the chain of thought strategy
320
 
321
  return "Not implemented yet"
322
+
323
+ @observe()
324
  class ReAct(GenerationStrategy):
325
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
326
  #TODO implement the ReAct framework
 
331
  from typing import Protocol, List, Tuple
332
  from transformers import AutoTokenizer
333
 
334
+ @observe()
335
  class PromptTemplate(Protocol):
336
  """Protocol for prompt templates."""
337
  def format(self, context: str, user_input: str, chat_history: List[Tuple[str, str]], **kwargs) -> str:
338
  pass
339
 
340
+ @observe()
341
  class LlamaPromptTemplate:
342
  def format(self, context: str, user_input: str, chat_history: List[Tuple[str, str]], max_history_turns: int = 1) -> str:
343
  system_message = f"Please assist based on the following context: {context}"
 
350
  prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{user_input}<|eot_id|>"
351
  prompt += "<|start_header_id|>assistant<|end_header_id|>\n\n"
352
  return prompt
353
+
354
+ @observe()
355
  class TransformersPromptTemplate:
356
  def __init__(self, model_path: str):
357
  self.tokenizer = AutoTokenizer.from_pretrained(model_path)
 
417
  # llama_generator.py
418
  from config.config import GenerationConfig, ModelConfig
419
 
420
+ @observe()
421
  class LlamaGenerator(BaseGenerator):
422
  def __init__(
423
  self,
 
433
 
434
  ):
435
 
436
+ @observe()
 
437
  def load_model(self, model_name: str):
438
  # Code to load your model, e.g., Hugging Face's transformers library
439
  from transformers import AutoModelForCausalLM
440
  return AutoModelForCausalLM.from_pretrained(model_name)
441
+
442
+ @observe()
443
  def load_tokenizer(self, model_name: str):
444
  # Load the tokenizer associated with the model
445
  from transformers import AutoTokenizer
446
  return AutoTokenizer.from_pretrained(model_name)
447
 
448
+ self.tokenizer = load_tokenizer(llama_model_name) # Add this line to initialize the tokenizer
449
+
450
+
451
+ @observe()
452
  super().__init__(
453
  llama_model_name,
454
  device,
 
500
  if hasattr(config, key)
501
  }
502
 
503
+ @observe()
504
  def generate_stream (self):
505
  return " NOt implememnted yet "
506
+
507
+ @observe()
508
  def generate(
509
  self,
510
  prompt: str,
 
542
  **kwargs # Any additional strategy-specific arguments
543
  )
544
 
545
+ @observe()
546
  def generate_with_context(
547
  self,
548
  context: str,