Tagmir Gilyazov
upd
3dd80c3
# -*- coding: utf-8 -*-
"""## hugging face funcs"""
import io
import matplotlib.pyplot as plt
import requests
import inflect
from PIL import Image
def load_image_from_url(url):
return Image.open(requests.get(url, stream=True).raw)
def render_results_in_image(in_pil_img, in_results):
plt.figure(figsize=(16, 10))
plt.imshow(in_pil_img)
ax = plt.gca()
for prediction in in_results:
x, y = prediction['box']['xmin'], prediction['box']['ymin']
w = prediction['box']['xmax'] - prediction['box']['xmin']
h = prediction['box']['ymax'] - prediction['box']['ymin']
ax.add_patch(plt.Rectangle((x, y),
w,
h,
fill=False,
color="green",
linewidth=2))
ax.text(
x,
y,
f"{prediction['label']}: {round(prediction['score']*100, 1)}%",
color='red'
)
plt.axis("off")
# Save the modified image to a BytesIO object
img_buf = io.BytesIO()
plt.savefig(img_buf, format='png',
bbox_inches='tight',
pad_inches=0)
img_buf.seek(0)
modified_image = Image.open(img_buf)
# Close the plot to prevent it from being displayed
plt.close()
return modified_image
def summarize_predictions_natural_language(predictions):
summary = {}
p = inflect.engine()
for prediction in predictions:
label = prediction['label']
if label in summary:
summary[label] += 1
else:
summary[label] = 1
result_string = "In this image, there are "
for i, (label, count) in enumerate(summary.items()):
count_string = p.number_to_words(count)
result_string += f"{count_string} {label}"
if count > 1:
result_string += "s"
result_string += " "
if i == len(summary) - 2:
result_string += "and "
# Remove the trailing comma and space
result_string = result_string.rstrip(', ') + "."
return result_string
##### To ignore warnings #####
import warnings
import logging
from transformers import logging as hf_logging
def ignore_warnings():
# Ignore specific Python warnings
warnings.filterwarnings("ignore", message="Some weights of the model checkpoint")
warnings.filterwarnings("ignore", message="Could not find image processor class")
warnings.filterwarnings("ignore", message="The `max_size` parameter is deprecated")
# Adjust logging for libraries using the logging module
logging.basicConfig(level=logging.ERROR)
hf_logging.set_verbosity_error()
########
import numpy as np
import torch
import matplotlib.pyplot as plt
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3),
np.array([0.6])],
axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0),
w,
h, edgecolor='green',
facecolor=(0,0,0,0),
lw=2))
def show_boxes_on_image(raw_image, boxes):
plt.figure(figsize=(10,10))
plt.imshow(raw_image)
for box in boxes:
show_box(box, plt.gca())
plt.axis('on')
plt.show()
def show_points_on_image(raw_image, input_points, input_labels=None):
plt.figure(figsize=(10,10))
plt.imshow(raw_image)
input_points = np.array(input_points)
if input_labels is None:
labels = np.ones_like(input_points[:, 0])
else:
labels = np.array(input_labels)
show_points(input_points, labels, plt.gca())
plt.axis('on')
plt.show()
def show_points_and_boxes_on_image(raw_image,
boxes,
input_points,
input_labels=None):
plt.figure(figsize=(10,10))
plt.imshow(raw_image)
input_points = np.array(input_points)
if input_labels is None:
labels = np.ones_like(input_points[:, 0])
else:
labels = np.array(input_labels)
show_points(input_points, labels, plt.gca())
for box in boxes:
show_box(box, plt.gca())
plt.axis('on')
plt.show()
def show_points_and_boxes_on_image(raw_image,
boxes,
input_points,
input_labels=None):
plt.figure(figsize=(10,10))
plt.imshow(raw_image)
input_points = np.array(input_points)
if input_labels is None:
labels = np.ones_like(input_points[:, 0])
else:
labels = np.array(input_labels)
show_points(input_points, labels, plt.gca())
for box in boxes:
show_box(box, plt.gca())
plt.axis('on')
plt.show()
def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels==1]
neg_points = coords[labels==0]
ax.scatter(pos_points[:, 0],
pos_points[:, 1],
color='green',
marker='*',
s=marker_size,
edgecolor='white',
linewidth=1.25)
ax.scatter(neg_points[:, 0],
neg_points[:, 1],
color='red',
marker='*',
s=marker_size,
edgecolor='white',
linewidth=1.25)
def fig2img(fig):
"""Convert a Matplotlib figure to a PIL Image and return it"""
import io
buf = io.BytesIO()
fig.savefig(buf)
buf.seek(0)
img = Image.open(buf)
return img
def show_mask_on_image(raw_image, mask, return_image=False):
if not isinstance(mask, torch.Tensor):
mask = torch.Tensor(mask)
if len(mask.shape) == 4:
mask = mask.squeeze()
fig, axes = plt.subplots(1, 1, figsize=(15, 15))
mask = mask.cpu().detach()
axes.imshow(np.array(raw_image))
show_mask(mask, axes)
axes.axis("off")
plt.show()
if return_image:
fig = plt.gcf()
return fig2img(fig)
def show_pipe_masks_on_image(raw_image, outputs, return_image=False):
plt.imshow(np.array(raw_image))
ax = plt.gca()
for mask in outputs["masks"]:
show_mask(mask, ax=ax, random_color=True)
plt.axis("off")
plt.show()
if return_image:
fig = plt.gcf()
return fig2img(fig)
"""## imports"""
from transformers import pipeline
from transformers import SamModel, SamProcessor
from transformers import BlipForImageTextRetrieval
from transformers import AutoProcessor
from transformers.utils import logging
logging.set_verbosity_error()
#ignore_warnings()
import io
import matplotlib.pyplot as plt
import requests
import inflect
from PIL import Image
import os
import gradio as gr
import time
"""# Object detection
## hugging face model ("facebook/detr-resnet-50"). 167MB
"""
od_pipe = pipeline("object-detection", "facebook/detr-resnet-50")
chosen_model = pipeline("object-detection", "hustvl/yolos-small")
"""## gradio funcs"""
def get_object_detection_prediction(model_name, raw_image):
model = od_pipe
if "chosen-model" in model_name:
model = chosen_model
start = time.time()
pipeline_output = model(raw_image)
end = time.time()
elapsed_result = f'{model_name} object detection elapsed {end-start} seconds'
print(elapsed_result)
processed_image = render_results_in_image(raw_image, pipeline_output)
return [processed_image, elapsed_result]
"""# Image segmentation
## hugging face models: Zigeng/SlimSAM-uniform-77(segmentation) 39MB, Intel/dpt-hybrid-midas(depth) 490MB
"""
hugging_face_segmentation_pipe = pipeline("mask-generation", "Zigeng/SlimSAM-uniform-77")
hugging_face_segmentation_model = SamModel.from_pretrained("Zigeng/SlimSAM-uniform-77")
hugging_face_segmentation_processor = SamProcessor.from_pretrained("Zigeng/SlimSAM-uniform-77")
hugging_face_depth_estimator = pipeline(task="depth-estimation", model="Intel/dpt-hybrid-midas")
"""## chosen models: facebook/sam-vit-base(segmentation) 375MB, LiheYoung/depth-anything-small-hf(depth) 100MB"""
chosen_name = "facebook/sam-vit-base"
chosen_segmentation_pipe = pipeline("mask-generation", chosen_name)
chosen_segmentation_model = SamModel.from_pretrained(chosen_name)
chosen_segmentation_processor = SamProcessor.from_pretrained(chosen_name)
chosen_depth_estimator = pipeline(task="depth-estimation", model="LiheYoung/depth-anything-small-hf")
"""## gradio funcs"""
input_points = [[[1600, 700]]]
def segment_image_pretrained(model_name, raw_image):
processor = hugging_face_segmentation_processor
model = hugging_face_segmentation_model
if("chosen" in model_name):
processor = chosen_segmentation_processor
model = chosen_segmentation_model
start = time.time()
inputs = processor(raw_image,
input_points=input_points,
return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
predicted_masks = processor.image_processor.post_process_masks(
outputs.pred_masks,
inputs["original_sizes"],
inputs["reshaped_input_sizes"])
results = []
predicted_mask = predicted_masks[0]
end = time.time()
elapsed_result = f'{model_name} pretrained image segmentation elapsed {end-start} seconds'
print(elapsed_result)
for i in range(3):
results.append(show_mask_on_image(raw_image, predicted_mask[:, i], return_image=True))
results.append(elapsed_result);
return results
def segment_image(model_name, raw_image):
model = hugging_face_segmentation_pipe
if("chosen" in model_name):
print("chosen model used")
model = chosen_segmentation_pipe
start = time.time()
output = model(raw_image, points_per_batch=32)
end = time.time()
elapsed_result = f'{model_name} raw image segmentation elapsed {end-start} seconds'
print(elapsed_result)
return [show_pipe_masks_on_image(raw_image, output, return_image = True), elapsed_result]
def depth_image(model_name, input_image):
depth_estimator = hugging_face_depth_estimator
print(model_name)
if("chosen" in model_name):
print("chosen model used")
depth_estimator = chosen_depth_estimator
start = time.time()
out = depth_estimator(input_image)
prediction = torch.nn.functional.interpolate(
out["predicted_depth"].unsqueeze(0).unsqueeze(0),
size=input_image.size[::-1],
mode="bicubic",
align_corners=False,
)
end = time.time()
elapsed_result = f'{model_name} Depth Estimation elapsed {end-start} seconds'
print(elapsed_result)
output = prediction.squeeze().numpy()
formatted = (output * 255 / np.max(output)).astype("uint8")
depth = Image.fromarray(formatted)
return [depth, elapsed_result]
"""# Image retrieval
## hugging face model: Salesforce/blip-itm-base-coco 900MB
"""
hugging_face_retrieval_model = BlipForImageTextRetrieval.from_pretrained(
"Salesforce/blip-itm-base-coco")
hugging_face_retrieval_processor = AutoProcessor.from_pretrained(
"Salesforce/blip-itm-base-coco")
"""## chosen model: Salesforce/blip-itm-base-flickr 900MB"""
chosen_retrieval_model = BlipForImageTextRetrieval.from_pretrained(
"Salesforce/blip-itm-base-flickr")
chosen_retrieval_processor = AutoProcessor.from_pretrained(
"Salesforce/blip-itm-base-flickr")
"""## gradion func"""
def retrieve_image(model_name, raw_image, predict_text):
processor = hugging_face_retrieval_processor
model = hugging_face_retrieval_model
if("chosen" in model_name):
processor = chosen_retrieval_processor
model = chosen_retrieval_model
start = time.time()
inputs = processor(images=raw_image,
text=predict_text,
return_tensors="pt")
end = time.time()
elapsed_result = f"{model_name} image retrieval elapsed {end-start} seconds"
print(elapsed_result)
itm_scores = model(**inputs)[0]
itm_score = torch.nn.functional.softmax(itm_scores,dim=1)
return [f"""\
The image and text are matched \
with a probability of {itm_score[0][1]:.4f}""",
elapsed_result]
"""# gradio"""
with gr.Blocks() as object_detection_tab:
gr.Markdown("# Detect objects on image")
gr.Markdown("Upload an image, choose model, press button.")
with gr.Row():
with gr.Column():
# Input components
input_image = gr.Image(label="Upload Image", type="pil")
model_selector = gr.Dropdown(["hugging-face(facebook/detr-resnet-50)", "chosen-model(hustvl/yolos-small)"],
label = "Select Model")
with gr.Column():
# Output image
elapsed_result = gr.Textbox(label="Seconds elapsed", lines=1)
output_image = gr.Image(label="Output Image", type="pil")
# Process button
process_btn = gr.Button("Detect objects")
# Connect the input components to the processing function
process_btn.click(
fn=get_object_detection_prediction,
inputs=[
model_selector,
input_image
],
outputs=[output_image, elapsed_result]
)
with gr.Blocks() as image_segmentation_detection_tab:
gr.Markdown("# Image segmentation")
gr.Markdown("Upload an image, choose model, press button.")
with gr.Row():
with gr.Column():
# Input components
input_image = gr.Image(label="Upload Image", type="pil")
model_selector = gr.Dropdown(["hugging-face(Zigeng/SlimSAM-uniform-77)", "chosen-model(facebook/sam-vit-base)"],
label = "Select Model")
with gr.Column():
elapsed_result = gr.Textbox(label="Seconds elapsed", lines=1)
# Output image
output_image = gr.Image(label="Segmented image", type="pil")
with gr.Row():
with gr.Column():
segment_btn = gr.Button("Segment image(not pretrained)")
with gr.Row():
elapsed_result_pretrained_segment = gr.Textbox(label="Seconds elapsed", lines=1)
with gr.Column():
segment_pretrained_output_image_1 = gr.Image(label="Segmented image by pretrained model", type="pil")
with gr.Column():
segment_pretrained_output_image_2 = gr.Image(label="Segmented image by pretrained model", type="pil")
with gr.Column():
segment_pretrained_output_image_3 = gr.Image(label="Segmented image by pretrained model", type="pil")
with gr.Row():
with gr.Column():
segment_pretrained_model_selector = gr.Dropdown(["hugging-face(Zigeng/SlimSAM-uniform-77)", "chosen-model(facebook/sam-vit-base)"],
label = "Select Model")
segment_pretrained_btn = gr.Button("Segment image(pretrained)")
with gr.Row():
with gr.Column():
depth_output_image = gr.Image(label="Depth image", type="pil")
elapsed_result_depth = gr.Textbox(label="Seconds elapsed", lines=1)
with gr.Row():
with gr.Column():
depth_model_selector = gr.Dropdown(["hugging-face(Intel/dpt-hybrid-midas)", "chosen-model(LiheYoung/depth-anything-small-hf)"],
label = "Select Model")
depth_btn = gr.Button("Get image depth")
segment_btn.click(
fn=segment_image,
inputs=[
model_selector,
input_image
],
outputs=[output_image, elapsed_result]
)
segment_pretrained_btn.click(
fn=segment_image_pretrained,
inputs=[
segment_pretrained_model_selector,
input_image
],
outputs=[segment_pretrained_output_image_1, segment_pretrained_output_image_2, segment_pretrained_output_image_3, elapsed_result_pretrained_segment]
)
depth_btn.click(
fn=depth_image,
inputs=[
depth_model_selector,
input_image,
],
outputs=[depth_output_image, elapsed_result_depth]
)
with gr.Blocks() as image_retrieval_tab:
gr.Markdown("# Check is text describes image")
gr.Markdown("Upload an image, choose model, press button.")
with gr.Row():
with gr.Column():
# Input components
input_image = gr.Image(label="Upload Image", type="pil")
text_prediction = gr.TextArea(label="Describe image")
model_selector = gr.Dropdown(["hugging-face(Salesforce/blip-itm-base-coco)", "chosen-model(Salesforce/blip-itm-base-flickr)"],
label = "Select Model")
with gr.Column():
# Output image
output_result = gr.Textbox(label="Probability result", lines=3)
elapsed_result = gr.Textbox(label="Seconds elapsed", lines=1)
# Process button
process_btn = gr.Button("Detect objects")
# Connect the input components to the processing function
process_btn.click(
fn=retrieve_image,
inputs=[
model_selector,
input_image,
text_prediction
],
outputs=[output_result, elapsed_result]
)
with gr.Blocks() as app:
gr.TabbedInterface(
[object_detection_tab,
image_segmentation_detection_tab,
image_retrieval_tab],
["Object detection",
"Image segmentation",
"Retrieve image"
],
)
app.launch(share=True, debug=True)
app.close()