import os # ✅ Set a writable cache directory inside `/app` os.environ["HF_HOME"] = "/app/huggingface" os.environ["TRANSFORMERS_CACHE"] = "/app/huggingface" os.environ["HF_HUB_CACHE"] = "/app/huggingface" from fastapi import FastAPI from pydantic import BaseModel import torch from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig from peft import PeftModel from huggingface_hub import login # ✅ Read token from Hugging Face Secrets HF_TOKEN = os.getenv("HF_TOKEN") # ✅ Login only if token exists (Prevent writing to protected directories) if HF_TOKEN: login(token=HF_TOKEN, cache_dir="/app/huggingface") # ✅ Initialize FastAPI app = FastAPI() # ✅ Define Base Model & LoRA Adapter Repository (Use a Smaller Model) base_model_name = "mistralai/Mistral-7B-Instruct-v0.1" # 🔹 Switched to a smaller model lora_repo_id = "khushi1234455687/fine-tuned-medical-qa-V8" # ✅ Force CPU Usage (Hugging Face Spaces Does NOT Support GPUs) device = "cpu" # ✅ Load Tokenizer tokenizer = AutoTokenizer.from_pretrained(base_model_name, cache_dir="/app/huggingface") # ✅ Configure 4-bit Quantization (Optimized for Spaces) quantization_config = BitsAndBytesConfig(load_in_4bit=True) # ✅ Load Base Model try: base_model = AutoModelForCausalLM.from_pretrained( base_model_name, quantization_config=quantization_config, device_map="cpu", torch_dtype=torch.float16, cache_dir="/app/huggingface" ) except Exception as e: print(f"❌ Error loading base model: {e}") raise e # ✅ Load LoRA Adapter try: model = PeftModel.from_pretrained(base_model, lora_repo_id, cache_dir="/app/huggingface") model.to(device) model.eval() except Exception as e: print(f"❌ Error loading LoRA adapter: {e}") raise e print("✅ Model is loaded and API is ready!") # ✅ Define Request Body Format class QueryRequest(BaseModel): question: str @app.post("/generate") async def generate_answer(request: QueryRequest): """Generate an answer for a given medical question.""" try: inputs = tokenizer(request.question, return_tensors="pt").to(device) with torch.no_grad(): output = model.generate(**inputs, max_length=256) answer = tokenizer.decode(output[0], skip_special_tokens=True) except Exception as e: return {"error": str(e)} return {"question": request.question, "answer": answer} # ✅ Health Check Endpoint @app.get("/health") async def health_check(): return {"status": "running", "device": device}