|
import base64 |
|
import io |
|
|
|
from PIL import Image |
|
|
|
import numpy as np |
|
import simpleeval |
|
import torch |
|
|
|
from ..utils import VyroParams |
|
|
|
class VyroParamUpdater: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"vyro_params": ("VYRO_PARAMS",), |
|
}, |
|
"optional": { |
|
"latents": ("LATENT",), |
|
} |
|
} |
|
|
|
RETURN_TYPES = ("VYRO_PARAMS",) |
|
RETURN_NAMES = ("vyro_params",) |
|
FUNCTION = "update" |
|
CATEGORY = "Vyro/Routing" |
|
|
|
def update(self, vyro_params:VyroParams, latents=None): |
|
if latents is not None: |
|
vyro_params.latents = latents |
|
return (vyro_params,) |
|
|
|
class VyroParamExtractor: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"vyro_params": ("VYRO_PARAMS",), |
|
} |
|
} |
|
RETURN_TYPES = ("LATENT","STRING","STRING","STRING","FLOAT","INT","INT","INT","INT","INT","FLOAT","FLOAT","FLOAT","FLOAT","STRING","STRING","STRING","BOOLEAN") |
|
RETURN_NAMES = ("latents","user_prompt","user_neg_prompt","mode","cfg","batch_size","steps","width","height","seed","denoise","stage1_strength","stage2_strength","efficiency_multiplier","style","final_positive_prompt","final_negative_prompt","is_raw") |
|
FUNCTION = "input" |
|
CATEGORY = "Vyro/Routing" |
|
|
|
def input(self, vyro_params:VyroParams): |
|
|
|
return (vyro_params.latents, vyro_params.user_prompt, vyro_params.user_neg_prompt, vyro_params.mode, vyro_params.cfg, vyro_params.batch_size, vyro_params.steps, vyro_params.width, vyro_params.height, vyro_params.seed, vyro_params.denoise, vyro_params.stage1_strength, vyro_params.stage2_strength, vyro_params.efficiency_multiplier, vyro_params.style, vyro_params.final_positive_prompt, vyro_params.final_negative_prompt, vyro_params.is_raw) |
|
|
|
class VyroModelStreamInput: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"model": ("MODEL",), |
|
"clip": ("CLIP",), |
|
}, |
|
"optional": { |
|
"refiner": ("MODEL",), |
|
"refiner_clip": ("CLIP",), |
|
}} |
|
|
|
RETURN_TYPES = ("*",) |
|
RETURN_NAMES = ("output",) |
|
FUNCTION = "input" |
|
CATEGORY = "Vyro/Routing" |
|
|
|
def input(self, model, clip, refiner=None, refiner_clip=None): |
|
if refiner is None: |
|
return {"model": model, "clip": clip, "refiner": None, "refiner_clip": None}, |
|
else: |
|
return {"model": model, "clip": clip, "refiner": refiner, "refiner_clip": refiner_clip}, |
|
|
|
class VyroModelStreamOutput: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"input": ("*",), |
|
}} |
|
RETURN_TYPES = ("MODEL","CLIP","MODEL","CLIP") |
|
RETURN_NAMES = ("model","clip","refiner_model","refiner_clip") |
|
FUNCTION = "output" |
|
CATEGORY = "Vyro/Routing" |
|
|
|
def output(self, input): |
|
if input["refiner"] is None: |
|
return input["model"], input["clip"], None, None |
|
else: |
|
return input["model"], input["clip"], input["refiner"], input["refiner_clip"] |
|
|
|
class VyroStyleSwitcher: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
input_types = {"required": { |
|
"style": ("STYLE", ), |
|
"default": ("*",) |
|
}, |
|
"optional": {} |
|
} |
|
for style in VyroParams.STYLES: |
|
input_types["optional"][style] = ("*",) |
|
return input_types |
|
|
|
RETURN_TYPES = ("*",) |
|
RETURN_NAMES = ("output",) |
|
FUNCTION = "switch" |
|
CATEGORY = "Vyro/Routing" |
|
|
|
def switch(self, style, default, **kwargs): |
|
return (kwargs.get(style, default), ) |
|
|
|
|
|
class VyroModeLatentMuxer: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
types = { |
|
"required": { |
|
"vyro_params": ("VYRO_PARAMS",), |
|
} |
|
} |
|
for mode in VyroParams.MODE: |
|
types["required"][mode] = ("LATENT",) |
|
return types |
|
|
|
RETURN_TYPES = ("LATENT",) |
|
RETURN_NAMES = ("latents",) |
|
FUNCTION = "mux_latents" |
|
CATEGORY = "Vyro/Routing" |
|
|
|
def mux_latents(self, vyro_params:VyroParams, **kwargs): |
|
mode = vyro_params.mode |
|
return (kwargs[mode],) |
|
|
|
def tensor2pil(image): |
|
return Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)) |
|
|
|
class VyroImageToString: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"image": ("IMAGE",), |
|
} |
|
} |
|
|
|
RETURN_TYPES = ("STRING",) |
|
RETURN_NAMES = ("string",) |
|
FUNCTION = "image_to_string" |
|
CATEGORY = "Vyro/Routing" |
|
|
|
def image_to_string(self, image): |
|
batch_size = image.shape[0] |
|
if batch_size > 1: |
|
split = torch.split(image, 1, dim=0) |
|
strs = [] |
|
for i in range(len(split)): |
|
image = split[i] |
|
image = image.squeeze(0) |
|
img = tensor2pil(image) |
|
|
|
buffered = io.BytesIO() |
|
img.save(buffered, format="PNG") |
|
img_str = base64.b64encode(buffered.getvalue()) |
|
strs.append(img_str.decode("utf-8")) |
|
return (strs,) |
|
|
|
image = image.squeeze(0) |
|
img = tensor2pil(image) |
|
|
|
buffered = io.BytesIO() |
|
img.save(buffered, format="PNG") |
|
img_str = base64.b64encode(buffered.getvalue()) |
|
|
|
return (img_str.decode("utf-8"),) |
|
|
|
class VyroModeFilter: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
|
|
req = { |
|
"required": { |
|
"vyro_params": ("VYRO_PARAMS",), |
|
} |
|
} |
|
for mode in VyroParams.MODE: |
|
req["required"][f'{mode}'] = (VyroParams.ALLOWED, {"default": VyroParams.ALLOWED[0]}) |
|
return req |
|
|
|
RETURN_TYPES = ("VYRO_PARAMS",) |
|
RETURN_NAMES = ("vyro_params",) |
|
FUNCTION = "filter" |
|
CATEGORY = "Vyro/Routing" |
|
|
|
def filter(self, vyro_params:VyroParams, **kwargs): |
|
mode = vyro_params.mode |
|
|
|
state = kwargs[mode] |
|
if state == VyroParams.ALLOWED[1]: |
|
raise ValueError("If you are receiving this error, it's because you're trying to execute the workflow in Comfy without detaching the preview nodes for the inactive modes. If you are reciving this error in the API, you are selecting the wrong output node.") |
|
|
|
return (vyro_params,) |
|
|
|
class VyroModelSwitcher: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
|
|
return { |
|
"required": { |
|
"vyro_params": ("VYRO_PARAMS",), |
|
"a": ("MODEL",), |
|
"b": ("MODEL",), |
|
"return_a_if_true": ("STRING", {"default": "face_swap_img is not None", "multiline": False}), |
|
} |
|
} |
|
|
|
RETURN_TYPES = ("MODEL",) |
|
FUNCTION = "switch" |
|
CATEGORY = "Vyro/Routing" |
|
|
|
def switch(self, vyro_params:VyroParams, a, b, return_a_if_true): |
|
names = {} |
|
for n in VyroParams.PARAMS: |
|
names[n] = getattr(vyro_params, n) |
|
|
|
result = simpleeval.simple_eval(return_a_if_true, names=names) |
|
if result: |
|
return (a,) |
|
else: |
|
return (b,) |
|
|
|
NODE_CLASS_MAPPINGS = { |
|
"Vyro Mode Latent Muxer": VyroModeLatentMuxer, |
|
"Vyro Style Switcher": VyroStyleSwitcher, |
|
"Vyro Model Stream Input": VyroModelStreamInput, |
|
"Vyro Model Stream Output": VyroModelStreamOutput, |
|
"Vyro Param Extractor": VyroParamExtractor, |
|
"Vyro Image to String": VyroImageToString, |
|
"Vyro Mode Filter": VyroModeFilter, |
|
"Vyro Model Switcher": VyroModelSwitcher, |
|
"Vyro Param Updater": VyroParamUpdater, |
|
} |
|
|
|
|
|
NODE_DISPLAY_NAME_MAPPINGS = { |
|
"VyroModeLatentMuxer": "Vyro Mode Latent Muxer", |
|
"VyroStyleSwitcher": "Vyro Style Switcher", |
|
"VyroModelStreamInput": "Vyro Model Stream Input", |
|
"VyroModelStreamOutput": "Vyro Model Stream Output", |
|
"VyroParamExtractor": "Vyro Param Extractor", |
|
"VyroImageToString": "Vyro Image to String", |
|
"VyroModeFilter": "Vyro Mode Filter", |
|
"VyroModelSwitcher": "Vyro Model Switcher", |
|
"VyroParamUpdater": "Vyro Param Updater", |
|
} |