|
from functools import lru_cache |
|
import json |
|
import os |
|
|
|
|
|
import folder_paths |
|
from folder_paths import folder_names_and_paths, supported_pt_extensions |
|
import nodes |
|
import spacy |
|
import torch |
|
|
|
def update_paths(): |
|
model_path = folder_paths.models_dir |
|
folder_names_and_paths["vyro_configs"] = ([os.path.join(os.path.dirname(os.path.abspath(__file__)), "../configs")], ['.json']) |
|
folder_names_and_paths["spacy"] = ([(os.path.join(model_path, "spacy"))], supported_pt_extensions) |
|
folder_names_and_paths["interposers"] = ([(os.path.join(model_path, "interposers"))], supported_pt_extensions) |
|
folder_names_and_paths['oneflow_graphs'] = ([(os.path.join(model_path, "oneflow_graphs"))], ("")) |
|
|
|
update_paths() |
|
|
|
class VyroConfigLoader: |
|
def __init__(self): |
|
pass |
|
|
|
@classmethod |
|
def INPUT_TYPES(s): |
|
|
|
paths = [] |
|
update_paths() |
|
spacy_paths = folder_paths.get_folder_paths("spacy") |
|
print(f"[VyroConfigLoader] spacy paths: {spacy_paths}") |
|
|
|
for search_path in spacy_paths: |
|
if os.path.exists(search_path): |
|
print(f"[VyroConfigLoader] Found spacy path: {search_path}") |
|
for root, subdir, files in os.walk(search_path, followlinks=True): |
|
if "config.cfg" in files: |
|
rel_path = os.path.relpath(root, start=search_path) |
|
paths.append(rel_path) |
|
print(f"[VyroConfigLoader] Added classifier path: {rel_path}") |
|
|
|
print(f"[VyroConfigLoader] Available classifier paths: {paths}") |
|
return { |
|
"required": { |
|
"config_path": (folder_paths.get_filename_list("vyro_configs"),), |
|
"classifier_path": (paths,), |
|
} |
|
} |
|
|
|
RETURN_TYPES = ("LIST","DICT","DICT","TRANSFORMER","LIST") |
|
RETURN_NAMES = ("styles","prompt_tree","model_config","classifier","unweighted_styles") |
|
FUNCTION = "load_config" |
|
CATEGORY = "Vyro/Loaders" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_config(self, config_path, classifier_path): |
|
config_path = folder_paths.get_full_path("vyro_configs", config_path) |
|
classifier_path = os.path.join(folder_names_and_paths["spacy"][0][0], classifier_path) |
|
|
|
|
|
with open(config_path, 'r') as json_file: |
|
try: |
|
config = json.load(json_file) |
|
except json.JSONDecodeError as json_e: |
|
print(f"[VyroConfigLoader] Error loading {config_path}:", json_e) |
|
return None |
|
|
|
|
|
styles_list = [] |
|
unweighted_styles = [] |
|
for style in config['styles']: |
|
if ':' in style: |
|
style_name = style.split(':')[0] |
|
else: |
|
style_name = style |
|
unweighted_styles.append(style_name) |
|
styles_list.append(style_name.replace(' ', '_')) |
|
|
|
|
|
spacy.prefer_gpu(gpu_id=0) |
|
classifier = spacy.load(classifier_path) |
|
|
|
|
|
if "textcat" in classifier.pipe_names: |
|
classifier.remove_pipe("textcat") |
|
if "textcat_multilabel" in classifier.pipe_names: |
|
classifier.remove_pipe("textcat_multilabel") |
|
|
|
|
|
@classifier.component("style_matcher") |
|
def style_matcher(doc): |
|
|
|
doc.cats = {} |
|
|
|
|
|
text_lower = doc.text.lower() |
|
|
|
|
|
for style in styles_list: |
|
style_human = style.replace('_', ' ').lower() |
|
|
|
base_score = 0.1 |
|
|
|
|
|
if style_human in text_lower: |
|
base_score = 0.9 |
|
|
|
elif any(word in text_lower for word in style_human.split()): |
|
base_score = 0.5 |
|
|
|
doc.cats[style] = base_score |
|
|
|
|
|
if not any(score > 0.5 for score in doc.cats.values()) and styles_list: |
|
doc.cats[styles_list[0]] = 0.6 |
|
|
|
return doc |
|
|
|
|
|
if "style_matcher" not in classifier.pipe_names: |
|
classifier.add_pipe("style_matcher") |
|
|
|
print(f"[VyroConfigLoader] Successfully configured style_matcher with {len(styles_list)} labels") |
|
print(f"[VyroConfigLoader] Pipeline: {classifier.pipe_names}") |
|
|
|
return (config['styles'], config['prompt_tree'], config['model_config'], classifier, unweighted_styles) |
|
|
|
|
|
|
|
class VyroModelLoader: |
|
def __init__(self): |
|
self.chkp_loader = nodes.CheckpointLoaderSimple() |
|
self.lora_loader = nodes.LoraLoader() |
|
self.tree = None |
|
self.config = None |
|
print("\n\nInitializing VyroModelLoader") |
|
|
|
@lru_cache(maxsize=6) |
|
def get_model(self, cfg): |
|
base = self.config['configs'][cfg]['base'] |
|
refiner = self.config['configs'][cfg]['refiner'] |
|
loras = self.config['configs'][cfg]['loras'] |
|
tonemap_multiplier = self.config['configs'][cfg]['tonemap'] |
|
base_model, base_clip, _ = self.chkp_loader.load_checkpoint(base) |
|
refiner_model, refiner_clip, _ = self.chkp_loader.load_checkpoint(refiner) |
|
for lora in loras: |
|
base_model, base_clip = self.lora_loader.load_lora(base_model, base_clip, lora['name'], lora['unet'], lora['clip']) |
|
|
|
def sampler_tonemap_reinhard(args): |
|
cond = args["cond"] |
|
uncond = args["uncond"] |
|
cond_scale = args["cond_scale"] |
|
noise_pred = (cond - uncond) |
|
noise_pred_vector_magnitude = (torch.linalg.vector_norm(noise_pred, dim=(1)) + 0.0000000001)[:,None] |
|
noise_pred /= noise_pred_vector_magnitude |
|
|
|
mean = torch.mean(noise_pred_vector_magnitude, dim=(1,2,3), keepdim=True) |
|
std = torch.std(noise_pred_vector_magnitude, dim=(1,2,3), keepdim=True) |
|
|
|
top = (std * 3 + mean) * tonemap_multiplier |
|
|
|
|
|
noise_pred_vector_magnitude *= (1.0 / top) |
|
new_magnitude = noise_pred_vector_magnitude / (noise_pred_vector_magnitude + 1.0) |
|
new_magnitude *= top |
|
|
|
return uncond + noise_pred * new_magnitude * cond_scale |
|
|
|
base_model = base_model.clone() |
|
base_model.set_model_sampler_cfg_function(sampler_tonemap_reinhard) |
|
|
|
refiner_model = refiner_model.clone() |
|
refiner_model.set_model_sampler_cfg_function(sampler_tonemap_reinhard) |
|
|
|
return (base_model, refiner_model, base_clip, refiner_clip) |
|
|
|
|
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"style": ("STYLE",), |
|
"prompt_tree": ("DICT",), |
|
"model_config": ("DICT",), |
|
} |
|
} |
|
|
|
RETURN_TYPES = ("MODEL","MODEL","CLIP","CLIP") |
|
RETURN_NAMES = ("base_model","refiner_model","base_clip","refiner_clip") |
|
FUNCTION = "load" |
|
CATEGORY = "Vyro/Loaders" |
|
|
|
def load(self, style, prompt_tree, model_config): |
|
print("\n\nExecuting VyroModelLoader load function...") |
|
if prompt_tree is None: |
|
raise ValueError("Prompt tree is None") |
|
|
|
if style is None or style == "qr": |
|
print("⛔ Style is qr changing to default qr models") |
|
self.tree = prompt_tree |
|
self.config = model_config |
|
cfg = 'default_qr' |
|
return self.get_model(cfg) |
|
|
|
|
|
|
|
if style is None or style == "": |
|
print("⛔ Style is none using Default config") |
|
self.tree = prompt_tree |
|
self.config = model_config |
|
cfg = 'default' |
|
return self.get_model(cfg) |
|
|
|
|
|
if style not in prompt_tree.keys(): |
|
raise ValueError("Style not in prompt tree") |
|
|
|
node = prompt_tree[style] |
|
if 'model_config' not in node: |
|
cfg = 'default' |
|
else: |
|
cfg = node['model_config'] |
|
|
|
self.tree = prompt_tree |
|
self.config = model_config |
|
|
|
return self.get_model(cfg) |
|
|
|
|
|
|
|
class VyroOneflowModelLoader: |
|
def __init__(self): |
|
self.chkp_loader = nodes.CheckpointLoaderSimple() |
|
self.lora_loader = nodes.LoraLoader() |
|
self.tree = None |
|
self.config = None |
|
print("\n\nInitializing VyroModelLoader") |
|
|
|
@lru_cache(maxsize=6) |
|
def get_model(self, cfg): |
|
base = self.config['configs'][cfg]['base'] |
|
refiner = self.config['configs'][cfg]['refiner'] |
|
loras = self.config['configs'][cfg]['loras'] |
|
tonemap_multiplier = self.config['configs'][cfg]['tonemap'] |
|
base_model, base_clip, _ = self.chkp_loader.load_checkpoint(base) |
|
refiner_model, refiner_clip, _ = self.chkp_loader.load_checkpoint(refiner) |
|
for lora in loras: |
|
base_model, base_clip = self.lora_loader.load_lora(base_model, base_clip, lora['name'], lora['unet'], lora['clip']) |
|
|
|
def sampler_tonemap_reinhard(args): |
|
cond = args["cond"] |
|
uncond = args["uncond"] |
|
cond_scale = args["cond_scale"] |
|
noise_pred = (cond - uncond) |
|
noise_pred_vector_magnitude = (torch.linalg.vector_norm(noise_pred, dim=(1)) + 0.0000000001)[:,None] |
|
noise_pred /= noise_pred_vector_magnitude |
|
|
|
mean = torch.mean(noise_pred_vector_magnitude, dim=(1,2,3), keepdim=True) |
|
std = torch.std(noise_pred_vector_magnitude, dim=(1,2,3), keepdim=True) |
|
|
|
top = (std * 3 + mean) * tonemap_multiplier |
|
|
|
|
|
noise_pred_vector_magnitude *= (1.0 / top) |
|
new_magnitude = noise_pred_vector_magnitude / (noise_pred_vector_magnitude + 1.0) |
|
new_magnitude *= top |
|
|
|
return uncond + noise_pred * new_magnitude * cond_scale |
|
|
|
base_model = base_model.clone() |
|
base_model.set_model_sampler_cfg_function(sampler_tonemap_reinhard) |
|
|
|
refiner_model = refiner_model.clone() |
|
refiner_model.set_model_sampler_cfg_function(sampler_tonemap_reinhard) |
|
|
|
return (base_model, refiner_model, base_clip, refiner_clip) |
|
|
|
|
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"style": ("STYLE",), |
|
"prompt_tree": ("DICT",), |
|
"model_config": ("DICT",), |
|
} |
|
} |
|
|
|
RETURN_TYPES = ("MODEL","MODEL","CLIP","CLIP") |
|
RETURN_NAMES = ("base_model","refiner_model","base_clip","refiner_clip") |
|
FUNCTION = "load" |
|
CATEGORY = "Vyro/Loaders" |
|
|
|
def load(self, style, prompt_tree, model_config): |
|
print("\n\nExecuting VyroOneflowModelLoader load function...") |
|
|
|
if prompt_tree is None: |
|
raise ValueError("Prompt tree is None") |
|
|
|
if style is None or style == "qr": |
|
print("⛔ Style is qr changing to default qr models") |
|
self.tree = prompt_tree |
|
self.config = model_config |
|
cfg = 'default_qr' |
|
return self.get_model(cfg) |
|
|
|
|
|
|
|
if style is None or style == "": |
|
print("⛔ Style is none using Default config") |
|
self.tree = prompt_tree |
|
self.config = model_config |
|
cfg = 'default' |
|
return self.get_model(cfg) |
|
|
|
|
|
if style not in prompt_tree.keys(): |
|
raise ValueError("Style not in prompt tree") |
|
|
|
node = prompt_tree[style] |
|
if 'model_config' not in node: |
|
cfg = 'default' |
|
else: |
|
cfg = node['model_config'] |
|
|
|
self.tree = prompt_tree |
|
self.config = model_config |
|
|
|
return self.get_model(cfg) |
|
|
|
|
|
class VyroOneFlowBaseModelLoader: |
|
def __init__(self) -> None: |
|
self.chkp_loader = nodes.CheckpointLoaderSimple() |
|
print("INITIALIZING ONEFLOW BASE MODEL LOADER") |
|
|
|
@lru_cache(maxsize=6) |
|
def get_base_model(self, base_model): |
|
print(f"\n\nLoading Base Model: {base_model}\n\n") |
|
base_model, base_clip, _ = self.chkp_loader.load_checkpoint(base_model) |
|
tonemap_multiplier = 1.0 |
|
|
|
def sampler_tonemap_reinhard(args): |
|
cond = args["cond"] |
|
uncond = args["uncond"] |
|
cond_scale = args["cond_scale"] |
|
noise_pred = (cond - uncond) |
|
noise_pred_vector_magnitude = (torch.linalg.vector_norm(noise_pred, dim=(1)) + 0.0000000001)[:,None] |
|
noise_pred /= noise_pred_vector_magnitude |
|
|
|
mean = torch.mean(noise_pred_vector_magnitude, dim=(1,2,3), keepdim=True) |
|
std = torch.std(noise_pred_vector_magnitude, dim=(1,2,3), keepdim=True) |
|
|
|
top = (std * 3 + mean) * tonemap_multiplier |
|
|
|
|
|
noise_pred_vector_magnitude *= (1.0 / top) |
|
new_magnitude = noise_pred_vector_magnitude / (noise_pred_vector_magnitude + 1.0) |
|
new_magnitude *= top |
|
|
|
return uncond + noise_pred * new_magnitude * cond_scale |
|
|
|
|
|
base_model = base_model.clone() |
|
base_model.set_model_sampler_cfg_function(sampler_tonemap_reinhard) |
|
|
|
return (base_model, base_clip) |
|
|
|
|
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"base_model": (folder_paths.get_filename_list("checkpoints"), ), |
|
} |
|
} |
|
|
|
RETURN_TYPES = ("MODEL", "CLIP",) |
|
RETURN_NAMES = ("base_model", "base_clip",) |
|
FUNCTION = "load" |
|
CATERGORY = "Vyro/Loaders/Oneflow" |
|
|
|
def load(self, base_model): |
|
return self.get_base_model(base_model) |
|
|
|
|
|
class VyroLoraLoader: |
|
|
|
def __init__(self) -> None: |
|
self.lora_loader = nodes.LoraLoader() |
|
self.tree = None |
|
self.config = None |
|
print("\n\nInitializing Vyro LORA Loader\n\n") |
|
|
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"base_model": ("MODEL",), |
|
"base_clip": ("CLIP",), |
|
"style": ("STYLE",), |
|
"prompt_tree": ("DICT",), |
|
"model_config": ("DICT",), |
|
} |
|
} |
|
RETURN_TYPES = ("MODEL", "CLIP") |
|
RETURN_NAMES = ("base_model", "base_clip") |
|
FUNCTION = "load_loras" |
|
CATEGORY = "Vyro/Loaders/Lora" |
|
|
|
def load_loras(self, base_model, base_clip, style, prompt_tree, model_config): |
|
print("\n\nExecuting VyroLORA load function...") |
|
if prompt_tree is None: |
|
raise ValueError("Prompt tree is None") |
|
|
|
if style is None or style == "qr": |
|
print("⛔ Style is qr changing to default qr models") |
|
self.tree = prompt_tree |
|
self.config = model_config |
|
cfg = 'default_qr' |
|
return self.get_model(cfg) |
|
|
|
|
|
|
|
if style is None or style == "": |
|
print("⛔ Style is none using Default config") |
|
self.tree = prompt_tree |
|
self.config = model_config |
|
cfg = 'default' |
|
return self.get_model(cfg) |
|
|
|
|
|
if style not in prompt_tree.keys(): |
|
raise ValueError("Style not in prompt tree") |
|
|
|
node = prompt_tree[style] |
|
if 'model_config' not in node: |
|
cfg = 'default' |
|
else: |
|
cfg = node['model_config'] |
|
|
|
self.tree = prompt_tree |
|
self.config = model_config |
|
|
|
loras = self.config['configs'][cfg]['loras'] |
|
for lora in loras: |
|
print(f"\nLoading Lora: {lora['name']}\n") |
|
base_model, base_clip = self.lora_loader.load_lora(base_model, base_clip, lora['name'], lora['unet'], lora['clip']) |
|
|
|
return base_model, base_clip |
|
|
|
|
|
class VyroOneFlowRefinerModelLoader: |
|
def __init__(self) -> None: |
|
self.chkp_loader = nodes.CheckpointLoaderSimple() |
|
self.tree = None |
|
self.config = None |
|
print("INITIALIZING ONEFLOW REFINER MODEL LOADER") |
|
|
|
@lru_cache(maxsize=6) |
|
def get_refiner_model(self, cfg): |
|
refiner = self.config['configs'][cfg]['refiner'] |
|
|
|
print(f"\n\nLoading Refiner Model: {refiner}\n\n") |
|
|
|
tonemap_multiplier = self.config['configs'][cfg]['tonemap'] |
|
|
|
refiner_model, refiner_clip, _ = self.chkp_loader.load_checkpoint(refiner) |
|
|
|
def sampler_tonemap_reinhard(args): |
|
cond = args["cond"] |
|
uncond = args["uncond"] |
|
cond_scale = args["cond_scale"] |
|
noise_pred = (cond - uncond) |
|
noise_pred_vector_magnitude = (torch.linalg.vector_norm(noise_pred, dim=(1)) + 0.0000000001)[:,None] |
|
noise_pred /= noise_pred_vector_magnitude |
|
|
|
mean = torch.mean(noise_pred_vector_magnitude, dim=(1,2,3), keepdim=True) |
|
std = torch.std(noise_pred_vector_magnitude, dim=(1,2,3), keepdim=True) |
|
|
|
top = (std * 3 + mean) * tonemap_multiplier |
|
|
|
|
|
noise_pred_vector_magnitude *= (1.0 / top) |
|
new_magnitude = noise_pred_vector_magnitude / (noise_pred_vector_magnitude + 1.0) |
|
new_magnitude *= top |
|
|
|
return uncond + noise_pred * new_magnitude * cond_scale |
|
|
|
refiner_model = refiner_model.clone() |
|
refiner_model.set_model_sampler_cfg_function(sampler_tonemap_reinhard) |
|
|
|
return (refiner_model, refiner_clip) |
|
|
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"style": ("STYLE",), |
|
"prompt_tree": ("DICT",), |
|
"model_config": ("DICT",) |
|
} |
|
} |
|
|
|
RETURN_TYPES = ("MODEL", "CLIP") |
|
RETURN_NAMES = ("refiner_model", "refiner_clip") |
|
FUNCTION = "load" |
|
CATERGORY = "Vyro/Loaders/Oneflow" |
|
|
|
def load(self, style, prompt_tree, model_config): |
|
if prompt_tree is None: |
|
raise ValueError("Prompt Tree is None") |
|
|
|
if style is None or style == "qr": |
|
print("⛔ Style is qr changing to default qr models") |
|
self.tree = prompt_tree |
|
self.config = model_config |
|
cfg = "default_qr" |
|
return self.get_model(cfg) |
|
|
|
if style is None or style == "": |
|
print("⛔ Style is none using Default config") |
|
self.tree = prompt_tree |
|
self.config = model_config |
|
cfg = 'default' |
|
return self.get_model(cfg) |
|
|
|
|
|
if style not in prompt_tree.keys(): |
|
raise ValueError("Style not in prompt tree") |
|
|
|
node = prompt_tree[style] |
|
if 'model_config' not in node: |
|
cfg = 'default' |
|
else: |
|
cfg = node['model_config'] |
|
|
|
self.tree = prompt_tree |
|
self.config = model_config |
|
|
|
return self.get_refiner_model(cfg) |
|
|
|
|
|
|
|
|
|
NODE_CLASS_MAPPINGS = { |
|
"Vyro Config Loader": VyroConfigLoader, |
|
"Vyro Model Loader": VyroModelLoader, |
|
"Vyro OneFlow Model Loader": VyroOneflowModelLoader, |
|
"Vyro Oneflow Base Model Loader": VyroOneFlowBaseModelLoader, |
|
"Vyro Oneflow Refiner Model Loader": VyroOneFlowRefinerModelLoader, |
|
"Vyro LoRa Loader": VyroLoraLoader, |
|
} |
|
NODE_DISPLAY_NAME_MAPPINGS = { |
|
"VyroConfigLoader": "Vyro Config Loader", |
|
"VyroModelLoader": "Vyro Model Loader", |
|
"VyroOneFlowModelLoader": "Vyro Oneflow Model Loader", |
|
"VyroOneFlowBaseModelLoader": "Vyro Oneflow Base Model Loader", |
|
"VyroOneflowRefinerModelLoader": "Vyro Oneflow Refiner Model Loader", |
|
"VyroLoraLoader": "Vyro LoRa Loader" |
|
} |