from fastapi import FastAPI, HTTPException from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig from peft import PeftModel import torch app = FastAPI() @app.on_event("startup") async def load_model(): try: # 4-bit config bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, ) # Load base model app.state.base_model = AutoModelForCausalLM.from_pretrained( "unsloth/deepseek-r1-distill-llama-8b-unsloth-bnb-4bit", quantization_config=bnb_config, device_map="auto", trust_remote_code=True ) # Attach PEFT adapter app.state.model = PeftModel.from_pretrained( app.state.base_model, "LAWSA07/medical_fine_tuned_deepseekR1" ) # Load tokenizer app.state.tokenizer = AutoTokenizer.from_pretrained( "unsloth/deepseek-r1-distill-llama-8b-unsloth-bnb-4bit" ) except Exception as e: raise HTTPException( status_code=500, detail=f"Model loading failed: {str(e)}" )