llm / services /llama_generator.py
Chris4K's picture
Update services/llama_generator.py
a71520e verified
raw
history blame
6.62 kB
# llama_generator.py
from config.config import GenerationConfig, ModelConfig
from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime
import logging
from config.config import settings
import asyncio
from io import StringIO
import pandas as pd
# Initialize Langfuse
os.environ["LANGFUSE_PUBLIC_KEY"] = "pk-lf-9f2c32d2-266f-421d-9b87-51377f0a268c"
os.environ["LANGFUSE_SECRET_KEY"] = "sk-lf-229e10c5-6210-4a4b-a432-0f17bc66e56c"
os.environ["LANGFUSE_HOST"] = "https://chris4k-langfuse-template-space.hf.space" # 🇪🇺 EU region
try:
langfuse = Langfuse()
except Exception as e:
print("Langfuse Offline")
@observe()
class LlamaGenerator(BaseGenerator):
def __init__(
self,
llama_model_name: str,
prm_model_path: str,
device: Optional[str] = None,
default_generation_config: Optional[GenerationConfig] = None,
model_config: Optional[ModelConfig] = None,
cache_size: int = 1000,
max_batch_size: int = 32,
# self.tokenizer = self.load_tokenizer(llama_model_name)
# self.tokenizer = self.load_tokenizer(llama_model_name) # Add this line to initialize the tokenizer
):
@observe()
def load_model(self, model_name: str):
# Code to load your model, e.g., Hugging Face's transformers library
from transformers import AutoModelForCausalLM
return AutoModelForCausalLM.from_pretrained(model_name)
@observe()
def load_tokenizer(self, model_name: str):
# Load the tokenizer associated with the model
from transformers import AutoTokenizer
return AutoTokenizer.from_pretrained(model_name)
self.tokenizer = load_tokenizer(llama_model_name) # Add this line to initialize the tokenizer
super().__init__(
llama_model_name,
device,
default_generation_config,
model_config,
cache_size,
max_batch_size
)
# Initialize models
self.model_manager.load_model(
"llama",
llama_model_name,
"llama",
self.model_config
)
self.model_manager.load_model(
"prm",
prm_model_path,
"gguf",
self.model_config
)
self.prompt_builder = LlamaPromptTemplate()
self._init_strategies()
def _init_strategies(self):
self.strategies = {
"default": DefaultStrategy(),
"majority_voting": MajorityVotingStrategy(),
"best_of_n": BestOfN(),
"beam_search": BeamSearch(),
"dvts": DVT(),
}
def _get_generation_kwargs(self, config: GenerationConfig) -> Dict[str, Any]:
"""Get generation kwargs based on config."""
return {
key: getattr(config, key)
for key in [
"max_new_tokens",
"temperature",
"top_p",
"top_k",
"repetition_penalty",
"length_penalty",
"do_sample"
]
if hasattr(config, key)
}
@observe()
def generate_stream (self):
return " NOt implememnted yet "
@observe()
def generate(
self,
prompt: str,
model_kwargs: Dict[str, Any],
strategy: str = "default",
**kwargs
) -> str:
"""
Generate text based on a given strategy.
Args:
prompt (str): Input prompt for text generation.
model_kwargs (Dict[str, Any]): Additional arguments for model generation.
strategy (str): The generation strategy to use (default: "default").
**kwargs: Additional arguments passed to the strategy.
Returns:
str: Generated text response.
Raises:
ValueError: If the specified strategy is not available.
"""
# Validate that the strategy exists
if strategy not in self.strategies:
raise ValueError(f"Unknown strategy: {strategy}. Available strategies are: {list(self.strategies.keys())}")
# Extract `generator` from kwargs if it exists to prevent duplication
kwargs.pop("generator", None)
# Call the selected strategy with the provided arguments
return self.strategies[strategy].generate(
generator=self, # The generator instance
prompt=prompt, # The input prompt
model_kwargs=model_kwargs, # Arguments for the model
**kwargs # Any additional strategy-specific arguments
)
@observe()
def generate_with_context(
self,
context: str,
user_input: str,
chat_history: List[Tuple[str, str]],
model_kwargs: Dict[str, Any],
max_history_turns: int = 3,
strategy: str = "default",
num_samples: int = 5,
depth: int = 3,
breadth: int = 2,
) -> str:
"""Generate a response using context and chat history.
Args:
context (str): Context for the conversation
user_input (str): Current user input
chat_history (List[Tuple[str, str]]): List of (user, assistant) message pairs
model_kwargs (dict): Additional arguments for model.generate()
max_history_turns (int): Maximum number of history turns to include
strategy (str): Generation strategy
num_samples (int): Number of samples for applicable strategies
depth (int): Depth for DVTS strategy
breadth (int): Breadth for DVTS strategy
Returns:
str: Generated response
"""
prompt = self.prompt_builder.format(
context,
user_input,
chat_history,
max_history_turns
)
return self.generate(
generator=self,
prompt=prompt,
model_kwargs=model_kwargs,
strategy=strategy,
num_samples=num_samples,
depth=depth,
breadth=breadth
)
def check_health(self) -> HealthStatus:
"""Check the health status of the generator."""
return self.health_check.check_system_resources() # TODO add model status