Spaces:
Configuration error
Configuration error
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline | |
import torch | |
from typing import Optional, List | |
app = FastAPI(title="LLM API", description="API for interacting with LLaMA model") | |
# Model configuration | |
class ModelConfig: | |
model_name = "ManojINaik/Strength_weakness" # Your fine-tuned model | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
max_length = 200 | |
temperature = 0.7 | |
# Request/Response models | |
class GenerateRequest(BaseModel): | |
prompt: str | |
history: Optional[List[str]] = [] | |
system_prompt: Optional[str] = "You are a very powerful AI assistant." | |
max_length: Optional[int] = 200 | |
temperature: Optional[float] = 0.7 | |
class GenerateResponse(BaseModel): | |
response: str | |
# Global variables for model and tokenizer | |
model = None | |
tokenizer = None | |
generator = None | |
async def load_model(): | |
global model, tokenizer, generator | |
try: | |
print("Loading model and tokenizer...") | |
# Configure quantization | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=torch.float16, | |
bnb_4bit_use_double_quant=False | |
) | |
tokenizer = AutoTokenizer.from_pretrained(ModelConfig.model_name) | |
model = AutoModelForCausalLM.from_pretrained( | |
ModelConfig.model_name, | |
quantization_config=bnb_config, | |
device_map="auto", | |
trust_remote_code=True | |
) | |
generator = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
device_map="auto" | |
) | |
print("Model loaded successfully!") | |
except Exception as e: | |
print(f"Error loading model: {str(e)}") | |
raise e | |
async def generate_text(request: GenerateRequest): | |
if generator is None: | |
raise HTTPException(status_code=500, detail="Model not loaded") | |
try: | |
# Format the prompt with system prompt and chat history | |
formatted_prompt = f"{request.system_prompt}\n\n" | |
for msg in request.history: | |
formatted_prompt += f"{msg}\n" | |
formatted_prompt += f"Human: {request.prompt}\nAssistant:" | |
# Generate response | |
outputs = generator( | |
formatted_prompt, | |
max_length=request.max_length, | |
temperature=request.temperature, | |
num_return_sequences=1, | |
do_sample=True, | |
pad_token_id=tokenizer.pad_token_id, | |
eos_token_id=tokenizer.eos_token_id | |
) | |
# Extract the generated text | |
generated_text = outputs[0]['generated_text'] | |
# Remove the prompt from the response | |
response = generated_text.split("Assistant:")[-1].strip() | |
return {"response": response} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error generating text: {str(e)}") | |
def root(): | |
return {"message": "LLM API is running. Use /generate endpoint for text generation."} | |