import re import time from colorama import Fore, Style import spacy from ..utils import VyroParams from ..utils.prompt import calc_prompt class VyroPromptAnalyzer: def __init__(self): pass @classmethod def INPUT_TYPES(s): return { "required": { "vyro_params": ("VYRO_PARAMS",), "styles": ("LIST",), "prompt_tree": ("DICT",), "classifier": ("TRANSFORMER",), "debug": (VyroParams.STATES, {"default": VyroParams.STATES[0]}), "skip": (VyroParams.STATES, {"default": VyroParams.STATES[0]}) } } RETURN_TYPES = ("VYRO_PARAMS", "STYLE",) RETURN_NAMES = ("vyro_params", "style", ) FUNCTION = "analyze_prompt" CATEGORY = "Vyro/Prompt" def analyze_prompt(self, vyro_params:VyroParams, styles:list, prompt_tree:dict, classifier:spacy.language.Language, debug, skip): user_prompt = vyro_params.user_prompt if '--raw' in user_prompt.lower() or skip == VyroParams.STATES[1]: user_prompt = user_prompt.lower().replace('--raw', '') vyro_params.is_raw = True vyro_params.user_prompt = user_prompt classifier_target = user_prompt if vyro_params.is_raw==True: vyro_params.user_prompt = user_prompt vyro_params.final_positive_prompt = vyro_params.user_prompt vyro_params.final_negative_prompt = vyro_params.user_neg_prompt #Check if its in qrcode mode, this way we can skip prompt analyze if vyro_params.mode =="qr": vyro_params.style = "qr" return (vyro_params, "qr", ) vyro_params.style = "" return (vyro_params, "", ) if '--style:' in user_prompt.lower(): # Extract style from prompt # Style is in the format of --style:style_name or --style:"style name" style_r = r'\-\-style\:(\"[a-zA-Z0-9\s]+\"|[a-zA-Z0-9_]+)' extracted_style = re.findall(style_r, user_prompt) if len(extracted_style) > 0: extracted_style = extracted_style[0] if extracted_style.startswith('"'): extracted_style = extracted_style[1:-1] user_prompt = re.sub(style_r, '', user_prompt) else: user_prompt = user_prompt.replace('--style:', '') vyro_params.user_prompt = user_prompt classifier_target = extracted_style if classifier_target in styles: vyro_params.style = classifier_target calc_prompt(vyro_params, prompt_tree) return (vyro_params, classifier_target, ) deweighted_styles = [] weights = [] for style in styles: if ':' in style: weight = style.split(':')[1] style = style.split(':')[0] if not isinstance(weight, float): weight = float(weight) else: weight = 1.0 deweighted_styles.append(style) weights.append(weight) # Use try-except to handle classifier errors try: t1 = time.time() doc = classifier(classifier_target) t2 = time.time() # Check if doc.cats exists and is not empty if not hasattr(doc, 'cats') or not doc.cats: raise ValueError("No categories found in classifier output") top_label = sorted(doc.cats.items(), key=lambda x: x[1], reverse=True)[:1] if not top_label: raise ValueError("No top label found in classifier output") top_label = top_label[0][0].replace('_', ' ') scores = [] labels = [] for i in range(len(deweighted_styles)): style = deweighted_styles[i].replace(' ', '_') if style in doc.cats.keys(): score = doc.cats[style] * weights[i] scores.append(score) labels.append(deweighted_styles[i]) if not scores: raise ValueError("No matching styles found in classifier output") zipped = zip(scores, labels) weighted_scores = [] for i, (score, label) in enumerate(zipped): style = label idx = deweighted_styles.index(style) weight = weights[idx] weighted_scores.append((label, score * weight, weight)) weighted_scores.sort(key=lambda x: x[1], reverse=True) top_x = 3 if len(weighted_scores) > 3 else len(weighted_scores) top3 = weighted_scores[:top_x] if debug == VyroParams.STATES[1]: print(f"[PromptAnalyzer] {Fore.LIGHTYELLOW_EX}Prompt analysis took {round(t2-t1,2)} seconds.{Style.RESET_ALL}") for i in range(len(top3)): if i == 0: print(f"[PromptAnalyzer] {top3[i][0]}: {Fore.RED}{round(top3[i][1],2)} {Fore.LIGHTRED_EX}({top3[i][2]}){Style.RESET_ALL}") else: print(f"[PromptAnalyzer] {top3[i][0]}: {Fore.LIGHTBLUE_EX}{round(top3[i][1],2)} {Fore.LIGHTGREEN_EX}({top3[i][2]}){Style.RESET_ALL}") style = top3[0][0] vyro_params.style = style calc_prompt(vyro_params, prompt_tree) return (vyro_params, style, ) except Exception as e: # Handle any classifier errors if debug == VyroParams.STATES[1]: print(f"[PromptAnalyzer] {Fore.LIGHTYELLOW_EX}Classifier error: {str(e)}{Style.RESET_ALL}") print(f"[PromptAnalyzer] {Fore.LIGHTYELLOW_EX}Falling back to default style.{Style.RESET_ALL}") # Use the first style as default # default_style = deweighted_styles[0] if deweighted_styles else "" default_style = deweighted_styles[0] if deweighted_styles else "" print(f"[PromptAnalyzer] the style is {default_style}") vyro_params.style = default_style calc_prompt(vyro_params, prompt_tree) return (vyro_params, default_style, ) class VyroPromptEncoder: @classmethod def INPUT_TYPES(s): return {"required": { "base_clip": ("CLIP", ), "refiner_clip": ("CLIP", ), "params": ("VYRO_PARAMS",), "crop_factor": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}), }, } RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "CONDITIONING", "CONDITIONING", ) RETURN_NAMES = ("base_positive", "base_negative", "refiner_positive", "refiner_negative", ) FUNCTION = "encode" CATEGORY = "Vyro/Prompt" def encode(self, base_clip, refiner_clip, params:VyroParams, crop_factor:float): empty = base_clip.tokenize("") print(f'[VyroPromptEncoder] {Fore.LIGHTYELLOW_EX}Base positive prompt: {params.final_positive_prompt}{Style.RESET_ALL}') print(f'[VyroPromptEncoder] {Fore.LIGHTYELLOW_EX}Base negative prompt: {params.final_negative_prompt}{Style.RESET_ALL}') pos_r = pos_g = params.final_positive_prompt neg_r = neg_g = params.final_negative_prompt pos_l = params.final_positive_prompt neg_l = params.final_positive_prompt base_width = params.width base_height = params.height crop_w = int(params.width * crop_factor) crop_h = int(params.height * crop_factor) target_width = params.width * 4 target_height = params.height * 4 pos_ascore = 6.0 neg_ascore = 1.0 refiner_width = params.width refiner_height = params.height # positive base prompt tokens1 = base_clip.tokenize(pos_g) tokens1["l"] = base_clip.tokenize(pos_l)["l"] if len(tokens1["l"]) != len(tokens1["g"]): while len(tokens1["l"]) < len(tokens1["g"]): tokens1["l"] += empty["l"] while len(tokens1["l"]) > len(tokens1["g"]): tokens1["g"] += empty["g"] cond1, pooled1 = base_clip.encode_from_tokens(tokens1, return_pooled=True) res1 = [[cond1, {"pooled_output": pooled1, "width": base_width, "height": base_height, "crop_w": crop_w, "crop_h": crop_h, "target_width": target_width, "target_height": target_height}]] # negative base prompt tokens2 = base_clip.tokenize(neg_g) tokens2["l"] = base_clip.tokenize(neg_l)["l"] if len(tokens2["l"]) != len(tokens2["g"]): while len(tokens2["l"]) < len(tokens2["g"]): tokens2["l"] += empty["l"] while len(tokens2["l"]) > len(tokens2["g"]): tokens2["g"] += empty["g"] cond2, pooled2 = base_clip.encode_from_tokens(tokens2, return_pooled=True) res2 = [[cond2, {"pooled_output": pooled2, "width": base_width, "height": base_height, "crop_w": crop_w, "crop_h": crop_h, "target_width": target_width, "target_height": target_height}]] # positive refiner prompt tokens3 = refiner_clip.tokenize(pos_r) cond3, pooled3 = refiner_clip.encode_from_tokens(tokens3, return_pooled=True) res3 = [[cond3, {"pooled_output": pooled3, "aesthetic_score": pos_ascore, "width": refiner_width, "height": refiner_height}]] # negative refiner prompt tokens4 = refiner_clip.tokenize(neg_r) cond4, pooled4 = refiner_clip.encode_from_tokens(tokens4, return_pooled=True) res4 = [[cond4, {"pooled_output": pooled4, "aesthetic_score": neg_ascore, "width": refiner_width, "height": refiner_height}]] return (res1, res2, res3, res4, ) NODE_CLASS_MAPPINGS = { "Vyro Prompt Analyzer": VyroPromptAnalyzer, "Vyro Prompt Encoder": VyroPromptEncoder, } NODE_DISPLAY_NAME_MAPPINGS = { "VyroPromptAnalyzer": "Vyro Prompt Analyzer", "VyroPromptEncoder": "Vyro Prompt Encoder", }