|
import os |
|
import gc |
|
import glob |
|
import copy |
|
from PIL import Image |
|
from collections import OrderedDict |
|
import numpy as np |
|
import torch |
|
import cv2 |
|
from sam_hq.automatic import SamAutomaticMaskGeneratorHQ |
|
from modules import scripts, shared |
|
from modules.paths import extensions_dir |
|
from modules.devices import torch_gc |
|
|
|
|
|
global_sam: SamAutomaticMaskGeneratorHQ = None |
|
sem_seg_cache = OrderedDict() |
|
sam_annotator_dir = os.path.join(scripts.basedir(), "annotator") |
|
original_uniformer_inference_segmentor = None |
|
|
|
|
|
def pad64(x): |
|
return int(np.ceil(float(x) / 64.0) * 64 - x) |
|
|
|
|
|
def safer_memory(x): |
|
|
|
return np.ascontiguousarray(x.copy()).copy() |
|
|
|
|
|
def resize_image_with_pad(input_image, resolution): |
|
from annotator.util import HWC3 |
|
img = HWC3(input_image) |
|
H_raw, W_raw, _ = img.shape |
|
k = float(resolution) / float(min(H_raw, W_raw)) |
|
interpolation = cv2.INTER_CUBIC if k > 1 else cv2.INTER_AREA |
|
H_target = int(np.round(float(H_raw) * k)) |
|
W_target = int(np.round(float(W_raw) * k)) |
|
img = cv2.resize(img, (W_target, H_target), interpolation=interpolation) |
|
H_pad, W_pad = pad64(H_target), pad64(W_target) |
|
img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode='edge') |
|
|
|
def remove_pad(x): |
|
return safer_memory(x[:H_target, :W_target]) |
|
|
|
return safer_memory(img_padded), remove_pad |
|
|
|
|
|
def blend_image_and_seg(image, seg, alpha=0.5): |
|
image_blend = image * (1 - alpha) + np.array(seg) * alpha |
|
return Image.fromarray(image_blend.astype(np.uint8)) |
|
|
|
|
|
def create_symbolic_link(): |
|
cnet_annotator_dir = os.path.join(extensions_dir, "sd-webui-controlnet/annotator") |
|
if os.path.isdir(cnet_annotator_dir): |
|
if not os.path.isdir(sam_annotator_dir): |
|
os.symlink(cnet_annotator_dir, sam_annotator_dir, target_is_directory=True) |
|
return True |
|
return False |
|
|
|
|
|
def clear_sem_sam_cache(): |
|
sem_seg_cache.clear() |
|
gc.collect() |
|
torch_gc() |
|
|
|
|
|
def sem_sam_garbage_collect(): |
|
if shared.cmd_opts.lowvram: |
|
for model_key, model in sem_seg_cache: |
|
if model_key == "uniformer": |
|
from annotator.uniformer import unload_uniformer_model |
|
unload_uniformer_model() |
|
else: |
|
model.unload_model() |
|
gc.collect() |
|
torch_gc() |
|
|
|
|
|
def strengthen_sem_seg(class_ids, img): |
|
print("Auto SAM strengthening semantic segmentation") |
|
import pycocotools.mask as maskUtils |
|
semantc_mask = copy.deepcopy(class_ids) |
|
annotations = global_sam.generate(img) |
|
annotations = sorted(annotations, key=lambda x: x['area'], reverse=True) |
|
print(f"Auto SAM generated {len(annotations)} masks") |
|
for ann in annotations: |
|
valid_mask = torch.tensor(maskUtils.decode(ann['segmentation'])).bool() |
|
propose_classes_ids = torch.tensor(class_ids[valid_mask]) |
|
num_class_proposals = len(torch.unique(propose_classes_ids)) |
|
if num_class_proposals == 1: |
|
semantc_mask[valid_mask] = propose_classes_ids[0].numpy() |
|
continue |
|
top_1_propose_class_ids = torch.bincount(propose_classes_ids.flatten()).topk(1).indices |
|
semantc_mask[valid_mask] = top_1_propose_class_ids.numpy() |
|
print("Auto SAM strengthen process end") |
|
return semantc_mask |
|
|
|
|
|
def random_segmentation(img): |
|
print("Auto SAM generating random segmentation for Edit-Anything") |
|
img_np = np.array(img.convert("RGB")) |
|
annotations = global_sam.generate(img_np) |
|
|
|
print(f"Auto SAM generated {len(annotations)} masks") |
|
H, W, _ = img_np.shape |
|
color_map = np.zeros((H, W, 3), dtype=np.uint8) |
|
detected_map_tmp = np.zeros((H, W), dtype=np.uint16) |
|
for idx, annotation in enumerate(annotations): |
|
current_seg = annotation['segmentation'] |
|
color_map[current_seg] = np.random.randint(0, 255, (3)) |
|
detected_map_tmp[current_seg] = idx + 1 |
|
detected_map = np.zeros((detected_map_tmp.shape[0], detected_map_tmp.shape[1], 3)) |
|
detected_map[:, :, 0] = detected_map_tmp % 256 |
|
detected_map[:, :, 1] = detected_map_tmp // 256 |
|
from annotator.util import HWC3 |
|
detected_map = HWC3(detected_map.astype(np.uint8)) |
|
print("Auto SAM generation process end") |
|
return [blend_image_and_seg(img_np, color_map), Image.fromarray(color_map), Image.fromarray(detected_map)], \ |
|
"Random segmentation done. Left above (0) is blended image, right above (1) is random segmentation, left below (2) is Edit-Anything control input." |
|
|
|
|
|
def image_layer_image(layout_input_image, layout_output_path): |
|
img_np = np.array(layout_input_image.convert("RGB")) |
|
annotations = global_sam.generate(img_np) |
|
print(f"AutoSAM generated {len(annotations)} annotations") |
|
annotations = sorted(annotations, key=lambda x: x['area']) |
|
for idx, annotation in enumerate(annotations): |
|
img_tmp = np.zeros((img_np.shape[0], img_np.shape[1], 3)) |
|
img_tmp[annotation['segmentation']] = img_np[annotation['segmentation']] |
|
img_np[annotation['segmentation']] = np.array([0, 0, 0]) |
|
img_tmp = Image.fromarray(img_tmp.astype(np.uint8)) |
|
img_tmp.save(os.path.join(layout_output_path, f"{idx}.png")) |
|
img_np = Image.fromarray(img_np.astype(np.uint8)) |
|
img_np.save(os.path.join(layout_output_path, "leftover.png")) |
|
|
|
|
|
def image_layer_internal(layout_input_image_or_path, layout_output_path): |
|
if isinstance(layout_input_image_or_path, str): |
|
print("Image layer division batch processing") |
|
all_files = glob.glob(os.path.join(layout_input_image_or_path, "*")) |
|
for image_index, input_image_file in enumerate(all_files): |
|
print(f"Processing {image_index}/{len(all_files)} {input_image_file}") |
|
try: |
|
input_image = Image.open(input_image_file) |
|
output_directory = os.path.join(layout_output_path, os.path.splitext(os.path.basename(input_image_file))[0]) |
|
from pathlib import Path |
|
Path(output_directory).mkdir(exist_ok=True) |
|
except: |
|
print(f"File {input_image_file} not image, skipped.") |
|
continue |
|
image_layer_image(input_image, output_directory) |
|
else: |
|
image_layer_image(layout_input_image_or_path, layout_output_path) |
|
return "Done" |
|
|
|
|
|
def inject_uniformer_inference_segmentor(model, img): |
|
original_result = original_uniformer_inference_segmentor(model, img) |
|
original_result[0] = strengthen_sem_seg(original_result[0], img) |
|
return original_result |
|
|
|
|
|
def inject_uniformer_show_result_pyplot(model, img, result, palette=None, fig_size=(15, 10), opacity=0.5, title='', block=True): |
|
return result[0] |
|
|
|
|
|
def inject_oneformer_semantic_run(img, predictor, metadata): |
|
predictions = predictor(img[:, :, ::-1], "semantic") |
|
original_result = predictions["sem_seg"].argmax(dim=0).cpu().numpy() |
|
from annotator.oneformer.oneformer.demo.visualizer import Visualizer, ColorMode |
|
visualizer_map = Visualizer(img, is_img=False, metadata=metadata, instance_mode=ColorMode.IMAGE) |
|
out_map = visualizer_map.draw_sem_seg(torch.tensor(strengthen_sem_seg(original_result, img), device='cpu'), alpha=1, is_text=False).get_image() |
|
return out_map |
|
|
|
|
|
def inject_oneformer_semantic_run_categorical_mask(img, predictor, metadata): |
|
predictions = predictor(img[:, :, ::-1], "semantic") |
|
original_result = predictions["sem_seg"].argmax(dim=0).cpu().numpy() |
|
return strengthen_sem_seg(original_result, img) |
|
|
|
|
|
def inject_oneformer_detector_call(self, img): |
|
if self.model is None: |
|
self.load_model() |
|
self.model.model.to(self.device) |
|
return inject_oneformer_semantic_run(img, self.model, self.metadata) |
|
|
|
|
|
def inject_oneformer_detector_call_categorical_mask(self, img): |
|
if self.model is None: |
|
self.load_model() |
|
self.model.model.to(self.device) |
|
return inject_oneformer_semantic_run_categorical_mask(img, self.model, self.metadata) |
|
|
|
|
|
def _uniformer(img): |
|
if "uniformer" not in sem_seg_cache: |
|
from annotator.uniformer import apply_uniformer |
|
sem_seg_cache["uniformer"] = apply_uniformer |
|
result = sem_seg_cache["uniformer"](img) |
|
return result |
|
|
|
|
|
def _oneformer(img, dataset="coco"): |
|
oneformer_key = f"oneformer_{dataset}" |
|
if oneformer_key not in sem_seg_cache: |
|
from annotator.oneformer import OneformerDetector |
|
sem_seg_cache[oneformer_key] = OneformerDetector(OneformerDetector.configs[dataset]) |
|
result = sem_seg_cache[oneformer_key](img) |
|
return result |
|
|
|
|
|
def semantic_segmentation(input_image, annotator_name, processor_res, |
|
use_pixel_perfect, resize_mode, target_W, target_H): |
|
if input_image is None: |
|
return [], "No input image." |
|
if "seg" in annotator_name: |
|
if not os.path.isdir(os.path.join(scripts.basedir(), "annotator")) and not create_symbolic_link(): |
|
return [], "ControlNet extension not found." |
|
global original_uniformer_inference_segmentor |
|
input_image_np = np.array(input_image) |
|
processor_res = pixel_perfect_lllyasviel(input_image_np, processor_res, use_pixel_perfect, resize_mode, target_W, target_H) |
|
input_image, remove_pad = resize_image_with_pad(input_image_np, processor_res) |
|
print("Generating semantic segmentation without SAM") |
|
if annotator_name == "seg_ufade20k": |
|
original_semseg = _uniformer(input_image) |
|
print("Generating semantic segmentation with SAM") |
|
import annotator.uniformer as uniformer |
|
original_uniformer_inference_segmentor = uniformer.inference_segmentor |
|
uniformer.inference_segmentor = inject_uniformer_inference_segmentor |
|
sam_semseg = _uniformer(input_image) |
|
uniformer.inference_segmentor = original_uniformer_inference_segmentor |
|
original_semseg = remove_pad(original_semseg) |
|
sam_semseg = remove_pad(sam_semseg) |
|
input_image = remove_pad(input_image) |
|
output_gallery = [original_semseg, sam_semseg, blend_image_and_seg(input_image, original_semseg), blend_image_and_seg(input_image, sam_semseg)] |
|
return output_gallery, "Uniformer semantic segmentation of ade20k done. Left is segmentation before SAM, right is segmentation after SAM." |
|
else: |
|
dataset = annotator_name.split('_')[-1][2:] |
|
original_semseg = _oneformer(input_image, dataset=dataset) |
|
print("Generating semantic segmentation with SAM") |
|
from annotator.oneformer import OneformerDetector |
|
original_oneformer_call = OneformerDetector.__call__ |
|
OneformerDetector.__call__ = inject_oneformer_detector_call |
|
sam_semseg = _oneformer(input_image, dataset=dataset) |
|
OneformerDetector.__call__ = original_oneformer_call |
|
original_semseg = remove_pad(original_semseg) |
|
sam_semseg = remove_pad(sam_semseg) |
|
input_image = remove_pad(input_image) |
|
output_gallery = [original_semseg, sam_semseg, blend_image_and_seg(input_image, original_semseg), blend_image_and_seg(input_image, sam_semseg)] |
|
return output_gallery, f"Oneformer semantic segmentation of {dataset} done. Left is segmentation before SAM, right is segmentation after SAM." |
|
else: |
|
return random_segmentation(input_image) |
|
|
|
|
|
def categorical_mask_image(crop_processor, crop_processor_res, crop_category_input, crop_input_image, |
|
crop_pixel_perfect, crop_resize_mode, target_W, target_H): |
|
if crop_input_image is None: |
|
return "No input image." |
|
if not os.path.isdir(os.path.join(scripts.basedir(), "annotator")) and not create_symbolic_link(): |
|
return "ControlNet extension not found." |
|
filter_classes = crop_category_input.split('+') |
|
if len(filter_classes) == 0: |
|
return "No class selected." |
|
try: |
|
filter_classes = [int(i) for i in filter_classes] |
|
except: |
|
return "Illegal class id. You may have input some string." |
|
crop_input_image_np = np.array(crop_input_image) |
|
crop_processor_res = pixel_perfect_lllyasviel(crop_input_image_np, crop_processor_res, crop_pixel_perfect, crop_resize_mode, target_W, target_H) |
|
crop_input_image, remove_pad = resize_image_with_pad(crop_input_image_np, crop_processor_res) |
|
crop_input_image_copy = copy.deepcopy(crop_input_image) |
|
global original_uniformer_inference_segmentor |
|
print(f"Generating categories with processor {crop_processor}") |
|
if crop_processor == "seg_ufade20k": |
|
import annotator.uniformer as uniformer |
|
original_uniformer_inference_segmentor = uniformer.inference_segmentor |
|
uniformer.inference_segmentor = inject_uniformer_inference_segmentor |
|
tmp_ouis = uniformer.show_result_pyplot |
|
uniformer.show_result_pyplot = inject_uniformer_show_result_pyplot |
|
sam_semseg = _uniformer(crop_input_image) |
|
uniformer.inference_segmentor = original_uniformer_inference_segmentor |
|
uniformer.show_result_pyplot = tmp_ouis |
|
else: |
|
dataset = crop_processor.split('_')[-1][2:] |
|
from annotator.oneformer import OneformerDetector |
|
original_oneformer_call = OneformerDetector.__call__ |
|
OneformerDetector.__call__ = inject_oneformer_detector_call_categorical_mask |
|
sam_semseg = _oneformer(crop_input_image, dataset=dataset) |
|
OneformerDetector.__call__ = original_oneformer_call |
|
sam_semseg = remove_pad(sam_semseg) |
|
mask = np.zeros(sam_semseg.shape, dtype=np.bool_) |
|
for i in filter_classes: |
|
mask[sam_semseg == i] = True |
|
return mask, remove_pad(crop_input_image_copy) |
|
|
|
|
|
def register_auto_sam(sam, |
|
auto_sam_points_per_side, auto_sam_points_per_batch, auto_sam_pred_iou_thresh, |
|
auto_sam_stability_score_thresh, auto_sam_stability_score_offset, auto_sam_box_nms_thresh, |
|
auto_sam_crop_n_layers, auto_sam_crop_nms_thresh, auto_sam_crop_overlap_ratio, |
|
auto_sam_crop_n_points_downscale_factor, auto_sam_min_mask_region_area, auto_sam_output_mode): |
|
global global_sam |
|
global_sam = SamAutomaticMaskGeneratorHQ( |
|
sam, auto_sam_points_per_side, auto_sam_points_per_batch, |
|
auto_sam_pred_iou_thresh, auto_sam_stability_score_thresh, |
|
auto_sam_stability_score_offset, auto_sam_box_nms_thresh, |
|
auto_sam_crop_n_layers, auto_sam_crop_nms_thresh, auto_sam_crop_overlap_ratio, |
|
auto_sam_crop_n_points_downscale_factor, None, |
|
auto_sam_min_mask_region_area, auto_sam_output_mode) |
|
|
|
|
|
def pixel_perfect_lllyasviel(input_image, processor_res, use_pixel_perfect, resize_mode, target_W, target_H): |
|
preprocessor_resolution = processor_res |
|
if use_pixel_perfect: |
|
raw_H, raw_W, _ = input_image.shape |
|
k0 = float(target_H) / float(raw_H) |
|
k1 = float(target_W) / float(raw_W) |
|
if resize_mode == 2: |
|
estimation = min(k0, k1) * float(min(raw_H, raw_W)) |
|
else: |
|
estimation = max(k0, k1) * float(min(raw_H, raw_W)) |
|
preprocessor_resolution = int(np.round(float(estimation) / 64.0)) * 64 |
|
print(f'Pixel Perfect Mode (written by lllyasviel) Enabled.') |
|
print(f'resize_mode = {str(resize_mode)}') |
|
print(f'raw_H = {raw_H}') |
|
print(f'raw_W = {raw_W}') |
|
print(f'target_H = {target_H}') |
|
print(f'target_W = {target_W}') |
|
print(f'estimation = {estimation}') |
|
print(f'preprocessor resolution = {preprocessor_resolution}') |
|
return preprocessor_resolution |
|
|