medical-qa-api / app.py
khushi1234455687's picture
Upload app.py
6e7ba05 verified
raw
history blame
2.66 kB
import os
# βœ… Set a writable cache directory inside `/app`
os.environ["HF_HOME"] = "/app/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/app/huggingface"
os.environ["HF_HUB_CACHE"] = "/app/huggingface"
from fastapi import FastAPI
from pydantic import BaseModel
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
from huggingface_hub import login
# βœ… Read token from Hugging Face Secrets
HF_TOKEN = os.getenv("HF_TOKEN")
# βœ… Login only if token exists (Prevent writing to protected directories)
if HF_TOKEN:
login(token=HF_TOKEN, cache_dir="/app/huggingface")
# βœ… Initialize FastAPI
app = FastAPI()
# βœ… Define Base Model & LoRA Adapter Repository (Use a Smaller Model)
base_model_name = "mistralai/Mistral-7B-Instruct-v0.1" # πŸ”Ή Switched to a smaller model
lora_repo_id = "khushi1234455687/fine-tuned-medical-qa-V8"
# βœ… Force CPU Usage (Hugging Face Spaces Does NOT Support GPUs)
device = "cpu"
# βœ… Load Tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model_name, cache_dir="/app/huggingface")
# βœ… Configure 4-bit Quantization (Optimized for Spaces)
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
# βœ… Load Base Model
try:
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name,
quantization_config=quantization_config,
device_map="cpu",
torch_dtype=torch.float16,
cache_dir="/app/huggingface"
)
except Exception as e:
print(f"❌ Error loading base model: {e}")
raise e
# βœ… Load LoRA Adapter
try:
model = PeftModel.from_pretrained(base_model, lora_repo_id, cache_dir="/app/huggingface")
model.to(device)
model.eval()
except Exception as e:
print(f"❌ Error loading LoRA adapter: {e}")
raise e
print("βœ… Model is loaded and API is ready!")
# βœ… Define Request Body Format
class QueryRequest(BaseModel):
question: str
@app.post("/generate")
async def generate_answer(request: QueryRequest):
"""Generate an answer for a given medical question."""
try:
inputs = tokenizer(request.question, return_tensors="pt").to(device)
with torch.no_grad():
output = model.generate(**inputs, max_length=256)
answer = tokenizer.decode(output[0], skip_special_tokens=True)
except Exception as e:
return {"error": str(e)}
return {"question": request.question, "answer": answer}
# βœ… Health Check Endpoint
@app.get("/health")
async def health_check():
return {"status": "running", "device": device}