et-viewer / CCD /src /classify.py
azizinaghsh's picture
add ccd
293829f
'''
This script does conditional image generation on MNIST, using a diffusion model
This code is modified from,
https://github.com/cloneofsimo/minDiffusion
Diffusion model is based on DDPM,
https://arxiv.org/abs/2006.11239
The conditioning idea is taken from 'Classifier-Free Diffusion Guidance',
https://arxiv.org/abs/2207.12598
This technique also features in ImageGen 'Photorealistic Text-to-Image Diffusion Modelswith Deep Language Understanding',
https://arxiv.org/abs/2205.11487
'''
from typing import Dict, Tuple
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import models, transforms
from torchvision.datasets import MNIST
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
import numpy as np
import os
import clip
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
# not used in the final model
x = x + self.pe[:x.shape[0], :]
return self.dropout(x)
class TimestepEmbedder(nn.Module):
def __init__(self, latent_dim, sequence_pos_encoder):
super().__init__()
self.latent_dim = latent_dim
self.sequence_pos_encoder = sequence_pos_encoder
time_embed_dim = self.latent_dim
self.time_embed = nn.Sequential(
nn.Linear(self.latent_dim, time_embed_dim),
nn.SiLU(),
nn.Linear(time_embed_dim, time_embed_dim),
)
def forward(self, timesteps):
return self.time_embed(self.sequence_pos_encoder.pe[timesteps]).permute(1, 0, 2)
class Transformer(nn.Module):
def __init__(self, n_feature, n_label, latent_dim=256,
num_heads=4, ff_size=1024, dropout=0.1, activation='gelu',
num_layers=4, sliding_wind=300):
super(Transformer, self).__init__()
self.n_feature = n_feature
self.n_label = n_label
self.num_heads = num_heads
self.ff_size = ff_size
self.dropout = dropout
self.activation = activation
self.num_layers = num_layers
self.latent_dim = latent_dim
self.input_process = nn.Linear(self.n_feature, self.latent_dim)
seqTransEncoderlayer = nn.TransformerEncoderLayer(d_model=self.latent_dim,
nhead = self.num_heads,
dim_feedforward = self.ff_size,
dropout = self.dropout,
activation=self.activation)
self.seqTransEncoder = nn.TransformerEncoder(seqTransEncoderlayer,
num_layers = self.num_layers)
self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout)
self.embed_timestep = TimestepEmbedder(self.latent_dim, self.sequence_pos_encoder)
self.output_process = nn.Sequential(
nn.Linear(self.latent_dim, 1),
nn.ReLU()
)
self.pred = nn.Sequential(
nn.Linear(sliding_wind, n_label),
# nn.Softmax(dim=1),
)
def forward(self, x):
bs = len(x)
x = self.input_process(x.permute(1, 0, 2))
xseq = self.sequence_pos_encoder(x)
xseq = self.seqTransEncoder(xseq)
xseq = self.output_process(xseq).permute(1, 0, 2)
xseq = xseq.view(bs, -1)
return self.pred(xseq)
def forward_feature(self, x):
bs = len(x)
x = self.input_process(x.permute(1, 0, 2))
xseq = self.sequence_pos_encoder(x)
xseq = self.seqTransEncoder(xseq)
xseq = self.output_process(xseq).permute(1, 0, 2)
return xseq.view(bs, -1)
import torch.utils.data as data
class camdataset(data.Dataset):
def __init__(self, cam, label):
self.cam = cam
self.label = label
def __getitem__(self, index):
d = self.cam[index]
data = np.concatenate((d, d[-1:].repeat(300-len(d), 0)), 0)
return np.array(data, dtype="float32"), self.label[index]
def __len__(self):
return len(self.cam)
def train_mnist():
data = np.load("data.npy", allow_pickle=True)[()]
d = np.concatenate(data["train_cam"]+data["test_cam"], 0)
Mean, Std = np.mean(d, 0), np.std(d, 0)
np.save("Mean_Std", {"Mean": Mean, "Std": Std})
for i in range(len(data["train_cam"])):
data["train_cam"][i] = (data["train_cam"][i] - Mean[None, :]) / (Std[None, :]+1e-8)
for i in range(len(data["test_cam"])):
data["test_cam"][i] = (data["test_cam"][i] - Mean[None, :]) / (Std[None, :]+1e-8)
# hardcoding these here
n_epoch = 1000
batch_size = 128
device = "cuda:0"
n_feature = 5
n_label = 6
lrate = 1e-4
save_model = True
save_dir = './result/'
if not os.path.exists(save_dir):
os.mkdir(save_dir)
criterion = torch.nn.CrossEntropyLoss()
trans = Transformer(n_feature=n_feature, n_label=n_label)
trans.to(device)
optim = torch.optim.Adam(trans.parameters(), lr=lrate)
dataloader = DataLoader(camdataset(data['train_cam'], data['train_label']), batch_size=batch_size, shuffle=True, num_workers=5)
testloader = DataLoader(camdataset(data['test_cam'], data['test_label']), batch_size=batch_size, shuffle=False, num_workers=5)
if not os.path.exists("result"):
os.mkdir("result")
for ep in range(n_epoch):
print(f'epoch {ep}')
# linear lrate decay
optim.param_groups[0]['lr'] = lrate*(1-ep/n_epoch)
pbar = tqdm(dataloader)
trans.train()
correct = 0
total = 0
for cam, label in pbar:
cam = cam.to(device)
label = label.to(device)
pred_v = trans(cam)
predictions = torch.argmax(pred_v, dim=1)
correct += torch.sum(predictions == label).item()
total += len(predictions)
optim.zero_grad()
loss = criterion(pred_v, label)
loss.backward()
pbar.set_description(f"training acc: {100.0 * correct/total:.4f}")
optim.step()
trans.eval()
correct = 0
total = 0
for cam, label in testloader:
cam = cam.to(device)
label = label.to(device)
pred_v = trans(cam)
predictions = torch.argmax(pred_v, dim=1)
correct += torch.sum(predictions == label)
total += len(predictions)
print("evaluation accuracy : {}".format(1.0 * correct / total))
torch.save(trans.state_dict(), save_dir + f"latest.pth")
if save_model and ep % 100 == 0:
torch.save(trans.state_dict(), save_dir + f"model_{ep}.pth")
print('saved model at ' + save_dir + f"model_{ep}.pth")
def eval_mnist(file_name):
if not os.path.exists("Mean_Std.npy"):
data = np.load("data.npy", allow_pickle=True)[()]
d = np.concatenate(data["train_cam"] + data["test_cam"], 0)
Mean, Std = np.mean(d, 0), np.std(d, 0)
np.save("Mean_Std", {"Mean":Mean, "Std":Std})
d = np.load("Mean_Std.npy", allow_pickle=True)[()]
Mean, Std = d["Mean"], d["Std"]
data = np.load(file_name+".npy", allow_pickle=True)[()]
for i in range(len(data["result"])):
data["result"][i] = (data["result"][i] - Mean[None, :]) / (Std[None, :]+1e-8)
device = "cuda:0"
n_feature = 5
n_label = 6
trans = Transformer(n_feature=n_feature, n_label=n_label)
trans.to(device)
# optionally load a model
trans.load_state_dict(torch.load("./result/latest.pth"))
testloader = DataLoader(camdataset(data['result'], data['label']), batch_size=8, num_workers=5)
correct = 0
total = 0
t = [0] * 10
f = [0] * 10
trans.eval()
with torch.no_grad():
for cam, label in tqdm(testloader):
cam = cam.to(device)
label = label.to(device)
pred_v = trans(cam)
predictions = torch.argmax(pred_v, dim=1)
correct += torch.sum(predictions == label)
total += len(predictions)
for i in range(len(predictions)):
if predictions[i] == label[i]:
t[label[i]] += 1
else:
f[label[i]] += 1
print("gen accuracy : {}/{}={} ".format(correct, total, 1.0 * correct / total))
for i in range(n_label):
print("{} {} {}".format(i, t[i], t[i]+f[i]))
def process_feature(file_list):
data = np.load("data.npy", allow_pickle=True)[()]
d = np.concatenate(data["train_cam"] + data["test_cam"], 0)
Mean, Std = np.mean(d, 0), np.std(d, 0)
for i in range(len(data["train_cam"])):
data["train_cam"][i] = (data["train_cam"][i] - Mean[None, :]) / (Std[None, :]+1e-8)
for i in range(len(data["test_cam"])):
data["test_cam"][i] = (data["test_cam"][i] - Mean[None, :]) / (Std[None, :]+1e-8)
device = "cuda:0"
n_feature = 5
n_label = 6
trans = Transformer(n_feature=n_feature, n_label=n_label)
trans.to(device)
# optionally load a model
trans.load_state_dict(torch.load("./result/latest.pth"))
trans.eval()
d = dict()
testloader = DataLoader(camdataset(data['train_cam'], data['train_label']), batch_size=8, num_workers=5)
feature = []
with torch.no_grad():
for cam, label in tqdm(testloader):
cam = cam.to(device)
pred_v = trans.forward_feature(cam).detach().cpu().numpy()
for v in pred_v:
feature.append(v)
d["train_data"] = feature
testloader = DataLoader(camdataset(data['test_cam'], data['test_label']), batch_size=8, num_workers=5)
feature = []
with torch.no_grad():
for cam, label in tqdm(testloader):
cam = cam.to(device)
pred_v = trans.forward_feature(cam).detach().cpu().numpy()
for v in pred_v:
feature.append(v)
d["test_data"] = feature
for file in file_list:
data = np.load(file+".npy", allow_pickle=True)[()]
for i in range(len(data["result"])):
data["result"][i] = (data["result"][i] - Mean[None, :]) / (Std[None, :] + 1e-8)
testloader = DataLoader(camdataset(data['result'], data['label']), batch_size=8, num_workers=5)
feature = []
with torch.no_grad():
for cam, label in tqdm(testloader):
cam = cam.to(device)
pred_v = trans.forward_feature(cam).detach().cpu().numpy()
for v in pred_v:
feature.append(v)
d[file] = feature
np.save("feature", d)
if __name__ == "__main__":
train_mnist()
#
# eval_mnist()
# process_feature()