Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,411 Bytes
7385f22 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
from contextlib import nullcontext
from torch.nn.functional import scaled_dot_product_attention
from unitok.quant import VectorQuantizerM
from unitok.vitamin import ViTaminDecoder, GeGluMlp
class PlainAttention(nn.Module):
def __init__(self, in_dim, out_dim, num_heads):
super().__init__()
if in_dim > out_dim:
# assert in_dim // num_heads == out_dim
self.head_dim = in_dim // num_heads
self.qkv = nn.Linear(in_dim, in_dim * 3, bias=False)
self.q_bias = nn.Parameter(torch.zeros(in_dim))
self.v_bias = nn.Parameter(torch.zeros(in_dim))
self.register_buffer('zero_k_bias', torch.zeros(in_dim))
else:
# assert out_dim // num_heads == in_dim
self.head_dim = out_dim // num_heads
self.qkv = nn.Linear(in_dim, out_dim * 3, bias=False)
self.q_bias = nn.Parameter(torch.zeros(out_dim))
self.v_bias = nn.Parameter(torch.zeros(out_dim))
self.register_buffer('zero_k_bias', torch.zeros(out_dim))
self.in_dim = in_dim
self.out_dim = out_dim
self.num_heads = num_heads
self.scale = self.head_dim ** -0.5
self.proj = nn.Linear(out_dim, out_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape
qkv = F.linear(input=x, weight=self.qkv.weight, bias=torch.cat((self.q_bias, self.zero_k_bias, self.v_bias)))
q, k, v = qkv.reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4).unbind(0)
x = scaled_dot_product_attention(q, k, v)
if self.in_dim > self.out_dim:
x = torch.mean(x, dim=1)
if self.in_dim // self.num_heads != self.out_dim:
x = nn.functional.adaptive_avg_pool1d(x, self.out_dim)
else:
x = x.transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
return x
class AttnProjection(nn.Module):
def __init__(self, in_dim, out_dim, num_heads, norm_layer=nn.LayerNorm, mlp_ratio=2):
super().__init__()
assert out_dim % in_dim == 0 or in_dim % out_dim == 0
self.in_dim = in_dim
self.out_dim = out_dim
self.norm1 = norm_layer(in_dim)
self.attn = PlainAttention(in_dim, out_dim, num_heads)
self.proj = nn.Linear(in_dim, out_dim)
self.norm3 = norm_layer(in_dim)
self.norm2 = norm_layer(out_dim)
hidden_dim = int(out_dim * mlp_ratio)
self.mlp = GeGluMlp(
in_features=out_dim,
hidden_features=hidden_dim
)
def forward(self, x):
x = self.proj(self.norm3(x)) + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
class VQVAE(nn.Module):
def __init__(self, args):
super().__init__()
# 1. build encoder
self.encoder = timm.create_model(
args.model,
patch_size=1,
fc_norm=True,
drop_rate=0.0,
num_classes=0,
global_pool='',
pos_embed='none',
class_token=False,
mlp_layer=GeGluMlp,
img_size=args.img_size,
drop_path_rate=args.drop_path,
)
self.encoder.set_grad_checkpointing(args.grad_ckpt)
# 2. build conv before quant
if args.quant_proj == 'linear':
self.quant_proj = nn.Linear(self.encoder.embed_dim, args.vocab_width)
elif args.quant_proj == 'attn':
self.quant_proj = AttnProjection(self.encoder.embed_dim, args.vocab_width, args.num_codebooks)
else:
raise NotImplementedError
# 3. build quant
self.quantize = VectorQuantizerM(
vocab_size=args.vocab_size,
vocab_width=args.vocab_width,
beta=args.vq_beta,
use_entropy_loss=args.le > 0,
entropy_temp=args.e_temp,
num_codebooks=args.num_codebooks,
)
# 4. build conv after quant
if args.quant_proj == 'linear':
self.post_quant_proj = nn.Linear(args.vocab_width, self.encoder.embed_dim)
elif args.quant_proj == 'attn':
self.post_quant_proj = AttnProjection(args.vocab_width, self.encoder.embed_dim, args.num_codebooks)
else:
raise NotImplementedError
# 5. build decoder
self.decoder = ViTaminDecoder(
args.model,
depths=(4, 2),
img_size=args.img_size,
drop_path=args.drop_path,
grad_ckpt=args.grad_ckpt
)
self.maybe_record_function = nullcontext
def forward(self, img):
features = self.encoder(img).float()
with torch.cuda.amp.autocast(enabled=False):
features = self.quant_proj(features)
quant_out = self.quantize(features)
features, vq_loss, entropy_loss, usages = quant_out
features = self.post_quant_proj(features)
rec_img = self.decoder(features).float()
return rec_img, vq_loss, entropy_loss, usages
def img_to_idx(self, img):
features = self.encoder(img).float()
features = self.quant_proj(features)
return self.quantize.f_to_idx(features)
def idx_to_img(self, indices):
features = self.quantize.idx_to_f(indices)
features = self.post_quant_proj(features)
img = self.decoder(features).clamp_(-1, 1)
return img
def img_to_reconstructed_img(self, img) -> torch.Tensor:
features = self.encoder(img).float()
with torch.cuda.amp.autocast(enabled=False):
features = self.quant_proj(features)
quant_out = self.quantize(features)
features, _, _, _ = quant_out
features = self.post_quant_proj(features)
rec_img = self.decoder(features).float().clamp_(-1, 1)
return rec_img
if __name__ == '__main__':
for clz in (nn.Linear, nn.LayerNorm, nn.BatchNorm2d, nn.SyncBatchNorm, nn.Conv1d, nn.Conv2d, nn.ConvTranspose1d,
nn.ConvTranspose2d):
setattr(clz, 'reset_parameters', lambda self: None)
cnn = VQVAE(channel_num=64, vocab_norm=False)
from models import init_weights
init_weights(cnn, -0.5)
torch.save(cnn.state_dict(), r'C:\Users\16333\Desktop\PyCharm\vlip\local_output\cnn.pth')
|