chats-bug commited on
Commit
fbee9c4
·
1 Parent(s): 207272c

Updated error handling

Browse files
Files changed (1) hide show
  1. app.py +28 -12
app.py CHANGED
@@ -14,8 +14,8 @@ preprocessor_blip_large = AutoProcessor.from_pretrained("Salesforce/blip-image-c
14
  model_blip_large = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
15
 
16
  # Load the GIT coco model
17
- # preprocessor_git_large_coco = AutoProcessor.from_pretrained("microsoft/git-large-coco")
18
- # model_git_large_coco = AutoModelForCausalLM.from_pretrained("microsoft/git-large-coco")
19
 
20
  # Load the CLIP model
21
  model_oc_coca, _, transform_oc_coca = open_clip.create_model_and_transforms(
@@ -27,7 +27,7 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
27
  # Transfer the models to the device
28
  model_blip_base.to(device)
29
  model_blip_large.to(device)
30
- # model_git_large_coco.to(device)
31
  model_oc_coca.to(device)
32
 
33
 
@@ -103,11 +103,10 @@ def generate_captions_clip(
103
  str
104
  The generated caption.
105
  """
106
- img = transform(image).unsqueeze(0).to(device)
107
  with torch.no_grad(), torch.cuda.amp.autocast():
108
- generated = model.generate(img, seq_len=32, do_sample=True, temperature=0.9)
109
-
110
- generated_caption = model.decode(generated[0].detach()).split("<end_of_text>")[0].replace("<start_of_text>", "")
111
  return generated_caption
112
 
113
 
@@ -129,17 +128,34 @@ def generate_captions(
129
  str
130
  The generated caption.
131
  """
 
 
 
 
 
132
  # Generate captions for the image using the Blip base model
133
- caption_blip_base = generate_caption(preprocessor_blip_base, model_blip_base, image).strip()
 
 
 
134
 
135
  # Generate captions for the image using the Blip large model
136
- caption_blip_large = generate_caption(preprocessor_blip_large, model_blip_large, image).strip()
 
 
 
137
 
138
  # Generate captions for the image using the GIT coco model
139
- # caption_git_large_coco = generate_caption(preprocessor_git_large_coco, model_git_large_coco, image).strip()
 
 
 
140
 
141
  # Generate captions for the image using the CLIP model
142
- caption_oc_coca = generate_captions_clip(model_oc_coca, transform_oc_coca, image).strip()
 
 
 
143
 
144
  return caption_blip_base, caption_blip_large, caption_git_large_coco, caption_oc_coca
145
 
@@ -157,7 +173,7 @@ iface = gr.Interface(
157
  outputs=[
158
  gr.outputs.Textbox(label="Blip base"),
159
  gr.outputs.Textbox(label="Blip large"),
160
- # gr.outputs.Textbox(label="GIT large coco"),
161
  gr.outputs.Textbox(label="CLIP"),
162
  ],
163
  title="Image Captioning",
 
14
  model_blip_large = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
15
 
16
  # Load the GIT coco model
17
+ preprocessor_git_large_coco = AutoProcessor.from_pretrained("microsoft/git-large-coco")
18
+ model_git_large_coco = AutoModelForCausalLM.from_pretrained("microsoft/git-large-coco")
19
 
20
  # Load the CLIP model
21
  model_oc_coca, _, transform_oc_coca = open_clip.create_model_and_transforms(
 
27
  # Transfer the models to the device
28
  model_blip_base.to(device)
29
  model_blip_large.to(device)
30
+ model_git_large_coco.to(device)
31
  model_oc_coca.to(device)
32
 
33
 
 
103
  str
104
  The generated caption.
105
  """
106
+ im = transform(image).unsqueeze(0).to(device)
107
  with torch.no_grad(), torch.cuda.amp.autocast():
108
+ generated = model.generate(im, seq_len=20)
109
+ generated_caption = open_clip.decode(generated[0].detach()).split("<end_of_text>")[0].replace("<start_of_text>", "")
 
110
  return generated_caption
111
 
112
 
 
128
  str
129
  The generated caption.
130
  """
131
+ caption_blip_base = ""
132
+ caption_blip_large = ""
133
+ caption_git_large_coco = ""
134
+ caption_oc_coca = ""
135
+
136
  # Generate captions for the image using the Blip base model
137
+ try:
138
+ caption_blip_base = generate_caption(preprocessor_blip_base, model_blip_base, image).strip()
139
+ except Exception as e:
140
+ print(e)
141
 
142
  # Generate captions for the image using the Blip large model
143
+ try:
144
+ caption_blip_large = generate_caption(preprocessor_blip_large, model_blip_large, image).strip()
145
+ except Exception as e:
146
+ print(e)
147
 
148
  # Generate captions for the image using the GIT coco model
149
+ try:
150
+ caption_git_large_coco = generate_caption(preprocessor_git_large_coco, model_git_large_coco, image).strip()
151
+ except Exception as e:
152
+ print(e)
153
 
154
  # Generate captions for the image using the CLIP model
155
+ try:
156
+ caption_oc_coca = generate_captions_clip(model_oc_coca, transform_oc_coca, image).strip()
157
+ except Exception as e:
158
+ print(e)
159
 
160
  return caption_blip_base, caption_blip_large, caption_git_large_coco, caption_oc_coca
161
 
 
173
  outputs=[
174
  gr.outputs.Textbox(label="Blip base"),
175
  gr.outputs.Textbox(label="Blip large"),
176
+ gr.outputs.Textbox(label="GIT large coco"),
177
  gr.outputs.Textbox(label="CLIP"),
178
  ],
179
  title="Image Captioning",