Spaces:
Running
Running
from langfuse import Langfuse | |
from langfuse.decorators import observe, langfuse_context | |
from config.config import settings | |
from services.llama_generator import LlamaGenerator | |
import os | |
# Initialize Langfuse | |
os.environ["LANGFUSE_PUBLIC_KEY"] = "pk-lf-04d2302a-aa5c-4870-9703-58ab64c3bcae" | |
os.environ["LANGFUSE_SECRET_KEY"] = "sk-lf-d34ea200-feec-428e-a621-784fce93a5af" | |
os.environ["LANGFUSE_HOST"] = "https://chris4k-langfuse-template-space.hf.space" # 🇪🇺 EU region | |
try: | |
langfuse = Langfuse() | |
except Exception as e: | |
print("Langfuse Offline") | |
################### | |
################# | |
from fastapi import FastAPI, HTTPException, BackgroundTasks, WebSocket, Depends | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import StreamingResponse | |
from pydantic import BaseModel, Field, ConfigDict | |
from typing import List, Optional, Dict, Any, AsyncGenerator | |
import asyncio | |
import uuid | |
from datetime import datetime | |
import json | |
from huggingface_hub import hf_hub_download | |
from contextlib import asynccontextmanager | |
class ChatMessage(BaseModel): | |
"""A single message in the chat history.""" | |
role: str = Field( | |
..., | |
description="Role of the message sender", | |
examples=["user", "assistant"] | |
) | |
content: str = Field(..., description="Content of the message") | |
model_config = ConfigDict( | |
json_schema_extra={ | |
"example": { | |
"role": "user", | |
"content": "What is the capital of France?" | |
} | |
} | |
) | |
class GenerationConfig(BaseModel): | |
"""Configuration for text generation.""" | |
temperature: float = Field( | |
0.7, | |
ge=0.0, | |
le=2.0, | |
description="Controls randomness in the output. Higher values (e.g., 0.8) make the output more random, lower values (e.g., 0.2) make it more focused and deterministic." | |
) | |
max_new_tokens: int = Field( | |
100, | |
ge=1, | |
le=2048, | |
description="Maximum number of tokens to generate" | |
) | |
top_p: float = Field( | |
0.9, | |
ge=0.0, | |
le=1.0, | |
description="Nucleus sampling parameter. Only tokens with cumulative probability < top_p are considered." | |
) | |
top_k: int = Field( | |
50, | |
ge=0, | |
description="Only consider the top k tokens for text generation" | |
) | |
strategy: str = Field( | |
"default", | |
description="Generation strategy to use", | |
examples=["default", "majority_voting", "best_of_n", "beam_search", "dvts"] | |
) | |
num_samples: int = Field( | |
5, | |
ge=1, | |
le=10, | |
description="Number of samples to generate (used in majority_voting and best_of_n strategies)" | |
) | |
class GenerationRequest(BaseModel): | |
"""Request model for text generation.""" | |
context: Optional[str] = Field( | |
None, | |
description="Additional context to guide the generation", | |
examples=["You are a helpful assistant skilled in Python programming"] | |
) | |
messages: List[ChatMessage] = Field( | |
..., | |
description="Chat history including the current message", | |
min_items=1 | |
) | |
config: Optional[GenerationConfig] = Field( | |
None, | |
description="Generation configuration parameters" | |
) | |
stream: bool = Field( | |
False, | |
description="Whether to stream the response token by token" | |
) | |
model_config = ConfigDict( | |
json_schema_extra={ | |
"example": { | |
"context": "You are a helpful assistant", | |
"messages": [ | |
{"role": "user", "content": "What is the capital of France?"} | |
], | |
"config": { | |
"temperature": 0.7, | |
"max_new_tokens": 100 | |
}, | |
"stream": False | |
} | |
} | |
) | |
class GenerationResponse(BaseModel): | |
"""Response model for text generation.""" | |
id: str = Field(..., description="Unique generation ID") | |
content: str = Field(..., description="Generated text content") | |
created_at: datetime = Field( | |
default_factory=datetime.now, | |
description="Timestamp of generation" | |
) | |
# Model and cache management | |
async def get_prm_model_path(): | |
"""Download and cache the PRM model.""" | |
return await asyncio.to_thread( | |
hf_hub_download, | |
repo_id="tensorblock/Llama3.1-8B-PRM-Mistral-Data-GGUF", | |
filename="Llama3.1-8B-PRM-Mistral-Data-Q4_K_M.gguf" | |
) | |
# Initialize generator globally | |
generator = None | |
async def lifespan(app: FastAPI): | |
"""Lifecycle management for the FastAPI application.""" | |
# Startup: Initialize generator | |
global generator | |
try: | |
prm_model_path = await get_prm_model_path() | |
generator = LlamaGenerator( | |
llama_model_name="meta-llama/Llama-3.2-1B-Instruct", | |
prm_model_path=prm_model_path, | |
default_generation_config=GenerationConfig( | |
max_new_tokens=100, | |
temperature=0.7 | |
) | |
) | |
yield | |
finally: | |
# Shutdown: Clean up resources | |
if generator: | |
await asyncio.to_thread(generator.cleanup) | |
# FastAPI application | |
app = FastAPI( | |
title="Inference Deluxe Service", | |
description=""" | |
A service for generating text using LLaMA models with various generation strategies. | |
Generation Strategies: | |
- default: Standard autoregressive generation | |
- majority_voting: Generates multiple responses and selects the most common one | |
- best_of_n: Generates multiple responses and selects the best based on a scoring metric | |
- beam_search: Uses beam search for more coherent text generation | |
- dvts: Dynamic vocabulary tree search for efficient generation | |
""", | |
version="1.0.0", | |
lifespan=lifespan | |
) | |
# CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
async def get_generator(): | |
"""Dependency to get the generator instance.""" | |
if not generator: | |
raise HTTPException( | |
status_code=503, | |
detail="Generator not initialized" | |
) | |
return generator | |
async def generate( | |
request: GenerationRequest, | |
generator: Any = Depends(get_generator) | |
): | |
""" | |
Generate a text response based on the provided context and chat history. | |
""" | |
try: | |
chat_history = [(msg.role, msg.content) for msg in request.messages[:-1]] | |
user_input = request.messages[-1].content | |
# Extract or set defaults for additional arguments | |
config = request.config or GenerationConfig() | |
model_kwargs = { | |
"temperature": config.temperature if hasattr(config, "temperature") else 0.7, | |
"max_new_tokens": config.max_new_tokens if hasattr(config, "max_new_tokens") else 100, | |
# Add other model kwargs as needed | |
} | |
# Explicitly pass additional required arguments | |
response = await asyncio.to_thread( | |
generator.generate_with_context, | |
context=request.context or "", | |
user_input=user_input, | |
chat_history=chat_history, | |
model_kwargs=model_kwargs, | |
max_history_turns=config.max_history_turns if hasattr(config, "max_history_turns") else 3, | |
strategy=config.strategy if hasattr(config, "strategy") else "default", | |
num_samples=config.num_samples if hasattr(config, "num_samples") else 5, | |
depth=config.depth if hasattr(config, "depth") else 3, | |
breadth=config.breadth if hasattr(config, "breadth") else 2, | |
) | |
return GenerationResponse( | |
id=str(uuid.uuid4()), | |
content=response | |
) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def generate_stream( | |
websocket: WebSocket, | |
generator: Any = Depends(get_generator) | |
): | |
""" | |
Stream generated text tokens over a WebSocket connection. | |
The stream sends JSON messages with the following format: | |
- During generation: {"token": "generated_token", "finished": false} | |
- End of generation: {"token": "", "finished": true} | |
- Error: {"error": "error_message"} | |
""" | |
await websocket.accept() | |
try: | |
while True: | |
request_data = await websocket.receive_text() | |
request = GenerationRequest.parse_raw(request_data) | |
chat_history = [(msg.role, msg.content) for msg in request.messages[:-1]] | |
user_input = request.messages[-1].content | |
config = request.config or GenerationConfig() | |
async for token in generator.generate_stream( | |
prompt=generator.prompt_builder.format( | |
context=request.context or "", | |
user_input=user_input, | |
chat_history=chat_history | |
), | |
config=config | |
): | |
await websocket.send_text(json.dumps({ | |
"token": token, | |
"finished": False | |
})) | |
await websocket.send_text(json.dumps({ | |
"token": "", | |
"finished": True | |
})) | |
except Exception as e: | |
await websocket.send_text(json.dumps({ | |
"error": str(e) | |
})) | |
finally: | |
await websocket.close() | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |