azkavyro's picture
Added all files including vyro_workflows
6fecfbe
import folder_paths
from safetensors.torch import load_file
import torch
import torch.nn as nn
class Interposer(nn.Module):
"""
Basic NN layout, ported from:
https://github.com/city96/SD-Latent-Interposer/blob/main/interposer.py
"""
version = 1.1 # network revision
def __init__(self):
super().__init__()
module_list = [
nn.Conv2d(4, 32, kernel_size=5, padding=2),
nn.ReLU(),
nn.Conv2d(32, 128, kernel_size=7, padding=3),
nn.ReLU(),
nn.Conv2d(128, 32, kernel_size=7, padding=3),
nn.ReLU(),
nn.Conv2d(32, 4, kernel_size=5, padding=2),
]
self.sequential = nn.Sequential(*module_list)
def forward(self, x: torch.cuda.FloatTensor) -> torch.cuda.FloatTensor:
return self.sequential(x)
class VyroLatentInterposer:
def __init__(self):
self.model = None
self.model_name = None
pass
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"samples": ("LATENT", ),
"latent_src": (["v1", "xl"],),
"latent_dst": (["v1", "xl"],),
}
}
RETURN_TYPES = ("LATENT",)
FUNCTION = "convert"
CATEGORY = "Vyro"
def convert(self, samples, latent_src, latent_dst):
if latent_src == latent_dst:
return (samples,)
if isinstance(samples, dict):
device = samples["samples"].device
else:
device = samples.device
model_name = f'{latent_src}-to-{latent_dst}_interposer-v1.1.safetensors'
if self.model == None or self.model_name != model_name:
model = Interposer()
weights = folder_paths.get_full_path("interposers", f'{latent_src}-to-{latent_dst}_interposer-v{model.version}.safetensors')
self.model_name = model_name
model.load_state_dict(load_file(weights))
self.model = model.to(device)
else:
model = self.model
if isinstance(samples, dict):
lt = samples["samples"]
lt = model(lt)
return ({"samples": lt},)
else:
return model(samples)
NODE_CLASS_MAPPINGS = {
"VyroLatentInterposer": VyroLatentInterposer,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"VyroLatentInterposer": "Vyro Latent Interposer"
}