Yvonnefanf
first
7b5e67a
########################################################################################################################
# IMPORT #
########################################################################################################################
import torch
import sys
import os
import json
import time
import numpy as np
import argparse
from torch.utils.data import DataLoader
from torch.utils.data import WeightedRandomSampler
from umap.umap_ import find_ab_params
from singleVis.custom_weighted_random_sampler import CustomWeightedRandomSampler
from singleVis.SingleVisualizationModel import VisModel
from singleVis.losses import UmapLoss, ReconstructionLoss, TemporalLoss, DVILoss, SingleVisLoss, DummyTemporalLoss
from singleVis.edge_dataset import DVIDataHandler
from singleVis.trainer import DVITrainer
from singleVis.eval.evaluator import Evaluator
from singleVis.data import NormalDataProvider
# from singleVis.spatial_edge_constructor import SingleEpochSpatialEdgeConstructor
from singleVis.spatial_skeleton_edge_constructor import ProxyBasedSpatialEdgeConstructor
from singleVis.projector import DVIProjector
from singleVis.utils import find_neighbor_preserving_rate
from trustVis.skeleton_generator import CenterSkeletonGenerator
########################################################################################################################
# DVI PARAMETERS #
########################################################################################################################
"""This serve as an example of DeepVisualInsight implementation in pytorch."""
VIS_METHOD = "DVI" # DeepVisualInsight
########################################################################################################################
# LOAD PARAMETERS #
########################################################################################################################
parser = argparse.ArgumentParser(description='Process hyperparameters...')
# get workspace dir
current_path = os.getcwd()
new_path = os.path.join(current_path, 'training_dynamic')
parser.add_argument('--content_path', type=str,default=new_path)
# parser.add_argument('--start', type=int,default=1)
# parser.add_argument('--end', type=int,default=3)
parser.add_argument('--epoch' ,default=3)
# parser.add_argument('--epoch_end', type=int)
parser.add_argument('--epoch_period', type=int,default=1)
parser.add_argument('--preprocess', type=int,default=0)
parser.add_argument('--base',type=bool,default=False)
args = parser.parse_args()
CONTENT_PATH = args.content_path
sys.path.append(CONTENT_PATH)
with open(os.path.join(CONTENT_PATH, "config.json"), "r") as f:
config = json.load(f)
config = config[VIS_METHOD]
# record output information
# now = time.strftime("%Y-%m-%d-%H_%M_%S", time.localtime(time.time()))
# sys.stdout = open(os.path.join(CONTENT_PATH, now+".txt"), "w")
SETTING = config["SETTING"]
CLASSES = config["CLASSES"]
DATASET = config["DATASET"]
PREPROCESS = config["VISUALIZATION"]["PREPROCESS"]
GPU_ID = config["GPU"]
GPU_ID = 0
EPOCH_START = config["EPOCH_START"]
EPOCH_END = config["EPOCH_END"]
EPOCH_PERIOD = config["EPOCH_PERIOD"]
EPOCH_START = args.epoch
EPOCH_END = args.epoch
EPOCH_PERIOD = args.epoch_period
# Training parameter (subject model)
TRAINING_PARAMETER = config["TRAINING"]
NET = TRAINING_PARAMETER["NET"]
LEN = TRAINING_PARAMETER["train_num"]
# Training parameter (visualization model)
VISUALIZATION_PARAMETER = config["VISUALIZATION"]
LAMBDA1 = VISUALIZATION_PARAMETER["LAMBDA1"]
LAMBDA2 = VISUALIZATION_PARAMETER["LAMBDA2"]
B_N_EPOCHS = VISUALIZATION_PARAMETER["BOUNDARY"]["B_N_EPOCHS"]
L_BOUND = VISUALIZATION_PARAMETER["BOUNDARY"]["L_BOUND"]
ENCODER_DIMS = VISUALIZATION_PARAMETER["ENCODER_DIMS"]
DECODER_DIMS = VISUALIZATION_PARAMETER["DECODER_DIMS"]
S_N_EPOCHS = VISUALIZATION_PARAMETER["S_N_EPOCHS"]
N_NEIGHBORS = VISUALIZATION_PARAMETER["N_NEIGHBORS"]
PATIENT = VISUALIZATION_PARAMETER["PATIENT"]
MAX_EPOCH = VISUALIZATION_PARAMETER["MAX_EPOCH"]
VIS_MODEL_NAME = 'proxy' ### saved_as
EVALUATION_NAME = VISUALIZATION_PARAMETER["EVALUATION_NAME"]
# Define hyperparameters
DEVICE = torch.device("cuda:{}".format(GPU_ID) if torch.cuda.is_available() else "cpu")
import Model.model as subject_model
net = eval("subject_model.{}()".format(NET))
########################################################################################################################
# TRAINING SETTING #
########################################################################################################################
# Define data_provider
data_provider = NormalDataProvider(CONTENT_PATH, net, EPOCH_START, EPOCH_END, EPOCH_PERIOD, device=DEVICE, epoch_name='Epoch',classes=CLASSES,verbose=1)
PREPROCESS = args.preprocess
if PREPROCESS:
data_provider._meta_data()
if B_N_EPOCHS >0:
data_provider._estimate_boundary(LEN // 10, l_bound=L_BOUND)
# Define visualization models
model = VisModel(ENCODER_DIMS, DECODER_DIMS)
# Define Losses
negative_sample_rate = 5
min_dist = .1
_a, _b = find_ab_params(1.0, min_dist)
umap_loss_fn = UmapLoss(negative_sample_rate, DEVICE, _a, _b, repulsion_strength=1.0)
recon_loss_fn = ReconstructionLoss(beta=1.0)
single_loss_fn = SingleVisLoss(umap_loss_fn, recon_loss_fn, lambd=LAMBDA1)
# Define Projector
projector = DVIProjector(vis_model=model, content_path=CONTENT_PATH, vis_model_name=VIS_MODEL_NAME, device=DEVICE)
start_flag = 1
prev_model = VisModel(ENCODER_DIMS, DECODER_DIMS)
for iteration in range(EPOCH_START, EPOCH_END+EPOCH_PERIOD, EPOCH_PERIOD):
# Define DVI Loss
if start_flag:
temporal_loss_fn = DummyTemporalLoss(DEVICE)
criterion = DVILoss(umap_loss_fn, recon_loss_fn, temporal_loss_fn, lambd1=LAMBDA1, lambd2=0.0,device=DEVICE)
start_flag = 0
else:
# TODO AL mode, redefine train_representation
prev_data = data_provider.train_representation(iteration-EPOCH_PERIOD)
prev_data = prev_data.reshape(prev_data.shape[0],prev_data.shape[1])
curr_data = data_provider.train_representation(iteration)
curr_data = curr_data.reshape(curr_data.shape[0],curr_data.shape[1])
t_1= time.time()
npr = torch.tensor(find_neighbor_preserving_rate(prev_data, curr_data, N_NEIGHBORS)).to(DEVICE)
t_2= time.time()
temporal_loss_fn = TemporalLoss(w_prev, DEVICE)
criterion = DVILoss(umap_loss_fn, recon_loss_fn, temporal_loss_fn, lambd1=LAMBDA1, lambd2=LAMBDA2*npr,device=DEVICE)
# Define training parameters
optimizer = torch.optim.Adam(model.parameters(), lr=.01, weight_decay=1e-5)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=.1)
# Define Edge dataset
###### generate the skeleton
skeleton_generator = CenterSkeletonGenerator(data_provider,EPOCH_START,1)
# Start timing
start_time = time.time()
## gennerate skeleton
high_bom,_ = skeleton_generator.center_skeleton_genertaion()
end_time = time.time()
elapsed_time = end_time - start_time
print("proxy generation finished ")
t0 = time.time()
##### construct the spitial complex
spatial_cons = ProxyBasedSpatialEdgeConstructor(data_provider, iteration, S_N_EPOCHS, B_N_EPOCHS, N_NEIGHBORS, net,high_bom)
edge_to, edge_from, probs, feature_vectors, attention = spatial_cons.construct()
t1 = time.time()
print('complex-construct:', t1-t0)
probs = probs / (probs.max()+1e-3)
eliminate_zeros = probs> 1e-3 #1e-3
edge_to = edge_to[eliminate_zeros]
edge_from = edge_from[eliminate_zeros]
probs = probs[eliminate_zeros]
dataset = DVIDataHandler(edge_to, edge_from, feature_vectors, attention)
n_samples = int(np.sum(S_N_EPOCHS * probs) // 1)
# chose sampler based on the number of dataset
if len(edge_to) > pow(2,24):
sampler = CustomWeightedRandomSampler(probs, n_samples, replacement=True)
else:
sampler = WeightedRandomSampler(probs, n_samples, replacement=True)
edge_loader = DataLoader(dataset, batch_size=2000, sampler=sampler, num_workers=8, prefetch_factor=10)
########################################################################################################################
# TRAIN #
########################################################################################################################
trainer = DVITrainer(model, criterion, optimizer, lr_scheduler, edge_loader=edge_loader, DEVICE=DEVICE)
t2=time.time()
trainer.train(PATIENT, MAX_EPOCH)
t3 = time.time()
print('training:', t3-t2)
# save result
save_dir = data_provider.model_path
trainer.record_time(save_dir, "time_{}".format(VIS_MODEL_NAME), "complex_construction", str(iteration), t1-t0)
trainer.record_time(save_dir, "time_{}".format(VIS_MODEL_NAME), "training", str(iteration), t3-t2)
save_dir = os.path.join(data_provider.model_path, "Epoch_{}".format(iteration))
trainer.save(save_dir=save_dir, file_name="{}".format(VIS_MODEL_NAME))
print("Finish epoch {}...".format(iteration))
prev_model.load_state_dict(model.state_dict())
for param in prev_model.parameters():
param.requires_grad = False
w_prev = dict(prev_model.named_parameters())
########################################################################################################################
# VISUALIZATION #
########################################################################################################################
from singleVis.visualizer import visualizer
now = time.strftime("%Y-%m-%d-%H_%M_%S", time.localtime(time.time()))
vis = visualizer(data_provider, projector, 200, "tab10")
save_dir = os.path.join(data_provider.content_path, "Proxy")
if not os.path.exists(save_dir):
os.mkdir(save_dir)
for i in range(EPOCH_START, EPOCH_END+1, EPOCH_PERIOD):
vis.savefig(i, path=os.path.join(save_dir, "{}_{}_{}_{}.png".format(DATASET, i, VIS_METHOD,now)))
data = data_provider.train_representation(i)
data = data.reshape(data.shape[0],data.shape[1])
##### save embeddings and background for visualization
emb = projector.batch_project(i,data)
np.save(os.path.join(CONTENT_PATH, 'Model', 'Epoch_{}'.format(i), 'embedding.npy'), emb)
vis.get_background(i,200)
# emb = projector.batch_project(data_provider)
########################################################################################################################
# EVALUATION #
########################################################################################################################
# eval_epochs = range(EPOCH_START, EPOCH_END+1, EPOCH_PERIOD)
# EVAL_EPOCH_DICT = {
# "mnist":[1,10,15],
# "fmnist":[1,25,50],
# "cifar10":[1,100,199]
# }
# eval_epochs = EVAL_EPOCH_DICT[DATASET]
evaluator = Evaluator(data_provider, projector)
Evaluation_NAME = 'proxy_eval'
for i in range(EPOCH_START, EPOCH_END+1, EPOCH_PERIOD):
evaluator.save_epoch_eval(i, 15, temporal_k=5, file_name="{}".format(Evaluation_NAME))