azkavyro's picture
Added all files including vyro_workflows
6fecfbe
from functools import lru_cache
import json
import os
import folder_paths
from folder_paths import folder_names_and_paths, supported_pt_extensions
import nodes
import spacy
import torch
def update_paths():
model_path = folder_paths.models_dir
folder_names_and_paths["vyro_configs"] = ([os.path.join(os.path.dirname(os.path.abspath(__file__)), "../configs")], ['.json'])
folder_names_and_paths["spacy"] = ([(os.path.join(model_path, "spacy"))], supported_pt_extensions)
folder_names_and_paths["interposers"] = ([(os.path.join(model_path, "interposers"))], supported_pt_extensions)
update_paths()
class VyroConfigLoader:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(s):
paths = []
update_paths()
for search_path in folder_paths.get_folder_paths("spacy"):
if os.path.exists(search_path):
for root, subdir, files in os.walk(search_path, followlinks=True):
if "config.cfg" in files:
paths.append(os.path.relpath(root, start=search_path))
return {
"required": {
"config_path": (folder_paths.get_filename_list("vyro_configs"),),
"classifier_path": (paths,),
}
}
RETURN_TYPES = ("LIST","DICT","DICT","TRANSFORMER","LIST")
RETURN_NAMES = ("styles","prompt_tree","model_config","classifier","unweighted_styles")
FUNCTION = "load_config"
CATEGORY = "Vyro/Loaders"
def load_config(self, config_path, classifier_path):
config_path = folder_paths.get_full_path("vyro_configs", config_path)
classifier_path = os.path.join(folder_names_and_paths["spacy"][0][0],classifier_path)
# classifier = pipeline("zero-shot-classification", model=classifier_path,device='cuda:0',local_files_only=True,use_safetensors=True)
spacy.prefer_gpu(gpu_id=0)
classifier = spacy.load(classifier_path)
# classifier = pipeline("zero-shot-classification", model=classifier_path, config=f'{classifier_path}.json', device=0)
with open(config_path, 'r') as json_file:
try:
config = json.load(json_file)
except json.JSONDecodeError as json_e:
print(f"[VyroConfigLoader] Error loading {config_path}:", json_e)
return None
unweighted_styles = []
for style in config['styles']:
if ':' in style:
style = style.split(':')[0]
unweighted_styles.append(style)
return (config['styles'], config['prompt_tree'], config['model_config'], classifier, unweighted_styles)
class VyroModelLoader:
def __init__(self):
self.chkp_loader = nodes.CheckpointLoaderSimple()
self.lora_loader = nodes.LoraLoader()
self.tree = None
self.config = None
@lru_cache(maxsize=6)
def get_model(self, cfg):
base = self.config['configs'][cfg]['base']
refiner = self.config['configs'][cfg]['refiner']
loras = self.config['configs'][cfg]['loras']
tonemap_multiplier = self.config['configs'][cfg]['tonemap']
base_model, base_clip, _ = self.chkp_loader.load_checkpoint(base)
refiner_model, refiner_clip, _ = self.chkp_loader.load_checkpoint(refiner)
for lora in loras:
base_model, base_clip = self.lora_loader.load_lora(base_model, base_clip, lora['name'], lora['unet'], lora['clip'])
def sampler_tonemap_reinhard(args):
cond = args["cond"]
uncond = args["uncond"]
cond_scale = args["cond_scale"]
noise_pred = (cond - uncond)
noise_pred_vector_magnitude = (torch.linalg.vector_norm(noise_pred, dim=(1)) + 0.0000000001)[:,None]
noise_pred /= noise_pred_vector_magnitude
mean = torch.mean(noise_pred_vector_magnitude, dim=(1,2,3), keepdim=True)
std = torch.std(noise_pred_vector_magnitude, dim=(1,2,3), keepdim=True)
top = (std * 3 + mean) * tonemap_multiplier
#reinhard
noise_pred_vector_magnitude *= (1.0 / top)
new_magnitude = noise_pred_vector_magnitude / (noise_pred_vector_magnitude + 1.0)
new_magnitude *= top
return uncond + noise_pred * new_magnitude * cond_scale
base_model = base_model.clone()
base_model.set_model_sampler_cfg_function(sampler_tonemap_reinhard)
refiner_model = refiner_model.clone()
refiner_model.set_model_sampler_cfg_function(sampler_tonemap_reinhard)
return (base_model, refiner_model, base_clip, refiner_clip)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"style": ("STYLE",),
"prompt_tree": ("DICT",),
"model_config": ("DICT",),
}
}
RETURN_TYPES = ("MODEL","MODEL","CLIP","CLIP")
RETURN_NAMES = ("base_model","refiner_model","base_clip","refiner_clip")
FUNCTION = "load"
CATEGORY = "Vyro/Loaders"
def load(self, style, prompt_tree, model_config):
if prompt_tree is None:
raise ValueError("Prompt tree is None")
if style is None or style == "qr":
print("⛔ Style is qr changing to default qr models")
self.tree = prompt_tree
self.config = model_config
cfg = 'default_qr'
return self.get_model(cfg)
#raise ValueError("Style is None")
if style is None or style == "":
print("⛔ Style is none using Default config")
self.tree = prompt_tree
self.config = model_config
cfg = 'default'
return self.get_model(cfg)
#raise ValueError("Style is None")
if style not in prompt_tree.keys():
raise ValueError("Style not in prompt tree")
node = prompt_tree[style]
if 'model_config' not in node:
cfg = 'default'
else:
cfg = node['model_config']
self.tree = prompt_tree
self.config = model_config
return self.get_model(cfg)
NODE_CLASS_MAPPINGS = {
"Vyro Config Loader": VyroConfigLoader,
"Vyro Model Loader": VyroModelLoader,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"VyroConfigLoader": "Vyro Config Loader",
"VyroModelLoader": "Vyro Model Loader",
}