import numpy as np import cv2 import torch import torchvision.transforms.functional as TF import sys as _sys from keyword import iskeyword as _iskeyword from operator import itemgetter as _itemgetter from segment_anything import SamPredictor from comfy import model_management ################################################################################ ### namedtuple ################################################################################ try: from _collections import _tuplegetter except ImportError: _tuplegetter = lambda index, doc: property(_itemgetter(index), doc=doc) def namedtuple(typename, field_names, *, rename=False, defaults=None, module=None): """Returns a new subclass of tuple with named fields. >>> Point = namedtuple('Point', ['x', 'y']) >>> Point.__doc__ # docstring for the new class 'Point(x, y)' >>> p = Point(11, y=22) # instantiate with positional args or keywords >>> p[0] + p[1] # indexable like a plain tuple 33 >>> x, y = p # unpack like a regular tuple >>> x, y (11, 22) >>> p.x + p.y # fields also accessible by name 33 >>> d = p._asdict() # convert to a dictionary >>> d['x'] 11 >>> Point(**d) # convert from a dictionary Point(x=11, y=22) >>> p._replace(x=100) # _replace() is like str.replace() but targets named fields Point(x=100, y=22) """ # Validate the field names. At the user's option, either generate an error # message or automatically replace the field name with a valid name. if isinstance(field_names, str): field_names = field_names.replace(',', ' ').split() field_names = list(map(str, field_names)) typename = _sys.intern(str(typename)) if rename: seen = set() for index, name in enumerate(field_names): if (not name.isidentifier() or _iskeyword(name) or name.startswith('_') or name in seen): field_names[index] = f'_{index}' seen.add(name) for name in [typename] + field_names: if type(name) is not str: raise TypeError('Type names and field names must be strings') if not name.isidentifier(): raise ValueError('Type names and field names must be valid ' f'identifiers: {name!r}') if _iskeyword(name): raise ValueError('Type names and field names cannot be a ' f'keyword: {name!r}') seen = set() for name in field_names: if name.startswith('_') and not rename: raise ValueError('Field names cannot start with an underscore: ' f'{name!r}') if name in seen: raise ValueError(f'Encountered duplicate field name: {name!r}') seen.add(name) field_defaults = {} if defaults is not None: defaults = tuple(defaults) if len(defaults) > len(field_names): raise TypeError('Got more default values than field names') field_defaults = dict(reversed(list(zip(reversed(field_names), reversed(defaults))))) # Variables used in the methods and docstrings field_names = tuple(map(_sys.intern, field_names)) num_fields = len(field_names) arg_list = ', '.join(field_names) if num_fields == 1: arg_list += ',' repr_fmt = '(' + ', '.join(f'{name}=%r' for name in field_names) + ')' tuple_new = tuple.__new__ _dict, _tuple, _len, _map, _zip = dict, tuple, len, map, zip # Create all the named tuple methods to be added to the class namespace namespace = { '_tuple_new': tuple_new, '__builtins__': {}, '__name__': f'namedtuple_{typename}', } code = f'lambda _cls, {arg_list}: _tuple_new(_cls, ({arg_list}))' __new__ = eval(code, namespace) __new__.__name__ = '__new__' __new__.__doc__ = f'Create new instance of {typename}({arg_list})' if defaults is not None: __new__.__defaults__ = defaults @classmethod def _make(cls, iterable): result = tuple_new(cls, iterable) if _len(result) != num_fields: raise TypeError(f'Expected {num_fields} arguments, got {len(result)}') return result _make.__func__.__doc__ = (f'Make a new {typename} object from a sequence ' 'or iterable') def _replace(self, /, **kwds): result = self._make(_map(kwds.pop, field_names, self)) if kwds: raise ValueError(f'Got unexpected field names: {list(kwds)!r}') return result _replace.__doc__ = (f'Return a new {typename} object replacing specified ' 'fields with new values') def __repr__(self): 'Return a nicely formatted representation string' return self.__class__.__name__ + repr_fmt % self def _asdict(self): 'Return a new dict which maps field names to their values.' return _dict(_zip(self._fields, self)) def __getnewargs__(self): 'Return self as a plain tuple. Used by copy and pickle.' return _tuple(self) # Modify function metadata to help with introspection and debugging for method in ( __new__, _make.__func__, _replace, __repr__, _asdict, __getnewargs__, ): method.__qualname__ = f'{typename}.{method.__name__}' # Build-up the class namespace dictionary # and use type() to build the result class class_namespace = { '__doc__': f'{typename}({arg_list})', '__slots__': (), '_fields': field_names, '_field_defaults': field_defaults, '__new__': __new__, '_make': _make, '_replace': _replace, '__repr__': __repr__, '_asdict': _asdict, '__getnewargs__': __getnewargs__, '__match_args__': field_names, } for index, name in enumerate(field_names): doc = _sys.intern(f'Alias for field number {index}') class_namespace[name] = _tuplegetter(index, doc) result = type(typename, (tuple,), class_namespace) # For pickling to work, the __module__ variable needs to be set to the frame # where the named tuple is created. Bypass this step in environments where # sys._getframe is not defined (Jython for example) or sys._getframe is not # defined for arguments greater than 0 (IronPython), or where the user has # specified a particular module. if module is None: try: module = _sys._getframe(1).f_globals.get('__name__', '__main__') except (AttributeError, ValueError): pass if module is not None: result.__module__ = module return result SEG = namedtuple("SEG", ['cropped_image', 'cropped_mask', 'confidence', 'crop_region', 'bbox', 'label', 'control_net_wrapper'], defaults=[None]) def crop_ndarray4(npimg, crop_region): x1 = crop_region[0] y1 = crop_region[1] x2 = crop_region[2] y2 = crop_region[3] cropped = npimg[:, y1:y2, x1:x2, :] return cropped crop_tensor4 = crop_ndarray4 def crop_ndarray2(npimg, crop_region): x1 = crop_region[0] y1 = crop_region[1] x2 = crop_region[2] y2 = crop_region[3] cropped = npimg[y1:y2, x1:x2] return cropped def crop_image(image, crop_region): return crop_tensor4(image, crop_region) def normalize_region(limit, startp, size): if startp < 0: new_endp = min(limit, size) new_startp = 0 elif startp + size > limit: new_startp = max(0, limit - size) new_endp = limit else: new_startp = startp new_endp = min(limit, startp+size) return int(new_startp), int(new_endp) def make_crop_region(w, h, bbox, crop_factor, crop_min_size=None): x1 = bbox[0] y1 = bbox[1] x2 = bbox[2] y2 = bbox[3] bbox_w = x2 - x1 bbox_h = y2 - y1 crop_w = bbox_w * crop_factor crop_h = bbox_h * crop_factor if crop_min_size is not None: crop_w = max(crop_min_size, crop_w) crop_h = max(crop_min_size, crop_h) kernel_x = x1 + bbox_w / 2 kernel_y = y1 + bbox_h / 2 new_x1 = int(kernel_x - crop_w / 2) new_y1 = int(kernel_y - crop_h / 2) # make sure position in (w,h) new_x1, new_x2 = normalize_region(w, new_x1, crop_w) new_y1, new_y2 = normalize_region(h, new_y1, crop_h) return [new_x1, new_y1, new_x2, new_y2] def create_segmasks(results): bboxs = results[1] segms = results[2] confidence = results[3] results = [] for i in range(len(segms)): item = (bboxs[i], segms[i].astype(np.float32), confidence[i]) results.append(item) return results def dilate_masks(segmasks, dilation_factor, iter=1): if dilation_factor == 0: return segmasks dilated_masks = [] kernel = np.ones((abs(dilation_factor), abs(dilation_factor)), np.uint8) kernel = cv2.UMat(kernel) for i in range(len(segmasks)): cv2_mask = segmasks[i][1] cv2_mask = cv2.UMat(cv2_mask) if dilation_factor > 0: dilated_mask = cv2.dilate(cv2_mask, kernel, iter) else: dilated_mask = cv2.erode(cv2_mask, kernel, iter) dilated_mask = dilated_mask.get() item = (segmasks[i][0], dilated_mask, segmasks[i][2]) dilated_masks.append(item) return dilated_masks def is_same_device(a, b): a_device = torch.device(a) if isinstance(a, str) else a b_device = torch.device(b) if isinstance(b, str) else b return a_device.type == b_device.type and a_device.index == b_device.index class SafeToGPU: def __init__(self, size): self.size = size def to_device(self, obj, device): if is_same_device(device, 'cpu'): obj.to(device) else: if is_same_device(obj.device, 'cpu'): # cpu to gpu model_management.free_memory(self.size * 1.3, device) if model_management.get_free_memory(device) > self.size * 1.3: try: obj.to(device) except: print(f"WARN: The model is not moved to the '{device}' due to insufficient memory. [1]") else: print(f"WARN: The model is not moved to the '{device}' due to insufficient memory. [2]") def center_of_bbox(bbox): w, h = bbox[2] - bbox[0], bbox[3] - bbox[1] return bbox[0] + w/2, bbox[1] + h/2 def sam_predict(predictor, points, plabs, bbox, threshold): point_coords = None if not points else np.array(points) point_labels = None if not plabs else np.array(plabs) box = np.array([bbox]) if bbox is not None else None cur_masks, scores, _ = predictor.predict(point_coords=point_coords, point_labels=point_labels, box=box) total_masks = [] selected = False max_score = 0 max_mask = None for idx in range(len(scores)): if scores[idx] > max_score: max_score = scores[idx] max_mask = cur_masks[idx] if scores[idx] >= threshold: selected = True total_masks.append(cur_masks[idx]) else: pass if not selected and max_mask is not None: total_masks.append(max_mask) return total_masks def make_2d_mask(mask): if len(mask.shape) == 4: return mask.squeeze(0).squeeze(0) elif len(mask.shape) == 3: return mask.squeeze(0) return mask def gen_detection_hints_from_mask_area(x, y, mask, threshold, use_negative): mask = make_2d_mask(mask) points = [] plabs = [] # minimum sampling step >= 3 y_step = max(3, int(mask.shape[0] / 20)) x_step = max(3, int(mask.shape[1] / 20)) for i in range(0, len(mask), y_step): for j in range(0, len(mask[i]), x_step): if mask[i][j] > threshold: points.append((x + j, y + i)) plabs.append(1) elif use_negative and mask[i][j] == 0: points.append((x + j, y + i)) plabs.append(0) return points, plabs def gen_negative_hints(w, h, x1, y1, x2, y2): npoints = [] nplabs = [] # minimum sampling step >= 3 y_step = max(3, int(w / 20)) x_step = max(3, int(h / 20)) for i in range(10, h - 10, y_step): for j in range(10, w - 10, x_step): if not (x1 - 10 <= j and j <= x2 + 10 and y1 - 10 <= i and i <= y2 + 10): npoints.append((j, i)) nplabs.append(0) return npoints, nplabs def generate_detection_hints(image, seg, center, detection_hint, dilated_bbox, mask_hint_threshold, use_small_negative, mask_hint_use_negative): [x1, y1, x2, y2] = dilated_bbox points = [] plabs = [] if detection_hint == "center-1": points.append(center) plabs = [1] # 1 = foreground point, 0 = background point elif detection_hint == "horizontal-2": gap = (x2 - x1) / 3 points.append((x1 + gap, center[1])) points.append((x1 + gap * 2, center[1])) plabs = [1, 1] elif detection_hint == "vertical-2": gap = (y2 - y1) / 3 points.append((center[0], y1 + gap)) points.append((center[0], y1 + gap * 2)) plabs = [1, 1] elif detection_hint == "rect-4": x_gap = (x2 - x1) / 3 y_gap = (y2 - y1) / 3 points.append((x1 + x_gap, center[1])) points.append((x1 + x_gap * 2, center[1])) points.append((center[0], y1 + y_gap)) points.append((center[0], y1 + y_gap * 2)) plabs = [1, 1, 1, 1] elif detection_hint == "diamond-4": x_gap = (x2 - x1) / 3 y_gap = (y2 - y1) / 3 points.append((x1 + x_gap, y1 + y_gap)) points.append((x1 + x_gap * 2, y1 + y_gap)) points.append((x1 + x_gap, y1 + y_gap * 2)) points.append((x1 + x_gap * 2, y1 + y_gap * 2)) plabs = [1, 1, 1, 1] elif detection_hint == "mask-point-bbox": center = center_of_bbox(seg.bbox) points.append(center) plabs = [1] elif detection_hint == "mask-area": points, plabs = gen_detection_hints_from_mask_area(seg.crop_region[0], seg.crop_region[1], seg.cropped_mask, mask_hint_threshold, use_small_negative) if mask_hint_use_negative == "Outter": npoints, nplabs = gen_negative_hints(image.shape[0], image.shape[1], seg.crop_region[0], seg.crop_region[1], seg.crop_region[2], seg.crop_region[3]) points += npoints plabs += nplabs return points, plabs def combine_masks2(masks): if len(masks) == 0: return None else: initial_cv2_mask = np.array(masks[0]).astype(np.uint8) combined_cv2_mask = initial_cv2_mask for i in range(1, len(masks)): cv2_mask = np.array(masks[i]).astype(np.uint8) if combined_cv2_mask.shape == cv2_mask.shape: combined_cv2_mask = cv2.bitwise_or(combined_cv2_mask, cv2_mask) else: # do nothing - incompatible mask pass mask = torch.from_numpy(combined_cv2_mask) return mask def dilate_mask(mask, dilation_factor, iter=1): if dilation_factor == 0: return make_2d_mask(mask) mask = make_2d_mask(mask) kernel = np.ones((abs(dilation_factor), abs(dilation_factor)), np.uint8) mask = cv2.UMat(mask) kernel = cv2.UMat(kernel) if dilation_factor > 0: result = cv2.dilate(mask, kernel, iter) else: result = cv2.erode(mask, kernel, iter) return result.get() def convert_and_stack_masks(masks): if len(masks) == 0: return None mask_tensors = [] for mask in masks: mask_array = np.array(mask, dtype=np.uint8) mask_tensor = torch.from_numpy(mask_array) mask_tensors.append(mask_tensor) stacked_masks = torch.stack(mask_tensors, dim=0) stacked_masks = stacked_masks.unsqueeze(1) return stacked_masks def merge_and_stack_masks(stacked_masks, group_size): if stacked_masks is None: return None num_masks = stacked_masks.size(0) merged_masks = [] for i in range(0, num_masks, group_size): subset_masks = stacked_masks[i:i + group_size] merged_mask = torch.any(subset_masks, dim=0) merged_masks.append(merged_mask) if len(merged_masks) > 0: merged_masks = torch.stack(merged_masks, dim=0) return merged_masks def make_sam_mask_segmented(sam_model, segs, image, detection_hint, dilation, threshold, bbox_expansion, mask_hint_threshold, mask_hint_use_negative): if sam_model.is_auto_mode: device = model_management.get_torch_device() sam_model.safe_to.to_device(sam_model, device=device) try: predictor = SamPredictor(sam_model) image = np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8) predictor.set_image(image, "RGB") total_masks = [] use_small_negative = mask_hint_use_negative == "Small" # seg_shape = segs[0] segs = segs[1] if detection_hint == "mask-points": points = [] plabs = [] for i in range(len(segs)): bbox = segs[i].bbox center = center_of_bbox(bbox) points.append(center) # small point is background, big point is foreground if use_small_negative and bbox[2] - bbox[0] < 10: plabs.append(0) else: plabs.append(1) detected_masks = sam_predict(predictor, points, plabs, None, threshold) total_masks += detected_masks else: for i in range(len(segs)): bbox = segs[i].bbox center = center_of_bbox(bbox) x1 = max(bbox[0] - bbox_expansion, 0) y1 = max(bbox[1] - bbox_expansion, 0) x2 = min(bbox[2] + bbox_expansion, image.shape[1]) y2 = min(bbox[3] + bbox_expansion, image.shape[0]) dilated_bbox = [x1, y1, x2, y2] points, plabs = generate_detection_hints(image, segs[i], center, detection_hint, dilated_bbox, mask_hint_threshold, use_small_negative, mask_hint_use_negative) detected_masks = sam_predict(predictor, points, plabs, dilated_bbox, threshold) total_masks += detected_masks # merge every collected masks mask = combine_masks2(total_masks) finally: if sam_model.is_auto_mode: sam_model.cpu() pass mask_working_device = torch.device("cpu") if mask is not None: mask = mask.float() mask = dilate_mask(mask.cpu().numpy(), dilation) mask = torch.from_numpy(mask) mask = mask.to(device=mask_working_device) else: # Extracting batch, height and width height, width, _ = image.shape mask = torch.zeros( (height, width), dtype=torch.float32, device=mask_working_device ) # empty mask stacked_masks = convert_and_stack_masks(total_masks) return (mask, merge_and_stack_masks(stacked_masks, group_size=3)) def tensor2mask(t: torch.Tensor) -> torch.Tensor: size = t.size() if (len(size) < 4): return t if size[3] == 1: return t[:,:,:,0] elif size[3] == 4: # Not sure what the right thing to do here is. Going to try to be a little smart and use alpha unless all alpha is 1 in case we'll fallback to RGB behavior if torch.min(t[:, :, :, 3]).item() != 1.: return t[:,:,:,3] return TF.rgb_to_grayscale(tensor2rgb(t).permute(0,3,1,2), num_output_channels=1)[:,0,:,:] def tensor2rgb(t: torch.Tensor) -> torch.Tensor: size = t.size() if (len(size) < 4): return t.unsqueeze(3).repeat(1, 1, 1, 3) if size[3] == 1: return t.repeat(1, 1, 1, 3) elif size[3] == 4: return t[:, :, :, :3] else: return t def tensor2rgba(t: torch.Tensor) -> torch.Tensor: size = t.size() if (len(size) < 4): return t.unsqueeze(3).repeat(1, 1, 1, 4) elif size[3] == 1: return t.repeat(1, 1, 1, 4) elif size[3] == 3: alpha_tensor = torch.ones((size[0], size[1], size[2], 1)) return torch.cat((t, alpha_tensor), dim=3) else: return t