# llama_generator.py from config.config import GenerationConfig, ModelConfig @observe() class LlamaGenerator(BaseGenerator): def __init__( self, llama_model_name: str, prm_model_path: str, device: Optional[str] = None, default_generation_config: Optional[GenerationConfig] = None, model_config: Optional[ModelConfig] = None, cache_size: int = 1000, max_batch_size: int = 32, # self.tokenizer = self.load_tokenizer(llama_model_name) # self.tokenizer = self.load_tokenizer(llama_model_name) # Add this line to initialize the tokenizer ): @observe() def load_model(self, model_name: str): # Code to load your model, e.g., Hugging Face's transformers library from transformers import AutoModelForCausalLM return AutoModelForCausalLM.from_pretrained(model_name) @observe() def load_tokenizer(self, model_name: str): # Load the tokenizer associated with the model from transformers import AutoTokenizer return AutoTokenizer.from_pretrained(model_name) self.tokenizer = load_tokenizer(llama_model_name) # Add this line to initialize the tokenizer super().__init__( llama_model_name, device, default_generation_config, model_config, cache_size, max_batch_size ) # Initialize models self.model_manager.load_model( "llama", llama_model_name, "llama", self.model_config ) self.model_manager.load_model( "prm", prm_model_path, "gguf", self.model_config ) self.prompt_builder = LlamaPromptTemplate() self._init_strategies() def _init_strategies(self): self.strategies = { "default": DefaultStrategy(), "majority_voting": MajorityVotingStrategy(), "best_of_n": BestOfN(), "beam_search": BeamSearch(), "dvts": DVT(), } def _get_generation_kwargs(self, config: GenerationConfig) -> Dict[str, Any]: """Get generation kwargs based on config.""" return { key: getattr(config, key) for key in [ "max_new_tokens", "temperature", "top_p", "top_k", "repetition_penalty", "length_penalty", "do_sample" ] if hasattr(config, key) } @observe() def generate_stream (self): return " NOt implememnted yet " @observe() def generate( self, prompt: str, model_kwargs: Dict[str, Any], strategy: str = "default", **kwargs ) -> str: """ Generate text based on a given strategy. Args: prompt (str): Input prompt for text generation. model_kwargs (Dict[str, Any]): Additional arguments for model generation. strategy (str): The generation strategy to use (default: "default"). **kwargs: Additional arguments passed to the strategy. Returns: str: Generated text response. Raises: ValueError: If the specified strategy is not available. """ # Validate that the strategy exists if strategy not in self.strategies: raise ValueError(f"Unknown strategy: {strategy}. Available strategies are: {list(self.strategies.keys())}") # Extract `generator` from kwargs if it exists to prevent duplication kwargs.pop("generator", None) # Call the selected strategy with the provided arguments return self.strategies[strategy].generate( generator=self, # The generator instance prompt=prompt, # The input prompt model_kwargs=model_kwargs, # Arguments for the model **kwargs # Any additional strategy-specific arguments ) @observe() def generate_with_context( self, context: str, user_input: str, chat_history: List[Tuple[str, str]], model_kwargs: Dict[str, Any], max_history_turns: int = 3, strategy: str = "default", num_samples: int = 5, depth: int = 3, breadth: int = 2, ) -> str: """Generate a response using context and chat history. Args: context (str): Context for the conversation user_input (str): Current user input chat_history (List[Tuple[str, str]]): List of (user, assistant) message pairs model_kwargs (dict): Additional arguments for model.generate() max_history_turns (int): Maximum number of history turns to include strategy (str): Generation strategy num_samples (int): Number of samples for applicable strategies depth (int): Depth for DVTS strategy breadth (int): Breadth for DVTS strategy Returns: str: Generated response """ prompt = self.prompt_builder.format( context, user_input, chat_history, max_history_turns ) return self.generate( generator=self, prompt=prompt, model_kwargs=model_kwargs, strategy=strategy, num_samples=num_samples, depth=depth, breadth=breadth ) def check_health(self) -> HealthStatus: """Check the health status of the generator.""" return self.health_check.check_system_resources() # TODO add model status