File size: 4,645 Bytes
8e2f9b6
 
 
 
 
 
 
 
bcafbb6
8e2f9b6
bcafbb6
8e2f9b6
bcafbb6
8e2f9b6
bcafbb6
8e2f9b6
69a69e6
8e2f9b6
bcafbb6
69a69e6
8e2f9b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bcafbb6
 
8e2f9b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# -*- coding: utf-8 -*-
"""Image caption generator.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1kJdblTHuqDn8HCKTuEpoApkN05Gzjpot
"""
pip install transformers

pip install gradio #used for creating the demo

pip install timm

pip install huggingface_hub

#from huggingface_hub import notebook_login


#notebook_login()

import gradio as gr
import requests
import matplotlib.pyplot as plt
import io
from PIL import Image
from matplotlib.patches import Rectangle

#Load model directly
from transformers import AutoProcessor, BlipForConditionalGeneration, pipeline

# Loading the BLIP model directly which generates the caption
processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")

#Using transformers to load DETR model for object detection
#This model adds a bounding box and label to detected objects
object_detector = pipeline("object-detection", model="facebook/detr-resnet-50")

#generates the caption for uploaded image
def caption_generator(input_img):
  inputs = processor(input_img, return_tensors="pt")
  out = model.generate(**inputs, max_new_tokens=500)
  caption = processor.decode(out[0], skip_special_tokens=True)
  return caption

#function to filter the generated caption checking whether human, cats and/or dogs are present using the labels from the object detection
#this is the method used in this project
def filter_caption(object_detection_results):
  labels = [result['label'] for result in object_detection_results]
  keywords = ["dog","dogs", "cat","cats", "human","humans","man", "men","woman","women","child","children","adult","adults","person"]
  return True if any(keyword in labels for keyword in keywords) else False

#function to filter the generated caption checking whether human, cats and/or dogs are present using the generated caption
#initial method considered
def
 filter(caption):
  #If any of these keywords are present, True is returned
  keywords = ["dog","dogs", "cat","cats", "human","humans","man", "men","woman","women","child","children","adult","adults","person"]
  caption = caption.lower()
  return True if any(keyword in caption for keyword in keywords) else False

#function to create the bounding box and label
#takes an image and list of results as inputs
def create_image_bbx_w_label(image, results):

  # Set up the plot
  fig, ax = plt.subplots(figsize=(12, 8))
  ax.imshow(image)

  # Plot the bounding boxes and labels
  for res in results:
      box = res['box']
      width = box['xmax'] - box['xmin']
      height = box['ymax'] - box['ymin']

      rect = Rectangle((box['xmin'], box['ymin']), width, height, linewidth=1, edgecolor='r', facecolor='none')
      ax.add_patch(rect)

      # Position the label above the rectangle
      label_position = (box['xmin'], box['ymin'] - 10)

      # Display the label and score
      label_text = f"{res['label']}: {res['score']:.2f}"
      ax.text(*label_position, label_text, color='white', fontsize=8, bbox=dict(facecolor='red', alpha=0.5))

  ax.axis('off')

  fname = './img.png'
  plt.savefig(fname, format='png', bbox_inches='tight', pad_inches=0)

  plt.close(fig)
  # Load this buffer into a PIL Image
  pil_img = Image.open(fname)

  # Return the PIL Image object
  return pil_img

def image_caption_generator(input_image):
  #detecting objects in image
  object_detection_results = object_detector(input_image)
  annotated_img = create_image_bbx_w_label(input_image, object_detection_results)

  #Generating caption of input image
  caption = caption_generator(input_image)
  #Filtering the captions for specific case (humans and/or cats/dogs)
  #filtered_caption = filter(caption) uncomment this if you want to filter using the generated caption
  filtered_caption = filter_caption(object_detection_results) #uses the generated labels from object detection to filter the captions
  if filtered_caption:
    return caption, annotated_img
  else:
    return "There are no humans, cats or dogs in this image!", annotated_img

demo = gr.Interface(fn = image_caption_generator,
                    inputs=[gr.Image(label="Upload image", type="pil")],
                    outputs=[gr.Textbox(label="Caption"), 'image'],
                    title="CaptionPlus - Image Caption Generator",
                    description="Captioning images of humans, cats and/or dogs with object detection",
                    allow_flagging="never",
                    examples=["/content/Example.jpg", '/content/OIP.jpg'])

demo.launch(share=True)