manojapinew / app.py
ManojINaik's picture
Upload 4 files
93cf301 verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
import torch
from typing import Optional, List
app = FastAPI(title="LLM API", description="API for interacting with LLaMA model")
# Model configuration
class ModelConfig:
model_name = "ManojINaik/Strength_weakness" # Your fine-tuned model
device = "cuda" if torch.cuda.is_available() else "cpu"
max_length = 200
temperature = 0.7
# Request/Response models
class GenerateRequest(BaseModel):
prompt: str
history: Optional[List[str]] = []
system_prompt: Optional[str] = "You are a very powerful AI assistant."
max_length: Optional[int] = 200
temperature: Optional[float] = 0.7
class GenerateResponse(BaseModel):
response: str
# Global variables for model and tokenizer
model = None
tokenizer = None
generator = None
@app.on_event("startup")
async def load_model():
global model, tokenizer, generator
try:
print("Loading model and tokenizer...")
# Configure quantization
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=False
)
tokenizer = AutoTokenizer.from_pretrained(ModelConfig.model_name)
model = AutoModelForCausalLM.from_pretrained(
ModelConfig.model_name,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True
)
generator = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device_map="auto"
)
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {str(e)}")
raise e
@app.post("/generate/", response_model=GenerateResponse)
async def generate_text(request: GenerateRequest):
if generator is None:
raise HTTPException(status_code=500, detail="Model not loaded")
try:
# Format the prompt with system prompt and chat history
formatted_prompt = f"{request.system_prompt}\n\n"
for msg in request.history:
formatted_prompt += f"{msg}\n"
formatted_prompt += f"Human: {request.prompt}\nAssistant:"
# Generate response
outputs = generator(
formatted_prompt,
max_length=request.max_length,
temperature=request.temperature,
num_return_sequences=1,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
# Extract the generated text
generated_text = outputs[0]['generated_text']
# Remove the prompt from the response
response = generated_text.split("Assistant:")[-1].strip()
return {"response": response}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error generating text: {str(e)}")
@app.get("/")
def root():
return {"message": "LLM API is running. Use /generate endpoint for text generation."}