|
|
|
|
|
|
|
|
|
import numpy as np |
|
import os, math, gc |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torchvision as vision |
|
import pytorch_lightning as pl |
|
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only |
|
from pytorch_lightning.strategies import DeepSpeedStrategy |
|
import deepspeed |
|
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam |
|
|
|
|
|
def __nop(ob): |
|
return ob |
|
MyModule = torch.jit.ScriptModule |
|
|
|
MyFunction = torch.jit.script_method |
|
|
|
import clip |
|
from transformers import CLIPModel |
|
|
|
class L2pooling(nn.Module): |
|
def __init__(self, filter_size=5, stride=2, channels=None, pad_off=0): |
|
super(L2pooling, self).__init__() |
|
self.padding = (filter_size - 2) // 2 |
|
self.stride = stride |
|
self.channels = channels |
|
a = np.hanning(filter_size)[1:-1] |
|
g = torch.Tensor(a[:, None] * a[None, :]) |
|
g = g / torch.sum(g) |
|
self.register_buffer( |
|
"filter", g[None, None, :, :].repeat((self.channels, 1, 1, 1)) |
|
) |
|
|
|
def forward(self, input): |
|
input = input**2 |
|
out = F.conv2d( |
|
input, |
|
self.filter, |
|
stride=self.stride, |
|
padding=self.padding, |
|
groups=input.shape[1], |
|
) |
|
return (out + 1e-12).sqrt() |
|
|
|
|
|
class DISTS(torch.nn.Module): |
|
def __init__(self, load_weights=True): |
|
super(DISTS, self).__init__() |
|
vgg_pretrained_features = vision.models.vgg16( |
|
weights="VGG16_Weights.IMAGENET1K_V1" |
|
).features |
|
self.stage1 = torch.nn.Sequential() |
|
self.stage2 = torch.nn.Sequential() |
|
self.stage3 = torch.nn.Sequential() |
|
self.stage4 = torch.nn.Sequential() |
|
self.stage5 = torch.nn.Sequential() |
|
for x in range(0, 4): |
|
self.stage1.add_module(str(x), vgg_pretrained_features[x]) |
|
self.stage2.add_module(str(4), L2pooling(channels=64)) |
|
for x in range(5, 9): |
|
self.stage2.add_module(str(x), vgg_pretrained_features[x]) |
|
self.stage3.add_module(str(9), L2pooling(channels=128)) |
|
for x in range(10, 16): |
|
self.stage3.add_module(str(x), vgg_pretrained_features[x]) |
|
self.stage4.add_module(str(16), L2pooling(channels=256)) |
|
for x in range(17, 23): |
|
self.stage4.add_module(str(x), vgg_pretrained_features[x]) |
|
self.stage5.add_module(str(23), L2pooling(channels=512)) |
|
for x in range(24, 30): |
|
self.stage5.add_module(str(x), vgg_pretrained_features[x]) |
|
|
|
self.register_buffer( |
|
"mean", torch.tensor([0.485, 0.456, 0.406]).view(1, -1, 1, 1) |
|
) |
|
self.register_buffer( |
|
"std", torch.tensor([0.229, 0.224, 0.225]).view(1, -1, 1, 1) |
|
) |
|
|
|
self.chns = [3, 64, 128, 256, 512, 512] |
|
self.register_buffer( |
|
"alpha", nn.Parameter(torch.randn(1, sum(self.chns), 1, 1)) |
|
) |
|
self.register_buffer("beta", nn.Parameter(torch.randn(1, sum(self.chns), 1, 1))) |
|
self.alpha.data.normal_(0.1, 0.01) |
|
self.beta.data.normal_(0.1, 0.01) |
|
weights = torch.load("test/DISTS_weights.pt") |
|
self.alpha.data = weights["alpha"] |
|
self.beta.data = weights["beta"] |
|
|
|
for param in self.parameters(): |
|
param.requires_grad = False |
|
|
|
def forward_once(self, x): |
|
h = (x - self.mean) / self.std |
|
h = self.stage1(h) |
|
h_relu1_2 = h |
|
h = self.stage2(h) |
|
h_relu2_2 = h |
|
h = self.stage3(h) |
|
h_relu3_3 = h |
|
h = self.stage4(h) |
|
h_relu4_3 = h |
|
h = self.stage5(h) |
|
h_relu5_3 = h |
|
return [x, h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3] |
|
|
|
def forward(self, x, y, require_grad=False, batch_average=False): |
|
if require_grad: |
|
feats0 = self.forward_once(x) |
|
feats1 = self.forward_once(y) |
|
else: |
|
with torch.no_grad(): |
|
feats0 = self.forward_once(x) |
|
feats1 = self.forward_once(y) |
|
dist1 = 0 |
|
dist2 = 0 |
|
c1 = 1e-6 |
|
c2 = 1e-6 |
|
w_sum = self.alpha.sum() + self.beta.sum() |
|
alpha = torch.split(self.alpha / w_sum, self.chns, dim=1) |
|
beta = torch.split(self.beta / w_sum, self.chns, dim=1) |
|
|
|
for k in range(len(self.chns)): |
|
x_mean = feats0[k].mean([2, 3], keepdim=True) |
|
y_mean = feats1[k].mean([2, 3], keepdim=True) |
|
S1 = (2 * x_mean * y_mean + c1) / (x_mean**2 + y_mean**2 + c1) |
|
dist1 = dist1 + (alpha[k] * S1).sum(1, keepdim=True) |
|
|
|
x_var = ((feats0[k] - x_mean) ** 2).mean([2, 3], keepdim=True) |
|
y_var = ((feats1[k] - y_mean) ** 2).mean([2, 3], keepdim=True) |
|
xy_cov = (feats0[k] * feats1[k]).mean( |
|
[2, 3], keepdim=True |
|
) - x_mean * y_mean |
|
S2 = (2 * xy_cov + c2) / (x_var + y_var + c2) |
|
dist2 = dist2 + (beta[k] * S2).sum(1, keepdim=True) |
|
|
|
score = 1 - (dist1 + dist2).squeeze() |
|
|
|
if batch_average: |
|
return score.mean() |
|
else: |
|
return score |
|
|
|
class ToBinary(torch.autograd.Function): |
|
@staticmethod |
|
def forward(ctx, x): |
|
|
|
|
|
|
|
|
|
|
|
return torch.floor(x + 0.5) |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
return grad_output.clone() |
|
|
|
|
|
|
|
class R_ENCODER(MyModule): |
|
def __init__(self, args): |
|
super().__init__() |
|
self.args = args |
|
dd = 8 |
|
self.Bxx = nn.BatchNorm2d(dd*64) |
|
|
|
self.CIN = nn.Conv2d(3, dd, kernel_size=3, padding=1) |
|
self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1) |
|
self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1) |
|
|
|
self.B00 = nn.BatchNorm2d(dd*4) |
|
self.C00 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1) |
|
self.C01 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1) |
|
self.C02 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1) |
|
self.C03 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1) |
|
|
|
self.B10 = nn.BatchNorm2d(dd*16) |
|
self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1) |
|
self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1) |
|
self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1) |
|
self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1) |
|
|
|
self.B20 = nn.BatchNorm2d(dd*64) |
|
self.C20 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1) |
|
self.C21 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1) |
|
self.C22 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1) |
|
self.C23 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
self.COUT = nn.Conv2d(dd*64, args.my_img_bit, kernel_size=3, padding=1) |
|
|
|
@MyFunction |
|
def forward(self, img): |
|
ACT = F.mish |
|
|
|
x = self.CIN(img) |
|
xx = self.Bxx(F.pixel_unshuffle(x, 8)) |
|
x = x + self.Cx1(ACT(self.Cx0(x))) |
|
|
|
x = F.pixel_unshuffle(x, 2) |
|
x = x + self.C01(ACT(self.C00(ACT(self.B00(x))))) |
|
x = x + self.C03(ACT(self.C02(x))) |
|
|
|
x = F.pixel_unshuffle(x, 2) |
|
x = x + self.C11(ACT(self.C10(ACT(self.B10(x))))) |
|
x = x + self.C13(ACT(self.C12(x))) |
|
|
|
x = F.pixel_unshuffle(x, 2) |
|
x = x + self.C21(ACT(self.C20(ACT(self.B20(x))))) |
|
x = x + self.C23(ACT(self.C22(x))) |
|
|
|
|
|
|
|
x = self.COUT(x + xx) |
|
return torch.sigmoid(x) |
|
|
|
|
|
|
|
class R_DECODER(MyModule): |
|
def __init__(self, args): |
|
super().__init__() |
|
self.args = args |
|
dd = 8 |
|
self.CIN = nn.Conv2d(args.my_img_bit, dd*64, kernel_size=3, padding=1) |
|
|
|
self.B00 = nn.BatchNorm2d(dd*64) |
|
self.C00 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1) |
|
self.C01 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1) |
|
self.C02 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1) |
|
self.C03 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
self.B10 = nn.BatchNorm2d(dd*16) |
|
self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1) |
|
self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1) |
|
self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1) |
|
self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1) |
|
|
|
self.B20 = nn.BatchNorm2d(dd*4) |
|
self.C20 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1) |
|
self.C21 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1) |
|
self.C22 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1) |
|
self.C23 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1) |
|
|
|
self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1) |
|
self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1) |
|
self.COUT = nn.Conv2d(dd, 3, kernel_size=3, padding=1) |
|
|
|
@MyFunction |
|
def forward(self, code): |
|
ACT = F.mish |
|
x = self.CIN(code) |
|
|
|
x = x + self.C01(ACT(self.C00(ACT(self.B00(x))))) |
|
x = x + self.C03(ACT(self.C02(x))) |
|
|
|
|
|
x = F.pixel_shuffle(x, 2) |
|
|
|
x = x + self.C11(ACT(self.C10(ACT(self.B10(x))))) |
|
x = x + self.C13(ACT(self.C12(x))) |
|
x = F.pixel_shuffle(x, 2) |
|
|
|
x = x + self.C21(ACT(self.C20(ACT(self.B20(x))))) |
|
x = x + self.C23(ACT(self.C22(x))) |
|
x = F.pixel_shuffle(x, 2) |
|
|
|
x = x + self.Cx1(ACT(self.Cx0(x))) |
|
x = self.COUT(x) |
|
|
|
return torch.sigmoid(x) |
|
|
|
|
|
|
|
def cosine_loss(x, y): |
|
x = F.normalize(x, dim=-1) |
|
y = F.normalize(y, dim=-1) |
|
return 1 - torch.einsum('ij,ij->i',[x,y]) |
|
|
|
class RWKV_IMG(pl.LightningModule): |
|
def __init__(self, args): |
|
super().__init__() |
|
self.args = args |
|
|
|
self.encoder = R_ENCODER(args) |
|
self.decoder = R_DECODER(args) |
|
|
|
self.clip_model = None |
|
clip_name = args.my_img_clip |
|
if clip_name == 'B32': |
|
clip_name = 'ViT-B/32' |
|
elif clip_name == 'B16': |
|
clip_name = 'ViT-B/16' |
|
elif clip_name == 'L14': |
|
clip_name = 'ViT-L/14' |
|
elif clip_name == 'OB32': |
|
clip_name = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K" |
|
self.clip_model = CLIPModel.from_pretrained(clip_name) |
|
self.clip_model.encode_image = self.clip_model.get_image_features |
|
if self.clip_model == None: |
|
self.clip_model, _ = clip.load(clip_name, jit = True) |
|
self.register_buffer( |
|
"clip_mean", torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(1, 3, 1, 1) |
|
) |
|
self.register_buffer( |
|
"clip_std", torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(1, 3, 1, 1) |
|
) |
|
|
|
for n, p in self.named_parameters(): |
|
if 'clip_model' in n: |
|
p.requires_grad = False |
|
|
|
self.loss_dists = DISTS() |
|
|
|
|
|
def configure_optimizers(self): |
|
args = self.args |
|
optim_groups = [ |
|
{"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0}, |
|
] |
|
if self.deepspeed_offload: |
|
return DeepSpeedCPUAdam( |
|
optim_groups, |
|
lr=self.args.lr_init, |
|
betas=self.args.betas, |
|
eps=self.args.adam_eps, |
|
bias_correction=True, |
|
adamw_mode=False, |
|
weight_decay=0, |
|
amsgrad=False, |
|
) |
|
return FusedAdam( |
|
optim_groups, |
|
lr=self.args.lr_init, |
|
betas=self.args.betas, |
|
eps=self.args.adam_eps, |
|
bias_correction=True, |
|
adam_w_mode=False, |
|
weight_decay=0, |
|
amsgrad=False, |
|
) |
|
|
|
|
|
@property |
|
def deepspeed_offload(self) -> bool: |
|
strategy = self.trainer.strategy |
|
if isinstance(strategy, DeepSpeedStrategy): |
|
config = strategy.config["zero_optimization"] |
|
return config.get("offload_optimizer") or config.get("offload_param") |
|
return False |
|
|
|
def forward(self, img): |
|
z = self.encoder(img) |
|
z = ToBinary.apply(z) |
|
out = self.decoder(z) |
|
return out |
|
|
|
def training_step(self, batch, batch_idx): |
|
args = self.args |
|
img, txt = batch |
|
out = self(img) |
|
if self.trainer.is_global_zero: |
|
if (self.trainer.global_step + 1) % (100 * int(args.devices)) == 0: |
|
img_dir = f"test/image_model/{args.run_name}" |
|
if not os.path.exists(img_dir): |
|
os.makedirs(img_dir) |
|
vision.utils.save_image( |
|
img[:4], f"{img_dir}/{self.trainer.global_step}-src.jpg" |
|
) |
|
vision.utils.save_image( |
|
out[:4], f"{img_dir}/{self.trainer.global_step}-out.jpg" |
|
) |
|
|
|
|
|
loss_dists = self.loss_dists(out, img, require_grad=True, batch_average=True) |
|
|
|
iii = self.clip_model.encode_image((img - self.clip_mean) / self.clip_std) |
|
ooo = self.clip_model.encode_image((out - self.clip_mean) / self.clip_std) |
|
loss_clip = torch.mean(cosine_loss(iii, ooo)) |
|
|
|
if args.my_img_l1_scale > 0: |
|
loss_l1 = F.l1_loss(out, img) |
|
return loss_dists + loss_clip * args.my_img_clip_scale + loss_l1 * args.my_img_l1_scale |
|
else: |
|
return loss_dists + loss_clip * args.my_img_clip_scale |
|
|
|
def training_step_end(self, batch_parts): |
|
all = self.all_gather(batch_parts) |
|
if self.trainer.is_global_zero: |
|
self.trainer.my_loss_all = all |
|
|
|
def generate_init_weight(self): |
|
print( |
|
f""" |
|
############################################################################ |
|
# |
|
# Init model weight (slow for large models)... |
|
# |
|
############################################################################ |
|
""" |
|
) |
|
m = {} |
|
for n in self.state_dict(): |
|
scale = 1 |
|
p = self.state_dict()[n] |
|
shape = p.shape |
|
ss = n.split('.') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
m[n] = p |
|
|
|
m[n] = m[n].cpu() |
|
if os.environ["RWKV_FLOAT_MODE"] == "fp16": |
|
m[n] = m[n].half() |
|
elif os.environ["RWKV_FLOAT_MODE"] == "bf16": |
|
m[n] = m[n].bfloat16() |
|
|
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
return m |
|
|