File size: 6,511 Bytes
919c57c a71520e b0bd177 1a48604 155e7ab b0bd177 a71520e c3f7b7b b0bd177 a71520e 555705a a71520e 919c57c 4d00775 919c57c 1a48604 919c57c e1499dd 919c57c e282347 919c57c cbf16f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
# llama_generator.py
from config.config import GenerationConfig, ModelConfig
from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime
import logging
from config.config import settings
from services.prompt_builder import LlamaPromptTemplate
from services.model_manager import ModelManager
from services.base_generator import BaseGenerator
import asyncio
from io import StringIO
import pandas as pd
from langfuse.decorators import observe, langfuse_context
import os
# Initialize Langfuse
os.environ["LANGFUSE_PUBLIC_KEY"] = "pk-lf-04d2302a-aa5c-4870-9703-58ab64c3bcae"
os.environ["LANGFUSE_SECRET_KEY"] = "sk-lf-d34ea200-feec-428e-a621-784fce93a5af"
os.environ["LANGFUSE_HOST"] = "https://chris4k-langfuse-template-space.hf.space" # 🇪🇺 EU region
try:
langfuse = Langfuse()
except Exception as e:
print("Langfuse Offline")
@observe()
class LlamaGenerator(BaseGenerator):
def __init__(
self,
llama_model_name: str,
prm_model_path: str,
device: Optional[str] = None,
default_generation_config: Optional[GenerationConfig] = None,
model_config: Optional[ModelConfig] = None,
cache_size: int = 1000,
max_batch_size: int = 32,
# self.tokenizer = self.load_tokenizer(llama_model_name)
# self.tokenizer = self.load_tokenizer(llama_model_name) # Add this line to initialize the tokenizer
):
print(llama_model_name)
print(prm_model_path)
self.model_manager = ModelManager()
self.tokenizer = self.model_manager.load_tokenizer(llama_model_name) # Add this line to initialize the tokenizer
super().__init__(
llama_model_name,
device,
default_generation_config,
model_config,
cache_size,
max_batch_size
)
# Initialize models
self.model_manager.load_model(
"llama",
llama_model_name,
"llama",
self.model_config
)
self.model_manager.load_model(
"prm",
prm_model_path,
"gguf",
self.model_config
)
self.prompt_builder = LlamaPromptTemplate()
self._init_strategies()
def _init_strategies(self):
self.strategies = {
"default": DefaultStrategy(),
"majority_voting": MajorityVotingStrategy(),
"best_of_n": BestOfN(),
"beam_search": BeamSearch(),
"dvts": DVT(),
}
def _get_generation_kwargs(self, config: GenerationConfig) -> Dict[str, Any]:
"""Get generation kwargs based on config."""
return {
key: getattr(config, key)
for key in [
"max_new_tokens",
"temperature",
"top_p",
"top_k",
"repetition_penalty",
"length_penalty",
"do_sample"
]
if hasattr(config, key)
}
@observe()
def generate_stream (self):
return " NOt implememnted yet "
@observe()
def generate(
self,
prompt: str,
model_kwargs: Dict[str, Any],
strategy: str = "default",
**kwargs
) -> str:
"""
Generate text based on a given strategy.
Args:
prompt (str): Input prompt for text generation.
model_kwargs (Dict[str, Any]): Additional arguments for model generation.
strategy (str): The generation strategy to use (default: "default").
**kwargs: Additional arguments passed to the strategy.
Returns:
str: Generated text response.
Raises:
ValueError: If the specified strategy is not available.
"""
# Validate that the strategy exists
if strategy not in self.strategies:
raise ValueError(f"Unknown strategy: {strategy}. Available strategies are: {list(self.strategies.keys())}")
# Extract `generator` from kwargs if it exists to prevent duplication
kwargs.pop("generator", None)
# Call the selected strategy with the provided arguments
return self.strategies[strategy].generate(
generator=self, # The generator instance
prompt=prompt, # The input prompt
model_kwargs=model_kwargs, # Arguments for the model
**kwargs # Any additional strategy-specific arguments
)
@observe()
def generate_with_context(
self,
context: str,
user_input: str,
chat_history: List[Tuple[str, str]],
model_kwargs: Dict[str, Any],
max_history_turns: int = 3,
strategy: str = "default",
num_samples: int = 5,
depth: int = 3,
breadth: int = 2,
) -> str:
"""Generate a response using context and chat history.
Args:
context (str): Context for the conversation
user_input (str): Current user input
chat_history (List[Tuple[str, str]]): List of (user, assistant) message pairs
model_kwargs (dict): Additional arguments for model.generate()
max_history_turns (int): Maximum number of history turns to include
strategy (str): Generation strategy
num_samples (int): Number of samples for applicable strategies
depth (int): Depth for DVTS strategy
breadth (int): Breadth for DVTS strategy
Returns:
str: Generated response
"""
prompt = self.prompt_builder.format(
context,
user_input,
chat_history,
max_history_turns
)
return self.generate(
generator=self,
prompt=prompt,
model_kwargs=model_kwargs,
strategy=strategy,
num_samples=num_samples,
depth=depth,
breadth=breadth
)
def check_health(self) : #-> HealthStatus:
"""Check the health status of the generator."""
#return self.health_check.check_system_resources() # TODO add model status
return "All good? - Check not omplemented " |