Faran Fahandezh
Add application file4
3c5efcb
import math
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from .nn import timestep_embedding
def dec2bin(xinp, bits):
mask = 2 ** th.arange(bits - 1, -1, -1).to(xinp.device, xinp.dtype)
return xinp.unsqueeze(-1).bitwise_and(mask).ne(0).float()
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = th.arange(max_len).unsqueeze(1)
div_term = th.exp(th.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = th.zeros(1, max_len, d_model)
pe[0, :, 0::2] = th.sin(position * div_term)
pe[0, :, 1::2] = th.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x):
"""
Args:
x: Tensor, shape [batch_size, seq_len, embedding_dim]
"""
x = x + self.pe[0:1, :x.size(1)]
return self.dropout(x)
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff, dropout, activation):
super().__init__()
# We set d_ff as a default to 2048
self.linear_1 = nn.Linear(d_model, d_ff)
self.dropout = nn.Dropout(dropout)
self.linear_2 = nn.Linear(d_ff, d_model)
self.activation = activation
def forward(self, x):
x = self.dropout(self.activation(self.linear_1(x)))
x = self.linear_2(x)
return x
def attention(q, k, v, d_k, mask=None, dropout=None):
scores = th.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
mask = mask.unsqueeze(1)
scores = scores.masked_fill(mask == 1, -1e9)
scores = F.softmax(scores, dim=-1)
if dropout is not None:
scores = dropout(scores)
output = th.matmul(scores, v)
return output
class MultiHeadAttention(nn.Module):
def __init__(self, heads, d_model, dropout = 0.1):
super().__init__()
self.d_model = d_model
self.d_k = d_model // heads
self.h = heads
self.q_linear = nn.Linear(d_model, d_model)
self.v_linear = nn.Linear(d_model, d_model)
self.k_linear = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
self.out = nn.Linear(d_model, d_model)
def forward(self, q, k, v, mask=None):
bs = q.size(0)
# perform linear operation and split into h heads
k = self.k_linear(k).view(bs, -1, self.h, self.d_k)
q = self.q_linear(q).view(bs, -1, self.h, self.d_k)
v = self.v_linear(v).view(bs, -1, self.h, self.d_k)
# transpose to get dimensions bs * h * sl * d_model
k = k.transpose(1,2)
q = q.transpose(1,2)
v = v.transpose(1,2)# calculate attention using function we will define next
#TODO
# mask = mask.to('cuda:0')
scores = attention(q, k, v, self.d_k, mask, self.dropout)
# concatenate heads and put through final linear layer
concat = scores.transpose(1,2).contiguous().view(bs, -1, self.d_model)
output = self.out(concat)
return output
class EncoderLayer(nn.Module):
def __init__(self, d_model, heads, dropout, activation):
super().__init__()
self.norm_1 = nn.InstanceNorm1d(d_model)
self.norm_2 = nn.InstanceNorm1d(d_model)
self.self_attn = MultiHeadAttention(heads, d_model)
self.door_attn = MultiHeadAttention(heads, d_model)
self.gen_attn = MultiHeadAttention(heads, d_model)
self.ff = FeedForward(d_model, d_model*2, dropout, activation)
self.dropout = nn.Dropout(dropout)
def forward(self, x, door_mask, self_mask, gen_mask):
assert (gen_mask.max()==1 and gen_mask.min()==0), f"{gen_mask.max()}, {gen_mask.min()}"
x2 = self.norm_1(x)
x = x + self.dropout(self.door_attn(x2,x2,x2,door_mask)) \
+ self.dropout(self.self_attn(x2, x2, x2, self_mask)) \
+ self.dropout(self.gen_attn(x2, x2, x2, gen_mask))
x2 = self.norm_2(x)
x = x + self.dropout(self.ff(x2))
return x
class TransformerModel(nn.Module):
"""
The full Transformer model with timestep embedding.
"""
def __init__(
self,
in_channels,
condition_channels,
model_channels,
out_channels,
dataset,
use_checkpoint,
use_unet,
analog_bit,
):
super().__init__()
self.in_channels = in_channels
self.condition_channels = condition_channels
self.model_channels = model_channels
self.out_channels = out_channels
self.time_channels = model_channels
self.use_checkpoint = use_checkpoint
self.analog_bit = analog_bit
self.use_unet = use_unet
self.num_layers = 4
# self.pos_encoder = PositionalEncoding(model_channels, 0.001)
# self.activation = nn.SiLU()
self.activation = nn.ReLU()
self.time_embed = nn.Sequential(
nn.Linear(self.model_channels, self.model_channels),
nn.SiLU(),
nn.Linear(self.model_channels, self.time_channels),
)
self.input_emb = nn.Linear(self.in_channels, self.model_channels)
self.condition_emb = nn.Linear(self.condition_channels, self.model_channels)
if use_unet:
self.unet = UNet(self.model_channels, 1)
self.transformer_layers = nn.ModuleList([EncoderLayer(self.model_channels, 4, 0.1, self.activation) for x in range(self.num_layers)])
# self.transformer_layers = nn.ModuleList([nn.TransformerEncoderLayer(self.model_channels, 4, self.model_channels*2, 0.1, self.activation, batch_first=True) for x in range(self.num_layers)])
self.output_linear1 = nn.Linear(self.model_channels, self.model_channels)
self.output_linear2 = nn.Linear(self.model_channels, self.model_channels//2)
self.output_linear3 = nn.Linear(self.model_channels//2, self.out_channels)
if not self.analog_bit:
self.output_linear_bin1 = nn.Linear(162+self.model_channels, self.model_channels)
self.output_linear_bin2 = EncoderLayer(self.model_channels, 1, 0.1, self.activation)
self.output_linear_bin3 = EncoderLayer(self.model_channels, 1, 0.1, self.activation)
self.output_linear_bin4 = nn.Linear(self.model_channels, 16)
print(f"Number of model parameters: {sum(p.numel() for p in self.parameters() if p.requires_grad)}")
def expand_points(self, points, connections):
def average_points(point1, point2):
points_new = (point1+point2)/2
return points_new
p1 = points
p1 = p1.view([p1.shape[0], p1.shape[1], 2, -1])
p5 = points[th.arange(points.shape[0])[:, None], connections[:,:,1].long()]
p5 = p5.view([p5.shape[0], p5.shape[1], 2, -1])
p3 = average_points(p1, p5)
p2 = average_points(p1, p3)
p4 = average_points(p3, p5)
p1_5 = average_points(p1, p2)
p2_5 = average_points(p2, p3)
p3_5 = average_points(p3, p4)
p4_5 = average_points(p4, p5)
points_new = th.cat((p1.view_as(points), p1_5.view_as(points), p2.view_as(points),
p2_5.view_as(points), p3.view_as(points), p3_5.view_as(points), p4.view_as(points), p4_5.view_as(points), p5.view_as(points)), 2)
return points_new.detach()
def create_image(self, points, connections, room_indices, img_size=256, res=200):
img = th.zeros((points.shape[0], 1, img_size, img_size), device=points.device)
points = (points+1)*(img_size//2)
points[points>=img_size] = img_size-1
points[points<0] = 0
p1 = points
p2 = points[th.arange(points.shape[0])[:, None], connections[:,:,1].long()]
slope = (p2[:,:,1]-p1[:,:,1])/((p2[:,:,0]-p1[:,:,0]))
slope[slope.isnan()] = 0
slope[slope.isinf()] = 1
m = th.linspace(0, 1, res, device=points.device)
new_shape = [p2.shape[0], res, p2.shape[1], p2.shape[2]]
new_p2 = p2.unsqueeze(1).expand(new_shape)
new_p1 = p1.unsqueeze(1).expand(new_shape)
new_room_indices = room_indices.unsqueeze(1).expand([p2.shape[0], res, p2.shape[1], 1])
inc = new_p2 - new_p1
xs = m.view(1,-1,1) * inc[:,:,:,0]
xs = xs + new_p1[:,:,:,0]
xs = xs.long()
x_inc = th.where(inc[:,:,:,0]==0, inc[:,:,:,1], inc[:,:,:,0])
x_inc = m.view(1,-1,1) * x_inc
ys = x_inc * slope.unsqueeze(1) + new_p1[:,:,:,1]
ys = ys.long()
img[th.arange(xs.shape[0])[:, None], :, xs.view(img.shape[0], -1), ys.view(img.shape[0], -1)] = new_room_indices.reshape(img.shape[0], -1, 1).float()
return img.detach()
def forward(self, x, timesteps, xtalpha, epsalpha, is_syn=False, **kwargs):
"""
Apply the model to an input batch.
:param x: an [N x S x C] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x S x C] Tensor of outputs.
"""
# prefix = 'syn_' if is_syn else ''
prefix = 'syn_' if is_syn else ''
x = x.permute([0, 2, 1]).float() # -> convert [N x C x S] to [N x S x C]
if not self.analog_bit:
x = self.expand_points(x, kwargs[f'{prefix}connections'])
# Different input embeddings (Input, Time, Conditions)
#TODO---------------------------------------------------------------
# x = x.to('cuda:0')
# timesteps = timesteps.to(x.device)
# print(x.device)
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
time_emb = time_emb.unsqueeze(1)
input_emb = self.input_emb(x)
if self.condition_channels>0:
cond = None
for key in [f'{prefix}room_types', f'{prefix}corner_indices', f'{prefix}room_indices']:
if cond is None:
cond = kwargs[key]
else:
cond = th.cat((cond, kwargs[key]), 2)
#TODO
# cond = cond.to('cuda:0')
cond_emb = self.condition_emb(cond.float())
# PositionalEncoding and DM model
out = input_emb + cond_emb + time_emb.repeat((1, input_emb.shape[1], 1))
for layer in self.transformer_layers:
out = layer(out, kwargs[f'{prefix}door_mask'], kwargs[f'{prefix}self_mask'], kwargs[f'{prefix}gen_mask'])
out_dec = self.output_linear1(out)
out_dec = self.activation(out_dec)
out_dec = self.output_linear2(out_dec)
out_dec = self.output_linear3(out_dec)
if not self.analog_bit:
out_bin_start = x*xtalpha.repeat([1,1,9]) - out_dec.repeat([1,1,9]) * epsalpha.repeat([1,1,9])
out_bin = (out_bin_start/2 + 0.5) # -> [0,1]
out_bin = out_bin * 256 #-> [0, 256]
out_bin = dec2bin(out_bin.round().int(), 8)
out_bin_inp = out_bin.reshape([x.shape[0], x.shape[1], 16*9])
out_bin_inp[out_bin_inp==0] = -1
out_bin = th.cat((out_bin_start, out_bin_inp, cond_emb), 2)
out_bin = self.activation(self.output_linear_bin1(out_bin))
out_bin = self.output_linear_bin2(out_bin, kwargs[f'{prefix}door_mask'], kwargs[f'{prefix}self_mask'], kwargs[f'{prefix}gen_mask'])
out_bin = self.output_linear_bin3(out_bin, kwargs[f'{prefix}door_mask'], kwargs[f'{prefix}self_mask'], kwargs[f'{prefix}gen_mask'])
out_bin = self.output_linear_bin4(out_bin)
out_bin = out_bin.permute([0, 2, 1]) # -> convert back [N x S x C] to [N x C x S]
out_dec = out_dec.permute([0, 2, 1]) # -> convert back [N x S x C] to [N x C x S]
if not self.analog_bit:
return out_dec, out_bin
else:
return out_dec, None