deepseek-7b / app.py
arya-ai-model's picture
fixing app.py
98db4b3
raw
history blame contribute delete
1.64 kB
import os
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, BitsAndBytesConfig
# Set a writable cache directory
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
# Model setup
MODEL_NAME = "google/gemma-2b" # Smaller, CPU-friendly model
DEVICE = "cpu"
# 4-bit Quantization for CPU
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True
)
# Load model & tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=quantization_config,
device_map="cpu"
)
# Set generation config
model.generation_config = GenerationConfig.from_pretrained(MODEL_NAME)
model.generation_config.pad_token_id = model.generation_config.eos_token_id
# FastAPI app
app = FastAPI()
# Request payload
class TextGenerationRequest(BaseModel):
prompt: str
max_tokens: int = Field(default=100, ge=1, le=512) # Prevent too large token requests
@app.post("/generate")
async def generate_text(request: TextGenerationRequest):
try:
inputs = tokenizer(request.prompt, return_tensors="pt").to(DEVICE)
outputs = model.generate(**inputs, max_new_tokens=request.max_tokens, do_sample=True)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"generated_text": result}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))