|
import folder_paths |
|
from safetensors.torch import load_file |
|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
class Interposer(nn.Module): |
|
""" |
|
Basic NN layout, ported from: |
|
https://github.com/city96/SD-Latent-Interposer/blob/main/interposer.py |
|
""" |
|
version = 1.1 |
|
def __init__(self): |
|
super().__init__() |
|
module_list = [ |
|
nn.Conv2d(4, 32, kernel_size=5, padding=2), |
|
nn.ReLU(), |
|
nn.Conv2d(32, 128, kernel_size=7, padding=3), |
|
nn.ReLU(), |
|
nn.Conv2d(128, 32, kernel_size=7, padding=3), |
|
nn.ReLU(), |
|
nn.Conv2d(32, 4, kernel_size=5, padding=2), |
|
] |
|
self.sequential = nn.Sequential(*module_list) |
|
def forward(self, x: torch.cuda.FloatTensor) -> torch.cuda.FloatTensor: |
|
return self.sequential(x) |
|
|
|
|
|
class VyroLatentInterposer: |
|
def __init__(self): |
|
self.model = None |
|
self.model_name = None |
|
pass |
|
|
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"samples": ("LATENT", ), |
|
"latent_src": (["v1", "xl"],), |
|
"latent_dst": (["v1", "xl"],), |
|
} |
|
} |
|
|
|
RETURN_TYPES = ("LATENT",) |
|
FUNCTION = "convert" |
|
CATEGORY = "Vyro" |
|
|
|
def convert(self, samples, latent_src, latent_dst): |
|
if latent_src == latent_dst: |
|
return (samples,) |
|
if isinstance(samples, dict): |
|
device = samples["samples"].device |
|
else: |
|
device = samples.device |
|
model_name = f'{latent_src}-to-{latent_dst}_interposer-v1.1.safetensors' |
|
if self.model == None or self.model_name != model_name: |
|
model = Interposer() |
|
weights = folder_paths.get_full_path("interposers", f'{latent_src}-to-{latent_dst}_interposer-v{model.version}.safetensors') |
|
|
|
self.model_name = model_name |
|
model.load_state_dict(load_file(weights)) |
|
self.model = model.to(device) |
|
else: |
|
model = self.model |
|
if isinstance(samples, dict): |
|
lt = samples["samples"] |
|
lt = model(lt) |
|
return ({"samples": lt},) |
|
else: |
|
return model(samples) |
|
|
|
|
|
NODE_CLASS_MAPPINGS = { |
|
"VyroLatentInterposer": VyroLatentInterposer, |
|
} |
|
|
|
NODE_DISPLAY_NAME_MAPPINGS = { |
|
"VyroLatentInterposer": "Vyro Latent Interposer" |
|
} |