daquanzhou
merge github repos and lfs track ckpt/path/safetensors/pt
613c9ab
raw
history blame
9 kB
import folder_paths
import impact.mmdet_nodes as mmdet_nodes
from impact.utils import *
from impact.core import SEG
import impact.core as core
import nodes
class NO_BBOX_MODEL:
pass
class NO_SEGM_MODEL:
pass
class MMDetLoader:
@classmethod
def INPUT_TYPES(s):
bboxs = ["bbox/"+x for x in folder_paths.get_filename_list("mmdets_bbox")]
segms = ["segm/"+x for x in folder_paths.get_filename_list("mmdets_segm")]
return {"required": {"model_name": (bboxs + segms, )}}
RETURN_TYPES = ("BBOX_MODEL", "SEGM_MODEL")
FUNCTION = "load_mmdet"
CATEGORY = "ImpactPack/Legacy"
def load_mmdet(self, model_name):
mmdet_path = folder_paths.get_full_path("mmdets", model_name)
model = mmdet_nodes.load_mmdet(mmdet_path)
if model_name.startswith("bbox"):
return model, NO_SEGM_MODEL()
else:
return NO_BBOX_MODEL(), model
class BboxDetectorForEach:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"bbox_model": ("BBOX_MODEL", ),
"image": ("IMAGE", ),
"threshold": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
"dilation": ("INT", {"default": 10, "min": 0, "max": 255, "step": 1}),
"crop_factor": ("FLOAT", {"default": 3.0, "min": 1.0, "max": 100, "step": 0.1}),
}
}
RETURN_TYPES = ("SEGS", )
FUNCTION = "doit"
CATEGORY = "ImpactPack/Legacy"
@staticmethod
def detect(bbox_model, image, threshold, dilation, crop_factor, drop_size=1, detailer_hook=None):
mmdet_results = mmdet_nodes.inference_bbox(bbox_model, image, threshold)
segmasks = core.create_segmasks(mmdet_results)
if dilation > 0:
segmasks = dilate_masks(segmasks, dilation)
items = []
h = image.shape[1]
w = image.shape[2]
for x in segmasks:
item_bbox = x[0]
item_mask = x[1]
y1, x1, y2, x2 = item_bbox
if x2 - x1 > drop_size and y2 - y1 > drop_size:
crop_region = make_crop_region(w, h, item_bbox, crop_factor)
cropped_image = crop_image(image, crop_region)
cropped_mask = crop_ndarray2(item_mask, crop_region)
confidence = x[2]
# bbox_size = (item_bbox[2]-item_bbox[0],item_bbox[3]-item_bbox[1]) # (w,h)
item = SEG(cropped_image, cropped_mask, confidence, crop_region, item_bbox, None, None)
items.append(item)
shape = h, w
return shape, items
def doit(self, bbox_model, image, threshold, dilation, crop_factor):
return (BboxDetectorForEach.detect(bbox_model, image, threshold, dilation, crop_factor), )
class SegmDetectorCombined:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"segm_model": ("SEGM_MODEL", ),
"image": ("IMAGE", ),
"threshold": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
"dilation": ("INT", {"default": 0, "min": 0, "max": 255, "step": 1}),
}
}
RETURN_TYPES = ("MASK",)
FUNCTION = "doit"
CATEGORY = "ImpactPack/Legacy"
def doit(self, segm_model, image, threshold, dilation):
mmdet_results = mmdet_nodes.inference_segm(image, segm_model, threshold)
segmasks = core.create_segmasks(mmdet_results)
if dilation > 0:
segmasks = dilate_masks(segmasks, dilation)
mask = combine_masks(segmasks)
return (mask,)
class BboxDetectorCombined(SegmDetectorCombined):
@classmethod
def INPUT_TYPES(s):
return {"required": {
"bbox_model": ("BBOX_MODEL", ),
"image": ("IMAGE", ),
"threshold": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
"dilation": ("INT", {"default": 4, "min": 0, "max": 255, "step": 1}),
}
}
def doit(self, bbox_model, image, threshold, dilation):
mmdet_results = mmdet_nodes.inference_bbox(bbox_model, image, threshold)
segmasks = core.create_segmasks(mmdet_results)
if dilation > 0:
segmasks = dilate_masks(segmasks, dilation)
mask = combine_masks(segmasks)
return (mask,)
class SegmDetectorForEach:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"segm_model": ("SEGM_MODEL", ),
"image": ("IMAGE", ),
"threshold": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
"dilation": ("INT", {"default": 10, "min": 0, "max": 255, "step": 1}),
"crop_factor": ("FLOAT", {"default": 3.0, "min": 1.0, "max": 100, "step": 0.1}),
}
}
RETURN_TYPES = ("SEGS", )
FUNCTION = "doit"
CATEGORY = "ImpactPack/Legacy"
def doit(self, segm_model, image, threshold, dilation, crop_factor):
mmdet_results = mmdet_nodes.inference_segm(image, segm_model, threshold)
segmasks = core.create_segmasks(mmdet_results)
if dilation > 0:
segmasks = dilate_masks(segmasks, dilation)
items = []
h = image.shape[1]
w = image.shape[2]
for x in segmasks:
item_bbox = x[0]
item_mask = x[1]
crop_region = make_crop_region(w, h, item_bbox, crop_factor)
cropped_image = crop_image(image, crop_region)
cropped_mask = crop_ndarray2(item_mask, crop_region)
confidence = x[2]
item = SEG(cropped_image, cropped_mask, confidence, crop_region, item_bbox, None, None)
items.append(item)
shape = h,w
return ((shape, items), )
class SegsMaskCombine:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"segs": ("SEGS", ),
"image": ("IMAGE", ),
}
}
RETURN_TYPES = ("MASK",)
FUNCTION = "doit"
CATEGORY = "ImpactPack/Legacy"
@staticmethod
def combine(segs, image):
h = image.shape[1]
w = image.shape[2]
mask = np.zeros((h, w), dtype=np.uint8)
for seg in segs[1]:
cropped_mask = seg.cropped_mask
crop_region = seg.crop_region
mask[crop_region[1]:crop_region[3], crop_region[0]:crop_region[2]] |= (cropped_mask * 255).astype(np.uint8)
return torch.from_numpy(mask.astype(np.float32) / 255.0)
def doit(self, segs, image):
return (SegsMaskCombine.combine(segs, image), )
class MaskPainter(nodes.PreviewImage):
@classmethod
def INPUT_TYPES(s):
return {"required": {"images": ("IMAGE",), },
"hidden": {
"prompt": "PROMPT",
"extra_pnginfo": "EXTRA_PNGINFO",
},
"optional": {"mask_image": ("IMAGE_PATH",), },
"optional": {"image": (["#placeholder"], )},
}
RETURN_TYPES = ("MASK",)
FUNCTION = "save_painted_images"
CATEGORY = "ImpactPack/Legacy"
def save_painted_images(self, images, filename_prefix="impact-mask",
prompt=None, extra_pnginfo=None, mask_image=None, image=None):
if image == "#placeholder" or image['image_hash'] != id(images):
# new input image
res = self.save_images(images, filename_prefix, prompt, extra_pnginfo)
item = res['ui']['images'][0]
if not item['filename'].endswith(']'):
filepath = f"{item['filename']} [{item['type']}]"
else:
filepath = item['filename']
_, mask = nodes.LoadImage().load_image(filepath)
res['ui']['aux'] = [id(images), res['ui']['images']]
res['result'] = (mask, )
return res
else:
# new mask
if '0' in image: # fallback
image = image['0']
forward = {'filename': image['forward_filename'],
'subfolder': image['forward_subfolder'],
'type': image['forward_type'], }
res = {'ui': {'images': [forward]}}
imgpath = ""
if 'subfolder' in image and image['subfolder'] != "":
imgpath = image['subfolder'] + "/"
imgpath += f"{image['filename']}"
if 'type' in image and image['type'] != "":
imgpath += f" [{image['type']}]"
res['ui']['aux'] = [id(images), [forward]]
_, mask = nodes.LoadImage().load_image(imgpath)
res['result'] = (mask, )
return res