|
|
|
"""Image caption generator.ipynb |
|
|
|
Automatically generated by Colaboratory. |
|
|
|
Original file is located at |
|
https://colab.research.google.com/drive/1kJdblTHuqDn8HCKTuEpoApkN05Gzjpot |
|
""" |
|
pip install transformers |
|
|
|
pip install gradio |
|
|
|
pip install timm |
|
|
|
pip install huggingface_hub |
|
|
|
|
|
|
|
|
|
|
|
|
|
import gradio as gr |
|
import requests |
|
import matplotlib.pyplot as plt |
|
import io |
|
from PIL import Image |
|
from matplotlib.patches import Rectangle |
|
|
|
|
|
from transformers import AutoProcessor, BlipForConditionalGeneration, pipeline |
|
|
|
|
|
processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base") |
|
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") |
|
|
|
|
|
|
|
object_detector = pipeline("object-detection", model="facebook/detr-resnet-50") |
|
|
|
|
|
def caption_generator(input_img): |
|
inputs = processor(input_img, return_tensors="pt") |
|
out = model.generate(**inputs, max_new_tokens=500) |
|
caption = processor.decode(out[0], skip_special_tokens=True) |
|
return caption |
|
|
|
|
|
|
|
def filter_caption(object_detection_results): |
|
labels = [result['label'] for result in object_detection_results] |
|
keywords = ["dog","dogs", "cat","cats", "human","humans","man", "men","woman","women","child","children","adult","adults","person"] |
|
return True if any(keyword in labels for keyword in keywords) else False |
|
|
|
|
|
|
|
def |
|
filter(caption): |
|
|
|
keywords = ["dog","dogs", "cat","cats", "human","humans","man", "men","woman","women","child","children","adult","adults","person"] |
|
caption = caption.lower() |
|
return True if any(keyword in caption for keyword in keywords) else False |
|
|
|
|
|
|
|
def create_image_bbx_w_label(image, results): |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(12, 8)) |
|
ax.imshow(image) |
|
|
|
|
|
for res in results: |
|
box = res['box'] |
|
width = box['xmax'] - box['xmin'] |
|
height = box['ymax'] - box['ymin'] |
|
|
|
rect = Rectangle((box['xmin'], box['ymin']), width, height, linewidth=1, edgecolor='r', facecolor='none') |
|
ax.add_patch(rect) |
|
|
|
|
|
label_position = (box['xmin'], box['ymin'] - 10) |
|
|
|
|
|
label_text = f"{res['label']}: {res['score']:.2f}" |
|
ax.text(*label_position, label_text, color='white', fontsize=8, bbox=dict(facecolor='red', alpha=0.5)) |
|
|
|
ax.axis('off') |
|
|
|
fname = './img.png' |
|
plt.savefig(fname, format='png', bbox_inches='tight', pad_inches=0) |
|
|
|
plt.close(fig) |
|
|
|
pil_img = Image.open(fname) |
|
|
|
|
|
return pil_img |
|
|
|
def image_caption_generator(input_image): |
|
|
|
object_detection_results = object_detector(input_image) |
|
annotated_img = create_image_bbx_w_label(input_image, object_detection_results) |
|
|
|
|
|
caption = caption_generator(input_image) |
|
|
|
|
|
filtered_caption = filter_caption(object_detection_results) |
|
if filtered_caption: |
|
return caption, annotated_img |
|
else: |
|
return "There are no humans, cats or dogs in this image!", annotated_img |
|
|
|
demo = gr.Interface(fn = image_caption_generator, |
|
inputs=[gr.Image(label="Upload image", type="pil")], |
|
outputs=[gr.Textbox(label="Caption"), 'image'], |
|
title="CaptionPlus - Image Caption Generator", |
|
description="Captioning images of humans, cats and/or dogs with object detection", |
|
allow_flagging="never", |
|
examples=["/content/Example.jpg", '/content/OIP.jpg']) |
|
|
|
demo.launch(share=True) |
|
|
|
|