khushi1234455687 commited on
Commit
8e67518
Β·
verified Β·
1 Parent(s): 4d80ee4

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -21
app.py CHANGED
@@ -3,7 +3,6 @@ from pydantic import BaseModel
3
  import torch
4
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
5
  from peft import PeftModel
6
-
7
  import os
8
  from huggingface_hub import login
9
 
@@ -14,35 +13,44 @@ HF_TOKEN = os.getenv("HF_TOKEN")
14
  if HF_TOKEN:
15
  login(token=HF_TOKEN)
16
 
17
-
18
  # βœ… Initialize FastAPI
19
  app = FastAPI()
20
 
21
- # βœ… Define Base Model & LoRA Adapter Repository
22
- base_model_name = "mistralai/Mistral-7B-v0.1"
23
  lora_repo_id = "khushi1234455687/fine-tuned-medical-qa-V8"
24
 
 
 
 
25
  # βœ… Load Tokenizer
26
  tokenizer = AutoTokenizer.from_pretrained(base_model_name)
27
 
28
  # βœ… Configure 4-bit Quantization
29
  quantization_config = BitsAndBytesConfig(
30
- load_in_4bit=True,
31
- llm_int8_enable_fp32_cpu_offload=True,
32
- offload_buffers=True
33
  )
34
 
35
- # βœ… Load Base Model
36
- base_model = AutoModelForCausalLM.from_pretrained(
37
- base_model_name,
38
- quantization_config=quantization_config,
39
- device_map="auto",
40
- torch_dtype=torch.float16
41
- )
 
 
 
 
42
 
43
  # βœ… Load LoRA Adapter
44
- model = PeftModel.from_pretrained(base_model, lora_repo_id)
45
- model.eval()
 
 
 
 
 
46
 
47
  print("βœ… Model is loaded and API is ready!")
48
 
@@ -53,10 +61,17 @@ class QueryRequest(BaseModel):
53
  @app.post("/generate")
54
  async def generate_answer(request: QueryRequest):
55
  """Generate an answer for a given medical question."""
56
- inputs = tokenizer(request.question, return_tensors="pt").to("cuda")
57
- with torch.no_grad():
58
- output = model.generate(**inputs, max_length=256)
59
- answer = tokenizer.decode(output[0], skip_special_tokens=True)
60
-
 
 
 
61
  return {"question": request.question, "answer": answer}
62
 
 
 
 
 
 
3
  import torch
4
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
5
  from peft import PeftModel
 
6
  import os
7
  from huggingface_hub import login
8
 
 
13
  if HF_TOKEN:
14
  login(token=HF_TOKEN)
15
 
 
16
  # βœ… Initialize FastAPI
17
  app = FastAPI()
18
 
19
+ # βœ… Define Base Model & LoRA Adapter Repository (Smaller Model for Hugging Face Spaces)
20
+ base_model_name = "mistralai/Mistral-7B-Instruct-v0.2" # πŸ”Ή Using a smaller model
21
  lora_repo_id = "khushi1234455687/fine-tuned-medical-qa-V8"
22
 
23
+ # βœ… Automatically Select CPU (Hugging Face Spaces Does NOT Support GPU)
24
+ device = "cuda" if torch.cuda.is_available() else "cpu"
25
+
26
  # βœ… Load Tokenizer
27
  tokenizer = AutoTokenizer.from_pretrained(base_model_name)
28
 
29
  # βœ… Configure 4-bit Quantization
30
  quantization_config = BitsAndBytesConfig(
31
+ load_in_4bit=True
 
 
32
  )
33
 
34
+ # βœ… Load Base Model (Optimized for CPU)
35
+ try:
36
+ base_model = AutoModelForCausalLM.from_pretrained(
37
+ base_model_name,
38
+ quantization_config=quantization_config,
39
+ device_map="auto", # βœ… Automatically assigns layers to CPU
40
+ torch_dtype=torch.float16
41
+ )
42
+ except Exception as e:
43
+ print(f"❌ Error loading base model: {e}")
44
+ raise e
45
 
46
  # βœ… Load LoRA Adapter
47
+ try:
48
+ model = PeftModel.from_pretrained(base_model, lora_repo_id)
49
+ model.to(device) # βœ… Ensure model is on the correct device
50
+ model.eval()
51
+ except Exception as e:
52
+ print(f"❌ Error loading LoRA adapter: {e}")
53
+ raise e
54
 
55
  print("βœ… Model is loaded and API is ready!")
56
 
 
61
  @app.post("/generate")
62
  async def generate_answer(request: QueryRequest):
63
  """Generate an answer for a given medical question."""
64
+ try:
65
+ inputs = tokenizer(request.question, return_tensors="pt").to(device) # βœ… Move to device
66
+ with torch.no_grad():
67
+ output = model.generate(**inputs, max_length=256)
68
+ answer = tokenizer.decode(output[0], skip_special_tokens=True)
69
+ except Exception as e:
70
+ return {"error": str(e)}
71
+
72
  return {"question": request.question, "answer": answer}
73
 
74
+ # βœ… Health Check Endpoint
75
+ @app.get("/health")
76
+ async def health_check():
77
+ return {"status": "running", "device": device}