import base64 import io from PIL import Image import numpy as np import simpleeval import torch from ..utils import VyroParams class VyroParamUpdater: @classmethod def INPUT_TYPES(s): return { "required": { "vyro_params": ("VYRO_PARAMS",), }, "optional": { "latents": ("LATENT",), } } RETURN_TYPES = ("VYRO_PARAMS",) RETURN_NAMES = ("vyro_params",) FUNCTION = "update" CATEGORY = "Vyro/Routing" def update(self, vyro_params:VyroParams, latents=None): if latents is not None: vyro_params.latents = latents return (vyro_params,) class VyroParamExtractor: @classmethod def INPUT_TYPES(s): return { "required": { "vyro_params": ("VYRO_PARAMS",), } } RETURN_TYPES = ("LATENT","STRING","STRING","STRING","FLOAT","INT","INT","INT","INT","INT","FLOAT","FLOAT","FLOAT","FLOAT","STRING","STRING","STRING","BOOLEAN") RETURN_NAMES = ("latents","user_prompt","user_neg_prompt","mode","cfg","batch_size","steps","width","height","seed","denoise","stage1_strength","stage2_strength","efficiency_multiplier","style","final_positive_prompt","final_negative_prompt","is_raw") FUNCTION = "input" CATEGORY = "Vyro/Routing" def input(self, vyro_params:VyroParams): return (vyro_params.latents, vyro_params.user_prompt, vyro_params.user_neg_prompt, vyro_params.mode, vyro_params.cfg, vyro_params.batch_size, vyro_params.steps, vyro_params.width, vyro_params.height, vyro_params.seed, vyro_params.denoise, vyro_params.stage1_strength, vyro_params.stage2_strength, vyro_params.efficiency_multiplier, vyro_params.style, vyro_params.final_positive_prompt, vyro_params.final_negative_prompt, vyro_params.is_raw) class VyroModelStreamInput: @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), "clip": ("CLIP",), }, "optional": { "refiner": ("MODEL",), "refiner_clip": ("CLIP",), }} RETURN_TYPES = ("*",) RETURN_NAMES = ("output",) FUNCTION = "input" CATEGORY = "Vyro/Routing" def input(self, model, clip, refiner=None, refiner_clip=None): if refiner is None: return {"model": model, "clip": clip, "refiner": None, "refiner_clip": None}, else: return {"model": model, "clip": clip, "refiner": refiner, "refiner_clip": refiner_clip}, class VyroModelStreamOutput: @classmethod def INPUT_TYPES(s): return {"required": { "input": ("*",), }} RETURN_TYPES = ("MODEL","CLIP","MODEL","CLIP") RETURN_NAMES = ("model","clip","refiner_model","refiner_clip") FUNCTION = "output" CATEGORY = "Vyro/Routing" def output(self, input): if input["refiner"] is None: return input["model"], input["clip"], None, None else: return input["model"], input["clip"], input["refiner"], input["refiner_clip"] class VyroStyleSwitcher: @classmethod def INPUT_TYPES(s): input_types = {"required": { "style": ("STYLE", ), "default": ("*",) }, "optional": {} } for style in VyroParams.STYLES: input_types["optional"][style] = ("*",) return input_types RETURN_TYPES = ("*",) RETURN_NAMES = ("output",) FUNCTION = "switch" CATEGORY = "Vyro/Routing" def switch(self, style, default, **kwargs): return (kwargs.get(style, default), ) class VyroModeLatentMuxer: @classmethod def INPUT_TYPES(s): types = { "required": { "vyro_params": ("VYRO_PARAMS",), } } for mode in VyroParams.MODE: types["required"][mode] = ("LATENT",) return types RETURN_TYPES = ("LATENT",) RETURN_NAMES = ("latents",) FUNCTION = "mux_latents" CATEGORY = "Vyro/Routing" def mux_latents(self, vyro_params:VyroParams, **kwargs): mode = vyro_params.mode return (kwargs[mode],) def tensor2pil(image): return Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)) class VyroImageToString: @classmethod def INPUT_TYPES(s): return { "required": { "image": ("IMAGE",), } } RETURN_TYPES = ("STRING",) RETURN_NAMES = ("string",) FUNCTION = "image_to_string" CATEGORY = "Vyro/Routing" def image_to_string(self, image): batch_size = image.shape[0] if batch_size > 1: split = torch.split(image, 1, dim=0) strs = [] for i in range(len(split)): image = split[i] image = image.squeeze(0) img = tensor2pil(image) buffered = io.BytesIO() img.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()) strs.append(img_str.decode("utf-8")) return (strs,) image = image.squeeze(0) img = tensor2pil(image) buffered = io.BytesIO() img.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()) return (img_str.decode("utf-8"),) class VyroModeFilter: @classmethod def INPUT_TYPES(s): req = { "required": { "vyro_params": ("VYRO_PARAMS",), } } for mode in VyroParams.MODE: req["required"][f'{mode}'] = (VyroParams.ALLOWED, {"default": VyroParams.ALLOWED[0]}) return req RETURN_TYPES = ("VYRO_PARAMS",) RETURN_NAMES = ("vyro_params",) FUNCTION = "filter" CATEGORY = "Vyro/Routing" def filter(self, vyro_params:VyroParams, **kwargs): mode = vyro_params.mode state = kwargs[mode] if state == VyroParams.ALLOWED[1]: raise ValueError("If you are receiving this error, it's because you're trying to execute the workflow in Comfy without detaching the preview nodes for the inactive modes. If you are reciving this error in the API, you are selecting the wrong output node.") return (vyro_params,) class VyroModelSwitcher: @classmethod def INPUT_TYPES(s): return { "required": { "vyro_params": ("VYRO_PARAMS",), "a": ("MODEL",), "b": ("MODEL",), "return_a_if_true": ("STRING", {"default": "face_swap_img is not None", "multiline": False}), } } RETURN_TYPES = ("MODEL",) FUNCTION = "switch" CATEGORY = "Vyro/Routing" def switch(self, vyro_params:VyroParams, a, b, return_a_if_true): names = {} for n in VyroParams.PARAMS: names[n] = getattr(vyro_params, n) result = simpleeval.simple_eval(return_a_if_true, names=names) if result: return (a,) else: return (b,) NODE_CLASS_MAPPINGS = { "Vyro Mode Latent Muxer": VyroModeLatentMuxer, "Vyro Style Switcher": VyroStyleSwitcher, "Vyro Model Stream Input": VyroModelStreamInput, "Vyro Model Stream Output": VyroModelStreamOutput, "Vyro Param Extractor": VyroParamExtractor, "Vyro Image to String": VyroImageToString, "Vyro Mode Filter": VyroModeFilter, "Vyro Model Switcher": VyroModelSwitcher, "Vyro Param Updater": VyroParamUpdater, } NODE_DISPLAY_NAME_MAPPINGS = { "VyroModeLatentMuxer": "Vyro Mode Latent Muxer", "VyroStyleSwitcher": "Vyro Style Switcher", "VyroModelStreamInput": "Vyro Model Stream Input", "VyroModelStreamOutput": "Vyro Model Stream Output", "VyroParamExtractor": "Vyro Param Extractor", "VyroImageToString": "Vyro Image to String", "VyroModeFilter": "Vyro Mode Filter", "VyroModelSwitcher": "Vyro Model Switcher", "VyroParamUpdater": "Vyro Param Updater", }