chats-bug commited on
Commit
b8b6ade
·
1 Parent(s): 245a3fa

Blip Base testing

Browse files
Files changed (1) hide show
  1. app.py +29 -30
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- from transformers import AutoProcessor, AutoTokenizer, AutoImageProcessor, AutoModelForCausalLM, BlipForConditionalGeneration, Blip2ForConditionalGeneration, VisionEncoderDecoderModel, BitsAndBytesConfig
3
  import torch
4
  import open_clip
5
 
@@ -16,17 +16,17 @@ device_map = {
16
  }
17
 
18
  # Load the Blip2 model
19
- preprocessor_blip2_8_bit = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
20
- model_blip2_8_bit = Blip2ForConditionalGeneration.from_pretrained(
21
- "Salesforce/blip2-opt-2.7b",
22
- device_map="auto",
23
- quantization_config=quantization_config,
24
- load_in_8bit=True
25
- )
26
 
27
  # Load the Blip base model
28
- # preprocessor_blip_base = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
29
- # model_blip_base = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
30
 
31
  # # Load the Blip large model
32
  # preprocessor_blip_large = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
@@ -44,8 +44,8 @@ model_blip2_8_bit = Blip2ForConditionalGeneration.from_pretrained(
44
 
45
  device = "cuda" if torch.cuda.is_available() else "cpu"
46
  # Transfer the models to the device
47
- model_blip2_8_bit.to(device)
48
- # model_blip_base.to(device)
49
  # model_blip_large.to(device)
50
  # model_git_large_coco.to(device)
51
  # model_oc_coca.to(device)
@@ -84,22 +84,21 @@ def generate_caption(
84
  if use_float_16:
85
  inputs = inputs.to(torch.float16)
86
 
87
- # generated_ids = model.generate(
88
- # pixel_values=inputs.pixel_values,
89
- # # attention_mask=inputs.attention_mask,
90
- # max_length=32,
91
- # use_cache=True,
92
- # )
93
 
94
- # if tokenizer is None:
95
- # generated_caption = preprocessor.batch_decode(generated_ids, skip_special_tokens=True)[0]
96
- # else:
97
- # generated_caption = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
98
 
99
- generated_ids = model.generate(**inputs, max_new_tokens=32)
100
- generated_text = preprocessor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
101
 
102
- return generated_text
103
 
104
 
105
  def generate_captions_clip(
@@ -149,10 +148,10 @@ def generate_captions(
149
  The generated caption.
150
  """
151
  # Generate captions for the image using the Blip2 model
152
- caption_blip2_8_bit = generate_caption(preprocessor_blip2_8_bit, model_blip2_8_bit, image, use_float_16=True).strip()
153
 
154
  # Generate captions for the image using the Blip base model
155
- # caption_blip_base = generate_caption(preprocessor_blip_base, model_blip_base, image).strip()
156
 
157
  # # Generate captions for the image using the Blip large model
158
  # caption_blip_large = generate_caption(preprocessor_blip_large, model_blip_large, image).strip()
@@ -163,7 +162,7 @@ def generate_captions(
163
  # # Generate captions for the image using the CLIP model
164
  # caption_oc_coca = generate_captions_clip(model_oc_coca, transform_oc_coca, image).strip()
165
 
166
- return caption_blip2_8_bit
167
 
168
 
169
  # Create the interface
@@ -177,8 +176,8 @@ iface = gr.Interface(
177
  ],
178
  # Define the outputs
179
  outputs=[
180
- gr.outputs.Textbox(label="Blip2 8-bit"),
181
- # gr.outputs.Textbox(label="Blip base"),
182
  # gr.outputs.Textbox(label="Blip large"),
183
  # gr.outputs.Textbox(label="GIT large coco"),
184
  # gr.outputs.Textbox(label="CLIP"),
 
1
  import gradio as gr
2
+ from transformers import AutoProcessor, AutoTokenizer, AutoImageProcessor, AutoModelForCausalLM, BlipForConditionalGeneration, Blip2ForConditionalGeneration, VisionEncoderDecoderModel, BitsAndBytesConfig, BlipProcessor
3
  import torch
4
  import open_clip
5
 
 
16
  }
17
 
18
  # Load the Blip2 model
19
+ # preprocessor_blip2_8_bit = BlipProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
20
+ # model_blip2_8_bit = Blip2ForConditionalGeneration.from_pretrained(
21
+ # "Salesforce/blip2-opt-2.7b",
22
+ # device_map="auto",
23
+ # quantization_config=quantization_config,
24
+ # load_in_8bit=True
25
+ # )
26
 
27
  # Load the Blip base model
28
+ preprocessor_blip_base = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
29
+ model_blip_base = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
30
 
31
  # # Load the Blip large model
32
  # preprocessor_blip_large = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
 
44
 
45
  device = "cuda" if torch.cuda.is_available() else "cpu"
46
  # Transfer the models to the device
47
+ # model_blip2_8_bit.to(device)
48
+ model_blip_base.to(device)
49
  # model_blip_large.to(device)
50
  # model_git_large_coco.to(device)
51
  # model_oc_coca.to(device)
 
84
  if use_float_16:
85
  inputs = inputs.to(torch.float16)
86
 
87
+ generated_ids = model.generate(
88
+ pixel_values=inputs.pixel_values,
89
+ # attention_mask=inputs.attention_mask,
90
+ max_length=64,
91
+ )
 
92
 
93
+ if tokenizer is None:
94
+ generated_caption = preprocessor.batch_decode(generated_ids, skip_special_tokens=True)[0]
95
+ else:
96
+ generated_caption = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
97
 
98
+ # generated_ids = model.generate(**inputs, max_new_tokens=32)
99
+ # generated_text = preprocessor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
100
 
101
+ return generated_caption
102
 
103
 
104
  def generate_captions_clip(
 
148
  The generated caption.
149
  """
150
  # Generate captions for the image using the Blip2 model
151
+ # caption_blip2_8_bit = generate_caption(preprocessor_blip2_8_bit, model_blip2_8_bit, image, use_float_16=True).strip()
152
 
153
  # Generate captions for the image using the Blip base model
154
+ caption_blip_base = generate_caption(preprocessor_blip_base, model_blip_base, image).strip()
155
 
156
  # # Generate captions for the image using the Blip large model
157
  # caption_blip_large = generate_caption(preprocessor_blip_large, model_blip_large, image).strip()
 
162
  # # Generate captions for the image using the CLIP model
163
  # caption_oc_coca = generate_captions_clip(model_oc_coca, transform_oc_coca, image).strip()
164
 
165
+ return caption_blip_base
166
 
167
 
168
  # Create the interface
 
176
  ],
177
  # Define the outputs
178
  outputs=[
179
+ # gr.outputs.Textbox(label="Blip2 8-bit"),
180
+ gr.outputs.Textbox(label="Blip base"),
181
  # gr.outputs.Textbox(label="Blip large"),
182
  # gr.outputs.Textbox(label="GIT large coco"),
183
  # gr.outputs.Textbox(label="CLIP"),