Charmainemahachi commited on
Commit
8e2f9b6
·
verified ·
1 Parent(s): e2a326f

Upload image_caption_generator.py

Browse files
Files changed (1) hide show
  1. image_caption_generator.py +120 -0
image_caption_generator.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """Image caption generator.ipynb
3
+
4
+ Automatically generated by Colaboratory.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1kJdblTHuqDn8HCKTuEpoApkN05Gzjpot
8
+ """
9
+
10
+ !pip install gradio #used for creating the demo
11
+
12
+ !pip install timm
13
+
14
+ !pip install huggingface_hub
15
+
16
+ from huggingface_hub import notebook_login
17
+
18
+ notebook_login()
19
+
20
+ import gradio as gr
21
+ import requests
22
+ import matplotlib.pyplot as plt
23
+ import io
24
+ from PIL import Image
25
+ from matplotlib.patches import Rectangle
26
+
27
+ #Load model directly
28
+ from transformers import AutoProcessor, BlipForConditionalGeneration, pipeline
29
+
30
+ # Loading the BLIP model directly which generates the caption
31
+ processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
32
+ model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
33
+
34
+ #Using transformers to load DETR model for object detection
35
+ #This model adds a bounding box and label to detected objects
36
+ object_detector = pipeline("object-detection", model="facebook/detr-resnet-50")
37
+
38
+ #generates the caption for uploaded image
39
+ def caption_generator(input_img):
40
+ inputs = processor(input_img, return_tensors="pt")
41
+ out = model.generate(**inputs, max_new_tokens=500)
42
+ caption = processor.decode(out[0], skip_special_tokens=True)
43
+ return caption
44
+
45
+ #function to filter the generated caption checking whether human, cats and/or dogs are present using the labels from the object detection
46
+ #this is the method used in this project
47
+ def filter_caption(object_detection_results):
48
+ labels = [result['label'] for result in object_detection_results]
49
+ keywords = ["dog","dogs", "cat","cats", "human","humans","man", "men","woman","women","child","children","adult","adults","person"]
50
+ return True if any(keyword in labels for keyword in keywords) else False
51
+
52
+ #function to filter the generated caption checking whether human, cats and/or dogs are present using the generated caption
53
+ #initial method considered
54
+ def filter(caption):
55
+ #If any of these keywords are present, True is returned
56
+ keywords = ["dog","dogs", "cat","cats", "human","humans","man", "men","woman","women","child","children","adult","adults","person"]
57
+ caption = caption.lower()
58
+ return True if any(keyword in caption for keyword in keywords) else False
59
+
60
+ #function to create the bounding box and label
61
+ #takes an image and list of results as inputs
62
+ def create_image_bbx_w_label(image, results):
63
+
64
+ # Set up the plot
65
+ fig, ax = plt.subplots(figsize=(12, 8))
66
+ ax.imshow(image)
67
+
68
+ # Plot the bounding boxes and labels
69
+ for res in results:
70
+ box = res['box']
71
+ width = box['xmax'] - box['xmin']
72
+ height = box['ymax'] - box['ymin']
73
+
74
+ rect = Rectangle((box['xmin'], box['ymin']), width, height, linewidth=1, edgecolor='r', facecolor='none')
75
+ ax.add_patch(rect)
76
+
77
+ # Position the label above the rectangle
78
+ label_position = (box['xmin'], box['ymin'] - 10)
79
+
80
+ # Display the label and score
81
+ label_text = f"{res['label']}: {res['score']:.2f}"
82
+ ax.text(*label_position, label_text, color='white', fontsize=8, bbox=dict(facecolor='red', alpha=0.5))
83
+
84
+ ax.axis('off')
85
+
86
+ fname = './img.png'
87
+ plt.savefig(fname, format='png', bbox_inches='tight', pad_inches=0)
88
+
89
+ plt.close(fig)
90
+ # Load this buffer into a PIL Image
91
+ pil_img = Image.open(fname)
92
+
93
+ # Return the PIL Image object
94
+ return pil_img
95
+
96
+ def image_caption_generator(input_image):
97
+ #detecting objects in image
98
+ object_detection_results = object_detector(input_image)
99
+ annotated_img = create_image_bbx_w_label(input_image, object_detection_results)
100
+
101
+ #Generating caption of input image
102
+ caption = caption_generator(input_image)
103
+ #Filtering the captions for specific case (humans and/or cats/dogs)
104
+ #filtered_caption = filter(caption) uncomment this if you want to filter using the generated caption
105
+ filtered_caption = filter_caption(object_detection_results) #uses the generated labels from object detection to filter the captions
106
+ if filtered_caption:
107
+ return caption, annotated_img
108
+ else:
109
+ return "There are no humans, cats or dogs in this image!", annotated_img
110
+
111
+ demo = gr.Interface(fn = image_caption_generator,
112
+ inputs=[gr.Image(label="Upload image", type="pil")],
113
+ outputs=[gr.Textbox(label="Caption"), 'image'],
114
+ title="CaptionPlus - Image Caption Generator",
115
+ description="Captioning images of humans, cats and/or dogs with object detection",
116
+ allow_flagging="never",
117
+ examples=["/content/Example.jpg", '/content/OIP.jpg'])
118
+
119
+ demo.launch(share=True)
120
+