File size: 2,007 Bytes
6fecfbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import folder_paths
from safetensors.torch import load_file
import torch
import torch.nn as nn


class Interposer(nn.Module):
	"""
		Basic NN layout, ported from:
		https://github.com/city96/SD-Latent-Interposer/blob/main/interposer.py
	"""
	version = 1.1 # network revision
	def __init__(self):
		super().__init__()
		module_list = [
			nn.Conv2d(4, 32, kernel_size=5, padding=2),
			nn.ReLU(),
			nn.Conv2d(32, 128, kernel_size=7, padding=3),
			nn.ReLU(),
			nn.Conv2d(128, 32, kernel_size=7, padding=3),
			nn.ReLU(),
			nn.Conv2d(32, 4, kernel_size=5, padding=2),
		]
		self.sequential = nn.Sequential(*module_list)
	def forward(self, x: torch.cuda.FloatTensor) -> torch.cuda.FloatTensor:
		return self.sequential(x)


class VyroLatentInterposer:
	def __init__(self):
		self.model = None
		self.model_name = None
		pass

	@classmethod
	def INPUT_TYPES(s):
		return {
			"required": {
				"samples": ("LATENT", ),
				"latent_src": (["v1", "xl"],),
				"latent_dst": (["v1", "xl"],),
			}
		}

	RETURN_TYPES = ("LATENT",)
	FUNCTION = "convert"
	CATEGORY = "Vyro"

	def convert(self, samples, latent_src, latent_dst):
		if latent_src == latent_dst:
			return (samples,)
		if isinstance(samples, dict):
			device = samples["samples"].device
		else:
			device = samples.device
		model_name = f'{latent_src}-to-{latent_dst}_interposer-v1.1.safetensors'
		if self.model == None or self.model_name != model_name:
			model = Interposer()
			weights = folder_paths.get_full_path("interposers", f'{latent_src}-to-{latent_dst}_interposer-v{model.version}.safetensors')
	
			self.model_name = model_name
			model.load_state_dict(load_file(weights))
			self.model = model.to(device)
		else:
			model = self.model
		if isinstance(samples, dict):
			lt = samples["samples"]
			lt = model(lt)
			return ({"samples": lt},)
		else:
			return model(samples)
		

NODE_CLASS_MAPPINGS = {
	"VyroLatentInterposer": VyroLatentInterposer,
}

NODE_DISPLAY_NAME_MAPPINGS = {
    "VyroLatentInterposer": "Vyro Latent Interposer"
}