|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from typing import List, Tuple, Optional, Dict, Any, Union |
|
from dataclasses import dataclass |
|
from enum import Enum |
|
import logging |
|
|
|
from huggingface_hub import hf_hub_download |
|
|
|
prm_model_path = hf_hub_download( |
|
repo_id="tensorblock/Llama3.1-8B-PRM-Mistral-Data-GGUF", |
|
filename="Llama3.1-8B-PRM-Mistral-Data-Q4_K_M.gguf" |
|
) |
|
|
|
class GenerationStrategy(str, Enum): |
|
DEFAULT = "default" |
|
MAJORITY_VOTING = "majority_voting" |
|
BEST_OF_N = "best_of_n" |
|
BEAM_SEARCH = "beam_search" |
|
DVTS = "dvts" |
|
|
|
@dataclass |
|
class GenerationConfig: |
|
num_samples: int = 5 |
|
depth: int = 3 |
|
breadth: int = 2 |
|
max_history_turns: int = 3 |
|
max_new_tokens: int = 50 |
|
temperature: float = 0.7 |
|
top_p: float = 0.9 |
|
strategy: GenerationStrategy = GenerationStrategy.DEFAULT |
|
|
|
class LlamaGenerator: |
|
def __init__( |
|
self, |
|
llama_model_name: str, |
|
prm_model_path: str, |
|
device: str = None, |
|
default_generation_config: Optional[GenerationConfig] = None |
|
): |
|
"""Initialize the LlamaGenerator with specified models.""" |
|
self.logger = logging.getLogger(__name__) |
|
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
|
self.default_config = default_generation_config or GenerationConfig() |
|
|
|
self.logger.info(f"Initializing LlamaGenerator on device: {self.device}") |
|
|
|
try: |
|
self._initialize_models(llama_model_name, prm_model_path) |
|
except Exception as e: |
|
self.logger.error(f"Failed to initialize models: {str(e)}") |
|
raise |
|
|
|
def _initialize_models(self, llama_model_name: str, prm_model_path: str): |
|
"""Initialize models with error handling and logging.""" |
|
|
|
self.llama_tokenizer = AutoTokenizer.from_pretrained( |
|
llama_model_name, |
|
padding_side='left', |
|
trust_remote_code=True |
|
) |
|
if self.llama_tokenizer.pad_token is None: |
|
self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token |
|
|
|
self.llama_model = AutoModelForCausalLM.from_pretrained( |
|
llama_model_name, |
|
device_map="auto", |
|
trust_remote_code=True |
|
) |
|
|
|
|
|
self.prm_model = self._load_quantized_model(prm_model_path) |
|
|
|
|
|
self.supports_streaming = hasattr(self.llama_model, "streamer") |
|
|
|
async def generate_stream( |
|
self, |
|
prompt: str, |
|
config: Optional[GenerationConfig] = None |
|
) -> AsyncGenerator[str, None]: |
|
"""Stream tokens as they're generated.""" |
|
if not self.supports_streaming: |
|
raise NotImplementedError("This model doesn't support streaming") |
|
|
|
config = config or self.default_config |
|
input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) |
|
|
|
async for token in self.llama_model.streamer(input_ids, **self._get_generation_kwargs(config)): |
|
yield self.llama_tokenizer.decode([token]) |
|
|
|
def _get_generation_kwargs(self, config: GenerationConfig) -> Dict[str, Any]: |
|
"""Get generation kwargs based on config.""" |
|
return { |
|
"max_new_tokens": config.max_new_tokens, |
|
"temperature": config.temperature, |
|
"top_p": config.top_p, |
|
"do_sample": config.temperature > 0, |
|
} |
|
|
|
def _load_quantized_model(self, model_path: str) -> Llama: |
|
"""Load a quantized GGUF model using llama-cpp-python. |
|
|
|
Args: |
|
model_path (str): Path to the GGUF model file |
|
|
|
Returns: |
|
Llama: Loaded model instance |
|
""" |
|
try: |
|
|
|
n_gpu_layers = -1 if torch.cuda.is_available() else 0 |
|
|
|
|
|
model = Llama( |
|
model_path=model_path, |
|
n_ctx=2048, |
|
n_batch=512, |
|
n_gpu_layers=n_gpu_layers, |
|
verbose=False |
|
) |
|
|
|
self.logger.info(f"Successfully loaded GGUF model from {model_path}") |
|
return model |
|
|
|
except Exception as e: |
|
self.logger.error(f"Failed to load GGUF model: {str(e)}") |
|
raise |
|
|
|
def _score_with_prm(self, text: str) -> float: |
|
"""Score text using the PRM model. |
|
|
|
Args: |
|
text (str): Text to score |
|
|
|
Returns: |
|
float: Model score |
|
""" |
|
try: |
|
|
|
result = self.prm_model.eval(text) |
|
return result['logprobs'] |
|
|
|
except Exception as e: |
|
self.logger.error(f"Error scoring text with PRM: {str(e)}") |
|
return float('-inf') |
|
|
|
|
|
def _construct_prompt( |
|
self, |
|
context: str, |
|
user_input: str, |
|
chat_history: List[Tuple[str, str]], |
|
max_history_turns: int = 3 |
|
) -> str: |
|
"""Construct a formatted prompt from the input components.""" |
|
system_message = f"Please assist based on the following context: {context}" |
|
prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>" |
|
|
|
for user_msg, assistant_msg in chat_history[-max_history_turns:]: |
|
prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{user_msg}<|eot_id|>" |
|
prompt += f"<|start_header_id|>assistant<|end_header_id|>\n\n{assistant_msg}<|eot_id|>" |
|
|
|
prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{user_input}<|eot_id|>" |
|
prompt += "<|start_header_id|>assistant<|end_header_id|>\n\n" |
|
return prompt |
|
|
|
def generate( |
|
self, |
|
prompt: str, |
|
model_kwargs: Dict[str, Any], |
|
strategy: str = "default", |
|
num_samples: int = 5, |
|
depth: int = 3, |
|
breadth: int = 2 |
|
) -> str: |
|
"""Generate a response using the specified strategy. |
|
|
|
Args: |
|
prompt (str): The input prompt |
|
model_kwargs (dict): Additional arguments for model.generate() |
|
strategy (str): Generation strategy ('default', 'majority_voting', 'best_of_n', 'beam_search', 'dvts') |
|
num_samples (int): Number of samples for applicable strategies |
|
depth (int): Depth for DVTS strategy |
|
breadth (int): Breadth for DVTS strategy |
|
|
|
Returns: |
|
str: Generated response |
|
""" |
|
if strategy == "default": |
|
input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) |
|
output = self.llama_model.generate(input_ids, **model_kwargs) |
|
return self.llama_tokenizer.decode(output[0], skip_special_tokens=True) |
|
|
|
elif strategy == "majority_voting": |
|
outputs = [] |
|
for _ in range(num_samples): |
|
input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) |
|
output = self.llama_model.generate(input_ids, **model_kwargs) |
|
outputs.append(self.llama_tokenizer.decode(output[0], skip_special_tokens=True)) |
|
return max(set(outputs), key=outputs.count) |
|
|
|
elif strategy == "best_of_n": |
|
scored_outputs = [] |
|
for _ in range(num_samples): |
|
input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) |
|
output = self.llama_model.generate(input_ids, **model_kwargs) |
|
response = self.llama_tokenizer.decode(output[0], skip_special_tokens=True) |
|
score = self.prm_model(**self.llama_tokenizer(response, return_tensors="pt").to(self.device)).logits.mean().item() |
|
scored_outputs.append((response, score)) |
|
return max(scored_outputs, key=lambda x: x[1])[0] |
|
|
|
elif strategy == "beam_search": |
|
input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) |
|
outputs = self.llama_model.generate( |
|
input_ids, |
|
num_beams=num_samples, |
|
num_return_sequences=num_samples, |
|
**model_kwargs |
|
) |
|
return [self.llama_tokenizer.decode(output, skip_special_tokens=True) for output in outputs] |
|
|
|
elif strategy == "dvts": |
|
results = [] |
|
for _ in range(breadth): |
|
input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) |
|
output = self.llama_model.generate(input_ids, **model_kwargs) |
|
response = self.llama_tokenizer.decode(output[0], skip_special_tokens=True) |
|
score = self.prm_model(**self.llama_tokenizer(response, return_tensors="pt").to(self.device)).logits.mean().item() |
|
results.append((response, score)) |
|
|
|
for _ in range(depth - 1): |
|
best_responses = sorted(results, key=lambda x: x[1], reverse=True)[:breadth] |
|
for response, _ in best_responses: |
|
input_ids = self.llama_tokenizer(response, return_tensors="pt").input_ids.to(self.device) |
|
output = self.llama_model.generate(input_ids, **model_kwargs) |
|
extended_response = self.llama_tokenizer.decode(output[0], skip_special_tokens=True) |
|
score = self.prm_model(**self.llama_tokenizer(extended_response, return_tensors="pt").to(self.device)).logits.mean().item() |
|
results.append((extended_response, score)) |
|
return max(results, key=lambda x: x[1])[0] |
|
|
|
else: |
|
raise ValueError(f"Unknown strategy: {strategy}") |
|
|
|
def generate_with_context( |
|
self, |
|
context: str, |
|
user_input: str, |
|
chat_history: List[Tuple[str, str]], |
|
model_kwargs: Dict[str, Any], |
|
max_history_turns: int = 3, |
|
strategy: str = "default", |
|
num_samples: int = 5, |
|
depth: int = 3, |
|
breadth: int = 2 |
|
) -> str: |
|
"""Generate a response using context and chat history. |
|
|
|
Args: |
|
context (str): Context for the conversation |
|
user_input (str): Current user input |
|
chat_history (List[Tuple[str, str]]): List of (user, assistant) message pairs |
|
model_kwargs (dict): Additional arguments for model.generate() |
|
max_history_turns (int): Maximum number of history turns to include |
|
strategy (str): Generation strategy |
|
num_samples (int): Number of samples for applicable strategies |
|
depth (int): Depth for DVTS strategy |
|
breadth (int): Breadth for DVTS strategy |
|
|
|
Returns: |
|
str: Generated response |
|
""" |
|
prompt = self._construct_prompt( |
|
context, |
|
user_input, |
|
chat_history, |
|
max_history_turns |
|
) |
|
return self.generate( |
|
prompt, |
|
model_kwargs, |
|
strategy, |
|
num_samples, |
|
depth, |
|
breadth |
|
) |
|
|
|
|
|
|
|
|
|
from fastapi import FastAPI, HTTPException, BackgroundTasks |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from pydantic import BaseModel, Field |
|
from typing import List, Optional, Dict |
|
import asyncio |
|
import uuid |
|
from datetime import datetime |
|
import json |
|
|
|
class ChatMessage(BaseModel): |
|
role: str = Field(..., description="Role of the message sender (user/assistant)") |
|
content: str = Field(..., description="Content of the message") |
|
|
|
class GenerationRequest(BaseModel): |
|
context: Optional[str] = Field(None, description="Context for the conversation") |
|
messages: List[ChatMessage] = Field(..., description="Chat history") |
|
config: Optional[Dict] = Field(None, description="Generation configuration") |
|
stream: bool = Field(False, description="Whether to stream the response") |
|
|
|
class GenerationResponse(BaseModel): |
|
id: str = Field(..., description="Generation ID") |
|
content: str = Field(..., description="Generated content") |
|
created_at: datetime = Field(default_factory=datetime.now) |
|
|
|
app = FastAPI(title="LLaMA Generation Service") |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
generator = None |
|
|
|
@app.on_event("startup") |
|
async def startup_event(): |
|
global generator |
|
try: |
|
generator = LlamaGenerator( |
|
llama_model_name="meta-llama/Llama-3.2-1B-Instruct", |
|
prm_model_path=prm_model_path, |
|
default_generation_config=GenerationConfig( |
|
max_new_tokens=100, |
|
temperature=0.7 |
|
) |
|
) |
|
except Exception as e: |
|
print(f"Failed to initialize generator: {str(e)}") |
|
raise |
|
|
|
@app.post("/generate", response_model=GenerationResponse) |
|
async def generate(request: GenerationRequest): |
|
if not generator: |
|
raise HTTPException(status_code=503, detail="Generator not initialized") |
|
|
|
try: |
|
|
|
chat_history = [(msg.role, msg.content) for msg in request.messages[:-1]] |
|
user_input = request.messages[-1].content |
|
|
|
|
|
config = GenerationConfig(**request.config) if request.config else None |
|
|
|
|
|
response = await asyncio.to_thread( |
|
generator.generate_with_context, |
|
context=request.context or "", |
|
user_input=user_input, |
|
chat_history=chat_history, |
|
model_kwargs={}, |
|
config=config |
|
) |
|
|
|
return GenerationResponse( |
|
id=str(uuid.uuid4()), |
|
content=response |
|
) |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
@app.websocket("/generate/stream") |
|
async def generate_stream(websocket): |
|
await websocket.accept() |
|
|
|
try: |
|
while True: |
|
|
|
request_data = await websocket.receive_text() |
|
request = GenerationRequest.parse_raw(request_data) |
|
|
|
|
|
chat_history = [(msg.role, msg.content) for msg in request.messages[:-1]] |
|
user_input = request.messages[-1].content |
|
|
|
|
|
config = GenerationConfig(**request.config) if request.config else None |
|
|
|
|
|
async for token in generator.generate_stream( |
|
prompt=generator._construct_prompt( |
|
context=request.context or "", |
|
user_input=user_input, |
|
chat_history=chat_history |
|
), |
|
config=config |
|
): |
|
await websocket.send_text(json.dumps({ |
|
"token": token, |
|
"finished": False |
|
})) |
|
|
|
|
|
await websocket.send_text(json.dumps({ |
|
"token": "", |
|
"finished": True |
|
})) |
|
|
|
except Exception as e: |
|
await websocket.send_text(json.dumps({ |
|
"error": str(e) |
|
})) |
|
finally: |
|
await websocket.close() |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=8000) |