daquanzhou
merge github repos and lfs track ckpt/path/safetensors/pt
613c9ab
raw
history blame
5.19 kB
import comfy
import re
from impact.utils import *
hf_transformer_model_urls = [
"rizvandwiki/gender-classification-2",
"NTQAI/pedestrian_gender_recognition",
"Leilab/gender_class",
"ProjectPersonal/GenderClassifier",
"crangana/trained-gender",
"cledoux42/GenderNew_v002",
"ivensamdh/genderage2"
]
class HF_TransformersClassifierProvider:
@classmethod
def INPUT_TYPES(s):
global hf_transformer_model_urls
return {"required": {
"preset_repo_id": (hf_transformer_model_urls + ['Manual repo id'],),
"manual_repo_id": ("STRING", {"multiline": False}),
"device_mode": (["AUTO", "Prefer GPU", "CPU"],),
},
}
RETURN_TYPES = ("TRANSFORMERS_CLASSIFIER",)
FUNCTION = "doit"
CATEGORY = "ImpactPack/HuggingFace"
def doit(self, preset_repo_id, manual_repo_id, device_mode):
from transformers import pipeline
if preset_repo_id == 'Manual repo id':
url = manual_repo_id
else:
url = preset_repo_id
if device_mode != 'CPU':
device = comfy.model_management.get_torch_device()
else:
device = "cpu"
classifier = pipeline(model=url, device=device)
return (classifier,)
preset_classify_expr = [
'#Female > #Male',
'#Female < #Male',
'female > 0.5',
'male > 0.5',
'Age16to25 > 0.1',
'Age50to69 > 0.1',
]
symbolic_label_map = {
'#Female': {'female', 'Female', 'Human Female', 'woman', 'women', 'girl'},
'#Male': {'male', 'Male', 'Human Male', 'man', 'men', 'boy'}
}
def is_numeric_string(input_str):
return re.match(r'^-?\d+(\.\d+)?$', input_str) is not None
classify_expr_pattern = r'([^><= ]+)\s*(>|<|>=|<=|=)\s*([^><= ]+)'
class SEGS_Classify:
@classmethod
def INPUT_TYPES(s):
global preset_classify_expr
return {"required": {
"classifier": ("TRANSFORMERS_CLASSIFIER",),
"segs": ("SEGS",),
"preset_expr": (preset_classify_expr + ['Manual expr'],),
"manual_expr": ("STRING", {"multiline": False}),
},
"optional": {
"ref_image_opt": ("IMAGE", ),
}
}
RETURN_TYPES = ("SEGS", "SEGS",)
RETURN_NAMES = ("filtered_SEGS", "remained_SEGS",)
FUNCTION = "doit"
CATEGORY = "ImpactPack/HuggingFace"
@staticmethod
def lookup_classified_label_score(score_infos, label):
global symbolic_label_map
if label.startswith('#'):
if label not in symbolic_label_map:
return None
else:
label = symbolic_label_map[label]
else:
label = {label}
for x in score_infos:
if x['label'] in label:
return x['score']
return None
def doit(self, classifier, segs, preset_expr, manual_expr, ref_image_opt=None):
if preset_expr == 'Manual expr':
expr_str = manual_expr
else:
expr_str = preset_expr
match = re.match(classify_expr_pattern, expr_str)
if match is None:
return ((segs[0], []), segs)
a = match.group(1)
op = match.group(2)
b = match.group(3)
a_is_lab = not is_numeric_string(a)
b_is_lab = not is_numeric_string(b)
classified = []
remained_SEGS = []
for seg in segs[1]:
cropped_image = None
if seg.cropped_image is not None:
cropped_image = seg.cropped_image
elif ref_image_opt is not None:
# take from original image
cropped_image = crop_image(ref_image_opt, seg.crop_region)
if cropped_image is not None:
cropped_image = to_pil(cropped_image)
res = classifier(cropped_image)
classified.append((seg, res))
else:
remained_SEGS.append(seg)
filtered_SEGS = []
for seg, res in classified:
if a_is_lab:
avalue = SEGS_Classify.lookup_classified_label_score(res, a)
else:
avalue = a
if b_is_lab:
bvalue = SEGS_Classify.lookup_classified_label_score(res, b)
else:
bvalue = b
if avalue is None or bvalue is None:
remained_SEGS.append(seg)
continue
avalue = float(avalue)
bvalue = float(bvalue)
if op == '>':
cond = avalue > bvalue
elif op == '<':
cond = avalue < bvalue
elif op == '>=':
cond = avalue >= bvalue
elif op == '<=':
cond = avalue <= bvalue
else:
cond = avalue == bvalue
if cond:
filtered_SEGS.append(seg)
else:
remained_SEGS.append(seg)
return ((segs[0], filtered_SEGS), (segs[0], remained_SEGS))