import torch from transformers import AutoModelForCausalLM, AutoTokenizer from typing import List, Tuple, Optional, Dict, Any, Union from dataclasses import dataclass from enum import Enum import logging from typing import List, Tuple, Optional, Dict, Any, Union, AsyncGenerator from dataclasses import dataclass from enum import Enum import logging import torch from transformers import AutoModelForCausalLM, AutoTokenizer from llama_cpp import Llama from huggingface_hub import hf_hub_download prm_model_path = hf_hub_download( repo_id="tensorblock/Llama3.1-8B-PRM-Mistral-Data-GGUF", filename="Llama3.1-8B-PRM-Mistral-Data-Q4_K_M.gguf" ) class GenerationStrategy(str, Enum): DEFAULT = "default" MAJORITY_VOTING = "majority_voting" BEST_OF_N = "best_of_n" BEAM_SEARCH = "beam_search" DVTS = "dvts" @dataclass class GenerationConfig: num_samples: int = 5 depth: int = 3 breadth: int = 2 max_history_turns: int = 3 max_new_tokens: int = 50 temperature: float = 0.7 top_p: float = 0.9 strategy: GenerationStrategy = GenerationStrategy.DEFAULT class LlamaGenerator: def __init__( self, llama_model_name: str, prm_model_path: str, device: str = None, default_generation_config: Optional[GenerationConfig] = None ): """Initialize the LlamaGenerator with specified models.""" self.logger = logging.getLogger(__name__) self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") self.default_config = default_generation_config or GenerationConfig() self.logger.info(f"Initializing LlamaGenerator on device: {self.device}") try: self._initialize_models(llama_model_name, prm_model_path) except Exception as e: self.logger.error(f"Failed to initialize models: {str(e)}") raise def _initialize_models(self, llama_model_name: str, prm_model_path: str): """Initialize models with error handling and logging.""" # Initialize LLaMA model and tokenizer self.llama_tokenizer = AutoTokenizer.from_pretrained( llama_model_name, padding_side='left', trust_remote_code=True ) if self.llama_tokenizer.pad_token is None: self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token self.llama_model = AutoModelForCausalLM.from_pretrained( llama_model_name, device_map="auto", trust_remote_code=True ) # Initialize PRM model self.prm_model = self._load_quantized_model(prm_model_path) # Enable token streaming self.supports_streaming = hasattr(self.llama_model, "streamer") async def generate_stream( self, prompt: str, config: Optional[GenerationConfig] = None ) -> AsyncGenerator[str, None]: """Stream tokens as they're generated.""" if not self.supports_streaming: raise NotImplementedError("This model doesn't support streaming") config = config or self.default_config input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) async for token in self.llama_model.streamer(input_ids, **self._get_generation_kwargs(config)): yield self.llama_tokenizer.decode([token]) def _get_generation_kwargs(self, config: GenerationConfig) -> Dict[str, Any]: """Get generation kwargs based on config.""" return { "max_new_tokens": config.max_new_tokens, "temperature": config.temperature, "top_p": config.top_p, "do_sample": config.temperature > 0, } def _load_quantized_model(self, model_path: str) -> Llama: """Load a quantized GGUF model using llama-cpp-python. Args: model_path (str): Path to the GGUF model file Returns: Llama: Loaded model instance """ try: # Configure GPU layers if CUDA is available n_gpu_layers = -1 if torch.cuda.is_available() else 0 # Load the model model = Llama( model_path=model_path, n_ctx=2048, # Context window n_batch=512, # Batch size for prompt processing n_gpu_layers=n_gpu_layers, # Number of layers to offload to GPU verbose=False ) self.logger.info(f"Successfully loaded GGUF model from {model_path}") return model except Exception as e: self.logger.error(f"Failed to load GGUF model: {str(e)}") raise def _score_with_prm(self, text: str) -> float: """Score text using the PRM model. Args: text (str): Text to score Returns: float: Model score """ try: # For GGUF models, we need to use the proper scoring interface result = self.prm_model.eval(text) return result['logprobs'] # Or another appropriate scoring metric except Exception as e: self.logger.error(f"Error scoring text with PRM: {str(e)}") return float('-inf') # Return very low score on error def _construct_prompt( self, context: str, user_input: str, chat_history: List[Tuple[str, str]], max_history_turns: int = 3 ) -> str: """Construct a formatted prompt from the input components.""" system_message = f"Please assist based on the following context: {context}" prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>" for user_msg, assistant_msg in chat_history[-max_history_turns:]: prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{user_msg}<|eot_id|>" prompt += f"<|start_header_id|>assistant<|end_header_id|>\n\n{assistant_msg}<|eot_id|>" prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{user_input}<|eot_id|>" prompt += "<|start_header_id|>assistant<|end_header_id|>\n\n" return prompt def generate( self, prompt: str, model_kwargs: Dict[str, Any], strategy: str = "default", num_samples: int = 5, depth: int = 3, breadth: int = 2 ) -> str: """Generate a response using the specified strategy. Args: prompt (str): The input prompt model_kwargs (dict): Additional arguments for model.generate() strategy (str): Generation strategy ('default', 'majority_voting', 'best_of_n', 'beam_search', 'dvts') num_samples (int): Number of samples for applicable strategies depth (int): Depth for DVTS strategy breadth (int): Breadth for DVTS strategy Returns: str: Generated response """ if strategy == "default": input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) output = self.llama_model.generate(input_ids, **model_kwargs) return self.llama_tokenizer.decode(output[0], skip_special_tokens=True) elif strategy == "majority_voting": outputs = [] for _ in range(num_samples): input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) output = self.llama_model.generate(input_ids, **model_kwargs) outputs.append(self.llama_tokenizer.decode(output[0], skip_special_tokens=True)) return max(set(outputs), key=outputs.count) elif strategy == "best_of_n": scored_outputs = [] for _ in range(num_samples): input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) output = self.llama_model.generate(input_ids, **model_kwargs) response = self.llama_tokenizer.decode(output[0], skip_special_tokens=True) score = self.prm_model(**self.llama_tokenizer(response, return_tensors="pt").to(self.device)).logits.mean().item() scored_outputs.append((response, score)) return max(scored_outputs, key=lambda x: x[1])[0] elif strategy == "beam_search": input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) outputs = self.llama_model.generate( input_ids, num_beams=num_samples, num_return_sequences=num_samples, **model_kwargs ) return [self.llama_tokenizer.decode(output, skip_special_tokens=True) for output in outputs] elif strategy == "dvts": results = [] for _ in range(breadth): input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) output = self.llama_model.generate(input_ids, **model_kwargs) response = self.llama_tokenizer.decode(output[0], skip_special_tokens=True) score = self.prm_model(**self.llama_tokenizer(response, return_tensors="pt").to(self.device)).logits.mean().item() results.append((response, score)) for _ in range(depth - 1): best_responses = sorted(results, key=lambda x: x[1], reverse=True)[:breadth] for response, _ in best_responses: input_ids = self.llama_tokenizer(response, return_tensors="pt").input_ids.to(self.device) output = self.llama_model.generate(input_ids, **model_kwargs) extended_response = self.llama_tokenizer.decode(output[0], skip_special_tokens=True) score = self.prm_model(**self.llama_tokenizer(extended_response, return_tensors="pt").to(self.device)).logits.mean().item() results.append((extended_response, score)) return max(results, key=lambda x: x[1])[0] else: raise ValueError(f"Unknown strategy: {strategy}") def generate_with_context( self, context: str, user_input: str, chat_history: List[Tuple[str, str]], model_kwargs: Dict[str, Any], max_history_turns: int = 3, strategy: str = "default", num_samples: int = 5, depth: int = 3, breadth: int = 2 ) -> str: """Generate a response using context and chat history. Args: context (str): Context for the conversation user_input (str): Current user input chat_history (List[Tuple[str, str]]): List of (user, assistant) message pairs model_kwargs (dict): Additional arguments for model.generate() max_history_turns (int): Maximum number of history turns to include strategy (str): Generation strategy num_samples (int): Number of samples for applicable strategies depth (int): Depth for DVTS strategy breadth (int): Breadth for DVTS strategy Returns: str: Generated response """ prompt = self._construct_prompt( context, user_input, chat_history, max_history_turns ) return self.generate( prompt, model_kwargs, strategy, num_samples, depth, breadth ) ###################### ######### ################# from fastapi import FastAPI, HTTPException, BackgroundTasks from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from typing import List, Optional, Dict import asyncio import uuid from datetime import datetime import json class ChatMessage(BaseModel): role: str = Field(..., description="Role of the message sender (user/assistant)") content: str = Field(..., description="Content of the message") class GenerationRequest(BaseModel): context: Optional[str] = Field(None, description="Context for the conversation") messages: List[ChatMessage] = Field(..., description="Chat history") config: Optional[Dict] = Field(None, description="Generation configuration") stream: bool = Field(False, description="Whether to stream the response") class GenerationResponse(BaseModel): id: str = Field(..., description="Generation ID") content: str = Field(..., description="Generated content") created_at: datetime = Field(default_factory=datetime.now) app = FastAPI(title="LLaMA Generation Service") # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Store generator instance generator = None @app.on_event("startup") async def startup_event(): global generator try: generator = LlamaGenerator( llama_model_name="meta-llama/Llama-3.2-1B-Instruct", prm_model_path=prm_model_path, default_generation_config=GenerationConfig( max_new_tokens=100, temperature=0.7 ) ) except Exception as e: print(f"Failed to initialize generator: {str(e)}") raise @app.post("/generate", response_model=GenerationResponse) async def generate(request: GenerationRequest): if not generator: raise HTTPException(status_code=503, detail="Generator not initialized") try: # Format chat history chat_history = [(msg.role, msg.content) for msg in request.messages[:-1]] user_input = request.messages[-1].content # Create generation config config = GenerationConfig(**request.config) if request.config else None # Generate response response = await asyncio.to_thread( generator.generate_with_context, context=request.context or "", user_input=user_input, chat_history=chat_history, model_kwargs={}, # Add any model-specific kwargs here config=config ) return GenerationResponse( id=str(uuid.uuid4()), content=response ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.websocket("/generate/stream") async def generate_stream(websocket): await websocket.accept() try: while True: # Receive and parse request request_data = await websocket.receive_text() request = GenerationRequest.parse_raw(request_data) # Format chat history chat_history = [(msg.role, msg.content) for msg in request.messages[:-1]] user_input = request.messages[-1].content # Create generation config config = GenerationConfig(**request.config) if request.config else None # Stream response async for token in generator.generate_stream( prompt=generator._construct_prompt( context=request.context or "", user_input=user_input, chat_history=chat_history ), config=config ): await websocket.send_text(json.dumps({ "token": token, "finished": False })) # Send finished message await websocket.send_text(json.dumps({ "token": "", "finished": True })) except Exception as e: await websocket.send_text(json.dumps({ "error": str(e) })) finally: await websocket.close() if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)