medical_model / app.py
LAWSA07's picture
Update app.py
738974d verified
from fastapi import FastAPI, HTTPException
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
import torch
app = FastAPI()
@app.on_event("startup")
async def load_model():
try:
# 4-bit config
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
)
# Load base model
app.state.base_model = AutoModelForCausalLM.from_pretrained(
"unsloth/deepseek-r1-distill-llama-8b-unsloth-bnb-4bit",
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True
)
# Attach PEFT adapter
app.state.model = PeftModel.from_pretrained(
app.state.base_model,
"LAWSA07/medical_fine_tuned_deepseekR1"
)
# Load tokenizer
app.state.tokenizer = AutoTokenizer.from_pretrained(
"unsloth/deepseek-r1-distill-llama-8b-unsloth-bnb-4bit"
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Model loading failed: {str(e)}"
)