|
import re |
|
import time |
|
|
|
from colorama import Fore, Style |
|
|
|
import spacy |
|
|
|
from ..utils import VyroParams |
|
from ..utils.prompt import calc_prompt |
|
|
|
class VyroPromptAnalyzer: |
|
def __init__(self): |
|
pass |
|
|
|
@classmethod |
|
def INPUT_TYPES(s): |
|
|
|
|
|
return { |
|
"required": { |
|
"vyro_params": ("VYRO_PARAMS",), |
|
"styles": ("LIST",), |
|
"prompt_tree": ("DICT",), |
|
"classifier": ("TRANSFORMER",), |
|
"debug": (VyroParams.STATES, {"default": VyroParams.STATES[0]}), |
|
"skip": (VyroParams.STATES, {"default": VyroParams.STATES[0]}) |
|
} |
|
} |
|
|
|
RETURN_TYPES = ("VYRO_PARAMS", "STYLE",) |
|
RETURN_NAMES = ("vyro_params", "style", ) |
|
|
|
FUNCTION = "analyze_prompt" |
|
|
|
CATEGORY = "Vyro/Prompt" |
|
|
|
|
|
def analyze_prompt(self, vyro_params:VyroParams, styles:list, prompt_tree:dict, classifier:spacy.language.Language, debug, skip): |
|
|
|
user_prompt = vyro_params.user_prompt |
|
if '--raw' in user_prompt.lower() or skip == VyroParams.STATES[1]: |
|
user_prompt = user_prompt.lower().replace('--raw', '') |
|
vyro_params.is_raw = True |
|
vyro_params.user_prompt = user_prompt |
|
|
|
classifier_target = user_prompt |
|
|
|
if vyro_params.is_raw==True: |
|
vyro_params.user_prompt = user_prompt |
|
vyro_params.final_positive_prompt = vyro_params.user_prompt |
|
vyro_params.final_negative_prompt = vyro_params.user_neg_prompt |
|
|
|
|
|
if vyro_params.mode =="qr": |
|
vyro_params.style = "qr" |
|
return (vyro_params, "qr", ) |
|
|
|
vyro_params.style = "" |
|
return (vyro_params, "", ) |
|
|
|
|
|
if '--style:' in user_prompt.lower(): |
|
|
|
|
|
|
|
style_r = r'\-\-style\:(\"[a-zA-Z0-9\s]+\"|[a-zA-Z0-9_]+)' |
|
extracted_style = re.findall(style_r, user_prompt) |
|
if len(extracted_style) > 0: |
|
extracted_style = extracted_style[0] |
|
if extracted_style.startswith('"'): |
|
extracted_style = extracted_style[1:-1] |
|
|
|
|
|
user_prompt = re.sub(style_r, '', user_prompt) |
|
else: |
|
user_prompt = user_prompt.replace('--style:', '') |
|
vyro_params.user_prompt = user_prompt |
|
classifier_target = extracted_style |
|
|
|
if classifier_target in styles: |
|
vyro_params.style = classifier_target |
|
calc_prompt(vyro_params, prompt_tree) |
|
return (vyro_params, classifier_target, ) |
|
|
|
deweighted_styles = [] |
|
weights = [] |
|
for style in styles: |
|
if ':' in style: |
|
weight = style.split(':')[1] |
|
style = style.split(':')[0] |
|
if not isinstance(weight, float): |
|
weight = float(weight) |
|
else: |
|
weight = 1.0 |
|
deweighted_styles.append(style) |
|
weights.append(weight) |
|
|
|
|
|
try: |
|
t1 = time.time() |
|
doc = classifier(classifier_target) |
|
t2 = time.time() |
|
|
|
|
|
if not hasattr(doc, 'cats') or not doc.cats: |
|
raise ValueError("No categories found in classifier output") |
|
|
|
top_label = sorted(doc.cats.items(), key=lambda x: x[1], reverse=True)[:1] |
|
if not top_label: |
|
raise ValueError("No top label found in classifier output") |
|
|
|
top_label = top_label[0][0].replace('_', ' ') |
|
|
|
scores = [] |
|
labels = [] |
|
|
|
for i in range(len(deweighted_styles)): |
|
style = deweighted_styles[i].replace(' ', '_') |
|
if style in doc.cats.keys(): |
|
score = doc.cats[style] * weights[i] |
|
scores.append(score) |
|
labels.append(deweighted_styles[i]) |
|
|
|
if not scores: |
|
raise ValueError("No matching styles found in classifier output") |
|
|
|
zipped = zip(scores, labels) |
|
weighted_scores = [] |
|
|
|
for i, (score, label) in enumerate(zipped): |
|
style = label |
|
idx = deweighted_styles.index(style) |
|
weight = weights[idx] |
|
weighted_scores.append((label, score * weight, weight)) |
|
|
|
weighted_scores.sort(key=lambda x: x[1], reverse=True) |
|
top_x = 3 if len(weighted_scores) > 3 else len(weighted_scores) |
|
top3 = weighted_scores[:top_x] |
|
|
|
if debug == VyroParams.STATES[1]: |
|
print(f"[PromptAnalyzer] {Fore.LIGHTYELLOW_EX}Prompt analysis took {round(t2-t1,2)} seconds.{Style.RESET_ALL}") |
|
for i in range(len(top3)): |
|
if i == 0: |
|
print(f"[PromptAnalyzer] {top3[i][0]}: {Fore.RED}{round(top3[i][1],2)} {Fore.LIGHTRED_EX}({top3[i][2]}){Style.RESET_ALL}") |
|
else: |
|
print(f"[PromptAnalyzer] {top3[i][0]}: {Fore.LIGHTBLUE_EX}{round(top3[i][1],2)} {Fore.LIGHTGREEN_EX}({top3[i][2]}){Style.RESET_ALL}") |
|
|
|
style = top3[0][0] |
|
vyro_params.style = style |
|
calc_prompt(vyro_params, prompt_tree) |
|
|
|
return (vyro_params, style, ) |
|
|
|
except Exception as e: |
|
|
|
if debug == VyroParams.STATES[1]: |
|
print(f"[PromptAnalyzer] {Fore.LIGHTYELLOW_EX}Classifier error: {str(e)}{Style.RESET_ALL}") |
|
print(f"[PromptAnalyzer] {Fore.LIGHTYELLOW_EX}Falling back to default style.{Style.RESET_ALL}") |
|
|
|
|
|
|
|
default_style = deweighted_styles[0] if deweighted_styles else "" |
|
print(f"[PromptAnalyzer] the style is {default_style}") |
|
vyro_params.style = default_style |
|
calc_prompt(vyro_params, prompt_tree) |
|
|
|
return (vyro_params, default_style, ) |
|
|
|
|
|
|
|
class VyroPromptEncoder: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"base_clip": ("CLIP", ), |
|
"refiner_clip": ("CLIP", ), |
|
"params": ("VYRO_PARAMS",), |
|
"crop_factor": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}), |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "CONDITIONING", "CONDITIONING", ) |
|
RETURN_NAMES = ("base_positive", "base_negative", "refiner_positive", "refiner_negative", ) |
|
FUNCTION = "encode" |
|
|
|
CATEGORY = "Vyro/Prompt" |
|
|
|
def encode(self, base_clip, refiner_clip, params:VyroParams, crop_factor:float): |
|
empty = base_clip.tokenize("") |
|
|
|
print(f'[VyroPromptEncoder] {Fore.LIGHTYELLOW_EX}Base positive prompt: {params.final_positive_prompt}{Style.RESET_ALL}') |
|
print(f'[VyroPromptEncoder] {Fore.LIGHTYELLOW_EX}Base negative prompt: {params.final_negative_prompt}{Style.RESET_ALL}') |
|
pos_r = pos_g = params.final_positive_prompt |
|
neg_r = neg_g = params.final_negative_prompt |
|
|
|
pos_l = params.final_positive_prompt |
|
neg_l = params.final_positive_prompt |
|
base_width = params.width |
|
base_height = params.height |
|
crop_w = int(params.width * crop_factor) |
|
crop_h = int(params.height * crop_factor) |
|
target_width = params.width * 4 |
|
target_height = params.height * 4 |
|
pos_ascore = 6.0 |
|
neg_ascore = 1.0 |
|
refiner_width = params.width |
|
refiner_height = params.height |
|
|
|
|
|
tokens1 = base_clip.tokenize(pos_g) |
|
tokens1["l"] = base_clip.tokenize(pos_l)["l"] |
|
|
|
if len(tokens1["l"]) != len(tokens1["g"]): |
|
while len(tokens1["l"]) < len(tokens1["g"]): |
|
tokens1["l"] += empty["l"] |
|
while len(tokens1["l"]) > len(tokens1["g"]): |
|
tokens1["g"] += empty["g"] |
|
|
|
cond1, pooled1 = base_clip.encode_from_tokens(tokens1, return_pooled=True) |
|
res1 = [[cond1, {"pooled_output": pooled1, "width": base_width, "height": base_height, "crop_w": crop_w, "crop_h": crop_h, "target_width": target_width, "target_height": target_height}]] |
|
|
|
|
|
tokens2 = base_clip.tokenize(neg_g) |
|
tokens2["l"] = base_clip.tokenize(neg_l)["l"] |
|
|
|
if len(tokens2["l"]) != len(tokens2["g"]): |
|
while len(tokens2["l"]) < len(tokens2["g"]): |
|
tokens2["l"] += empty["l"] |
|
while len(tokens2["l"]) > len(tokens2["g"]): |
|
tokens2["g"] += empty["g"] |
|
|
|
cond2, pooled2 = base_clip.encode_from_tokens(tokens2, return_pooled=True) |
|
res2 = [[cond2, {"pooled_output": pooled2, "width": base_width, "height": base_height, "crop_w": crop_w, "crop_h": crop_h, "target_width": target_width, "target_height": target_height}]] |
|
|
|
|
|
|
|
tokens3 = refiner_clip.tokenize(pos_r) |
|
cond3, pooled3 = refiner_clip.encode_from_tokens(tokens3, return_pooled=True) |
|
res3 = [[cond3, {"pooled_output": pooled3, "aesthetic_score": pos_ascore, "width": refiner_width, "height": refiner_height}]] |
|
|
|
|
|
tokens4 = refiner_clip.tokenize(neg_r) |
|
cond4, pooled4 = refiner_clip.encode_from_tokens(tokens4, return_pooled=True) |
|
res4 = [[cond4, {"pooled_output": pooled4, "aesthetic_score": neg_ascore, "width": refiner_width, "height": refiner_height}]] |
|
|
|
return (res1, res2, res3, res4, ) |
|
|
|
|
|
|
|
NODE_CLASS_MAPPINGS = { |
|
"Vyro Prompt Analyzer": VyroPromptAnalyzer, |
|
"Vyro Prompt Encoder": VyroPromptEncoder, |
|
} |
|
|
|
NODE_DISPLAY_NAME_MAPPINGS = { |
|
"VyroPromptAnalyzer": "Vyro Prompt Analyzer", |
|
"VyroPromptEncoder": "Vyro Prompt Encoder", |
|
} |