from functools import lru_cache import json import os import folder_paths from folder_paths import folder_names_and_paths, supported_pt_extensions import nodes import spacy import torch def update_paths(): model_path = folder_paths.models_dir folder_names_and_paths["vyro_configs"] = ([os.path.join(os.path.dirname(os.path.abspath(__file__)), "../configs")], ['.json']) folder_names_and_paths["spacy"] = ([(os.path.join(model_path, "spacy"))], supported_pt_extensions) folder_names_and_paths["interposers"] = ([(os.path.join(model_path, "interposers"))], supported_pt_extensions) update_paths() class VyroConfigLoader: def __init__(self): pass @classmethod def INPUT_TYPES(s): paths = [] update_paths() for search_path in folder_paths.get_folder_paths("spacy"): if os.path.exists(search_path): for root, subdir, files in os.walk(search_path, followlinks=True): if "config.cfg" in files: paths.append(os.path.relpath(root, start=search_path)) return { "required": { "config_path": (folder_paths.get_filename_list("vyro_configs"),), "classifier_path": (paths,), } } RETURN_TYPES = ("LIST","DICT","DICT","TRANSFORMER","LIST") RETURN_NAMES = ("styles","prompt_tree","model_config","classifier","unweighted_styles") FUNCTION = "load_config" CATEGORY = "Vyro/Loaders" def load_config(self, config_path, classifier_path): config_path = folder_paths.get_full_path("vyro_configs", config_path) classifier_path = os.path.join(folder_names_and_paths["spacy"][0][0],classifier_path) # classifier = pipeline("zero-shot-classification", model=classifier_path,device='cuda:0',local_files_only=True,use_safetensors=True) spacy.prefer_gpu(gpu_id=0) classifier = spacy.load(classifier_path) # classifier = pipeline("zero-shot-classification", model=classifier_path, config=f'{classifier_path}.json', device=0) with open(config_path, 'r') as json_file: try: config = json.load(json_file) except json.JSONDecodeError as json_e: print(f"[VyroConfigLoader] Error loading {config_path}:", json_e) return None unweighted_styles = [] for style in config['styles']: if ':' in style: style = style.split(':')[0] unweighted_styles.append(style) return (config['styles'], config['prompt_tree'], config['model_config'], classifier, unweighted_styles) class VyroModelLoader: def __init__(self): self.chkp_loader = nodes.CheckpointLoaderSimple() self.lora_loader = nodes.LoraLoader() self.tree = None self.config = None @lru_cache(maxsize=6) def get_model(self, cfg): base = self.config['configs'][cfg]['base'] refiner = self.config['configs'][cfg]['refiner'] loras = self.config['configs'][cfg]['loras'] tonemap_multiplier = self.config['configs'][cfg]['tonemap'] base_model, base_clip, _ = self.chkp_loader.load_checkpoint(base) refiner_model, refiner_clip, _ = self.chkp_loader.load_checkpoint(refiner) for lora in loras: base_model, base_clip = self.lora_loader.load_lora(base_model, base_clip, lora['name'], lora['unet'], lora['clip']) def sampler_tonemap_reinhard(args): cond = args["cond"] uncond = args["uncond"] cond_scale = args["cond_scale"] noise_pred = (cond - uncond) noise_pred_vector_magnitude = (torch.linalg.vector_norm(noise_pred, dim=(1)) + 0.0000000001)[:,None] noise_pred /= noise_pred_vector_magnitude mean = torch.mean(noise_pred_vector_magnitude, dim=(1,2,3), keepdim=True) std = torch.std(noise_pred_vector_magnitude, dim=(1,2,3), keepdim=True) top = (std * 3 + mean) * tonemap_multiplier #reinhard noise_pred_vector_magnitude *= (1.0 / top) new_magnitude = noise_pred_vector_magnitude / (noise_pred_vector_magnitude + 1.0) new_magnitude *= top return uncond + noise_pred * new_magnitude * cond_scale base_model = base_model.clone() base_model.set_model_sampler_cfg_function(sampler_tonemap_reinhard) refiner_model = refiner_model.clone() refiner_model.set_model_sampler_cfg_function(sampler_tonemap_reinhard) return (base_model, refiner_model, base_clip, refiner_clip) @classmethod def INPUT_TYPES(s): return { "required": { "style": ("STYLE",), "prompt_tree": ("DICT",), "model_config": ("DICT",), } } RETURN_TYPES = ("MODEL","MODEL","CLIP","CLIP") RETURN_NAMES = ("base_model","refiner_model","base_clip","refiner_clip") FUNCTION = "load" CATEGORY = "Vyro/Loaders" def load(self, style, prompt_tree, model_config): if prompt_tree is None: raise ValueError("Prompt tree is None") if style is None or style == "qr": print("⛔ Style is qr changing to default qr models") self.tree = prompt_tree self.config = model_config cfg = 'default_qr' return self.get_model(cfg) #raise ValueError("Style is None") if style is None or style == "": print("⛔ Style is none using Default config") self.tree = prompt_tree self.config = model_config cfg = 'default' return self.get_model(cfg) #raise ValueError("Style is None") if style not in prompt_tree.keys(): raise ValueError("Style not in prompt tree") node = prompt_tree[style] if 'model_config' not in node: cfg = 'default' else: cfg = node['model_config'] self.tree = prompt_tree self.config = model_config return self.get_model(cfg) NODE_CLASS_MAPPINGS = { "Vyro Config Loader": VyroConfigLoader, "Vyro Model Loader": VyroModelLoader, } NODE_DISPLAY_NAME_MAPPINGS = { "VyroConfigLoader": "Vyro Config Loader", "VyroModelLoader": "Vyro Model Loader", }