khushi1234455687 commited on
Commit
4d80ee4
Β·
verified Β·
1 Parent(s): 093b386

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -0
app.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ import torch
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
5
+ from peft import PeftModel
6
+
7
+ import os
8
+ from huggingface_hub import login
9
+
10
+ # βœ… Read token from Hugging Face Secrets
11
+ HF_TOKEN = os.getenv("HF_TOKEN")
12
+
13
+ # βœ… Login only if token exists
14
+ if HF_TOKEN:
15
+ login(token=HF_TOKEN)
16
+
17
+
18
+ # βœ… Initialize FastAPI
19
+ app = FastAPI()
20
+
21
+ # βœ… Define Base Model & LoRA Adapter Repository
22
+ base_model_name = "mistralai/Mistral-7B-v0.1"
23
+ lora_repo_id = "khushi1234455687/fine-tuned-medical-qa-V8"
24
+
25
+ # βœ… Load Tokenizer
26
+ tokenizer = AutoTokenizer.from_pretrained(base_model_name)
27
+
28
+ # βœ… Configure 4-bit Quantization
29
+ quantization_config = BitsAndBytesConfig(
30
+ load_in_4bit=True,
31
+ llm_int8_enable_fp32_cpu_offload=True,
32
+ offload_buffers=True
33
+ )
34
+
35
+ # βœ… Load Base Model
36
+ base_model = AutoModelForCausalLM.from_pretrained(
37
+ base_model_name,
38
+ quantization_config=quantization_config,
39
+ device_map="auto",
40
+ torch_dtype=torch.float16
41
+ )
42
+
43
+ # βœ… Load LoRA Adapter
44
+ model = PeftModel.from_pretrained(base_model, lora_repo_id)
45
+ model.eval()
46
+
47
+ print("βœ… Model is loaded and API is ready!")
48
+
49
+ # βœ… Define Request Body Format
50
+ class QueryRequest(BaseModel):
51
+ question: str
52
+
53
+ @app.post("/generate")
54
+ async def generate_answer(request: QueryRequest):
55
+ """Generate an answer for a given medical question."""
56
+ inputs = tokenizer(request.question, return_tensors="pt").to("cuda")
57
+ with torch.no_grad():
58
+ output = model.generate(**inputs, max_length=256)
59
+ answer = tokenizer.decode(output[0], skip_special_tokens=True)
60
+
61
+ return {"question": request.question, "answer": answer}
62
+