Upyaya commited on
Commit
37db1ce
·
1 Parent(s): 3773d42

Fixed the issue to load the model for inferance in CPU device

Browse files

The model is trained on GPU, with bitsandbytes, peft. But bitsandbytes does work only on GPU devices. So modify the inti model and input dtype to work on CPU

Files changed (1) hide show
  1. app.py +29 -12
app.py CHANGED
@@ -4,25 +4,32 @@ from peft import PeftModel
4
  import streamlit as st
5
  from PIL import Image
6
  import torch
 
7
 
8
  preprocess_ckp = "Salesforce/blip2-opt-2.7b" #Checkpoint path used for perprocess image
9
  base_model_ckp = "./model/blip2-opt-2.7b-fp16-sharded" #Base model checkpoint path
10
  peft_model_ckp = "./model/blip2_peft" #PEFT model checkpoint path
11
-
 
12
  #init_model_required = True
13
- processor = None
14
- model = None
15
 
16
- def init_model():
17
 
18
  #if init_model_required:
19
 
20
- #Preprocess input
21
- processor = Blip2Processor.from_pretrained(preprocess_ckp)
 
 
 
 
22
 
23
- #Model
24
- model = Blip2ForConditionalGeneration.from_pretrained(base_model_ckp)#, load_in_8bit = True, device_map = "auto")
25
- model = PeftModel.from_pretrained(model, peft_model_ckp)
 
26
 
27
  #init_model_required = False
28
 
@@ -32,10 +39,16 @@ def main():
32
 
33
  st.title("Fashion Image Caption using BLIP2")
34
 
35
- init_model()
36
 
 
 
37
  file_name = st.file_uploader("Upload image")
38
 
 
 
 
 
39
  if file_name is not None:
40
 
41
  image_col, caption_text = st.columns(2)
@@ -45,7 +58,12 @@ def main():
45
  image_col.image(image, use_column_width = True)
46
 
47
  #Preprocess the image
48
- inputs = processor(images = image, return_tensors = "pt").to('cuda', torch.float16)
 
 
 
 
 
49
  pixel_values = inputs.pixel_values
50
 
51
  #Predict the caption for the imahe
@@ -56,6 +74,5 @@ def main():
56
  caption_text.header("Generated Caption")
57
  caption_text.text(generated_caption)
58
 
59
-
60
  if __name__ == "__main__":
61
  main()
 
4
  import streamlit as st
5
  from PIL import Image
6
  import torch
7
+ import os
8
 
9
  preprocess_ckp = "Salesforce/blip2-opt-2.7b" #Checkpoint path used for perprocess image
10
  base_model_ckp = "./model/blip2-opt-2.7b-fp16-sharded" #Base model checkpoint path
11
  peft_model_ckp = "./model/blip2_peft" #PEFT model checkpoint path
12
+ sample_img_path = "./sample_images/"
13
+
14
  #init_model_required = True
15
+ #processor = None
16
+ #model = None
17
 
18
+ #def init_model():
19
 
20
  #if init_model_required:
21
 
22
+ #Preprocess input
23
+ processor = Blip2Processor.from_pretrained(preprocess_ckp)
24
+
25
+ #Model
26
+ #Inferance on GPU device. Will give error in CPU system, as "load_in_8bit" is an setting of bitsandbytes library and only works for GPU
27
+ #model = Blip2ForConditionalGeneration.from_pretrained(base_model_ckp, load_in_8bit = True, device_map = "auto")
28
 
29
+ #Inferance on CPU device
30
+ model = Blip2ForConditionalGeneration.from_pretrained(base_model_ckp)
31
+
32
+ model = PeftModel.from_pretrained(model, peft_model_ckp)
33
 
34
  #init_model_required = False
35
 
 
39
 
40
  st.title("Fashion Image Caption using BLIP2")
41
 
42
+ #init_model()
43
 
44
+ #Select few sample images for the catagory of cloths
45
+ option = st.selectbox('Sample images ?', ('cap', 'tee', 'dress'))
46
  file_name = st.file_uploader("Upload image")
47
 
48
+ if file_name is None and option is not None:
49
+
50
+ file_name = os.join.path(sample_img_path, option)
51
+
52
  if file_name is not None:
53
 
54
  image_col, caption_text = st.columns(2)
 
58
  image_col.image(image, use_column_width = True)
59
 
60
  #Preprocess the image
61
+ #Inferance on GPU. When used this on GPU will get errors like: "slow_conv2d_cpu" not implemented for 'Half'" , " Input type (float) and bias type (struct c10::Half)"
62
+ #inputs = processor(images = image, return_tensors = "pt").to('cuda', torch.float16)
63
+
64
+ #Inferance on CPU
65
+ inputs = processor(images = image, return_tensors = "pt")
66
+
67
  pixel_values = inputs.pixel_values
68
 
69
  #Predict the caption for the imahe
 
74
  caption_text.header("Generated Caption")
75
  caption_text.text(generated_caption)
76
 
 
77
  if __name__ == "__main__":
78
  main()