azkavyro's picture
Added all files including vyro_workflows
6fecfbe
import base64
import io
from PIL import Image
import numpy as np
import simpleeval
import torch
from ..utils import VyroParams
class VyroParamUpdater:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"vyro_params": ("VYRO_PARAMS",),
},
"optional": {
"latents": ("LATENT",),
}
}
RETURN_TYPES = ("VYRO_PARAMS",)
RETURN_NAMES = ("vyro_params",)
FUNCTION = "update"
CATEGORY = "Vyro/Routing"
def update(self, vyro_params:VyroParams, latents=None):
if latents is not None:
vyro_params.latents = latents
return (vyro_params,)
class VyroParamExtractor:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"vyro_params": ("VYRO_PARAMS",),
}
}
RETURN_TYPES = ("LATENT","STRING","STRING","STRING","FLOAT","INT","INT","INT","INT","INT","FLOAT","FLOAT","FLOAT","FLOAT","STRING","STRING","STRING","BOOLEAN")
RETURN_NAMES = ("latents","user_prompt","user_neg_prompt","mode","cfg","batch_size","steps","width","height","seed","denoise","stage1_strength","stage2_strength","efficiency_multiplier","style","final_positive_prompt","final_negative_prompt","is_raw")
FUNCTION = "input"
CATEGORY = "Vyro/Routing"
def input(self, vyro_params:VyroParams):
return (vyro_params.latents, vyro_params.user_prompt, vyro_params.user_neg_prompt, vyro_params.mode, vyro_params.cfg, vyro_params.batch_size, vyro_params.steps, vyro_params.width, vyro_params.height, vyro_params.seed, vyro_params.denoise, vyro_params.stage1_strength, vyro_params.stage2_strength, vyro_params.efficiency_multiplier, vyro_params.style, vyro_params.final_positive_prompt, vyro_params.final_negative_prompt, vyro_params.is_raw)
class VyroModelStreamInput:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"model": ("MODEL",),
"clip": ("CLIP",),
},
"optional": {
"refiner": ("MODEL",),
"refiner_clip": ("CLIP",),
}}
RETURN_TYPES = ("*",)
RETURN_NAMES = ("output",)
FUNCTION = "input"
CATEGORY = "Vyro/Routing"
def input(self, model, clip, refiner=None, refiner_clip=None):
if refiner is None:
return {"model": model, "clip": clip, "refiner": None, "refiner_clip": None},
else:
return {"model": model, "clip": clip, "refiner": refiner, "refiner_clip": refiner_clip},
class VyroModelStreamOutput:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"input": ("*",),
}}
RETURN_TYPES = ("MODEL","CLIP","MODEL","CLIP")
RETURN_NAMES = ("model","clip","refiner_model","refiner_clip")
FUNCTION = "output"
CATEGORY = "Vyro/Routing"
def output(self, input):
if input["refiner"] is None:
return input["model"], input["clip"], None, None
else:
return input["model"], input["clip"], input["refiner"], input["refiner_clip"]
class VyroStyleSwitcher:
@classmethod
def INPUT_TYPES(s):
input_types = {"required": {
"style": ("STYLE", ),
"default": ("*",)
},
"optional": {}
}
for style in VyroParams.STYLES:
input_types["optional"][style] = ("*",)
return input_types
RETURN_TYPES = ("*",)
RETURN_NAMES = ("output",)
FUNCTION = "switch"
CATEGORY = "Vyro/Routing"
def switch(self, style, default, **kwargs):
return (kwargs.get(style, default), )
class VyroModeLatentMuxer:
@classmethod
def INPUT_TYPES(s):
types = {
"required": {
"vyro_params": ("VYRO_PARAMS",),
}
}
for mode in VyroParams.MODE:
types["required"][mode] = ("LATENT",)
return types
RETURN_TYPES = ("LATENT",)
RETURN_NAMES = ("latents",)
FUNCTION = "mux_latents"
CATEGORY = "Vyro/Routing"
def mux_latents(self, vyro_params:VyroParams, **kwargs):
mode = vyro_params.mode
return (kwargs[mode],)
def tensor2pil(image):
return Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8))
class VyroImageToString:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("string",)
FUNCTION = "image_to_string"
CATEGORY = "Vyro/Routing"
def image_to_string(self, image):
batch_size = image.shape[0]
if batch_size > 1:
split = torch.split(image, 1, dim=0)
strs = []
for i in range(len(split)):
image = split[i]
image = image.squeeze(0)
img = tensor2pil(image)
buffered = io.BytesIO()
img.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue())
strs.append(img_str.decode("utf-8"))
return (strs,)
image = image.squeeze(0)
img = tensor2pil(image)
buffered = io.BytesIO()
img.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue())
return (img_str.decode("utf-8"),)
class VyroModeFilter:
@classmethod
def INPUT_TYPES(s):
req = {
"required": {
"vyro_params": ("VYRO_PARAMS",),
}
}
for mode in VyroParams.MODE:
req["required"][f'{mode}'] = (VyroParams.ALLOWED, {"default": VyroParams.ALLOWED[0]})
return req
RETURN_TYPES = ("VYRO_PARAMS",)
RETURN_NAMES = ("vyro_params",)
FUNCTION = "filter"
CATEGORY = "Vyro/Routing"
def filter(self, vyro_params:VyroParams, **kwargs):
mode = vyro_params.mode
state = kwargs[mode]
if state == VyroParams.ALLOWED[1]:
raise ValueError("If you are receiving this error, it's because you're trying to execute the workflow in Comfy without detaching the preview nodes for the inactive modes. If you are reciving this error in the API, you are selecting the wrong output node.")
return (vyro_params,)
class VyroModelSwitcher:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"vyro_params": ("VYRO_PARAMS",),
"a": ("MODEL",),
"b": ("MODEL",),
"return_a_if_true": ("STRING", {"default": "face_swap_img is not None", "multiline": False}),
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "switch"
CATEGORY = "Vyro/Routing"
def switch(self, vyro_params:VyroParams, a, b, return_a_if_true):
names = {}
for n in VyroParams.PARAMS:
names[n] = getattr(vyro_params, n)
result = simpleeval.simple_eval(return_a_if_true, names=names)
if result:
return (a,)
else:
return (b,)
NODE_CLASS_MAPPINGS = {
"Vyro Mode Latent Muxer": VyroModeLatentMuxer,
"Vyro Style Switcher": VyroStyleSwitcher,
"Vyro Model Stream Input": VyroModelStreamInput,
"Vyro Model Stream Output": VyroModelStreamOutput,
"Vyro Param Extractor": VyroParamExtractor,
"Vyro Image to String": VyroImageToString,
"Vyro Mode Filter": VyroModeFilter,
"Vyro Model Switcher": VyroModelSwitcher,
"Vyro Param Updater": VyroParamUpdater,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"VyroModeLatentMuxer": "Vyro Mode Latent Muxer",
"VyroStyleSwitcher": "Vyro Style Switcher",
"VyroModelStreamInput": "Vyro Model Stream Input",
"VyroModelStreamOutput": "Vyro Model Stream Output",
"VyroParamExtractor": "Vyro Param Extractor",
"VyroImageToString": "Vyro Image to String",
"VyroModeFilter": "Vyro Mode Filter",
"VyroModelSwitcher": "Vyro Model Switcher",
"VyroParamUpdater": "Vyro Param Updater",
}