|
''' |
|
This script does conditional image generation on MNIST, using a diffusion model |
|
|
|
This code is modified from, |
|
https://github.com/cloneofsimo/minDiffusion |
|
|
|
Diffusion model is based on DDPM, |
|
https://arxiv.org/abs/2006.11239 |
|
|
|
The conditioning idea is taken from 'Classifier-Free Diffusion Guidance', |
|
https://arxiv.org/abs/2207.12598 |
|
|
|
This technique also features in ImageGen 'Photorealistic Text-to-Image Diffusion Modelswith Deep Language Understanding', |
|
https://arxiv.org/abs/2205.11487 |
|
|
|
''' |
|
|
|
from typing import Dict, Tuple |
|
from tqdm import tqdm |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.utils.data import DataLoader |
|
from torchvision import models, transforms |
|
from torchvision.datasets import MNIST |
|
from torchvision.utils import save_image, make_grid |
|
import matplotlib.pyplot as plt |
|
from matplotlib.animation import FuncAnimation, PillowWriter |
|
import numpy as np |
|
import os |
|
import clip |
|
|
|
class LSTM(nn.Module): |
|
def __init__(self, input_size, hidden_size, output_size, embed_size=512, n_layer=1, bidirectional=False): |
|
super(LSTM, self).__init__() |
|
self.n_layer = n_layer |
|
self.bidirectional = bidirectional |
|
self.hidden_size = hidden_size |
|
|
|
self.num_directions = 2 if bidirectional else 1 |
|
self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers=n_layer, batch_first=True, bidirectional=bidirectional) |
|
|
|
self.encoder = nn.Sequential(nn.Linear(embed_size, hidden_size)) |
|
|
|
self.decoder = nn.Sequential(nn.Linear(hidden_size, output_size)) |
|
|
|
self.embed = nn.Sequential(nn.Linear(embed_size, embed_size)) |
|
|
|
|
|
def initHidden(self, batch_size=1): |
|
h0 = torch.zeros(self.n_layer, batch_size, self.hidden_size, requires_grad=False).cuda() |
|
c0 = torch.zeros(self.n_layer, batch_size, self.hidden_size, requires_grad=False).cuda() |
|
return (h0, c0) |
|
|
|
def forward(self, input, embed): |
|
bs, length, n_feat = input.shape |
|
|
|
embed = self.embed(embed).unsqueeze(1).repeat(1, length, 1) |
|
|
|
hidden = self.initHidden(bs) |
|
output, hidden = self.lstm(embed, hidden) |
|
|
|
return self.decoder(output) |
|
|
|
import torch.utils.data as data |
|
class camdataset(data.Dataset): |
|
def __init__(self, data, label): |
|
self.data = data |
|
self.label = label |
|
|
|
def __getitem__(self, index): |
|
text = np.random.choice(self.label[index], np.random.randint(1, len(self.label[index])+1), replace=False) |
|
|
|
d = self.data[index] |
|
d = np.concatenate((d, d[-1:].repeat(300-len(d), 0)), 0) |
|
|
|
return np.array(d, dtype="float32"), " ".join(text) |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
|
|
def train(): |
|
data = np.load("data.npy", allow_pickle=True)[()] |
|
|
|
d = np.concatenate(data["cam"], 0) |
|
Mean, Std = np.mean(d, 0), np.std(d, 0) |
|
|
|
for i in range(len(data["cam"])): |
|
data["cam"][i] = (data["cam"][i] - Mean[None, :]) / (Std[None, :] + 1e-8) |
|
|
|
|
|
n_epoch = 1000 |
|
batch_size = 128 |
|
device = "cuda:0" |
|
n_feature = 5 |
|
lrate = 1e-4 |
|
save_model = True |
|
save_dir = './result/' |
|
if not os.path.exists(save_dir): |
|
os.mkdir(save_dir) |
|
|
|
criterion = torch.nn.MSELoss() |
|
trans = LSTM(input_size=n_feature, hidden_size=512, output_size=n_feature) |
|
trans.to(device) |
|
|
|
optim = torch.optim.Adam(trans.parameters(), lr=lrate) |
|
|
|
dataloader = DataLoader(camdataset(data['cam'], data['info']), batch_size=batch_size, shuffle=True, num_workers=5) |
|
|
|
if not os.path.exists("result"): |
|
os.mkdir("result") |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
model, preprocess = clip.load("ViT-B/32", device=device) |
|
|
|
for ep in range(n_epoch): |
|
print(f'epoch {ep}') |
|
trans.train() |
|
|
|
|
|
optim.param_groups[0]['lr'] = lrate * (1 - ep / n_epoch) |
|
|
|
pbar = tqdm(dataloader) |
|
loss_ema = None |
|
for x, c in pbar: |
|
optim.zero_grad() |
|
x = x.to(device) |
|
with torch.no_grad(): |
|
c = clip.tokenize(c, truncate=True).to(device) |
|
c = model.encode_text(c).float().detach() |
|
|
|
loss = criterion(trans(x, c), x) |
|
loss.backward() |
|
if loss_ema is None: |
|
loss_ema = loss.item() |
|
else: |
|
loss_ema = 0.95 * loss_ema + 0.05 * loss.item() |
|
pbar.set_description(f"loss: {loss_ema:.4f}") |
|
optim.step() |
|
|
|
torch.save(trans.state_dict(), save_dir + f"latest.pth") |
|
if save_model and ep % 100 == 0: |
|
|
|
torch.save(trans.state_dict(), save_dir + f"model_{ep}.pth") |
|
print('saved model at ' + save_dir + f"model_{ep}.pth") |
|
|
|
def eval(): |
|
if not os.path.exists("Mean_Std.npy"): |
|
data = np.load("data.npy", allow_pickle=True)[()] |
|
|
|
d = np.concatenate(data["cam"], 0) |
|
Mean, Std = np.mean(d, 0), np.std(d, 0) |
|
np.save("Mean_Std", {"Mean": Mean, "Std": Std}) |
|
d = np.load("Mean_Std.npy", allow_pickle=True)[()] |
|
Mean, Std = d["Mean"], d["Std"] |
|
|
|
device = "cuda:0" |
|
n_feature = 5 |
|
|
|
trans = LSTM(input_size=n_feature, hidden_size=512, output_size=n_feature) |
|
trans.to(device) |
|
|
|
|
|
trans.load_state_dict(torch.load("./result/latest.pth")) |
|
|
|
if not os.path.exists("viz"): |
|
os.mkdir("viz") |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
model, preprocess = clip.load("ViT-B/32", device=device) |
|
|
|
d = np.load("test_prompt.npy", allow_pickle=True)[()] |
|
|
|
result = [] |
|
for i in tqdm(range(0, len(d['info']), 100)): |
|
txt = d['info'][i:i + 100] |
|
text = [" ".join(v) for v in txt] |
|
|
|
with torch.no_grad(): |
|
c = clip.tokenize(text, truncate=True).to(device) |
|
c = model.encode_text(c).float().detach() |
|
|
|
sample = trans(torch.zeros(len(c), 300, n_feature), c) |
|
sample = sample.detach().cpu().numpy() |
|
|
|
for j in range(len(text)): |
|
s = sample[j] * Std[None, :] + Mean[None, :] |
|
result.append(s) |
|
|
|
np.save("LSTM_test", {"result": result, "label": d["label"]}) |
|
|
|
if __name__ == "__main__": |
|
import sys |
|
mode = sys.argv[1] |
|
|
|
if mode == 'train': |
|
train() |
|
elif mode == 'eval': |
|
eval() |
|
else: |
|
print('Error, instruction {} is not in {train, eval}') |
|
|