azkavyro's picture
Added all files including vyro_workflows
6fecfbe
import re
import time
from colorama import Fore, Style
import spacy
from ..utils import VyroParams
from ..utils.prompt import calc_prompt
class VyroPromptAnalyzer:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"vyro_params": ("VYRO_PARAMS",),
"styles": ("LIST",),
"prompt_tree": ("DICT",),
"classifier": ("TRANSFORMER",),
"debug": (VyroParams.STATES, {"default": VyroParams.STATES[0]}),
"skip": (VyroParams.STATES, {"default": VyroParams.STATES[0]})
}
}
RETURN_TYPES = ("VYRO_PARAMS", "STYLE",)
RETURN_NAMES = ("vyro_params", "style", )
FUNCTION = "analyze_prompt"
CATEGORY = "Vyro/Prompt"
def analyze_prompt(self, vyro_params:VyroParams, styles:list, prompt_tree:dict, classifier:spacy.language.Language, debug, skip):
user_prompt = vyro_params.user_prompt
if '--raw' in user_prompt.lower() or skip == VyroParams.STATES[1]:
user_prompt = user_prompt.lower().replace('--raw', '')
vyro_params.is_raw = True
vyro_params.user_prompt = user_prompt
classifier_target = user_prompt
if vyro_params.is_raw==True:
vyro_params.user_prompt = user_prompt
vyro_params.final_positive_prompt = vyro_params.user_prompt
vyro_params.final_negative_prompt = vyro_params.user_neg_prompt
#Check if its in qrcode mode, this way we can skip prompt analyze
if vyro_params.mode =="qr":
vyro_params.style = "qr"
return (vyro_params, "qr", )
vyro_params.style = ""
return (vyro_params, "", )
if '--style:' in user_prompt.lower():
# Extract style from prompt
# Style is in the format of --style:style_name or --style:"style name"
style_r = r'\-\-style\:(\"[a-zA-Z0-9\s]+\"|[a-zA-Z0-9_]+)'
extracted_style = re.findall(style_r, user_prompt)
if len(extracted_style) > 0:
extracted_style = extracted_style[0]
if extracted_style.startswith('"'):
extracted_style = extracted_style[1:-1]
user_prompt = re.sub(style_r, '', user_prompt)
else:
user_prompt = user_prompt.replace('--style:', '')
vyro_params.user_prompt = user_prompt
classifier_target = extracted_style
if classifier_target in styles:
vyro_params.style = classifier_target
calc_prompt(vyro_params, prompt_tree)
return (vyro_params, classifier_target, )
deweighted_styles = []
weights = []
for style in styles:
if ':' in style:
weight = style.split(':')[1]
style = style.split(':')[0]
if not isinstance(weight, float):
weight = float(weight)
else:
weight = 1.0
deweighted_styles.append(style)
weights.append(weight)
# Use try-except to handle classifier errors
try:
t1 = time.time()
doc = classifier(classifier_target)
t2 = time.time()
# Check if doc.cats exists and is not empty
if not hasattr(doc, 'cats') or not doc.cats:
raise ValueError("No categories found in classifier output")
top_label = sorted(doc.cats.items(), key=lambda x: x[1], reverse=True)[:1]
if not top_label:
raise ValueError("No top label found in classifier output")
top_label = top_label[0][0].replace('_', ' ')
scores = []
labels = []
for i in range(len(deweighted_styles)):
style = deweighted_styles[i].replace(' ', '_')
if style in doc.cats.keys():
score = doc.cats[style] * weights[i]
scores.append(score)
labels.append(deweighted_styles[i])
if not scores:
raise ValueError("No matching styles found in classifier output")
zipped = zip(scores, labels)
weighted_scores = []
for i, (score, label) in enumerate(zipped):
style = label
idx = deweighted_styles.index(style)
weight = weights[idx]
weighted_scores.append((label, score * weight, weight))
weighted_scores.sort(key=lambda x: x[1], reverse=True)
top_x = 3 if len(weighted_scores) > 3 else len(weighted_scores)
top3 = weighted_scores[:top_x]
if debug == VyroParams.STATES[1]:
print(f"[PromptAnalyzer] {Fore.LIGHTYELLOW_EX}Prompt analysis took {round(t2-t1,2)} seconds.{Style.RESET_ALL}")
for i in range(len(top3)):
if i == 0:
print(f"[PromptAnalyzer] {top3[i][0]}: {Fore.RED}{round(top3[i][1],2)} {Fore.LIGHTRED_EX}({top3[i][2]}){Style.RESET_ALL}")
else:
print(f"[PromptAnalyzer] {top3[i][0]}: {Fore.LIGHTBLUE_EX}{round(top3[i][1],2)} {Fore.LIGHTGREEN_EX}({top3[i][2]}){Style.RESET_ALL}")
style = top3[0][0]
vyro_params.style = style
calc_prompt(vyro_params, prompt_tree)
return (vyro_params, style, )
except Exception as e:
# Handle any classifier errors
if debug == VyroParams.STATES[1]:
print(f"[PromptAnalyzer] {Fore.LIGHTYELLOW_EX}Classifier error: {str(e)}{Style.RESET_ALL}")
print(f"[PromptAnalyzer] {Fore.LIGHTYELLOW_EX}Falling back to default style.{Style.RESET_ALL}")
# Use the first style as default
# default_style = deweighted_styles[0] if deweighted_styles else ""
default_style = deweighted_styles[0] if deweighted_styles else ""
print(f"[PromptAnalyzer] the style is {default_style}")
vyro_params.style = default_style
calc_prompt(vyro_params, prompt_tree)
return (vyro_params, default_style, )
class VyroPromptEncoder:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"base_clip": ("CLIP", ),
"refiner_clip": ("CLIP", ),
"params": ("VYRO_PARAMS",),
"crop_factor": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
},
}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "CONDITIONING", "CONDITIONING", )
RETURN_NAMES = ("base_positive", "base_negative", "refiner_positive", "refiner_negative", )
FUNCTION = "encode"
CATEGORY = "Vyro/Prompt"
def encode(self, base_clip, refiner_clip, params:VyroParams, crop_factor:float):
empty = base_clip.tokenize("")
print(f'[VyroPromptEncoder] {Fore.LIGHTYELLOW_EX}Base positive prompt: {params.final_positive_prompt}{Style.RESET_ALL}')
print(f'[VyroPromptEncoder] {Fore.LIGHTYELLOW_EX}Base negative prompt: {params.final_negative_prompt}{Style.RESET_ALL}')
pos_r = pos_g = params.final_positive_prompt
neg_r = neg_g = params.final_negative_prompt
pos_l = params.final_positive_prompt
neg_l = params.final_positive_prompt
base_width = params.width
base_height = params.height
crop_w = int(params.width * crop_factor)
crop_h = int(params.height * crop_factor)
target_width = params.width * 4
target_height = params.height * 4
pos_ascore = 6.0
neg_ascore = 1.0
refiner_width = params.width
refiner_height = params.height
# positive base prompt
tokens1 = base_clip.tokenize(pos_g)
tokens1["l"] = base_clip.tokenize(pos_l)["l"]
if len(tokens1["l"]) != len(tokens1["g"]):
while len(tokens1["l"]) < len(tokens1["g"]):
tokens1["l"] += empty["l"]
while len(tokens1["l"]) > len(tokens1["g"]):
tokens1["g"] += empty["g"]
cond1, pooled1 = base_clip.encode_from_tokens(tokens1, return_pooled=True)
res1 = [[cond1, {"pooled_output": pooled1, "width": base_width, "height": base_height, "crop_w": crop_w, "crop_h": crop_h, "target_width": target_width, "target_height": target_height}]]
# negative base prompt
tokens2 = base_clip.tokenize(neg_g)
tokens2["l"] = base_clip.tokenize(neg_l)["l"]
if len(tokens2["l"]) != len(tokens2["g"]):
while len(tokens2["l"]) < len(tokens2["g"]):
tokens2["l"] += empty["l"]
while len(tokens2["l"]) > len(tokens2["g"]):
tokens2["g"] += empty["g"]
cond2, pooled2 = base_clip.encode_from_tokens(tokens2, return_pooled=True)
res2 = [[cond2, {"pooled_output": pooled2, "width": base_width, "height": base_height, "crop_w": crop_w, "crop_h": crop_h, "target_width": target_width, "target_height": target_height}]]
# positive refiner prompt
tokens3 = refiner_clip.tokenize(pos_r)
cond3, pooled3 = refiner_clip.encode_from_tokens(tokens3, return_pooled=True)
res3 = [[cond3, {"pooled_output": pooled3, "aesthetic_score": pos_ascore, "width": refiner_width, "height": refiner_height}]]
# negative refiner prompt
tokens4 = refiner_clip.tokenize(neg_r)
cond4, pooled4 = refiner_clip.encode_from_tokens(tokens4, return_pooled=True)
res4 = [[cond4, {"pooled_output": pooled4, "aesthetic_score": neg_ascore, "width": refiner_width, "height": refiner_height}]]
return (res1, res2, res3, res4, )
NODE_CLASS_MAPPINGS = {
"Vyro Prompt Analyzer": VyroPromptAnalyzer,
"Vyro Prompt Encoder": VyroPromptEncoder,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"VyroPromptAnalyzer": "Vyro Prompt Analyzer",
"VyroPromptEncoder": "Vyro Prompt Encoder",
}