File size: 2,655 Bytes
6e7ba05
 
 
 
 
 
 
4d80ee4
 
 
 
 
 
 
 
 
 
6e7ba05
4d80ee4
6e7ba05
4d80ee4
 
 
 
6e7ba05
 
4d80ee4
 
6e7ba05
 
8e67518
4d80ee4
6e7ba05
4d80ee4
6e7ba05
 
4d80ee4
6e7ba05
8e67518
 
 
 
6e7ba05
 
 
8e67518
 
 
 
4d80ee4
 
8e67518
6e7ba05
 
8e67518
 
 
 
4d80ee4
 
 
 
 
 
 
 
 
 
8e67518
6e7ba05
8e67518
 
 
 
 
 
4d80ee4
 
8e67518
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os

# βœ… Set a writable cache directory inside `/app`
os.environ["HF_HOME"] = "/app/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/app/huggingface"
os.environ["HF_HUB_CACHE"] = "/app/huggingface"

from fastapi import FastAPI
from pydantic import BaseModel
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
from huggingface_hub import login

# βœ… Read token from Hugging Face Secrets
HF_TOKEN = os.getenv("HF_TOKEN")

# βœ… Login only if token exists (Prevent writing to protected directories)
if HF_TOKEN:
    login(token=HF_TOKEN, cache_dir="/app/huggingface")

# βœ… Initialize FastAPI
app = FastAPI()

# βœ… Define Base Model & LoRA Adapter Repository (Use a Smaller Model)
base_model_name = "mistralai/Mistral-7B-Instruct-v0.1"  # πŸ”Ή Switched to a smaller model
lora_repo_id = "khushi1234455687/fine-tuned-medical-qa-V8"

# βœ… Force CPU Usage (Hugging Face Spaces Does NOT Support GPUs)
device = "cpu"

# βœ… Load Tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model_name, cache_dir="/app/huggingface")

# βœ… Configure 4-bit Quantization (Optimized for Spaces)
quantization_config = BitsAndBytesConfig(load_in_4bit=True)

# βœ… Load Base Model
try:
    base_model = AutoModelForCausalLM.from_pretrained(
        base_model_name,
        quantization_config=quantization_config,
        device_map="cpu",
        torch_dtype=torch.float16,
        cache_dir="/app/huggingface"
    )
except Exception as e:
    print(f"❌ Error loading base model: {e}")
    raise e

# βœ… Load LoRA Adapter
try:
    model = PeftModel.from_pretrained(base_model, lora_repo_id, cache_dir="/app/huggingface")
    model.to(device)
    model.eval()
except Exception as e:
    print(f"❌ Error loading LoRA adapter: {e}")
    raise e

print("βœ… Model is loaded and API is ready!")

# βœ… Define Request Body Format
class QueryRequest(BaseModel):
    question: str

@app.post("/generate")
async def generate_answer(request: QueryRequest):
    """Generate an answer for a given medical question."""
    try:
        inputs = tokenizer(request.question, return_tensors="pt").to(device)
        with torch.no_grad():
            output = model.generate(**inputs, max_length=256)
        answer = tokenizer.decode(output[0], skip_special_tokens=True)
    except Exception as e:
        return {"error": str(e)}

    return {"question": request.question, "answer": answer}

# βœ… Health Check Endpoint
@app.get("/health")
async def health_check():
    return {"status": "running", "device": device}