Spaces:
Running
Running
import folder_paths | |
import impact.mmdet_nodes as mmdet_nodes | |
from impact.utils import * | |
from impact.core import SEG | |
import impact.core as core | |
import nodes | |
class NO_BBOX_MODEL: | |
pass | |
class NO_SEGM_MODEL: | |
pass | |
class MMDetLoader: | |
def INPUT_TYPES(s): | |
bboxs = ["bbox/"+x for x in folder_paths.get_filename_list("mmdets_bbox")] | |
segms = ["segm/"+x for x in folder_paths.get_filename_list("mmdets_segm")] | |
return {"required": {"model_name": (bboxs + segms, )}} | |
RETURN_TYPES = ("BBOX_MODEL", "SEGM_MODEL") | |
FUNCTION = "load_mmdet" | |
CATEGORY = "ImpactPack/Legacy" | |
def load_mmdet(self, model_name): | |
mmdet_path = folder_paths.get_full_path("mmdets", model_name) | |
model = mmdet_nodes.load_mmdet(mmdet_path) | |
if model_name.startswith("bbox"): | |
return model, NO_SEGM_MODEL() | |
else: | |
return NO_BBOX_MODEL(), model | |
class BboxDetectorForEach: | |
def INPUT_TYPES(s): | |
return {"required": { | |
"bbox_model": ("BBOX_MODEL", ), | |
"image": ("IMAGE", ), | |
"threshold": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), | |
"dilation": ("INT", {"default": 10, "min": 0, "max": 255, "step": 1}), | |
"crop_factor": ("FLOAT", {"default": 3.0, "min": 1.0, "max": 100, "step": 0.1}), | |
} | |
} | |
RETURN_TYPES = ("SEGS", ) | |
FUNCTION = "doit" | |
CATEGORY = "ImpactPack/Legacy" | |
def detect(bbox_model, image, threshold, dilation, crop_factor, drop_size=1, detailer_hook=None): | |
mmdet_results = mmdet_nodes.inference_bbox(bbox_model, image, threshold) | |
segmasks = core.create_segmasks(mmdet_results) | |
if dilation > 0: | |
segmasks = dilate_masks(segmasks, dilation) | |
items = [] | |
h = image.shape[1] | |
w = image.shape[2] | |
for x in segmasks: | |
item_bbox = x[0] | |
item_mask = x[1] | |
y1, x1, y2, x2 = item_bbox | |
if x2 - x1 > drop_size and y2 - y1 > drop_size: | |
crop_region = make_crop_region(w, h, item_bbox, crop_factor) | |
cropped_image = crop_image(image, crop_region) | |
cropped_mask = crop_ndarray2(item_mask, crop_region) | |
confidence = x[2] | |
# bbox_size = (item_bbox[2]-item_bbox[0],item_bbox[3]-item_bbox[1]) # (w,h) | |
item = SEG(cropped_image, cropped_mask, confidence, crop_region, item_bbox, None, None) | |
items.append(item) | |
shape = h, w | |
return shape, items | |
def doit(self, bbox_model, image, threshold, dilation, crop_factor): | |
return (BboxDetectorForEach.detect(bbox_model, image, threshold, dilation, crop_factor), ) | |
class SegmDetectorCombined: | |
def INPUT_TYPES(s): | |
return {"required": { | |
"segm_model": ("SEGM_MODEL", ), | |
"image": ("IMAGE", ), | |
"threshold": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), | |
"dilation": ("INT", {"default": 0, "min": 0, "max": 255, "step": 1}), | |
} | |
} | |
RETURN_TYPES = ("MASK",) | |
FUNCTION = "doit" | |
CATEGORY = "ImpactPack/Legacy" | |
def doit(self, segm_model, image, threshold, dilation): | |
mmdet_results = mmdet_nodes.inference_segm(image, segm_model, threshold) | |
segmasks = core.create_segmasks(mmdet_results) | |
if dilation > 0: | |
segmasks = dilate_masks(segmasks, dilation) | |
mask = combine_masks(segmasks) | |
return (mask,) | |
class BboxDetectorCombined(SegmDetectorCombined): | |
def INPUT_TYPES(s): | |
return {"required": { | |
"bbox_model": ("BBOX_MODEL", ), | |
"image": ("IMAGE", ), | |
"threshold": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), | |
"dilation": ("INT", {"default": 4, "min": 0, "max": 255, "step": 1}), | |
} | |
} | |
def doit(self, bbox_model, image, threshold, dilation): | |
mmdet_results = mmdet_nodes.inference_bbox(bbox_model, image, threshold) | |
segmasks = core.create_segmasks(mmdet_results) | |
if dilation > 0: | |
segmasks = dilate_masks(segmasks, dilation) | |
mask = combine_masks(segmasks) | |
return (mask,) | |
class SegmDetectorForEach: | |
def INPUT_TYPES(s): | |
return {"required": { | |
"segm_model": ("SEGM_MODEL", ), | |
"image": ("IMAGE", ), | |
"threshold": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), | |
"dilation": ("INT", {"default": 10, "min": 0, "max": 255, "step": 1}), | |
"crop_factor": ("FLOAT", {"default": 3.0, "min": 1.0, "max": 100, "step": 0.1}), | |
} | |
} | |
RETURN_TYPES = ("SEGS", ) | |
FUNCTION = "doit" | |
CATEGORY = "ImpactPack/Legacy" | |
def doit(self, segm_model, image, threshold, dilation, crop_factor): | |
mmdet_results = mmdet_nodes.inference_segm(image, segm_model, threshold) | |
segmasks = core.create_segmasks(mmdet_results) | |
if dilation > 0: | |
segmasks = dilate_masks(segmasks, dilation) | |
items = [] | |
h = image.shape[1] | |
w = image.shape[2] | |
for x in segmasks: | |
item_bbox = x[0] | |
item_mask = x[1] | |
crop_region = make_crop_region(w, h, item_bbox, crop_factor) | |
cropped_image = crop_image(image, crop_region) | |
cropped_mask = crop_ndarray2(item_mask, crop_region) | |
confidence = x[2] | |
item = SEG(cropped_image, cropped_mask, confidence, crop_region, item_bbox, None, None) | |
items.append(item) | |
shape = h,w | |
return ((shape, items), ) | |
class SegsMaskCombine: | |
def INPUT_TYPES(s): | |
return {"required": { | |
"segs": ("SEGS", ), | |
"image": ("IMAGE", ), | |
} | |
} | |
RETURN_TYPES = ("MASK",) | |
FUNCTION = "doit" | |
CATEGORY = "ImpactPack/Legacy" | |
def combine(segs, image): | |
h = image.shape[1] | |
w = image.shape[2] | |
mask = np.zeros((h, w), dtype=np.uint8) | |
for seg in segs[1]: | |
cropped_mask = seg.cropped_mask | |
crop_region = seg.crop_region | |
mask[crop_region[1]:crop_region[3], crop_region[0]:crop_region[2]] |= (cropped_mask * 255).astype(np.uint8) | |
return torch.from_numpy(mask.astype(np.float32) / 255.0) | |
def doit(self, segs, image): | |
return (SegsMaskCombine.combine(segs, image), ) | |
class MaskPainter(nodes.PreviewImage): | |
def INPUT_TYPES(s): | |
return {"required": {"images": ("IMAGE",), }, | |
"hidden": { | |
"prompt": "PROMPT", | |
"extra_pnginfo": "EXTRA_PNGINFO", | |
}, | |
"optional": {"mask_image": ("IMAGE_PATH",), }, | |
"optional": {"image": (["#placeholder"], )}, | |
} | |
RETURN_TYPES = ("MASK",) | |
FUNCTION = "save_painted_images" | |
CATEGORY = "ImpactPack/Legacy" | |
def save_painted_images(self, images, filename_prefix="impact-mask", | |
prompt=None, extra_pnginfo=None, mask_image=None, image=None): | |
if image == "#placeholder" or image['image_hash'] != id(images): | |
# new input image | |
res = self.save_images(images, filename_prefix, prompt, extra_pnginfo) | |
item = res['ui']['images'][0] | |
if not item['filename'].endswith(']'): | |
filepath = f"{item['filename']} [{item['type']}]" | |
else: | |
filepath = item['filename'] | |
_, mask = nodes.LoadImage().load_image(filepath) | |
res['ui']['aux'] = [id(images), res['ui']['images']] | |
res['result'] = (mask, ) | |
return res | |
else: | |
# new mask | |
if '0' in image: # fallback | |
image = image['0'] | |
forward = {'filename': image['forward_filename'], | |
'subfolder': image['forward_subfolder'], | |
'type': image['forward_type'], } | |
res = {'ui': {'images': [forward]}} | |
imgpath = "" | |
if 'subfolder' in image and image['subfolder'] != "": | |
imgpath = image['subfolder'] + "/" | |
imgpath += f"{image['filename']}" | |
if 'type' in image and image['type'] != "": | |
imgpath += f" [{image['type']}]" | |
res['ui']['aux'] = [id(images), [forward]] | |
_, mask = nodes.LoadImage().load_image(imgpath) | |
res['result'] = (mask, ) | |
return res | |