LAWSA07 commited on
Commit
738974d
·
verified ·
1 Parent(s): 73dc516

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -48
app.py CHANGED
@@ -1,72 +1,42 @@
1
  from fastapi import FastAPI, HTTPException
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
  from peft import PeftModel
4
  import torch
5
 
6
  app = FastAPI()
7
 
8
- # Load model once at startup
9
  @app.on_event("startup")
10
  async def load_model():
11
  try:
12
- # Configuration
13
- model_name = "unsloth/deepseek-r1-distill-llama-8b-unsloth-bnb-4bit"
14
- adapter_name = "LAWSA07/medical_fine_tuned_deepseekR1"
15
-
16
- # Load base model with 4-bit quantization
17
- app.state.base_model = AutoModelForCausalLM.from_pretrained(
18
- model_name,
19
  load_in_4bit=True,
20
- torch_dtype=torch.float16,
 
 
 
 
 
 
 
 
21
  device_map="auto",
22
- trust_remote_code=True,
23
  )
24
-
25
  # Attach PEFT adapter
26
  app.state.model = PeftModel.from_pretrained(
27
  app.state.base_model,
28
- adapter_name,
29
- adapter_weight_name="adapter_model.safetensors"
30
  )
31
 
32
  # Load tokenizer
33
- app.state.tokenizer = AutoTokenizer.from_pretrained(model_name)
34
-
35
- except Exception as e:
36
- raise HTTPException(
37
- status_code=500,
38
- detail=f"Model loading failed: {str(e)}"
39
  )
40
 
41
- @app.get("/")
42
- def health_check():
43
- return {"status": "OK"}
44
-
45
- @app.post("/generate")
46
- async def generate_text(prompt: str, max_length: int = 200):
47
- try:
48
- inputs = app.state.tokenizer(
49
- prompt,
50
- return_tensors="pt",
51
- padding=True
52
- ).to("cuda")
53
-
54
- outputs = app.state.model.generate(
55
- **inputs,
56
- max_length=max_length,
57
- temperature=0.7,
58
- do_sample=True
59
- )
60
-
61
- decoded = app.state.tokenizer.decode(
62
- outputs[0],
63
- skip_special_tokens=True
64
- )
65
-
66
- return {"response": decoded}
67
-
68
  except Exception as e:
69
  raise HTTPException(
70
  status_code=500,
71
- detail=f"Generation failed: {str(e)}"
72
  )
 
1
  from fastapi import FastAPI, HTTPException
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
3
  from peft import PeftModel
4
  import torch
5
 
6
  app = FastAPI()
7
 
 
8
  @app.on_event("startup")
9
  async def load_model():
10
  try:
11
+ # 4-bit config
12
+ bnb_config = BitsAndBytesConfig(
 
 
 
 
 
13
  load_in_4bit=True,
14
+ bnb_4bit_quant_type="nf4",
15
+ bnb_4bit_compute_dtype=torch.float16,
16
+ bnb_4bit_use_double_quant=True,
17
+ )
18
+
19
+ # Load base model
20
+ app.state.base_model = AutoModelForCausalLM.from_pretrained(
21
+ "unsloth/deepseek-r1-distill-llama-8b-unsloth-bnb-4bit",
22
+ quantization_config=bnb_config,
23
  device_map="auto",
24
+ trust_remote_code=True
25
  )
26
+
27
  # Attach PEFT adapter
28
  app.state.model = PeftModel.from_pretrained(
29
  app.state.base_model,
30
+ "LAWSA07/medical_fine_tuned_deepseekR1"
 
31
  )
32
 
33
  # Load tokenizer
34
+ app.state.tokenizer = AutoTokenizer.from_pretrained(
35
+ "unsloth/deepseek-r1-distill-llama-8b-unsloth-bnb-4bit"
 
 
 
 
36
  )
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  except Exception as e:
39
  raise HTTPException(
40
  status_code=500,
41
+ detail=f"Model loading failed: {str(e)}"
42
  )