|
|
|
|
|
|
|
from torchvision.utils import save_image |
|
from PIL import Image |
|
import subprocess |
|
from collections import OrderedDict |
|
import numpy as np |
|
import cv2 |
|
import textwrap |
|
import torch |
|
import os |
|
from annotator.util import resize_image, HWC3 |
|
import mmcv |
|
import random |
|
|
|
|
|
device = "cpu" |
|
use_blip = True |
|
use_gradio = True |
|
|
|
if device == 'cpu': |
|
data_type = torch.float32 |
|
else: |
|
data_type = torch.float16 |
|
|
|
|
|
|
|
from diffusers.utils import load_image |
|
|
|
base_model_path = "stabilityai/stable-diffusion-2-inpainting" |
|
config_dict = OrderedDict([('SAM Pretrained(v0-1): Good Natural Sense', 'shgao/edit-anything-v0-1-1'), |
|
('LAION Pretrained(v0-3): Good Face', 'shgao/edit-anything-v0-3'), |
|
('SD Inpainting: Not keep position', 'stabilityai/stable-diffusion-2-inpainting') |
|
]) |
|
|
|
|
|
|
|
try: |
|
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator |
|
except ImportError: |
|
print('segment_anything not installed') |
|
result = subprocess.run(['pip', 'install', 'git+https://github.com/facebookresearch/segment-anything.git'], check=True) |
|
print(f'Install segment_anything {result}') |
|
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator |
|
if not os.path.exists('./models/sam_vit_h_4b8939.pth'): |
|
result = subprocess.run(['wget', 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth', '-P', 'models'], check=True) |
|
print(f'Download sam_vit_h_4b8939.pth {result}') |
|
sam_checkpoint = "models/sam_vit_h_4b8939.pth" |
|
model_type = "default" |
|
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) |
|
sam.to(device=device) |
|
mask_generator = SamAutomaticMaskGenerator(sam) |
|
|
|
|
|
|
|
if use_blip: |
|
|
|
|
|
from transformers import AutoProcessor, Blip2ForConditionalGeneration |
|
processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b") |
|
blip_model = Blip2ForConditionalGeneration.from_pretrained( |
|
"Salesforce/blip2-opt-2.7b", torch_dtype=data_type) |
|
|
|
|
|
def region_classify_w_blip2(image): |
|
inputs = processor(image, return_tensors="pt").to(device, data_type) |
|
generated_ids = blip_model.generate(**inputs, max_new_tokens=15) |
|
generated_text = processor.batch_decode( |
|
generated_ids, skip_special_tokens=True)[0].strip() |
|
return generated_text |
|
|
|
def region_level_semantic_api(image, topk=5): |
|
""" |
|
rank regions by area, and classify each region with blip2 |
|
Args: |
|
image: numpy array |
|
topk: int |
|
Returns: |
|
topk_region_w_class_label: list of dict with key 'class_label' |
|
""" |
|
topk_region_w_class_label = [] |
|
anns = mask_generator.generate(image) |
|
if len(anns) == 0: |
|
return [] |
|
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) |
|
for i in range(min(topk, len(sorted_anns))): |
|
ann = anns[i] |
|
m = ann['segmentation'] |
|
m_3c = m[:,:, np.newaxis] |
|
m_3c = np.concatenate((m_3c,m_3c,m_3c), axis=2) |
|
bbox = ann['bbox'] |
|
region = mmcv.imcrop(image*m_3c, np.array([bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]]), scale=1) |
|
region_class_label = region_classify_w_blip2(region) |
|
ann['class_label'] = region_class_label |
|
print(ann['class_label'], str(bbox)) |
|
topk_region_w_class_label.append(ann) |
|
return topk_region_w_class_label |
|
|
|
def show_semantic_image_label(anns): |
|
""" |
|
show semantic image label for each region |
|
Args: |
|
anns: list of dict with key 'class_label' |
|
Returns: |
|
full_img: numpy array |
|
""" |
|
full_img = None |
|
|
|
for i in range(len(anns)): |
|
m = anns[i]['segmentation'] |
|
if full_img is None: |
|
full_img = np.zeros((m.shape[0], m.shape[1], 3)) |
|
color_mask = np.random.random((1, 3)).tolist()[0] |
|
full_img[m != 0] = color_mask |
|
full_img = full_img*255 |
|
|
|
for i in range(len(anns)): |
|
m = anns[i]['segmentation'] |
|
class_label = anns[i]['class_label'] |
|
|
|
|
|
y, x = np.where(m != 0) |
|
x_center, y_center = int(np.mean(x)), int(np.mean(y)) |
|
|
|
|
|
max_width = 20 |
|
wrapped_text = textwrap.wrap(class_label, width=max_width) |
|
|
|
|
|
font = cv2.FONT_HERSHEY_SIMPLEX |
|
font_scale = 1.2 |
|
font_thickness = 2 |
|
font_color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) |
|
line_spacing = 40 |
|
|
|
for idx, line in enumerate(wrapped_text): |
|
y_offset = y_center - (len(wrapped_text) - 1) * line_spacing // 2 + idx * line_spacing |
|
text_size = cv2.getTextSize(line, font, font_scale, font_thickness)[0] |
|
x_offset = x_center - text_size[0] // 2 |
|
|
|
offsets = [(-1, -1), (-1, 0), (-1, 1), (0, -1), (0, 1), (1, -1), (1, 0), (1, 1)] |
|
for off_x, off_y in offsets: |
|
cv2.putText(full_img, line, (x_offset + off_x, y_offset + off_y), font, font_scale, font_color, font_thickness, cv2.LINE_AA) |
|
|
|
return full_img |
|
|
|
|
|
|
|
image_path = "images/sa_224577.jpg" |
|
input_image = Image.open(image_path) |
|
detect_resolution=1024 |
|
input_image = resize_image(np.array(input_image, dtype=np.uint8), detect_resolution) |
|
region_level_annots = region_level_semantic_api(input_image, topk=5) |
|
output = show_semantic_image_label(region_level_annots) |
|
|
|
image_list = [] |
|
input_image = resize_image(input_image, 512) |
|
output = resize_image(output, 512) |
|
input_image = np.array(input_image, dtype=np.uint8) |
|
output = np.array(output, dtype=np.uint8) |
|
image_list.append(torch.tensor(input_image).float()) |
|
image_list.append(torch.tensor(output).float()) |
|
for each in image_list: |
|
print(each.shape, type(each)) |
|
print(each.max(), each.min()) |
|
|
|
|
|
image_list = torch.stack(image_list).permute(0, 3, 1, 2) |
|
print(image_list.shape) |
|
|
|
save_image(image_list, "images/sample_semantic.jpg", nrow=2, |
|
normalize=True) |
|
|
|
|