Spaces:
Runtime error
Runtime error
from inspect import isfunction | |
import math | |
import torch | |
import torch.nn.functional as F | |
from torch import nn, einsum | |
from einops import rearrange, repeat | |
from typing import Optional, Any | |
from ldm.modules.diffusionmodules.util import checkpoint | |
try: | |
import xformers | |
import xformers.ops | |
XFORMERS_IS_AVAILBLE = True | |
except: | |
XFORMERS_IS_AVAILBLE = False | |
# CrossAttn precision handling | |
import os | |
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32") | |
def exists(val): | |
return val is not None | |
def uniq(arr): | |
return{el: True for el in arr}.keys() | |
def default(val, d): | |
if exists(val): | |
return val | |
return d() if isfunction(d) else d | |
def max_neg_value(t): | |
return -torch.finfo(t.dtype).max | |
def init_(tensor): | |
dim = tensor.shape[-1] | |
std = 1 / math.sqrt(dim) | |
tensor.uniform_(-std, std) | |
return tensor | |
# feedforward | |
class GEGLU(nn.Module): | |
def __init__(self, dim_in, dim_out): | |
super().__init__() | |
self.proj = nn.Linear(dim_in, dim_out * 2) | |
def forward(self, x): | |
x, gate = self.proj(x).chunk(2, dim=-1) | |
return x * F.gelu(gate) | |
class FeedForward(nn.Module): | |
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): | |
super().__init__() | |
inner_dim = int(dim * mult) | |
dim_out = default(dim_out, dim) | |
project_in = nn.Sequential( | |
nn.Linear(dim, inner_dim), | |
nn.GELU() | |
) if not glu else GEGLU(dim, inner_dim) | |
self.net = nn.Sequential( | |
project_in, | |
nn.Dropout(dropout), | |
nn.Linear(inner_dim, dim_out) | |
) | |
def forward(self, x): | |
return self.net(x) | |
def zero_module(module): | |
""" | |
Zero out the parameters of a module and return it. | |
""" | |
for p in module.parameters(): | |
p.detach().zero_() | |
return module | |
def Normalize(in_channels): | |
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) | |
class SpatialSelfAttention(nn.Module): | |
def __init__(self, in_channels): | |
super().__init__() | |
self.in_channels = in_channels | |
self.norm = Normalize(in_channels) | |
self.q = torch.nn.Conv2d(in_channels, | |
in_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0) | |
self.k = torch.nn.Conv2d(in_channels, | |
in_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0) | |
self.v = torch.nn.Conv2d(in_channels, | |
in_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0) | |
self.proj_out = torch.nn.Conv2d(in_channels, | |
in_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0) | |
def forward(self, x): | |
h_ = x | |
h_ = self.norm(h_) | |
q = self.q(h_) | |
k = self.k(h_) | |
v = self.v(h_) | |
# compute attention | |
b,c,h,w = q.shape | |
q = rearrange(q, 'b c h w -> b (h w) c') | |
k = rearrange(k, 'b c h w -> b c (h w)') | |
w_ = torch.einsum('bij,bjk->bik', q, k) | |
w_ = w_ * (int(c)**(-0.5)) | |
w_ = torch.nn.functional.softmax(w_, dim=2) | |
# attend to values | |
v = rearrange(v, 'b c h w -> b c (h w)') | |
w_ = rearrange(w_, 'b i j -> b j i') | |
h_ = torch.einsum('bij,bjk->bik', v, w_) | |
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) | |
h_ = self.proj_out(h_) | |
return x+h_ | |
class CrossAttention(nn.Module): | |
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): | |
super().__init__() | |
inner_dim = dim_head * heads | |
context_dim = default(context_dim, query_dim) | |
self.scale = dim_head ** -0.5 | |
self.heads = heads | |
self.to_q = nn.Linear(query_dim, inner_dim, bias=False) | |
self.to_k = nn.Linear(context_dim, inner_dim, bias=False) | |
self.to_v = nn.Linear(context_dim, inner_dim, bias=False) | |
self.to_out = nn.Sequential( | |
nn.Linear(inner_dim, query_dim), | |
nn.Dropout(dropout) | |
) | |
self.attention_probs=None | |
def forward(self, x, context=None, mask=None): | |
h = self.heads | |
q = self.to_q(x) | |
context = default(context, x) | |
k = self.to_k(context) | |
v = self.to_v(context) | |
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) | |
# force cast to fp32 to avoid overflowing | |
if _ATTN_PRECISION =="fp32": | |
with torch.autocast(enabled=False, device_type = 'cuda'): | |
q, k = q.float(), k.float() | |
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale | |
else: | |
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale | |
del q, k | |
if exists(mask): | |
mask = rearrange(mask, 'b ... -> b (...)') | |
max_neg_value = -torch.finfo(sim.dtype).max | |
mask = repeat(mask, 'b j -> (b h) () j', h=h) | |
sim.masked_fill_(~mask, max_neg_value) | |
# attention, what we cannot get enough of | |
sim = sim.softmax(dim=-1) | |
self.attention_probs = sim | |
#print("similarity",sim.shape) | |
out = einsum('b i j, b j d -> b i d', sim, v) | |
out = rearrange(out, '(b h) n d -> b n (h d)', h=h) | |
return self.to_out(out) | |
class MemoryEfficientCrossAttention(nn.Module): | |
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 | |
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): | |
super().__init__() | |
print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " | |
f"{heads} heads.") | |
inner_dim = dim_head * heads | |
context_dim = default(context_dim, query_dim) | |
self.heads = heads | |
self.dim_head = dim_head | |
self.to_q = nn.Linear(query_dim, inner_dim, bias=False) | |
self.to_k = nn.Linear(context_dim, inner_dim, bias=False) | |
self.to_v = nn.Linear(context_dim, inner_dim, bias=False) | |
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) | |
self.attention_op: Optional[Any] = None | |
self.attention_probs=None | |
def forward(self, x, context=None, mask=None):#,timestep=None): | |
h = self.heads | |
q = self.to_q(x) | |
context = default(context, x) | |
k = self.to_k(context) | |
v = self.to_v(context) | |
b, _, _ = q.shape | |
q, k, v = map( | |
lambda t: t.unsqueeze(3) | |
.reshape(b, t.shape[1], self.heads, self.dim_head) | |
.permute(0, 2, 1, 3) | |
.reshape(b * self.heads, t.shape[1], self.dim_head) | |
.contiguous(), | |
(q, k, v), | |
) | |
# actually compute the attention, what we cannot get enough of | |
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) | |
if exists(mask): | |
raise NotImplementedError | |
out = ( | |
out.unsqueeze(0) | |
.reshape(b, self.heads, out.shape[1], self.dim_head) | |
.permute(0, 2, 1, 3) | |
.reshape(b, out.shape[1], self.heads * self.dim_head) | |
) | |
prob=rearrange(out, 'b n (h d) -> (b h) n d', h=h) | |
prob = einsum('b i d, b j d -> b i j', prob, v) | |
self.attention_probs = prob | |
# print("emb",emb) | |
# print(timestep) | |
# if prob.shape[1] ==6144 and prob.shape[2]==6144 and timestep!=None and timestep<100: #and emb==0: | |
# torch.save(q,"./q1.pt") | |
# torch.save(k,"./k1.pt") | |
# torch.save(prob,"./prob.pt") | |
# print(prob.shape) | |
return self.to_out(out) | |
class BasicTransformerBlock(nn.Module): | |
ATTENTION_MODES = { | |
"softmax": CrossAttention, # vanilla attention | |
"softmax-xformers": MemoryEfficientCrossAttention | |
} | |
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, | |
disable_self_attn=False): | |
super().__init__() | |
attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax" | |
assert attn_mode in self.ATTENTION_MODES | |
attn_cls = self.ATTENTION_MODES[attn_mode] | |
self.disable_self_attn = disable_self_attn | |
self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, | |
context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn | |
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) | |
self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim, | |
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none | |
self.norm1 = nn.LayerNorm(dim) | |
self.norm2 = nn.LayerNorm(dim) | |
self.norm3 = nn.LayerNorm(dim) | |
self.checkpoint = checkpoint | |
def forward(self, x, context=None):#, timestep=None): | |
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) | |
def _forward(self, x, context=None):#, timestep=None): | |
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x | |
x = self.attn2(self.norm2(x), context=context) + x | |
x = self.ff(self.norm3(x)) + x | |
return x | |
def _trunc_normal_(tensor, mean, std, a, b): | |
# Cut & paste from PyTorch official master until it's in a few official releases - RW | |
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf | |
def norm_cdf(x): | |
# Computes standard normal cumulative distribution function | |
return (1. + math.erf(x / math.sqrt(2.))) / 2. | |
if (mean < a - 2 * std) or (mean > b + 2 * std): | |
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " | |
"The distribution of values may be incorrect.", | |
stacklevel=2) | |
# Values are generated by using a truncated uniform distribution and | |
# then using the inverse CDF for the normal distribution. | |
# Get upper and lower cdf values | |
l = norm_cdf((a - mean) / std) | |
u = norm_cdf((b - mean) / std) | |
# Uniformly fill tensor with values from [l, u], then translate to | |
# [2l-1, 2u-1]. | |
tensor.uniform_(2 * l - 1, 2 * u - 1) | |
# Use inverse cdf transform for normal distribution to get truncated | |
# standard normal | |
tensor.erfinv_() | |
# Transform to proper mean, std | |
tensor.mul_(std * math.sqrt(2.)) | |
tensor.add_(mean) | |
# Clamp to ensure it's in the proper range | |
tensor.clamp_(min=a, max=b) | |
return tensor | |
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): | |
# type: (Tensor, float, float, float, float) -> Tensor | |
r"""Fills the input Tensor with values drawn from a truncated | |
normal distribution. The values are effectively drawn from the | |
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` | |
with values outside :math:`[a, b]` redrawn until they are within | |
the bounds. The method used for generating the random values works | |
best when :math:`a \leq \text{mean} \leq b`. | |
NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are | |
applied while sampling the normal with mean/std applied, therefore a, b args | |
should be adjusted to match the range of mean, std args. | |
Args: | |
tensor: an n-dimensional `torch.Tensor` | |
mean: the mean of the normal distribution | |
std: the standard deviation of the normal distribution | |
a: the minimum cutoff value | |
b: the maximum cutoff value | |
Examples: | |
>>> w = torch.empty(3, 5) | |
>>> nn.init.trunc_normal_(w) | |
""" | |
with torch.no_grad(): | |
return _trunc_normal_(tensor, mean, std, a, b) | |
class PostionalAttention(nn.Module): | |
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., | |
proj_drop=0., attn_head_dim=None, use_rpb=False, window_size=14): | |
super().__init__() | |
self.num_heads = num_heads | |
head_dim = dim // num_heads | |
if attn_head_dim is not None: | |
head_dim = attn_head_dim | |
all_head_dim = head_dim * self.num_heads | |
self.scale = qk_scale or head_dim ** -0.5 | |
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) | |
if qkv_bias: | |
self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) | |
self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) | |
else: | |
self.q_bias = None | |
self.v_bias = None | |
# relative positional bias option | |
self.use_rpb = use_rpb | |
if use_rpb: | |
self.window_size = window_size | |
self.rpb_table = nn.Parameter(torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads)) | |
trunc_normal_(self.rpb_table, std=.02) | |
coords_h = torch.arange(window_size) | |
coords_w = torch.arange(window_size) | |
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, h, w | |
coords_flatten = torch.flatten(coords, 1) # 2, h*w | |
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, h*w, h*w | |
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # h*w, h*w, 2 | |
relative_coords[:, :, 0] += window_size - 1 # shift to start from 0 | |
relative_coords[:, :, 1] += window_size - 1 | |
relative_coords[:, :, 0] *= 2 * window_size - 1 | |
relative_position_index = relative_coords.sum(-1) # h*w, h*w | |
self.register_buffer("relative_position_index", relative_position_index) | |
self.attn_drop = nn.Dropout(attn_drop) | |
self.proj = nn.Linear(all_head_dim, dim) | |
self.proj_drop = nn.Dropout(proj_drop) | |
def forward(self, x): | |
B, N, C = x.shape | |
qkv_bias = None | |
if self.q_bias is not None: | |
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) | |
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) | |
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) | |
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) | |
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) | |
q = q * self.scale | |
attn = (q @ k.transpose(-2, -1)) | |
if self.use_rpb: | |
relative_position_bias = self.rpb_table[self.relative_position_index.view(-1)].view( | |
self.window_size * self.window_size, self.window_size * self.window_size, -1) # h*w,h*w,nH | |
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, h*w, h*w | |
attn += relative_position_bias | |
attn = attn.softmax(dim=-1) | |
attn = self.attn_drop(attn) | |
x = (attn @ v).transpose(1, 2).reshape(B, N, -1) | |
x = self.proj(x) | |
x = self.proj_drop(x) | |
return x | |
class Mlp(nn.Module): | |
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): | |
super().__init__() | |
out_features = out_features or in_features | |
hidden_features = hidden_features or in_features | |
self.fc1 = nn.Linear(in_features, hidden_features) | |
self.act = act_layer() | |
self.fc2 = nn.Linear(hidden_features, out_features) | |
self.drop = nn.Dropout(drop) | |
def forward(self, x): | |
x = self.fc1(x) | |
x = self.act(x) | |
# x = self.drop(x) | |
# commit this for the orignal BERT implement | |
x = self.fc2(x) | |
x = self.drop(x) | |
return x | |
class Block(nn.Module): | |
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., | |
drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, | |
attn_head_dim=None, use_rpb=False, window_size=14): | |
super().__init__() | |
self.norm1 = norm_layer(dim) | |
self.attn = PostionalAttention( | |
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, | |
attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim, | |
use_rpb=use_rpb, window_size=window_size) | |
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here | |
self.drop_path = nn.Identity() #DropPath(drop_path) if drop_path > 0. else nn.Identity() | |
self.norm2 = norm_layer(dim) | |
mlp_hidden_dim = int(dim * mlp_ratio) | |
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) | |
if init_values > 0: | |
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) | |
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) | |
else: | |
self.gamma_1, self.gamma_2 = None, None | |
def forward(self, x): | |
if self.gamma_1 is None: | |
x = x + self.drop_path(self.attn(self.norm1(x))) | |
x = x + self.drop_path(self.mlp(self.norm2(x))) | |
else: | |
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x))) | |
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) | |
return x | |
class PatchEmbed(nn.Module): | |
""" Image to Patch Embedding | |
""" | |
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, mask_cent=False): | |
super().__init__() | |
# to_2tuple = _ntuple(2) | |
# img_size = to_2tuple(img_size) | |
# patch_size = to_2tuple(patch_size) | |
img_size = tuple((img_size, img_size)) | |
patch_size = tuple((patch_size,patch_size)) | |
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) | |
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) | |
self.img_size = img_size | |
self.patch_size = patch_size | |
self.num_patches = num_patches | |
self.mask_cent = mask_cent | |
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) | |
# # From PyTorch internals | |
# def _ntuple(n): | |
# def parse(x): | |
# if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): | |
# return tuple(x) | |
# return tuple(repeat(x, n)) | |
# return parse | |
def forward(self, x, **kwargs): | |
B, C, H, W = x.shape | |
# FIXME look at relaxing size constraints | |
assert H == self.img_size[0] and W == self.img_size[1], \ | |
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." | |
if self.mask_cent: | |
x[:, -1] = x[:, -1] - 0.5 | |
x = self.proj(x).flatten(2).transpose(1, 2) | |
return x | |
class CnnHead(nn.Module): | |
def __init__(self, embed_dim, num_classes, window_size): | |
super().__init__() | |
self.embed_dim = embed_dim | |
self.num_classes = num_classes | |
self.window_size = window_size | |
self.head = nn.Conv2d(embed_dim, num_classes, kernel_size=3, stride=1, padding=1, padding_mode='reflect') | |
def forward(self, x): | |
x = rearrange(x, 'b (p1 p2) c -> b c p1 p2', p1=self.window_size, p2=self.window_size) | |
x = self.head(x) | |
x = rearrange(x, 'b c p1 p2 -> b (p1 p2) c') | |
return x | |
# sin-cos position encoding | |
# https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Models.py#L31 | |
import numpy as np | |
def get_sinusoid_encoding_table(n_position, d_hid): | |
''' Sinusoid position encoding table ''' | |
# TODO: make it with torch instead of numpy | |
def get_position_angle_vec(position): | |
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] | |
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) | |
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i | |
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 | |
return torch.FloatTensor(sinusoid_table).unsqueeze(0) | |
class SpatialTransformer(nn.Module): | |
""" | |
Transformer block for image-like data. | |
First, project the input (aka embedding) | |
and reshape to b, t, d. | |
Then apply standard transformer action. | |
Finally, reshape to image | |
NEW: use_linear for more efficiency instead of the 1x1 convs | |
""" | |
def __init__(self, in_channels, n_heads, d_head, | |
depth=1, dropout=0., context_dim=None, | |
disable_self_attn=False, use_linear=False, | |
use_checkpoint=True): | |
super().__init__() | |
if exists(context_dim) and not isinstance(context_dim, list): | |
context_dim = [context_dim] | |
self.in_channels = in_channels | |
inner_dim = n_heads * d_head | |
self.norm = Normalize(in_channels) | |
if not use_linear: | |
self.proj_in = nn.Conv2d(in_channels, | |
inner_dim, | |
kernel_size=1, | |
stride=1, | |
padding=0) | |
else: | |
self.proj_in = nn.Linear(in_channels, inner_dim) | |
self.transformer_blocks = nn.ModuleList( | |
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], | |
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint) | |
for d in range(depth)] | |
) | |
if not use_linear: | |
self.proj_out = zero_module(nn.Conv2d(inner_dim, | |
in_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0)) | |
else: | |
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) | |
self.use_linear = use_linear | |
self.map_size = None | |
# self.cnnhead = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, padding_mode='reflect') | |
# embed_dim=192 | |
# img_size=64 | |
# patch_size=8 | |
# self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, | |
# in_chans=4, embed_dim=embed_dim, mask_cent=False) | |
# num_patches = self.patch_embed.num_patches # 2 | |
# self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim) | |
# self.cnnhead = CnnHead(embed_dim, num_classes=32, window_size=img_size // patch_size) | |
# self.posatnn_block = Block(dim=embed_dim, num_heads=3, mlp_ratio=4., qkv_bias=True, qk_scale=None, | |
# drop=0., attn_drop=0., norm_layer=nn.LayerNorm, | |
# init_values=0., use_rpb=True, window_size=img_size // patch_size) | |
# # self.window_size=8 | |
# self.norm1=nn.LayerNorm(embed_dim) | |
def forward(self, x, context=None):#,timestep=None): | |
# note: if no context is given, cross-attention defaults to self-attention | |
if not isinstance(context, list): | |
context = [context] | |
b, c, h, w = x.shape | |
x_in = x | |
x = self.norm(x) | |
if not self.use_linear: | |
x = self.proj_in(x) | |
x = rearrange(x, 'b c h w -> b (h w) c').contiguous() | |
if self.use_linear: | |
x = self.proj_in(x) | |
for i, block in enumerate(self.transformer_blocks): | |
x = block(x, context=context[i])#,timestep=timestep) | |
if self.use_linear: | |
x = self.proj_out(x) | |
# x = rearrange(x, 'b (p1 p2) c -> b c p1 p2', p1=self.window_size, p2=self.window_size) | |
# x = self.cnnhead(x) | |
# x = rearrange(x, 'b c p1 p2 -> b (p1 p2) c') | |
# x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() | |
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() | |
# print("before",x.shape) | |
# if x.shape[1]==4: | |
# x = self.patch_embed(x) | |
# print("after PatchEmbed",x.shape) | |
# x = x + self.pos_embed.type_as(x).to(x.device).clone().detach() | |
# x =self.posatnn_block(x) | |
# x = self.norm1(x) | |
# print("after norm",x.shape) | |
# x = self.cnnhead(x) | |
# print("after",x.shape) | |
if not self.use_linear: | |
x = self.proj_out(x) | |
self.map_size = x.shape[-2:] | |
return x + x_in | |
# res = self.cnnhead(x+x_in) | |
# return res | |