File size: 2,007 Bytes
6fecfbe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import folder_paths
from safetensors.torch import load_file
import torch
import torch.nn as nn
class Interposer(nn.Module):
"""
Basic NN layout, ported from:
https://github.com/city96/SD-Latent-Interposer/blob/main/interposer.py
"""
version = 1.1 # network revision
def __init__(self):
super().__init__()
module_list = [
nn.Conv2d(4, 32, kernel_size=5, padding=2),
nn.ReLU(),
nn.Conv2d(32, 128, kernel_size=7, padding=3),
nn.ReLU(),
nn.Conv2d(128, 32, kernel_size=7, padding=3),
nn.ReLU(),
nn.Conv2d(32, 4, kernel_size=5, padding=2),
]
self.sequential = nn.Sequential(*module_list)
def forward(self, x: torch.cuda.FloatTensor) -> torch.cuda.FloatTensor:
return self.sequential(x)
class VyroLatentInterposer:
def __init__(self):
self.model = None
self.model_name = None
pass
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"samples": ("LATENT", ),
"latent_src": (["v1", "xl"],),
"latent_dst": (["v1", "xl"],),
}
}
RETURN_TYPES = ("LATENT",)
FUNCTION = "convert"
CATEGORY = "Vyro"
def convert(self, samples, latent_src, latent_dst):
if latent_src == latent_dst:
return (samples,)
if isinstance(samples, dict):
device = samples["samples"].device
else:
device = samples.device
model_name = f'{latent_src}-to-{latent_dst}_interposer-v1.1.safetensors'
if self.model == None or self.model_name != model_name:
model = Interposer()
weights = folder_paths.get_full_path("interposers", f'{latent_src}-to-{latent_dst}_interposer-v{model.version}.safetensors')
self.model_name = model_name
model.load_state_dict(load_file(weights))
self.model = model.to(device)
else:
model = self.model
if isinstance(samples, dict):
lt = samples["samples"]
lt = model(lt)
return ({"samples": lt},)
else:
return model(samples)
NODE_CLASS_MAPPINGS = {
"VyroLatentInterposer": VyroLatentInterposer,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"VyroLatentInterposer": "Vyro Latent Interposer"
} |