File size: 5,901 Bytes
d6d7648
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
from einops import rearrange
import torch
import torch.nn as nn
import torch.nn.functional as F

from unet_utils import FFInflatedConv3d


class FFSpatioTempResUpsample3D(nn.Module):
	def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
		super().__init__()
		self.channels = channels
		self.out_channels = out_channels or channels
		self.use_conv = use_conv
		self.use_conv_transpose = use_conv_transpose
		self.name = name

		conv = None
		if use_conv_transpose:
			raise NotImplementedError
		elif use_conv:
			conv = FFInflatedConv3d(self.channels, self.out_channels, 3, padding=1)

		if name == "conv":
			self.conv = conv
		else:
			self.Conv2d_0 = conv

	def forward(self, hidden_states, output_size=None):
		assert hidden_states.shape[1] == self.channels

		if self.use_conv_transpose:
			raise NotImplementedError

		# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
		dtype = hidden_states.dtype
		if dtype == torch.bfloat16:
			hidden_states = hidden_states.to(torch.float32)

		# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
		if hidden_states.shape[0] >= 64:
			hidden_states = hidden_states.contiguous()

		# if `output_size` is passed we force the interpolation output
		# size and do not make use of `scale_factor=2`
		if output_size is None:
			hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
		else:
			hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")

		# If the input is bfloat16, we cast back to bfloat16
		if dtype == torch.bfloat16:
			hidden_states = hidden_states.to(dtype)

		if self.use_conv:
			if self.name == "conv":
				hidden_states = self.conv(hidden_states)
			else:
				hidden_states = self.Conv2d_0(hidden_states)

		return hidden_states


class FFSpatioTempResDownsample3D(nn.Module):
	def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
		super().__init__()
		self.channels = channels
		self.out_channels = out_channels or channels
		self.use_conv = use_conv
		self.padding = padding
		stride = 2
		self.name = name

		if use_conv:
			conv = FFInflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
		else:
			raise NotImplementedError

		if name == "conv":
			self.Conv2d_0 = conv
			self.conv = conv
		elif name == "Conv2d_0":
			self.conv = conv
		else:
			self.conv = conv

	def forward(self, hidden_states):
		assert hidden_states.shape[1] == self.channels
		if self.use_conv and self.padding == 0:
			raise NotImplementedError

		assert hidden_states.shape[1] == self.channels
		hidden_states = self.conv(hidden_states)

		return hidden_states


class FFSpatioTempResnetBlock3D(nn.Module):
	def __init__(
		self,
		*,
		in_channels,
		out_channels=None,
		conv_shortcut=False,
		dropout=0.0,
		temb_channels=512,
		groups=32,
		groups_out=None,
		pre_norm=True,
		eps=1e-6,
		non_linearity="swish",
		time_embedding_norm="default",
		output_scale_factor=1.0,
		use_in_shortcut=None
	):
		super().__init__()
		self.pre_norm = pre_norm
		self.pre_norm = True
		self.in_channels = in_channels
		out_channels = in_channels if out_channels is None else out_channels
		self.out_channels = out_channels
		self.use_conv_shortcut = conv_shortcut
		self.time_embedding_norm = time_embedding_norm
		self.output_scale_factor = output_scale_factor

		if groups_out is None:
			groups_out = groups

		self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)

		self.conv1 = FFInflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)

		if temb_channels is not None:
			if self.time_embedding_norm == "default":
				time_emb_proj_out_channels = out_channels
			elif self.time_embedding_norm == "scale_shift":
				time_emb_proj_out_channels = out_channels * 2
			else:
				raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")

			self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
		else:
			self.time_emb_proj = None

		self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
		self.dropout = torch.nn.Dropout(dropout)
		self.conv2 = FFInflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)

		if non_linearity == "swish":
			self.nonlinearity = lambda x: F.silu(x)
		elif non_linearity == "silu":
			self.nonlinearity = nn.SiLU()

		self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut

		self.conv_shortcut = None
		if self.use_in_shortcut:
			self.conv_shortcut = FFInflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)

	def forward(self, input_tensor, temb):
		hidden_states = input_tensor

		hidden_states = self.norm1(hidden_states)
		hidden_states = self.nonlinearity(hidden_states)

		hidden_states = self.conv1(hidden_states)

		if temb is not None:
			temb = rearrange(self.time_emb_proj(self.nonlinearity(temb)), "b f c -> b c f")[:, :, :, None, None]

		if temb is not None and self.time_embedding_norm == "default":
			hidden_states = hidden_states + temb

		hidden_states = self.norm2(hidden_states)

		if temb is not None and self.time_embedding_norm == "scale_shift":
			scale, shift = torch.chunk(temb, 2, dim=1)
			hidden_states = hidden_states * (1 + scale) + shift

		hidden_states = self.nonlinearity(hidden_states)

		hidden_states = self.dropout(hidden_states)
		hidden_states = self.conv2(hidden_states)

		if self.conv_shortcut is not None:
			input_tensor = self.conv_shortcut(input_tensor)

		output_tensor = (input_tensor + hidden_states) / self.output_scale_factor

		return output_tensor