|
import folder_paths
|
|
from impact.core import *
|
|
import os
|
|
|
|
import mmcv
|
|
from mmdet.apis import (inference_detector, init_detector)
|
|
from mmdet.evaluation import get_classes
|
|
|
|
|
|
def load_mmdet(model_path):
|
|
model_config = os.path.splitext(model_path)[0] + ".py"
|
|
model = init_detector(model_config, model_path, device="cpu")
|
|
return model
|
|
|
|
|
|
def inference_segm_old(model, image, conf_threshold):
|
|
image = image.numpy()[0] * 255
|
|
mmdet_results = inference_detector(model, image)
|
|
|
|
bbox_results, segm_results = mmdet_results
|
|
label = "A"
|
|
|
|
classes = get_classes("coco")
|
|
labels = [
|
|
np.full(bbox.shape[0], i, dtype=np.int32)
|
|
for i, bbox in enumerate(bbox_results)
|
|
]
|
|
n, m = bbox_results[0].shape
|
|
if n == 0:
|
|
return [[], [], []]
|
|
labels = np.concatenate(labels)
|
|
bboxes = np.vstack(bbox_results)
|
|
segms = mmcv.concat_list(segm_results)
|
|
filter_idxs = np.where(bboxes[:, -1] > conf_threshold)[0]
|
|
results = [[], [], []]
|
|
for i in filter_idxs:
|
|
results[0].append(label + "-" + classes[labels[i]])
|
|
results[1].append(bboxes[i])
|
|
results[2].append(segms[i])
|
|
|
|
return results
|
|
|
|
|
|
def inference_segm(image, modelname, conf_thres, lab="A"):
|
|
image = image.numpy()[0] * 255
|
|
mmdet_results = inference_detector(modelname, image).pred_instances
|
|
bboxes = mmdet_results.bboxes.numpy()
|
|
segms = mmdet_results.masks.numpy()
|
|
scores = mmdet_results.scores.numpy()
|
|
|
|
classes = get_classes("coco")
|
|
|
|
n, m = bboxes.shape
|
|
if n == 0:
|
|
return [[], [], [], []]
|
|
labels = mmdet_results.labels
|
|
filter_inds = np.where(mmdet_results.scores > conf_thres)[0]
|
|
results = [[], [], [], []]
|
|
for i in filter_inds:
|
|
results[0].append(lab + "-" + classes[labels[i]])
|
|
results[1].append(bboxes[i])
|
|
results[2].append(segms[i])
|
|
results[3].append(scores[i])
|
|
|
|
return results
|
|
|
|
|
|
def inference_bbox(modelname, image, conf_threshold):
|
|
image = image.numpy()[0] * 255
|
|
label = "A"
|
|
output = inference_detector(modelname, image).pred_instances
|
|
cv2_image = np.array(image)
|
|
cv2_image = cv2_image[:, :, ::-1].copy()
|
|
cv2_gray = cv2.cvtColor(cv2_image, cv2.COLOR_BGR2GRAY)
|
|
|
|
segms = []
|
|
for x0, y0, x1, y1 in output.bboxes:
|
|
cv2_mask = np.zeros(cv2_gray.shape, np.uint8)
|
|
cv2.rectangle(cv2_mask, (int(x0), int(y0)), (int(x1), int(y1)), 255, -1)
|
|
cv2_mask_bool = cv2_mask.astype(bool)
|
|
segms.append(cv2_mask_bool)
|
|
|
|
n, m = output.bboxes.shape
|
|
if n == 0:
|
|
return [[], [], [], []]
|
|
|
|
bboxes = output.bboxes.numpy()
|
|
scores = output.scores.numpy()
|
|
filter_idxs = np.where(scores > conf_threshold)[0]
|
|
results = [[], [], [], []]
|
|
for i in filter_idxs:
|
|
results[0].append(label)
|
|
results[1].append(bboxes[i])
|
|
results[2].append(segms[i])
|
|
results[3].append(scores[i])
|
|
|
|
return results
|
|
|
|
|
|
class BBoxDetector:
|
|
bbox_model = None
|
|
|
|
def __init__(self, bbox_model):
|
|
self.bbox_model = bbox_model
|
|
|
|
def detect(self, image, threshold, dilation, crop_factor, drop_size=1, detailer_hook=None):
|
|
drop_size = max(drop_size, 1)
|
|
mmdet_results = inference_bbox(self.bbox_model, image, threshold)
|
|
segmasks = create_segmasks(mmdet_results)
|
|
|
|
if dilation > 0:
|
|
segmasks = dilate_masks(segmasks, dilation)
|
|
|
|
items = []
|
|
h = image.shape[1]
|
|
w = image.shape[2]
|
|
|
|
for x in segmasks:
|
|
item_bbox = x[0]
|
|
item_mask = x[1]
|
|
|
|
y1, x1, y2, x2 = item_bbox
|
|
|
|
if x2 - x1 > drop_size and y2 - y1 > drop_size:
|
|
crop_region = make_crop_region(w, h, item_bbox, crop_factor)
|
|
cropped_image = crop_image(image, crop_region)
|
|
cropped_mask = crop_ndarray2(item_mask, crop_region)
|
|
confidence = x[2]
|
|
|
|
|
|
item = SEG(cropped_image, cropped_mask, confidence, crop_region, item_bbox, None, None)
|
|
|
|
items.append(item)
|
|
|
|
shape = image.shape[1], image.shape[2]
|
|
return shape, items
|
|
|
|
def detect_combined(self, image, threshold, dilation):
|
|
mmdet_results = inference_bbox(self.bbox_model, image, threshold)
|
|
segmasks = create_segmasks(mmdet_results)
|
|
if dilation > 0:
|
|
segmasks = dilate_masks(segmasks, dilation)
|
|
|
|
return combine_masks(segmasks)
|
|
|
|
def setAux(self, x):
|
|
pass
|
|
|
|
|
|
class SegmDetector(BBoxDetector):
|
|
segm_model = None
|
|
|
|
def __init__(self, segm_model):
|
|
self.segm_model = segm_model
|
|
|
|
def detect(self, image, threshold, dilation, crop_factor, drop_size=1, detailer_hook=None):
|
|
drop_size = max(drop_size, 1)
|
|
mmdet_results = inference_segm(image, self.segm_model, threshold)
|
|
segmasks = create_segmasks(mmdet_results)
|
|
|
|
if dilation > 0:
|
|
segmasks = dilate_masks(segmasks, dilation)
|
|
|
|
items = []
|
|
h = image.shape[1]
|
|
w = image.shape[2]
|
|
for x in segmasks:
|
|
item_bbox = x[0]
|
|
item_mask = x[1]
|
|
|
|
y1, x1, y2, x2 = item_bbox
|
|
|
|
if x2 - x1 > drop_size and y2 - y1 > drop_size:
|
|
crop_region = make_crop_region(w, h, item_bbox, crop_factor)
|
|
cropped_image = crop_image(image, crop_region)
|
|
cropped_mask = crop_ndarray2(item_mask, crop_region)
|
|
confidence = x[2]
|
|
|
|
item = SEG(cropped_image, cropped_mask, confidence, crop_region, item_bbox, None, None)
|
|
items.append(item)
|
|
|
|
segs = image.shape, items
|
|
|
|
if detailer_hook is not None and hasattr(detailer_hook, "post_detection"):
|
|
segs = detailer_hook.post_detection(segs)
|
|
|
|
return segs
|
|
|
|
def detect_combined(self, image, threshold, dilation):
|
|
mmdet_results = inference_bbox(self.bbox_model, image, threshold)
|
|
segmasks = create_segmasks(mmdet_results)
|
|
if dilation > 0:
|
|
segmasks = dilate_masks(segmasks, dilation)
|
|
|
|
return combine_masks(segmasks)
|
|
|
|
def setAux(self, x):
|
|
pass
|
|
|
|
|
|
class MMDetDetectorProvider:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
bboxs = ["bbox/"+x for x in folder_paths.get_filename_list("mmdets_bbox")]
|
|
segms = ["segm/"+x for x in folder_paths.get_filename_list("mmdets_segm")]
|
|
return {"required": {"model_name": (bboxs + segms, )}}
|
|
RETURN_TYPES = ("BBOX_DETECTOR", "SEGM_DETECTOR")
|
|
FUNCTION = "load_mmdet"
|
|
|
|
CATEGORY = "ImpactPack"
|
|
|
|
def load_mmdet(self, model_name):
|
|
mmdet_path = folder_paths.get_full_path("mmdets", model_name)
|
|
model = load_mmdet(mmdet_path)
|
|
|
|
if model_name.startswith("bbox"):
|
|
return BBoxDetector(model), NO_SEGM_DETECTOR()
|
|
else:
|
|
return NO_BBOX_DETECTOR(), model |