chats-bug commited on
Commit
8236a85
·
1 Parent(s): 826388b

Changed blip2 model to 2.7b

Browse files
Files changed (1) hide show
  1. app.py +17 -14
app.py CHANGED
@@ -16,8 +16,8 @@ device_map = {
16
  }
17
 
18
  # Load the Blip2 model
19
- preprocessor_blip2_8_bit = AutoProcessor.from_pretrained("Salesforce/blip2-opt-6.7b")
20
- model_blip2_8_bit = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-6.7b", device_map=device_map)
21
 
22
  # Load the Blip base model
23
  # preprocessor_blip_base = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
@@ -79,19 +79,22 @@ def generate_caption(
79
  if use_float_16:
80
  inputs = inputs.to(torch.float16)
81
 
82
- generated_ids = model.generate(
83
- pixel_values=inputs.pixel_values,
84
- # attention_mask=inputs.attention_mask,
85
- max_length=32,
86
- use_cache=True,
87
- )
88
-
89
- if tokenizer is None:
90
- generated_caption = preprocessor.batch_decode(generated_ids, skip_special_tokens=True)[0]
91
- else:
92
- generated_caption = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
 
 
 
93
 
94
- return generated_caption
95
 
96
 
97
  def generate_captions_clip(
 
16
  }
17
 
18
  # Load the Blip2 model
19
+ preprocessor_blip2_8_bit = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
20
+ model_blip2_8_bit = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", device_map="auto", load_in_8bit=True)
21
 
22
  # Load the Blip base model
23
  # preprocessor_blip_base = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
 
79
  if use_float_16:
80
  inputs = inputs.to(torch.float16)
81
 
82
+ # generated_ids = model.generate(
83
+ # pixel_values=inputs.pixel_values,
84
+ # # attention_mask=inputs.attention_mask,
85
+ # max_length=32,
86
+ # use_cache=True,
87
+ # )
88
+
89
+ # if tokenizer is None:
90
+ # generated_caption = preprocessor.batch_decode(generated_ids, skip_special_tokens=True)[0]
91
+ # else:
92
+ # generated_caption = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
93
+
94
+ generated_ids = model.generate(**inputs, max_new_tokens=32)
95
+ generated_text = preprocessor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
96
 
97
+ return generated_text
98
 
99
 
100
  def generate_captions_clip(