Chris4K commited on
Commit
3455289
·
verified ·
1 Parent(s): ca508ba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -1
app.py CHANGED
@@ -388,7 +388,7 @@ class LlamaGenerator(BaseGenerator):
388
  )
389
 
390
  # Initialize models
391
- self.model_manager.load_model(
392
  "llama",
393
  llama_model_name,
394
  "llama",
@@ -448,6 +448,52 @@ class LlamaGenerator(BaseGenerator):
448
  model_kwargs,
449
  **kwargs
450
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
451
 
452
  def check_health(self) -> HealthStatus:
453
  """Check the health status of the generator."""
 
388
  )
389
 
390
  # Initialize models
391
+ self.model_manager.load_model(
392
  "llama",
393
  llama_model_name,
394
  "llama",
 
448
  model_kwargs,
449
  **kwargs
450
  )
451
+
452
+
453
+ def generate_with_context(
454
+ self,
455
+ context: str,
456
+ user_input: str,
457
+ chat_history: List[Tuple[str, str]],
458
+ model_kwargs: Dict[str, Any],
459
+ max_history_turns: int = 3,
460
+ strategy: str = "default",
461
+ num_samples: int = 5,
462
+ depth: int = 3,
463
+ breadth: int = 2,
464
+
465
+ ) -> str:
466
+ """Generate a response using context and chat history.
467
+
468
+ Args:
469
+ context (str): Context for the conversation
470
+ user_input (str): Current user input
471
+ chat_history (List[Tuple[str, str]]): List of (user, assistant) message pairs
472
+ model_kwargs (dict): Additional arguments for model.generate()
473
+ max_history_turns (int): Maximum number of history turns to include
474
+ strategy (str): Generation strategy
475
+ num_samples (int): Number of samples for applicable strategies
476
+ depth (int): Depth for DVTS strategy
477
+ breadth (int): Breadth for DVTS strategy
478
+
479
+ Returns:
480
+ str: Generated response
481
+ """
482
+ prompt = self._construct_prompt(
483
+ context,
484
+ user_input,
485
+ chat_history,
486
+ max_history_turns
487
+ )
488
+ return self.generate(
489
+ prompt,
490
+ model_kwargs,
491
+ strategy,
492
+ num_samples,
493
+ depth,
494
+ breadth
495
+ )
496
+
497
 
498
  def check_health(self) -> HealthStatus:
499
  """Check the health status of the generator."""