kartikay24 commited on
Commit
c78e02a
·
1 Parent(s): de7b165

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -3,9 +3,9 @@ import requests
3
  from PIL import Image
4
  from transformers import BlipProcessor, BlipForConditionalGeneration
5
  import gradio as gr
6
-
7
  processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
8
- model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float16).to("cuda")
9
 
10
  # Function to process the image and generate captions
11
  def generate_caption(image, caption_type, text):
@@ -20,14 +20,14 @@ def generate_caption(image, caption_type, text):
20
 
21
  # Conditional image captioning
22
  def conditional_image_captioning(raw_image, text):
23
- inputs = processor(raw_image, text, return_tensors="pt").to("cuda", torch.float16)
24
  out = model.generate(**inputs)
25
  caption = processor.decode(out[0], skip_special_tokens=True)
26
  return caption
27
 
28
  # Unconditional image captioning
29
  def unconditional_image_captioning(raw_image):
30
- inputs = processor(raw_image, return_tensors="pt").to("cuda", torch.float16)
31
  out = model.generate(**inputs)
32
  caption = processor.decode(out[0], skip_special_tokens=True)
33
  return caption
 
3
  from PIL import Image
4
  from transformers import BlipProcessor, BlipForConditionalGeneration
5
  import gradio as gr
6
+ device="cpu"
7
  processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
8
+ model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float16).to(device)
9
 
10
  # Function to process the image and generate captions
11
  def generate_caption(image, caption_type, text):
 
20
 
21
  # Conditional image captioning
22
  def conditional_image_captioning(raw_image, text):
23
+ inputs = processor(raw_image, text, return_tensors="pt").to(device, torch.float16)
24
  out = model.generate(**inputs)
25
  caption = processor.decode(out[0], skip_special_tokens=True)
26
  return caption
27
 
28
  # Unconditional image captioning
29
  def unconditional_image_captioning(raw_image):
30
+ inputs = processor(raw_image, return_tensors="pt").to(device, torch.float16)
31
  out = model.generate(**inputs)
32
  caption = processor.decode(out[0], skip_special_tokens=True)
33
  return caption