Advait3009 commited on
Commit
a558a96
·
verified ·
1 Parent(s): 2821092

Update utils/model_loader.py

Browse files
Files changed (1) hide show
  1. utils/model_loader.py +16 -11
utils/model_loader.py CHANGED
@@ -1,24 +1,26 @@
1
- from transformers import pipeline, AutoProcessor, AutoModelForCausalLM
2
  import torch
3
  from typing import Optional
4
 
5
  def load_llava_model():
6
  """Load LLaVA model with 4-bit quantization for HF Spaces"""
7
  model_id = "llava-hf/llava-1.5-7b-hf"
8
-
 
 
 
 
 
 
 
9
  return pipeline(
10
  "image-to-text",
11
  model=model_id,
 
12
  device_map="auto",
13
  model_kwargs={
14
  "torch_dtype": torch.float16,
15
- "load_in_4bit": True,
16
- "quantization_config": {
17
- "load_in_4bit": True,
18
- "bnb_4bit_compute_dtype": torch.float16,
19
- "bnb_4bit_use_double_quant": True,
20
- "bnb_4bit_quant_type": "nf4"
21
- }
22
  }
23
  )
24
 
@@ -34,16 +36,19 @@ def load_caption_model():
34
 
35
  def load_retrieval_models():
36
  """Load encoders with shared weights"""
 
 
 
37
  models = {}
38
  models['text_encoder'] = SentenceTransformer(
39
  'sentence-transformers/all-MiniLM-L6-v2',
40
  device="cuda" if torch.cuda.is_available() else "cpu"
41
  )
42
-
43
  models['image_encoder'] = AutoModel.from_pretrained(
44
  "openai/clip-vit-base-patch32",
45
  device_map="auto",
46
  torch_dtype=torch.float16
47
  )
48
-
49
  return models
 
1
+ from transformers import pipeline, AutoTokenizer, BitsAndBytesConfig
2
  import torch
3
  from typing import Optional
4
 
5
  def load_llava_model():
6
  """Load LLaVA model with 4-bit quantization for HF Spaces"""
7
  model_id = "llava-hf/llava-1.5-7b-hf"
8
+
9
+ quant_config = BitsAndBytesConfig(
10
+ load_in_4bit=True,
11
+ bnb_4bit_compute_dtype=torch.float16,
12
+ bnb_4bit_use_double_quant=True,
13
+ bnb_4bit_quant_type="nf4"
14
+ )
15
+
16
  return pipeline(
17
  "image-to-text",
18
  model=model_id,
19
+ tokenizer=model_id,
20
  device_map="auto",
21
  model_kwargs={
22
  "torch_dtype": torch.float16,
23
+ "quantization_config": quant_config
 
 
 
 
 
 
24
  }
25
  )
26
 
 
36
 
37
  def load_retrieval_models():
38
  """Load encoders with shared weights"""
39
+ from sentence_transformers import SentenceTransformer
40
+ from transformers import AutoModel
41
+
42
  models = {}
43
  models['text_encoder'] = SentenceTransformer(
44
  'sentence-transformers/all-MiniLM-L6-v2',
45
  device="cuda" if torch.cuda.is_available() else "cpu"
46
  )
47
+
48
  models['image_encoder'] = AutoModel.from_pretrained(
49
  "openai/clip-vit-base-patch32",
50
  device_map="auto",
51
  torch_dtype=torch.float16
52
  )
53
+
54
  return models