Spaces:
Running
Running
import comfy | |
import re | |
from impact.utils import * | |
hf_transformer_model_urls = [ | |
"rizvandwiki/gender-classification-2", | |
"NTQAI/pedestrian_gender_recognition", | |
"Leilab/gender_class", | |
"ProjectPersonal/GenderClassifier", | |
"crangana/trained-gender", | |
"cledoux42/GenderNew_v002", | |
"ivensamdh/genderage2" | |
] | |
class HF_TransformersClassifierProvider: | |
def INPUT_TYPES(s): | |
global hf_transformer_model_urls | |
return {"required": { | |
"preset_repo_id": (hf_transformer_model_urls + ['Manual repo id'],), | |
"manual_repo_id": ("STRING", {"multiline": False}), | |
"device_mode": (["AUTO", "Prefer GPU", "CPU"],), | |
}, | |
} | |
RETURN_TYPES = ("TRANSFORMERS_CLASSIFIER",) | |
FUNCTION = "doit" | |
CATEGORY = "ImpactPack/HuggingFace" | |
def doit(self, preset_repo_id, manual_repo_id, device_mode): | |
from transformers import pipeline | |
if preset_repo_id == 'Manual repo id': | |
url = manual_repo_id | |
else: | |
url = preset_repo_id | |
if device_mode != 'CPU': | |
device = comfy.model_management.get_torch_device() | |
else: | |
device = "cpu" | |
classifier = pipeline(model=url, device=device) | |
return (classifier,) | |
preset_classify_expr = [ | |
'#Female > #Male', | |
'#Female < #Male', | |
'female > 0.5', | |
'male > 0.5', | |
'Age16to25 > 0.1', | |
'Age50to69 > 0.1', | |
] | |
symbolic_label_map = { | |
'#Female': {'female', 'Female', 'Human Female', 'woman', 'women', 'girl'}, | |
'#Male': {'male', 'Male', 'Human Male', 'man', 'men', 'boy'} | |
} | |
def is_numeric_string(input_str): | |
return re.match(r'^-?\d+(\.\d+)?$', input_str) is not None | |
classify_expr_pattern = r'([^><= ]+)\s*(>|<|>=|<=|=)\s*([^><= ]+)' | |
class SEGS_Classify: | |
def INPUT_TYPES(s): | |
global preset_classify_expr | |
return {"required": { | |
"classifier": ("TRANSFORMERS_CLASSIFIER",), | |
"segs": ("SEGS",), | |
"preset_expr": (preset_classify_expr + ['Manual expr'],), | |
"manual_expr": ("STRING", {"multiline": False}), | |
}, | |
"optional": { | |
"ref_image_opt": ("IMAGE", ), | |
} | |
} | |
RETURN_TYPES = ("SEGS", "SEGS",) | |
RETURN_NAMES = ("filtered_SEGS", "remained_SEGS",) | |
FUNCTION = "doit" | |
CATEGORY = "ImpactPack/HuggingFace" | |
def lookup_classified_label_score(score_infos, label): | |
global symbolic_label_map | |
if label.startswith('#'): | |
if label not in symbolic_label_map: | |
return None | |
else: | |
label = symbolic_label_map[label] | |
else: | |
label = {label} | |
for x in score_infos: | |
if x['label'] in label: | |
return x['score'] | |
return None | |
def doit(self, classifier, segs, preset_expr, manual_expr, ref_image_opt=None): | |
if preset_expr == 'Manual expr': | |
expr_str = manual_expr | |
else: | |
expr_str = preset_expr | |
match = re.match(classify_expr_pattern, expr_str) | |
if match is None: | |
return ((segs[0], []), segs) | |
a = match.group(1) | |
op = match.group(2) | |
b = match.group(3) | |
a_is_lab = not is_numeric_string(a) | |
b_is_lab = not is_numeric_string(b) | |
classified = [] | |
remained_SEGS = [] | |
for seg in segs[1]: | |
cropped_image = None | |
if seg.cropped_image is not None: | |
cropped_image = seg.cropped_image | |
elif ref_image_opt is not None: | |
# take from original image | |
cropped_image = crop_image(ref_image_opt, seg.crop_region) | |
if cropped_image is not None: | |
cropped_image = to_pil(cropped_image) | |
res = classifier(cropped_image) | |
classified.append((seg, res)) | |
else: | |
remained_SEGS.append(seg) | |
filtered_SEGS = [] | |
for seg, res in classified: | |
if a_is_lab: | |
avalue = SEGS_Classify.lookup_classified_label_score(res, a) | |
else: | |
avalue = a | |
if b_is_lab: | |
bvalue = SEGS_Classify.lookup_classified_label_score(res, b) | |
else: | |
bvalue = b | |
if avalue is None or bvalue is None: | |
remained_SEGS.append(seg) | |
continue | |
avalue = float(avalue) | |
bvalue = float(bvalue) | |
if op == '>': | |
cond = avalue > bvalue | |
elif op == '<': | |
cond = avalue < bvalue | |
elif op == '>=': | |
cond = avalue >= bvalue | |
elif op == '<=': | |
cond = avalue <= bvalue | |
else: | |
cond = avalue == bvalue | |
if cond: | |
filtered_SEGS.append(seg) | |
else: | |
remained_SEGS.append(seg) | |
return ((segs[0], filtered_SEGS), (segs[0], remained_SEGS)) | |