Spaces:
Sleeping
Sleeping
from langfuse import Langfuse | |
from langfuse.decorators import observe, langfuse_context | |
from config.config import settings | |
import os | |
# Initialize Langfuse | |
os.environ["LANGFUSE_PUBLIC_KEY"] = "pk-lf-9f2c32d2-266f-421d-9b87-51377f0a268c" | |
os.environ["LANGFUSE_SECRET_KEY"] = "sk-lf-229e10c5-6210-4a4b-a432-0f17bc66e56c" | |
os.environ["LANGFUSE_HOST"] = "https://chris4k-langfuse-template-space.hf.space" # 🇪🇺 EU region | |
try: | |
langfuse = Langfuse() | |
except Exception as e: | |
print("Langfuse Offline") | |
# model_manager.py | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from llama_cpp import Llama | |
from typing import Optional, Dict | |
import logging | |
from functools import lru_cache | |
from config.config import GenerationConfig, ModelConfig | |
class ModelManager: | |
def __init__(self, device: Optional[str] = None): | |
self.logger = logging.getLogger(__name__) | |
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") | |
self.models: Dict[str, Any] = {} | |
self.tokenizers: Dict[str, Any] = {} | |
def load_model(self, model_id: str, model_path: str, model_type: str, config: ModelConfig) -> None: | |
"""Load a model with specified configuration.""" | |
try: | |
##could be differnt models, so we can use a factory pattern to load the correct model - textgen, llama, gguf, text2video, text2image etc. | |
if model_type == "llama": | |
self.tokenizers[model_id] = AutoTokenizer.from_pretrained( | |
model_path, | |
padding_side='left', | |
trust_remote_code=True, | |
**config.tokenizer_kwargs | |
) | |
if self.tokenizers[model_id].pad_token is None: | |
self.tokenizers[model_id].pad_token = self.tokenizers[model_id].eos_token | |
self.models[model_id] = AutoModelForCausalLM.from_pretrained( | |
model_path, | |
device_map="auto", | |
trust_remote_code=True, | |
**config.model_kwargs | |
) | |
elif model_type == "gguf": | |
#TODO load the model first from the cache, if not found load the model and save it in the cache | |
#from huggingface_hub import hf_hub_download | |
#prm_model_path = hf_hub_download( | |
# repo_id="tensorblock/Llama3.1-8B-PRM-Mistral-Data-GGUF", | |
# filename="Llama3.1-8B-PRM-Mistral-Data-Q4_K_M.gguf" | |
#) | |
self.models[model_id] = self._load_quantized_model( | |
model_path, | |
**config.quantization_kwargs | |
) | |
except Exception as e: | |
self.logger.error(f"Failed to load model {model_id}: {str(e)}") | |
raise | |
def unload_model(self, model_id: str) -> None: | |
"""Unload a model and free resources.""" | |
if model_id in self.models: | |
del self.models[model_id] | |
if model_id in self.tokenizers: | |
del self.tokenizers[model_id] | |
torch.cuda.empty_cache() | |
def _load_quantized_model(self, model_path: str, **kwargs) -> Llama: | |
"""Load a quantized GGUF model.""" | |
try: | |
n_gpu_layers = -1 if torch.cuda.is_available() else 0 | |
model = Llama( | |
model_path=model_path, | |
n_ctx=kwargs.get('n_ctx', 2048), | |
n_batch=kwargs.get('n_batch', 512), | |
n_gpu_layers=kwargs.get('n_gpu_layers', n_gpu_layers), | |
verbose=kwargs.get('verbose', False) | |
) | |
return model | |
except Exception as e: | |
self.logger.error(f"Failed to load GGUF model: {str(e)}") | |
raise | |
# cache.py | |
from functools import lru_cache | |
from typing import Tuple, Any | |
# TODO explain howto use the cache | |
class ResponseCache: | |
def __init__(self, cache_size: int = 1000): | |
self.cache_size = cache_size | |
self._initialize_cache() | |
def _initialize_cache(self): | |
def cached_response(prompt: str, config_hash: str) -> Tuple[str, float]: | |
pass | |
self.get_cached_response = cached_response | |
def cache_response(self, prompt: str, config: GenerationConfig, response: str, score: float) -> None: | |
config_hash = hash(str(config.__dict__)) | |
self.get_cached_response(prompt, str(config_hash)) | |
def get_response(self, prompt: str, config: GenerationConfig) -> Optional[Tuple[str, float]]: | |
config_hash = hash(str(config.__dict__)) | |
return self.get_cached_response(prompt, str(config_hash)) | |
# batch_processor.py | |
from typing import List, Dict | |
import asyncio | |
#TODO explain how to use the batch processor | |
class BatchProcessor: | |
def __init__(self, max_batch_size: int = 32, max_wait_time: float = 0.1): | |
self.max_batch_size = max_batch_size | |
self.max_wait_time = max_wait_time | |
self.pending_requests: List[Dict] = [] | |
self.lock = asyncio.Lock() | |
async def add_request(self, request: Dict) -> Any: | |
async with self.lock: | |
self.pending_requests.append(request) | |
if len(self.pending_requests) >= self.max_batch_size: | |
return await self._process_batch() | |
else: | |
await asyncio.sleep(self.max_wait_time) | |
if self.pending_requests: | |
return await self._process_batch() | |
async def _process_batch(self) -> List[Any]: | |
batch = self.pending_requests[:self.max_batch_size] | |
self.pending_requests = self.pending_requests[self.max_batch_size:] | |
# TODO implement the batch processing logic | |
return batch | |
# base_generator.py | |
from abc import ABC, abstractmethod | |
from typing import AsyncGenerator, Dict, Any, Optional, List, Tuple | |
from dataclasses import dataclass | |
from logging import getLogger | |
from config.config import GenerationConfig, ModelConfig | |
class BaseGenerator(ABC): | |
"""Base class for all generator implementations.""" | |
def __init__( | |
self, | |
model_name: str, | |
device: Optional[str] = None, | |
default_generation_config: Optional[GenerationConfig] = None, | |
model_config: Optional[ModelConfig] = None, | |
cache_size: int = 1000, | |
max_batch_size: int = 32 | |
): | |
self.logger = getLogger(__name__) | |
self.model_manager = ModelManager(device) | |
self.cache = ResponseCache(cache_size) | |
self.batch_processor = BatchProcessor(max_batch_size) | |
self.health_check = HealthCheck() | |
self.default_config = default_generation_config or GenerationConfig() | |
self.model_config = model_config or ModelConfig() | |
async def generate_stream( | |
self, | |
prompt: str, | |
config: Optional[GenerationConfig] = None | |
) -> AsyncGenerator[str, None]: | |
pass | |
def _get_generation_kwargs(self, config: GenerationConfig) -> Dict[str, Any]: | |
pass | |
def generate(self, prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str: | |
pass | |
# strategy.py | |
#TODO UPDATE Paths | |
from abc import ABC, abstractmethod | |
from typing import List, Tuple | |
class GenerationStrategy(ABC): | |
"""Base class for generation strategies.""" | |
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str: | |
pass | |
class DefaultStrategy(GenerationStrategy): | |
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str: | |
input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device) | |
output = generator.model.generate(input_ids, **model_kwargs) | |
return generator.tokenizer.decode(output[0], skip_special_tokens=True) | |
class MajorityVotingStrategy(GenerationStrategy): | |
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: | |
outputs = [] | |
for _ in range(num_samples): | |
input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device) | |
output = generator.model.generate(input_ids, **model_kwargs) | |
outputs.append(generator.tokenizer.decode(output[0], skip_special_tokens=True)) | |
return max(set(outputs), key=outputs.count) | |
class BestOfN(GenerationStrategy): | |
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: | |
scored_outputs = [] | |
for _ in range(num_samples): | |
input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) | |
output = self.llama_model.generate(input_ids, **model_kwargs) | |
response = self.llama_tokenizer.decode(output[0], skip_special_tokens=True) | |
score = self.prm_model(**self.llama_tokenizer(response, return_tensors="pt").to(self.device)).logits.mean().item() | |
scored_outputs.append((response, score)) | |
return max(scored_outputs, key=lambda x: x[1])[0] | |
class BeamSearch(GenerationStrategy): | |
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: | |
input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) | |
outputs = self.llama_model.generate( | |
input_ids, | |
num_beams=num_samples, | |
num_return_sequences=num_samples, | |
**model_kwargs | |
) | |
return [self.llama_tokenizer.decode(output, skip_special_tokens=True) for output in outputs] | |
class DVT(GenerationStrategy): | |
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: | |
results = [] | |
for _ in range(breadth): | |
input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) | |
output = self.llama_model.generate(input_ids, **model_kwargs) | |
response = self.llama_tokenizer.decode(output[0], skip_special_tokens=True) | |
score = self.prm_model(**self.llama_tokenizer(response, return_tensors="pt").to(self.device)).logits.mean().item() | |
results.append((response, score)) | |
for _ in range(depth - 1): | |
best_responses = sorted(results, key=lambda x: x[1], reverse=True)[:breadth] | |
for response, _ in best_responses: | |
input_ids = self.llama_tokenizer(response, return_tensors="pt").input_ids.to(self.device) | |
output = self.llama_model.generate(input_ids, **model_kwargs) | |
extended_response = self.llama_tokenizer.decode(output[0], skip_special_tokens=True) | |
score = self.prm_model(**self.llama_tokenizer(extended_response, return_tensors="pt").to(self.device)).logits.mean().item() | |
results.append((extended_response, score)) | |
return max(results, key=lambda x: x[1])[0] | |
class COT(GenerationStrategy): | |
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: | |
#TODO implement the chain of thought strategy | |
return "Not implemented yet" | |
class ReAct(GenerationStrategy): | |
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: | |
#TODO implement the ReAct framework | |
return "Not implemented yet" | |
# Add other strategy implementations... | |
# prompt_builder.py | |
from typing import Protocol, List, Tuple | |
from transformers import AutoTokenizer | |
class PromptTemplate(Protocol): | |
"""Protocol for prompt templates.""" | |
def format(self, context: str, user_input: str, chat_history: List[Tuple[str, str]], **kwargs) -> str: | |
pass | |
class LlamaPromptTemplate: | |
def format(self, context: str, user_input: str, chat_history: List[Tuple[str, str]], max_history_turns: int = 1) -> str: | |
system_message = f"Please assist based on the following context: {context}" | |
prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>" | |
for user_msg, assistant_msg in chat_history[-max_history_turns:]: | |
prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{user_msg}<|eot_id|>" | |
prompt += f"<|start_header_id|>assistant<|end_header_id|>\n\n{assistant_msg}<|eot_id|>" | |
prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{user_input}<|eot_id|>" | |
prompt += "<|start_header_id|>assistant<|end_header_id|>\n\n" | |
return prompt | |
class TransformersPromptTemplate: | |
def __init__(self, model_path: str): | |
self.tokenizer = AutoTokenizer.from_pretrained(model_path) | |
def format(self, context: str, user_input: str, chat_history: List[Tuple[str, str]], **kwargs) -> str: | |
messages = [ | |
{ | |
"role": "system", | |
"content": f"Please assist based on the following context: {context}", | |
} | |
] | |
for user_msg, assistant_msg in chat_history: | |
messages.extend([ | |
{"role": "user", "content": user_msg}, | |
{"role": "assistant", "content": assistant_msg} | |
]) | |
messages.append({"role": "user", "content": user_input}) | |
tokenized_chat = self.tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
return tokenized_chat | |
# health_check.py | |
import psutil | |
from dataclasses import dataclass | |
from typing import Dict, Any | |
class HealthStatus: | |
status: str | |
gpu_memory: Dict[str, float] | |
cpu_usage: float | |
ram_usage: float | |
model_status: Dict[str, str] | |
class HealthCheck: | |
def check_gpu_memory() -> Dict[str, float]: | |
if torch.cuda.is_available(): | |
return { | |
f"gpu_{i}": torch.cuda.memory_allocated(i) / 1024**3 | |
for i in range(torch.cuda.device_count()) | |
} | |
return {} | |
def check_system_resources() -> HealthStatus: | |
return HealthStatus( | |
status="healthy", | |
gpu_memory=HealthCheck.check_gpu_memory(), | |
cpu_usage=psutil.cpu_percent(), | |
ram_usage=psutil.virtual_memory().percent, | |
#TODO add more system resources like disk, network, etc. | |
model_status={} # To be filled by the model manager | |
) | |
# llama_generator.py | |
from config.config import GenerationConfig, ModelConfig | |
class LlamaGenerator(BaseGenerator): | |
def __init__( | |
self, | |
llama_model_name: str, | |
prm_model_path: str, | |
device: Optional[str] = None, | |
default_generation_config: Optional[GenerationConfig] = None, | |
model_config: Optional[ModelConfig] = None, | |
cache_size: int = 1000, | |
max_batch_size: int = 32 | |
): | |
super().__init__( | |
llama_model_name, | |
device, | |
default_generation_config, | |
model_config, | |
cache_size, | |
max_batch_size | |
) | |
# Initialize models | |
self.model_manager.load_model( | |
"llama", | |
llama_model_name, | |
"llama", | |
self.model_config | |
) | |
self.model_manager.load_model( | |
"prm", | |
prm_model_path, | |
"gguf", | |
self.model_config | |
) | |
self.prompt_builder = LlamaPromptTemplate() | |
self._init_strategies() | |
def _init_strategies(self): | |
self.strategies = { | |
"default": DefaultStrategy(), | |
"majority_voting": MajorityVotingStrategy(), | |
"best_of_n": BestOfN(), | |
"beam_search": BeamSearch(), | |
"dvts": DVT(), | |
} | |
def _get_generation_kwargs(self, config: GenerationConfig) -> Dict[str, Any]: | |
"""Get generation kwargs based on config.""" | |
return { | |
key: getattr(config, key) | |
for key in [ | |
"max_new_tokens", | |
"temperature", | |
"top_p", | |
"top_k", | |
"repetition_penalty", | |
"length_penalty", | |
"do_sample" | |
] | |
if hasattr(config, key) | |
} | |
def generate( | |
self, | |
prompt: str, | |
model_kwargs: Dict[str, Any], | |
strategy: str = "default", | |
**kwargs | |
) -> str: | |
if strategy not in self.strategies: | |
raise ValueError(f"Unknown strategy: {strategy}") | |
return self.strategies[strategy].generate( | |
self, | |
prompt, | |
model_kwargs, | |
**kwargs | |
) | |
def check_health(self) -> HealthStatus: | |
"""Check the health status of the generator.""" | |
return self.health_check.check_system_resources() # TODO add model status | |
################### | |
################# | |
from fastapi import FastAPI, HTTPException, BackgroundTasks, WebSocket, Depends | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import StreamingResponse | |
from pydantic import BaseModel, Field, ConfigDict | |
from typing import List, Optional, Dict, Any, AsyncGenerator | |
import asyncio | |
import uuid | |
from datetime import datetime | |
import json | |
from huggingface_hub import hf_hub_download | |
from contextlib import asynccontextmanager | |
class ChatMessage(BaseModel): | |
"""A single message in the chat history.""" | |
role: str = Field( | |
..., | |
description="Role of the message sender", | |
examples=["user", "assistant"] | |
) | |
content: str = Field(..., description="Content of the message") | |
model_config = ConfigDict( | |
json_schema_extra={ | |
"example": { | |
"role": "user", | |
"content": "What is the capital of France?" | |
} | |
} | |
) | |
class GenerationConfig(BaseModel): | |
"""Configuration for text generation.""" | |
temperature: float = Field( | |
0.7, | |
ge=0.0, | |
le=2.0, | |
description="Controls randomness in the output. Higher values (e.g., 0.8) make the output more random, lower values (e.g., 0.2) make it more focused and deterministic." | |
) | |
max_new_tokens: int = Field( | |
100, | |
ge=1, | |
le=2048, | |
description="Maximum number of tokens to generate" | |
) | |
top_p: float = Field( | |
0.9, | |
ge=0.0, | |
le=1.0, | |
description="Nucleus sampling parameter. Only tokens with cumulative probability < top_p are considered." | |
) | |
top_k: int = Field( | |
50, | |
ge=0, | |
description="Only consider the top k tokens for text generation" | |
) | |
strategy: str = Field( | |
"default", | |
description="Generation strategy to use", | |
examples=["default", "majority_voting", "best_of_n", "beam_search", "dvts"] | |
) | |
num_samples: int = Field( | |
5, | |
ge=1, | |
le=10, | |
description="Number of samples to generate (used in majority_voting and best_of_n strategies)" | |
) | |
class GenerationRequest(BaseModel): | |
"""Request model for text generation.""" | |
context: Optional[str] = Field( | |
None, | |
description="Additional context to guide the generation", | |
examples=["You are a helpful assistant skilled in Python programming"] | |
) | |
messages: List[ChatMessage] = Field( | |
..., | |
description="Chat history including the current message", | |
min_items=1 | |
) | |
config: Optional[GenerationConfig] = Field( | |
None, | |
description="Generation configuration parameters" | |
) | |
stream: bool = Field( | |
False, | |
description="Whether to stream the response token by token" | |
) | |
model_config = ConfigDict( | |
json_schema_extra={ | |
"example": { | |
"context": "You are a helpful assistant", | |
"messages": [ | |
{"role": "user", "content": "What is the capital of France?"} | |
], | |
"config": { | |
"temperature": 0.7, | |
"max_new_tokens": 100 | |
}, | |
"stream": False | |
} | |
} | |
) | |
class GenerationResponse(BaseModel): | |
"""Response model for text generation.""" | |
id: str = Field(..., description="Unique generation ID") | |
content: str = Field(..., description="Generated text content") | |
created_at: datetime = Field( | |
default_factory=datetime.now, | |
description="Timestamp of generation" | |
) | |
# Model and cache management | |
async def get_prm_model_path(): | |
"""Download and cache the PRM model.""" | |
return await asyncio.to_thread( | |
hf_hub_download, | |
repo_id="tensorblock/Llama3.1-8B-PRM-Mistral-Data-GGUF", | |
filename="Llama3.1-8B-PRM-Mistral-Data-Q4_K_M.gguf" | |
) | |
async def lifespan(app: FastAPI): | |
"""Lifecycle management for the FastAPI application.""" | |
# Startup: Initialize generator | |
global generator | |
try: | |
prm_model_path = await get_prm_model_path() | |
generator = LlamaGenerator( | |
llama_model_name="meta-llama/Llama-3.2-1B-Instruct", | |
prm_model_path=prm_model_path, | |
default_generation_config=GenerationConfig( | |
max_new_tokens=100, | |
temperature=0.7 | |
) | |
) | |
yield | |
finally: | |
# Shutdown: Clean up resources | |
if generator: | |
await asyncio.to_thread(generator.cleanup) | |
# FastAPI application | |
app = FastAPI( | |
title="Inference Deluxe Service", | |
description=""" | |
A service for generating text using LLaMA models with various generation strategies. | |
Generation Strategies: | |
- default: Standard autoregressive generation | |
- majority_voting: Generates multiple responses and selects the most common one | |
- best_of_n: Generates multiple responses and selects the best based on a scoring metric | |
- beam_search: Uses beam search for more coherent text generation | |
- dvts: Dynamic vocabulary tree search for efficient generation | |
""", | |
version="1.0.0", | |
lifespan=lifespan | |
) | |
# CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
async def get_generator(): | |
"""Dependency to get the generator instance.""" | |
if not generator: | |
raise HTTPException( | |
status_code=503, | |
detail="Generator not initialized" | |
) | |
return generator | |
async def generate( | |
request: GenerationRequest, | |
generator: Any = Depends(get_generator) | |
): | |
""" | |
Generate a text response based on the provided context and chat history. | |
The generation process can be customized using various parameters in the config: | |
- temperature: Controls randomness (0.0 to 2.0) | |
- max_new_tokens: Maximum length of generated text | |
- top_p: Nucleus sampling parameter | |
- top_k: Top-k sampling parameter | |
- strategy: Generation strategy to use | |
- num_samples: Number of samples for applicable strategies | |
Generation Strategies: | |
- default: Standard generation | |
- majority_voting: Generates multiple responses and uses the most common one | |
- best_of_n: Generates multiple responses and picks the best | |
- beam_search: Uses beam search for coherent generation | |
- dvts: Dynamic vocabulary tree search | |
""" | |
try: | |
chat_history = [(msg.role, msg.content) for msg in request.messages[:-1]] | |
user_input = request.messages[-1].content | |
config = request.config or GenerationConfig() | |
response = await asyncio.to_thread( | |
generator.generate_with_context, | |
context=request.context or "", | |
user_input=user_input, | |
chat_history=chat_history, | |
config=config | |
) | |
return GenerationResponse( | |
id=str(uuid.uuid4()), | |
content=response | |
) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def generate_stream( | |
websocket: WebSocket, | |
generator: Any = Depends(get_generator) | |
): | |
""" | |
Stream generated text tokens over a WebSocket connection. | |
The stream sends JSON messages with the following format: | |
- During generation: {"token": "generated_token", "finished": false} | |
- End of generation: {"token": "", "finished": true} | |
- Error: {"error": "error_message"} | |
""" | |
await websocket.accept() | |
try: | |
while True: | |
request_data = await websocket.receive_text() | |
request = GenerationRequest.parse_raw(request_data) | |
chat_history = [(msg.role, msg.content) for msg in request.messages[:-1]] | |
user_input = request.messages[-1].content | |
config = request.config or GenerationConfig() | |
async for token in generator.generate_stream( | |
prompt=generator._construct_prompt( | |
context=request.context or "", | |
user_input=user_input, | |
chat_history=chat_history | |
), | |
config=config | |
): | |
await websocket.send_text(json.dumps({ | |
"token": token, | |
"finished": False | |
})) | |
await websocket.send_text(json.dumps({ | |
"token": "", | |
"finished": True | |
})) | |
except Exception as e: | |
await websocket.send_text(json.dumps({ | |
"error": str(e) | |
})) | |
finally: | |
await websocket.close() | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |