azkavyro's picture
Added all files including vyro_workflows
6fecfbe
from functools import lru_cache
import json
import os
# from turtle import speed
import folder_paths
from folder_paths import folder_names_and_paths, supported_pt_extensions
import nodes
import spacy
import torch
def update_paths():
model_path = folder_paths.models_dir
folder_names_and_paths["vyro_configs"] = ([os.path.join(os.path.dirname(os.path.abspath(__file__)), "../configs")], ['.json'])
folder_names_and_paths["spacy"] = ([(os.path.join(model_path, "spacy"))], supported_pt_extensions)
folder_names_and_paths["interposers"] = ([(os.path.join(model_path, "interposers"))], supported_pt_extensions)
folder_names_and_paths['oneflow_graphs'] = ([(os.path.join(model_path, "oneflow_graphs"))], (""))
update_paths()
class VyroConfigLoader:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(s):
paths = []
update_paths()
spacy_paths = folder_paths.get_folder_paths("spacy")
print(f"[VyroConfigLoader] spacy paths: {spacy_paths}") # Debug print
for search_path in spacy_paths:
if os.path.exists(search_path):
print(f"[VyroConfigLoader] Found spacy path: {search_path}") # Debug print
for root, subdir, files in os.walk(search_path, followlinks=True):
if "config.cfg" in files:
rel_path = os.path.relpath(root, start=search_path)
paths.append(rel_path)
print(f"[VyroConfigLoader] Added classifier path: {rel_path}") # Debug print
print(f"[VyroConfigLoader] Available classifier paths: {paths}") # Debug print
return {
"required": {
"config_path": (folder_paths.get_filename_list("vyro_configs"),),
"classifier_path": (paths,),
}
}
RETURN_TYPES = ("LIST","DICT","DICT","TRANSFORMER","LIST")
RETURN_NAMES = ("styles","prompt_tree","model_config","classifier","unweighted_styles")
FUNCTION = "load_config"
CATEGORY = "Vyro/Loaders"
# def load_config(self, config_path, classifier_path):
# config_path = folder_paths.get_full_path("vyro_configs", config_path)
# classifier_path = os.path.join(folder_names_and_paths["spacy"][0][0],classifier_path)
# # classifier = pipeline("zero-shot-classification", model=classifier_path,device='cuda:0',local_files_only=True,use_safetensors=True)
# spacy.prefer_gpu(gpu_id=0)
# classifier = spacy.load(classifier_path)
# # classifier = pipeline("zero-shot-classification", model=classifier_path, config=f'{classifier_path}.json', device=0)
# with open(config_path, 'r') as json_file:
# try:
# config = json.load(json_file)
# except json.JSONDecodeError as json_e:
# print(f"[VyroConfigLoader] Error loading {config_path}:", json_e)
# return None
# unweighted_styles = []
# for style in config['styles']:
# if ':' in style:
# style = style.split(':')[0]
# unweighted_styles.append(style)
# return (config['styles'], config['prompt_tree'], config['model_config'], classifier, unweighted_styles)
def load_config(self, config_path, classifier_path):
config_path = folder_paths.get_full_path("vyro_configs", config_path)
classifier_path = os.path.join(folder_names_and_paths["spacy"][0][0], classifier_path)
# Load the configuration file first to get styles
with open(config_path, 'r') as json_file:
try:
config = json.load(json_file)
except json.JSONDecodeError as json_e:
print(f"[VyroConfigLoader] Error loading {config_path}:", json_e)
return None
# Extract styles from config
styles_list = []
unweighted_styles = []
for style in config['styles']:
if ':' in style:
style_name = style.split(':')[0]
else:
style_name = style
unweighted_styles.append(style_name)
styles_list.append(style_name.replace(' ', '_'))
# Load the base spaCy model
spacy.prefer_gpu(gpu_id=0)
classifier = spacy.load(classifier_path)
# Remove existing textcat components if any
if "textcat" in classifier.pipe_names:
classifier.remove_pipe("textcat")
if "textcat_multilabel" in classifier.pipe_names:
classifier.remove_pipe("textcat_multilabel")
# Create a custom style matcher component instead of textcat_multilabel
@classifier.component("style_matcher")
def style_matcher(doc):
# Initialize scores dictionary
doc.cats = {}
# Simple rule-based matching
text_lower = doc.text.lower()
# Set a base score for all styles
for style in styles_list:
style_human = style.replace('_', ' ').lower()
# Default low score
base_score = 0.1
# Check for exact matches
if style_human in text_lower:
base_score = 0.9
# Check for partial matches
elif any(word in text_lower for word in style_human.split()):
base_score = 0.5
doc.cats[style] = base_score
# If no strong matches found, set the first style as default with medium score
if not any(score > 0.5 for score in doc.cats.values()) and styles_list:
doc.cats[styles_list[0]] = 0.6
return doc
# Add the component to the pipeline
if "style_matcher" not in classifier.pipe_names:
classifier.add_pipe("style_matcher")
print(f"[VyroConfigLoader] Successfully configured style_matcher with {len(styles_list)} labels")
print(f"[VyroConfigLoader] Pipeline: {classifier.pipe_names}")
return (config['styles'], config['prompt_tree'], config['model_config'], classifier, unweighted_styles)
class VyroModelLoader:
def __init__(self):
self.chkp_loader = nodes.CheckpointLoaderSimple()
self.lora_loader = nodes.LoraLoader()
self.tree = None
self.config = None
print("\n\nInitializing VyroModelLoader")
@lru_cache(maxsize=6)
def get_model(self, cfg):
base = self.config['configs'][cfg]['base']
refiner = self.config['configs'][cfg]['refiner']
loras = self.config['configs'][cfg]['loras']
tonemap_multiplier = self.config['configs'][cfg]['tonemap']
base_model, base_clip, _ = self.chkp_loader.load_checkpoint(base)
refiner_model, refiner_clip, _ = self.chkp_loader.load_checkpoint(refiner)
for lora in loras:
base_model, base_clip = self.lora_loader.load_lora(base_model, base_clip, lora['name'], lora['unet'], lora['clip'])
def sampler_tonemap_reinhard(args):
cond = args["cond"]
uncond = args["uncond"]
cond_scale = args["cond_scale"]
noise_pred = (cond - uncond)
noise_pred_vector_magnitude = (torch.linalg.vector_norm(noise_pred, dim=(1)) + 0.0000000001)[:,None]
noise_pred /= noise_pred_vector_magnitude
mean = torch.mean(noise_pred_vector_magnitude, dim=(1,2,3), keepdim=True)
std = torch.std(noise_pred_vector_magnitude, dim=(1,2,3), keepdim=True)
top = (std * 3 + mean) * tonemap_multiplier
#reinhard
noise_pred_vector_magnitude *= (1.0 / top)
new_magnitude = noise_pred_vector_magnitude / (noise_pred_vector_magnitude + 1.0)
new_magnitude *= top
return uncond + noise_pred * new_magnitude * cond_scale
base_model = base_model.clone()
base_model.set_model_sampler_cfg_function(sampler_tonemap_reinhard)
refiner_model = refiner_model.clone()
refiner_model.set_model_sampler_cfg_function(sampler_tonemap_reinhard)
return (base_model, refiner_model, base_clip, refiner_clip)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"style": ("STYLE",),
"prompt_tree": ("DICT",),
"model_config": ("DICT",),
}
}
RETURN_TYPES = ("MODEL","MODEL","CLIP","CLIP")
RETURN_NAMES = ("base_model","refiner_model","base_clip","refiner_clip")
FUNCTION = "load"
CATEGORY = "Vyro/Loaders"
def load(self, style, prompt_tree, model_config):
print("\n\nExecuting VyroModelLoader load function...")
if prompt_tree is None:
raise ValueError("Prompt tree is None")
if style is None or style == "qr":
print("⛔ Style is qr changing to default qr models")
self.tree = prompt_tree
self.config = model_config
cfg = 'default_qr'
return self.get_model(cfg)
#raise ValueError("Style is None")
if style is None or style == "":
print("⛔ Style is none using Default config")
self.tree = prompt_tree
self.config = model_config
cfg = 'default'
return self.get_model(cfg)
#raise ValueError("Style is None")
if style not in prompt_tree.keys():
raise ValueError("Style not in prompt tree")
node = prompt_tree[style]
if 'model_config' not in node:
cfg = 'default'
else:
cfg = node['model_config']
self.tree = prompt_tree
self.config = model_config
return self.get_model(cfg)
class VyroOneflowModelLoader:
def __init__(self):
self.chkp_loader = nodes.CheckpointLoaderSimple()
self.lora_loader = nodes.LoraLoader()
self.tree = None
self.config = None
print("\n\nInitializing VyroModelLoader")
@lru_cache(maxsize=6)
def get_model(self, cfg):
base = self.config['configs'][cfg]['base']
refiner = self.config['configs'][cfg]['refiner']
loras = self.config['configs'][cfg]['loras']
tonemap_multiplier = self.config['configs'][cfg]['tonemap']
base_model, base_clip, _ = self.chkp_loader.load_checkpoint(base)
refiner_model, refiner_clip, _ = self.chkp_loader.load_checkpoint(refiner)
for lora in loras:
base_model, base_clip = self.lora_loader.load_lora(base_model, base_clip, lora['name'], lora['unet'], lora['clip'])
def sampler_tonemap_reinhard(args):
cond = args["cond"]
uncond = args["uncond"]
cond_scale = args["cond_scale"]
noise_pred = (cond - uncond)
noise_pred_vector_magnitude = (torch.linalg.vector_norm(noise_pred, dim=(1)) + 0.0000000001)[:,None]
noise_pred /= noise_pred_vector_magnitude
mean = torch.mean(noise_pred_vector_magnitude, dim=(1,2,3), keepdim=True)
std = torch.std(noise_pred_vector_magnitude, dim=(1,2,3), keepdim=True)
top = (std * 3 + mean) * tonemap_multiplier
#reinhard
noise_pred_vector_magnitude *= (1.0 / top)
new_magnitude = noise_pred_vector_magnitude / (noise_pred_vector_magnitude + 1.0)
new_magnitude *= top
return uncond + noise_pred * new_magnitude * cond_scale
base_model = base_model.clone()
base_model.set_model_sampler_cfg_function(sampler_tonemap_reinhard)
refiner_model = refiner_model.clone()
refiner_model.set_model_sampler_cfg_function(sampler_tonemap_reinhard)
return (base_model, refiner_model, base_clip, refiner_clip)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"style": ("STYLE",),
"prompt_tree": ("DICT",),
"model_config": ("DICT",),
}
}
RETURN_TYPES = ("MODEL","MODEL","CLIP","CLIP")
RETURN_NAMES = ("base_model","refiner_model","base_clip","refiner_clip")
FUNCTION = "load"
CATEGORY = "Vyro/Loaders"
def load(self, style, prompt_tree, model_config):
print("\n\nExecuting VyroOneflowModelLoader load function...")
if prompt_tree is None:
raise ValueError("Prompt tree is None")
if style is None or style == "qr":
print("⛔ Style is qr changing to default qr models")
self.tree = prompt_tree
self.config = model_config
cfg = 'default_qr'
return self.get_model(cfg)
#raise ValueError("Style is None")
if style is None or style == "":
print("⛔ Style is none using Default config")
self.tree = prompt_tree
self.config = model_config
cfg = 'default'
return self.get_model(cfg)
#raise ValueError("Style is None")
if style not in prompt_tree.keys():
raise ValueError("Style not in prompt tree")
node = prompt_tree[style]
if 'model_config' not in node:
cfg = 'default'
else:
cfg = node['model_config']
self.tree = prompt_tree
self.config = model_config
return self.get_model(cfg)
class VyroOneFlowBaseModelLoader:
def __init__(self) -> None:
self.chkp_loader = nodes.CheckpointLoaderSimple()
print("INITIALIZING ONEFLOW BASE MODEL LOADER")
@lru_cache(maxsize=6)
def get_base_model(self, base_model):
print(f"\n\nLoading Base Model: {base_model}\n\n")
base_model, base_clip, _ = self.chkp_loader.load_checkpoint(base_model)
tonemap_multiplier = 1.0
def sampler_tonemap_reinhard(args):
cond = args["cond"]
uncond = args["uncond"]
cond_scale = args["cond_scale"]
noise_pred = (cond - uncond)
noise_pred_vector_magnitude = (torch.linalg.vector_norm(noise_pred, dim=(1)) + 0.0000000001)[:,None]
noise_pred /= noise_pred_vector_magnitude
mean = torch.mean(noise_pred_vector_magnitude, dim=(1,2,3), keepdim=True)
std = torch.std(noise_pred_vector_magnitude, dim=(1,2,3), keepdim=True)
top = (std * 3 + mean) * tonemap_multiplier
#reinhard
noise_pred_vector_magnitude *= (1.0 / top)
new_magnitude = noise_pred_vector_magnitude / (noise_pred_vector_magnitude + 1.0)
new_magnitude *= top
return uncond + noise_pred * new_magnitude * cond_scale
base_model = base_model.clone()
base_model.set_model_sampler_cfg_function(sampler_tonemap_reinhard)
return (base_model, base_clip)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"base_model": (folder_paths.get_filename_list("checkpoints"), ),
}
}
RETURN_TYPES = ("MODEL", "CLIP",)
RETURN_NAMES = ("base_model", "base_clip",)
FUNCTION = "load"
CATERGORY = "Vyro/Loaders/Oneflow"
def load(self, base_model):
return self.get_base_model(base_model)
class VyroLoraLoader:
def __init__(self) -> None:
self.lora_loader = nodes.LoraLoader()
self.tree = None
self.config = None
print("\n\nInitializing Vyro LORA Loader\n\n")
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"base_model": ("MODEL",),
"base_clip": ("CLIP",),
"style": ("STYLE",),
"prompt_tree": ("DICT",),
"model_config": ("DICT",),
}
}
RETURN_TYPES = ("MODEL", "CLIP")
RETURN_NAMES = ("base_model", "base_clip")
FUNCTION = "load_loras"
CATEGORY = "Vyro/Loaders/Lora"
def load_loras(self, base_model, base_clip, style, prompt_tree, model_config):
print("\n\nExecuting VyroLORA load function...")
if prompt_tree is None:
raise ValueError("Prompt tree is None")
if style is None or style == "qr":
print("⛔ Style is qr changing to default qr models")
self.tree = prompt_tree
self.config = model_config
cfg = 'default_qr'
return self.get_model(cfg)
#raise ValueError("Style is None")
if style is None or style == "":
print("⛔ Style is none using Default config")
self.tree = prompt_tree
self.config = model_config
cfg = 'default'
return self.get_model(cfg)
#raise ValueError("Style is None")
if style not in prompt_tree.keys():
raise ValueError("Style not in prompt tree")
node = prompt_tree[style]
if 'model_config' not in node:
cfg = 'default'
else:
cfg = node['model_config']
self.tree = prompt_tree
self.config = model_config
loras = self.config['configs'][cfg]['loras']
for lora in loras:
print(f"\nLoading Lora: {lora['name']}\n")
base_model, base_clip = self.lora_loader.load_lora(base_model, base_clip, lora['name'], lora['unet'], lora['clip'])
return base_model, base_clip
class VyroOneFlowRefinerModelLoader:
def __init__(self) -> None:
self.chkp_loader = nodes.CheckpointLoaderSimple()
self.tree = None
self.config = None
print("INITIALIZING ONEFLOW REFINER MODEL LOADER")
@lru_cache(maxsize=6)
def get_refiner_model(self, cfg):
refiner = self.config['configs'][cfg]['refiner']
print(f"\n\nLoading Refiner Model: {refiner}\n\n")
tonemap_multiplier = self.config['configs'][cfg]['tonemap']
refiner_model, refiner_clip, _ = self.chkp_loader.load_checkpoint(refiner)
def sampler_tonemap_reinhard(args):
cond = args["cond"]
uncond = args["uncond"]
cond_scale = args["cond_scale"]
noise_pred = (cond - uncond)
noise_pred_vector_magnitude = (torch.linalg.vector_norm(noise_pred, dim=(1)) + 0.0000000001)[:,None]
noise_pred /= noise_pred_vector_magnitude
mean = torch.mean(noise_pred_vector_magnitude, dim=(1,2,3), keepdim=True)
std = torch.std(noise_pred_vector_magnitude, dim=(1,2,3), keepdim=True)
top = (std * 3 + mean) * tonemap_multiplier
#reinhard
noise_pred_vector_magnitude *= (1.0 / top)
new_magnitude = noise_pred_vector_magnitude / (noise_pred_vector_magnitude + 1.0)
new_magnitude *= top
return uncond + noise_pred * new_magnitude * cond_scale
refiner_model = refiner_model.clone()
refiner_model.set_model_sampler_cfg_function(sampler_tonemap_reinhard)
return (refiner_model, refiner_clip)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"style": ("STYLE",),
"prompt_tree": ("DICT",),
"model_config": ("DICT",)
}
}
RETURN_TYPES = ("MODEL", "CLIP")
RETURN_NAMES = ("refiner_model", "refiner_clip")
FUNCTION = "load"
CATERGORY = "Vyro/Loaders/Oneflow"
def load(self, style, prompt_tree, model_config):
if prompt_tree is None:
raise ValueError("Prompt Tree is None")
if style is None or style == "qr":
print("⛔ Style is qr changing to default qr models")
self.tree = prompt_tree
self.config = model_config
cfg = "default_qr"
return self.get_model(cfg)
if style is None or style == "":
print("⛔ Style is none using Default config")
self.tree = prompt_tree
self.config = model_config
cfg = 'default'
return self.get_model(cfg)
#raise ValueError("Style is None")
if style not in prompt_tree.keys():
raise ValueError("Style not in prompt tree")
node = prompt_tree[style]
if 'model_config' not in node:
cfg = 'default'
else:
cfg = node['model_config']
self.tree = prompt_tree
self.config = model_config
return self.get_refiner_model(cfg)
NODE_CLASS_MAPPINGS = {
"Vyro Config Loader": VyroConfigLoader,
"Vyro Model Loader": VyroModelLoader,
"Vyro OneFlow Model Loader": VyroOneflowModelLoader,
"Vyro Oneflow Base Model Loader": VyroOneFlowBaseModelLoader,
"Vyro Oneflow Refiner Model Loader": VyroOneFlowRefinerModelLoader,
"Vyro LoRa Loader": VyroLoraLoader,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"VyroConfigLoader": "Vyro Config Loader",
"VyroModelLoader": "Vyro Model Loader",
"VyroOneFlowModelLoader": "Vyro Oneflow Model Loader",
"VyroOneFlowBaseModelLoader": "Vyro Oneflow Base Model Loader",
"VyroOneflowRefinerModelLoader": "Vyro Oneflow Refiner Model Loader",
"VyroLoraLoader": "Vyro LoRa Loader"
}