et-viewer / CCD /src /LSTM.py
azizinaghsh's picture
add ccd
293829f
'''
This script does conditional image generation on MNIST, using a diffusion model
This code is modified from,
https://github.com/cloneofsimo/minDiffusion
Diffusion model is based on DDPM,
https://arxiv.org/abs/2006.11239
The conditioning idea is taken from 'Classifier-Free Diffusion Guidance',
https://arxiv.org/abs/2207.12598
This technique also features in ImageGen 'Photorealistic Text-to-Image Diffusion Modelswith Deep Language Understanding',
https://arxiv.org/abs/2205.11487
'''
from typing import Dict, Tuple
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import models, transforms
from torchvision.datasets import MNIST
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
import numpy as np
import os
import clip
class LSTM(nn.Module):
def __init__(self, input_size, hidden_size, output_size, embed_size=512, n_layer=1, bidirectional=False):
super(LSTM, self).__init__()
self.n_layer = n_layer
self.bidirectional = bidirectional
self.hidden_size = hidden_size
self.num_directions = 2 if bidirectional else 1
self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers=n_layer, batch_first=True, bidirectional=bidirectional)
self.encoder = nn.Sequential(nn.Linear(embed_size, hidden_size))
self.decoder = nn.Sequential(nn.Linear(hidden_size, output_size))
self.embed = nn.Sequential(nn.Linear(embed_size, embed_size))
def initHidden(self, batch_size=1):
h0 = torch.zeros(self.n_layer, batch_size, self.hidden_size, requires_grad=False).cuda()
c0 = torch.zeros(self.n_layer, batch_size, self.hidden_size, requires_grad=False).cuda()
return (h0, c0)
def forward(self, input, embed):
bs, length, n_feat = input.shape
embed = self.embed(embed).unsqueeze(1).repeat(1, length, 1)
hidden = self.initHidden(bs)
output, hidden = self.lstm(embed, hidden)
return self.decoder(output)
import torch.utils.data as data
class camdataset(data.Dataset):
def __init__(self, data, label):
self.data = data
self.label = label
def __getitem__(self, index):
text = np.random.choice(self.label[index], np.random.randint(1, len(self.label[index])+1), replace=False)
d = self.data[index]
d = np.concatenate((d, d[-1:].repeat(300-len(d), 0)), 0)
return np.array(d, dtype="float32"), " ".join(text)
def __len__(self):
return len(self.data)
def train():
data = np.load("data.npy", allow_pickle=True)[()]
d = np.concatenate(data["cam"], 0)
Mean, Std = np.mean(d, 0), np.std(d, 0)
for i in range(len(data["cam"])):
data["cam"][i] = (data["cam"][i] - Mean[None, :]) / (Std[None, :] + 1e-8)
# hardcoding these here
n_epoch = 1000
batch_size = 128
device = "cuda:0"
n_feature = 5
lrate = 1e-4
save_model = True
save_dir = './result/'
if not os.path.exists(save_dir):
os.mkdir(save_dir)
criterion = torch.nn.MSELoss()
trans = LSTM(input_size=n_feature, hidden_size=512, output_size=n_feature)
trans.to(device)
optim = torch.optim.Adam(trans.parameters(), lr=lrate)
dataloader = DataLoader(camdataset(data['cam'], data['info']), batch_size=batch_size, shuffle=True, num_workers=5)
if not os.path.exists("result"):
os.mkdir("result")
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
for ep in range(n_epoch):
print(f'epoch {ep}')
trans.train()
# linear lrate decay
optim.param_groups[0]['lr'] = lrate * (1 - ep / n_epoch)
pbar = tqdm(dataloader)
loss_ema = None
for x, c in pbar:
optim.zero_grad()
x = x.to(device)
with torch.no_grad():
c = clip.tokenize(c, truncate=True).to(device)
c = model.encode_text(c).float().detach()
loss = criterion(trans(x, c), x)
loss.backward()
if loss_ema is None:
loss_ema = loss.item()
else:
loss_ema = 0.95 * loss_ema + 0.05 * loss.item()
pbar.set_description(f"loss: {loss_ema:.4f}")
optim.step()
torch.save(trans.state_dict(), save_dir + f"latest.pth")
if save_model and ep % 100 == 0:
torch.save(trans.state_dict(), save_dir + f"model_{ep}.pth")
print('saved model at ' + save_dir + f"model_{ep}.pth")
def eval():
if not os.path.exists("Mean_Std.npy"):
data = np.load("data.npy", allow_pickle=True)[()]
d = np.concatenate(data["cam"], 0)
Mean, Std = np.mean(d, 0), np.std(d, 0)
np.save("Mean_Std", {"Mean": Mean, "Std": Std})
d = np.load("Mean_Std.npy", allow_pickle=True)[()]
Mean, Std = d["Mean"], d["Std"]
device = "cuda:0"
n_feature = 5
trans = LSTM(input_size=n_feature, hidden_size=512, output_size=n_feature)
trans.to(device)
# optionally load a model
trans.load_state_dict(torch.load("./result/latest.pth"))
if not os.path.exists("viz"):
os.mkdir("viz")
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
d = np.load("test_prompt.npy", allow_pickle=True)[()]
result = []
for i in tqdm(range(0, len(d['info']), 100)):
txt = d['info'][i:i + 100]
text = [" ".join(v) for v in txt]
with torch.no_grad():
c = clip.tokenize(text, truncate=True).to(device)
c = model.encode_text(c).float().detach()
sample = trans(torch.zeros(len(c), 300, n_feature), c)
sample = sample.detach().cpu().numpy()
for j in range(len(text)):
s = sample[j] * Std[None, :] + Mean[None, :]
result.append(s)
np.save("LSTM_test", {"result": result, "label": d["label"]})
if __name__ == "__main__":
import sys
mode = sys.argv[1]
if mode == 'train':
train()
elif mode == 'eval':
eval()
else:
print('Error, instruction {} is not in {train, eval}')