ajithradnus's picture
Upload 108 files
09cc77a verified
import numpy as np
import cv2
import torch
import torchvision.transforms.functional as TF
import sys as _sys
from keyword import iskeyword as _iskeyword
from operator import itemgetter as _itemgetter
from segment_anything import SamPredictor
from comfy import model_management
################################################################################
### namedtuple
################################################################################
try:
from _collections import _tuplegetter
except ImportError:
_tuplegetter = lambda index, doc: property(_itemgetter(index), doc=doc)
def namedtuple(typename, field_names, *, rename=False, defaults=None, module=None):
"""Returns a new subclass of tuple with named fields.
>>> Point = namedtuple('Point', ['x', 'y'])
>>> Point.__doc__ # docstring for the new class
'Point(x, y)'
>>> p = Point(11, y=22) # instantiate with positional args or keywords
>>> p[0] + p[1] # indexable like a plain tuple
33
>>> x, y = p # unpack like a regular tuple
>>> x, y
(11, 22)
>>> p.x + p.y # fields also accessible by name
33
>>> d = p._asdict() # convert to a dictionary
>>> d['x']
11
>>> Point(**d) # convert from a dictionary
Point(x=11, y=22)
>>> p._replace(x=100) # _replace() is like str.replace() but targets named fields
Point(x=100, y=22)
"""
# Validate the field names. At the user's option, either generate an error
# message or automatically replace the field name with a valid name.
if isinstance(field_names, str):
field_names = field_names.replace(',', ' ').split()
field_names = list(map(str, field_names))
typename = _sys.intern(str(typename))
if rename:
seen = set()
for index, name in enumerate(field_names):
if (not name.isidentifier()
or _iskeyword(name)
or name.startswith('_')
or name in seen):
field_names[index] = f'_{index}'
seen.add(name)
for name in [typename] + field_names:
if type(name) is not str:
raise TypeError('Type names and field names must be strings')
if not name.isidentifier():
raise ValueError('Type names and field names must be valid '
f'identifiers: {name!r}')
if _iskeyword(name):
raise ValueError('Type names and field names cannot be a '
f'keyword: {name!r}')
seen = set()
for name in field_names:
if name.startswith('_') and not rename:
raise ValueError('Field names cannot start with an underscore: '
f'{name!r}')
if name in seen:
raise ValueError(f'Encountered duplicate field name: {name!r}')
seen.add(name)
field_defaults = {}
if defaults is not None:
defaults = tuple(defaults)
if len(defaults) > len(field_names):
raise TypeError('Got more default values than field names')
field_defaults = dict(reversed(list(zip(reversed(field_names),
reversed(defaults)))))
# Variables used in the methods and docstrings
field_names = tuple(map(_sys.intern, field_names))
num_fields = len(field_names)
arg_list = ', '.join(field_names)
if num_fields == 1:
arg_list += ','
repr_fmt = '(' + ', '.join(f'{name}=%r' for name in field_names) + ')'
tuple_new = tuple.__new__
_dict, _tuple, _len, _map, _zip = dict, tuple, len, map, zip
# Create all the named tuple methods to be added to the class namespace
namespace = {
'_tuple_new': tuple_new,
'__builtins__': {},
'__name__': f'namedtuple_{typename}',
}
code = f'lambda _cls, {arg_list}: _tuple_new(_cls, ({arg_list}))'
__new__ = eval(code, namespace)
__new__.__name__ = '__new__'
__new__.__doc__ = f'Create new instance of {typename}({arg_list})'
if defaults is not None:
__new__.__defaults__ = defaults
@classmethod
def _make(cls, iterable):
result = tuple_new(cls, iterable)
if _len(result) != num_fields:
raise TypeError(f'Expected {num_fields} arguments, got {len(result)}')
return result
_make.__func__.__doc__ = (f'Make a new {typename} object from a sequence '
'or iterable')
def _replace(self, /, **kwds):
result = self._make(_map(kwds.pop, field_names, self))
if kwds:
raise ValueError(f'Got unexpected field names: {list(kwds)!r}')
return result
_replace.__doc__ = (f'Return a new {typename} object replacing specified '
'fields with new values')
def __repr__(self):
'Return a nicely formatted representation string'
return self.__class__.__name__ + repr_fmt % self
def _asdict(self):
'Return a new dict which maps field names to their values.'
return _dict(_zip(self._fields, self))
def __getnewargs__(self):
'Return self as a plain tuple. Used by copy and pickle.'
return _tuple(self)
# Modify function metadata to help with introspection and debugging
for method in (
__new__,
_make.__func__,
_replace,
__repr__,
_asdict,
__getnewargs__,
):
method.__qualname__ = f'{typename}.{method.__name__}'
# Build-up the class namespace dictionary
# and use type() to build the result class
class_namespace = {
'__doc__': f'{typename}({arg_list})',
'__slots__': (),
'_fields': field_names,
'_field_defaults': field_defaults,
'__new__': __new__,
'_make': _make,
'_replace': _replace,
'__repr__': __repr__,
'_asdict': _asdict,
'__getnewargs__': __getnewargs__,
'__match_args__': field_names,
}
for index, name in enumerate(field_names):
doc = _sys.intern(f'Alias for field number {index}')
class_namespace[name] = _tuplegetter(index, doc)
result = type(typename, (tuple,), class_namespace)
# For pickling to work, the __module__ variable needs to be set to the frame
# where the named tuple is created. Bypass this step in environments where
# sys._getframe is not defined (Jython for example) or sys._getframe is not
# defined for arguments greater than 0 (IronPython), or where the user has
# specified a particular module.
if module is None:
try:
module = _sys._getframe(1).f_globals.get('__name__', '__main__')
except (AttributeError, ValueError):
pass
if module is not None:
result.__module__ = module
return result
SEG = namedtuple("SEG",
['cropped_image', 'cropped_mask', 'confidence', 'crop_region', 'bbox', 'label', 'control_net_wrapper'],
defaults=[None])
def crop_ndarray4(npimg, crop_region):
x1 = crop_region[0]
y1 = crop_region[1]
x2 = crop_region[2]
y2 = crop_region[3]
cropped = npimg[:, y1:y2, x1:x2, :]
return cropped
crop_tensor4 = crop_ndarray4
def crop_ndarray2(npimg, crop_region):
x1 = crop_region[0]
y1 = crop_region[1]
x2 = crop_region[2]
y2 = crop_region[3]
cropped = npimg[y1:y2, x1:x2]
return cropped
def crop_image(image, crop_region):
return crop_tensor4(image, crop_region)
def normalize_region(limit, startp, size):
if startp < 0:
new_endp = min(limit, size)
new_startp = 0
elif startp + size > limit:
new_startp = max(0, limit - size)
new_endp = limit
else:
new_startp = startp
new_endp = min(limit, startp+size)
return int(new_startp), int(new_endp)
def make_crop_region(w, h, bbox, crop_factor, crop_min_size=None):
x1 = bbox[0]
y1 = bbox[1]
x2 = bbox[2]
y2 = bbox[3]
bbox_w = x2 - x1
bbox_h = y2 - y1
crop_w = bbox_w * crop_factor
crop_h = bbox_h * crop_factor
if crop_min_size is not None:
crop_w = max(crop_min_size, crop_w)
crop_h = max(crop_min_size, crop_h)
kernel_x = x1 + bbox_w / 2
kernel_y = y1 + bbox_h / 2
new_x1 = int(kernel_x - crop_w / 2)
new_y1 = int(kernel_y - crop_h / 2)
# make sure position in (w,h)
new_x1, new_x2 = normalize_region(w, new_x1, crop_w)
new_y1, new_y2 = normalize_region(h, new_y1, crop_h)
return [new_x1, new_y1, new_x2, new_y2]
def create_segmasks(results):
bboxs = results[1]
segms = results[2]
confidence = results[3]
results = []
for i in range(len(segms)):
item = (bboxs[i], segms[i].astype(np.float32), confidence[i])
results.append(item)
return results
def dilate_masks(segmasks, dilation_factor, iter=1):
if dilation_factor == 0:
return segmasks
dilated_masks = []
kernel = np.ones((abs(dilation_factor), abs(dilation_factor)), np.uint8)
kernel = cv2.UMat(kernel)
for i in range(len(segmasks)):
cv2_mask = segmasks[i][1]
cv2_mask = cv2.UMat(cv2_mask)
if dilation_factor > 0:
dilated_mask = cv2.dilate(cv2_mask, kernel, iter)
else:
dilated_mask = cv2.erode(cv2_mask, kernel, iter)
dilated_mask = dilated_mask.get()
item = (segmasks[i][0], dilated_mask, segmasks[i][2])
dilated_masks.append(item)
return dilated_masks
def is_same_device(a, b):
a_device = torch.device(a) if isinstance(a, str) else a
b_device = torch.device(b) if isinstance(b, str) else b
return a_device.type == b_device.type and a_device.index == b_device.index
class SafeToGPU:
def __init__(self, size):
self.size = size
def to_device(self, obj, device):
if is_same_device(device, 'cpu'):
obj.to(device)
else:
if is_same_device(obj.device, 'cpu'): # cpu to gpu
model_management.free_memory(self.size * 1.3, device)
if model_management.get_free_memory(device) > self.size * 1.3:
try:
obj.to(device)
except:
print(f"WARN: The model is not moved to the '{device}' due to insufficient memory. [1]")
else:
print(f"WARN: The model is not moved to the '{device}' due to insufficient memory. [2]")
def center_of_bbox(bbox):
w, h = bbox[2] - bbox[0], bbox[3] - bbox[1]
return bbox[0] + w/2, bbox[1] + h/2
def sam_predict(predictor, points, plabs, bbox, threshold):
point_coords = None if not points else np.array(points)
point_labels = None if not plabs else np.array(plabs)
box = np.array([bbox]) if bbox is not None else None
cur_masks, scores, _ = predictor.predict(point_coords=point_coords, point_labels=point_labels, box=box)
total_masks = []
selected = False
max_score = 0
max_mask = None
for idx in range(len(scores)):
if scores[idx] > max_score:
max_score = scores[idx]
max_mask = cur_masks[idx]
if scores[idx] >= threshold:
selected = True
total_masks.append(cur_masks[idx])
else:
pass
if not selected and max_mask is not None:
total_masks.append(max_mask)
return total_masks
def make_2d_mask(mask):
if len(mask.shape) == 4:
return mask.squeeze(0).squeeze(0)
elif len(mask.shape) == 3:
return mask.squeeze(0)
return mask
def gen_detection_hints_from_mask_area(x, y, mask, threshold, use_negative):
mask = make_2d_mask(mask)
points = []
plabs = []
# minimum sampling step >= 3
y_step = max(3, int(mask.shape[0] / 20))
x_step = max(3, int(mask.shape[1] / 20))
for i in range(0, len(mask), y_step):
for j in range(0, len(mask[i]), x_step):
if mask[i][j] > threshold:
points.append((x + j, y + i))
plabs.append(1)
elif use_negative and mask[i][j] == 0:
points.append((x + j, y + i))
plabs.append(0)
return points, plabs
def gen_negative_hints(w, h, x1, y1, x2, y2):
npoints = []
nplabs = []
# minimum sampling step >= 3
y_step = max(3, int(w / 20))
x_step = max(3, int(h / 20))
for i in range(10, h - 10, y_step):
for j in range(10, w - 10, x_step):
if not (x1 - 10 <= j and j <= x2 + 10 and y1 - 10 <= i and i <= y2 + 10):
npoints.append((j, i))
nplabs.append(0)
return npoints, nplabs
def generate_detection_hints(image, seg, center, detection_hint, dilated_bbox, mask_hint_threshold, use_small_negative,
mask_hint_use_negative):
[x1, y1, x2, y2] = dilated_bbox
points = []
plabs = []
if detection_hint == "center-1":
points.append(center)
plabs = [1] # 1 = foreground point, 0 = background point
elif detection_hint == "horizontal-2":
gap = (x2 - x1) / 3
points.append((x1 + gap, center[1]))
points.append((x1 + gap * 2, center[1]))
plabs = [1, 1]
elif detection_hint == "vertical-2":
gap = (y2 - y1) / 3
points.append((center[0], y1 + gap))
points.append((center[0], y1 + gap * 2))
plabs = [1, 1]
elif detection_hint == "rect-4":
x_gap = (x2 - x1) / 3
y_gap = (y2 - y1) / 3
points.append((x1 + x_gap, center[1]))
points.append((x1 + x_gap * 2, center[1]))
points.append((center[0], y1 + y_gap))
points.append((center[0], y1 + y_gap * 2))
plabs = [1, 1, 1, 1]
elif detection_hint == "diamond-4":
x_gap = (x2 - x1) / 3
y_gap = (y2 - y1) / 3
points.append((x1 + x_gap, y1 + y_gap))
points.append((x1 + x_gap * 2, y1 + y_gap))
points.append((x1 + x_gap, y1 + y_gap * 2))
points.append((x1 + x_gap * 2, y1 + y_gap * 2))
plabs = [1, 1, 1, 1]
elif detection_hint == "mask-point-bbox":
center = center_of_bbox(seg.bbox)
points.append(center)
plabs = [1]
elif detection_hint == "mask-area":
points, plabs = gen_detection_hints_from_mask_area(seg.crop_region[0], seg.crop_region[1],
seg.cropped_mask,
mask_hint_threshold, use_small_negative)
if mask_hint_use_negative == "Outter":
npoints, nplabs = gen_negative_hints(image.shape[0], image.shape[1],
seg.crop_region[0], seg.crop_region[1],
seg.crop_region[2], seg.crop_region[3])
points += npoints
plabs += nplabs
return points, plabs
def combine_masks2(masks):
if len(masks) == 0:
return None
else:
initial_cv2_mask = np.array(masks[0]).astype(np.uint8)
combined_cv2_mask = initial_cv2_mask
for i in range(1, len(masks)):
cv2_mask = np.array(masks[i]).astype(np.uint8)
if combined_cv2_mask.shape == cv2_mask.shape:
combined_cv2_mask = cv2.bitwise_or(combined_cv2_mask, cv2_mask)
else:
# do nothing - incompatible mask
pass
mask = torch.from_numpy(combined_cv2_mask)
return mask
def dilate_mask(mask, dilation_factor, iter=1):
if dilation_factor == 0:
return make_2d_mask(mask)
mask = make_2d_mask(mask)
kernel = np.ones((abs(dilation_factor), abs(dilation_factor)), np.uint8)
mask = cv2.UMat(mask)
kernel = cv2.UMat(kernel)
if dilation_factor > 0:
result = cv2.dilate(mask, kernel, iter)
else:
result = cv2.erode(mask, kernel, iter)
return result.get()
def convert_and_stack_masks(masks):
if len(masks) == 0:
return None
mask_tensors = []
for mask in masks:
mask_array = np.array(mask, dtype=np.uint8)
mask_tensor = torch.from_numpy(mask_array)
mask_tensors.append(mask_tensor)
stacked_masks = torch.stack(mask_tensors, dim=0)
stacked_masks = stacked_masks.unsqueeze(1)
return stacked_masks
def merge_and_stack_masks(stacked_masks, group_size):
if stacked_masks is None:
return None
num_masks = stacked_masks.size(0)
merged_masks = []
for i in range(0, num_masks, group_size):
subset_masks = stacked_masks[i:i + group_size]
merged_mask = torch.any(subset_masks, dim=0)
merged_masks.append(merged_mask)
if len(merged_masks) > 0:
merged_masks = torch.stack(merged_masks, dim=0)
return merged_masks
def make_sam_mask_segmented(sam_model, segs, image, detection_hint, dilation,
threshold, bbox_expansion, mask_hint_threshold, mask_hint_use_negative):
if sam_model.is_auto_mode:
device = model_management.get_torch_device()
sam_model.safe_to.to_device(sam_model, device=device)
try:
predictor = SamPredictor(sam_model)
image = np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)
predictor.set_image(image, "RGB")
total_masks = []
use_small_negative = mask_hint_use_negative == "Small"
# seg_shape = segs[0]
segs = segs[1]
if detection_hint == "mask-points":
points = []
plabs = []
for i in range(len(segs)):
bbox = segs[i].bbox
center = center_of_bbox(bbox)
points.append(center)
# small point is background, big point is foreground
if use_small_negative and bbox[2] - bbox[0] < 10:
plabs.append(0)
else:
plabs.append(1)
detected_masks = sam_predict(predictor, points, plabs, None, threshold)
total_masks += detected_masks
else:
for i in range(len(segs)):
bbox = segs[i].bbox
center = center_of_bbox(bbox)
x1 = max(bbox[0] - bbox_expansion, 0)
y1 = max(bbox[1] - bbox_expansion, 0)
x2 = min(bbox[2] + bbox_expansion, image.shape[1])
y2 = min(bbox[3] + bbox_expansion, image.shape[0])
dilated_bbox = [x1, y1, x2, y2]
points, plabs = generate_detection_hints(image, segs[i], center, detection_hint, dilated_bbox,
mask_hint_threshold, use_small_negative,
mask_hint_use_negative)
detected_masks = sam_predict(predictor, points, plabs, dilated_bbox, threshold)
total_masks += detected_masks
# merge every collected masks
mask = combine_masks2(total_masks)
finally:
if sam_model.is_auto_mode:
sam_model.cpu()
pass
mask_working_device = torch.device("cpu")
if mask is not None:
mask = mask.float()
mask = dilate_mask(mask.cpu().numpy(), dilation)
mask = torch.from_numpy(mask)
mask = mask.to(device=mask_working_device)
else:
# Extracting batch, height and width
height, width, _ = image.shape
mask = torch.zeros(
(height, width), dtype=torch.float32, device=mask_working_device
) # empty mask
stacked_masks = convert_and_stack_masks(total_masks)
return (mask, merge_and_stack_masks(stacked_masks, group_size=3))
def tensor2mask(t: torch.Tensor) -> torch.Tensor:
size = t.size()
if (len(size) < 4):
return t
if size[3] == 1:
return t[:,:,:,0]
elif size[3] == 4:
# Not sure what the right thing to do here is. Going to try to be a little smart and use alpha unless all alpha is 1 in case we'll fallback to RGB behavior
if torch.min(t[:, :, :, 3]).item() != 1.:
return t[:,:,:,3]
return TF.rgb_to_grayscale(tensor2rgb(t).permute(0,3,1,2), num_output_channels=1)[:,0,:,:]
def tensor2rgb(t: torch.Tensor) -> torch.Tensor:
size = t.size()
if (len(size) < 4):
return t.unsqueeze(3).repeat(1, 1, 1, 3)
if size[3] == 1:
return t.repeat(1, 1, 1, 3)
elif size[3] == 4:
return t[:, :, :, :3]
else:
return t
def tensor2rgba(t: torch.Tensor) -> torch.Tensor:
size = t.size()
if (len(size) < 4):
return t.unsqueeze(3).repeat(1, 1, 1, 4)
elif size[3] == 1:
return t.repeat(1, 1, 1, 4)
elif size[3] == 3:
alpha_tensor = torch.ones((size[0], size[1], size[2], 1))
return torch.cat((t, alpha_tensor), dim=3)
else:
return t