jatingocodeo commited on
Commit
b81df6e
·
verified ·
1 Parent(s): 0bdc84a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -4
app.py CHANGED
@@ -4,8 +4,12 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from peft import PeftModel
5
  from PIL import Image
6
  import torchvision.datasets as datasets
 
7
 
8
  def load_model(model_id):
 
 
 
9
  # First load the base model
10
  base_model_id = "microsoft/Phi-3-mini-4k-instruct"
11
  tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
@@ -14,15 +18,29 @@ def load_model(model_id):
14
  if tokenizer.pad_token is None:
15
  tokenizer.pad_token = tokenizer.eos_token
16
 
 
17
  base_model = AutoModelForCausalLM.from_pretrained(
18
  base_model_id,
19
- torch_dtype=torch.float16, # Use float16 like assignment22
20
- device_map="auto",
 
 
 
 
 
 
 
21
  trust_remote_code=True
22
  )
23
 
24
- # Load the LoRA adapter
25
- model = PeftModel.from_pretrained(base_model, model_id)
 
 
 
 
 
 
26
  return model, tokenizer
27
 
28
  def generate_description(image, model, tokenizer, max_length=100, temperature=0.7, top_p=0.9):
 
4
  from peft import PeftModel
5
  from PIL import Image
6
  import torchvision.datasets as datasets
7
+ import os
8
 
9
  def load_model(model_id):
10
+ # Create offload directory
11
+ os.makedirs("offload", exist_ok=True)
12
+
13
  # First load the base model
14
  base_model_id = "microsoft/Phi-3-mini-4k-instruct"
15
  tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
 
18
  if tokenizer.pad_token is None:
19
  tokenizer.pad_token = tokenizer.eos_token
20
 
21
+ # Load base model with 8-bit quantization and offloading
22
  base_model = AutoModelForCausalLM.from_pretrained(
23
  base_model_id,
24
+ load_in_8bit=True, # Use 8-bit quantization
25
+ torch_dtype=torch.float16,
26
+ device_map={
27
+ "model.embed_tokens": 0,
28
+ "model.layers": "auto",
29
+ "model.norm": "cpu",
30
+ "lm_head": 0
31
+ },
32
+ offload_folder="offload",
33
  trust_remote_code=True
34
  )
35
 
36
+ # Load the LoRA adapter with same device mapping
37
+ model = PeftModel.from_pretrained(
38
+ base_model,
39
+ model_id,
40
+ offload_folder="offload",
41
+ device_map="auto"
42
+ )
43
+
44
  return model, tokenizer
45
 
46
  def generate_description(image, model, tokenizer, max_length=100, temperature=0.7, top_p=0.9):