from torch.utils.data import Dataset | |
class CustomDataset(Dataset): | |
def __init__(self, data) -> None: | |
super().__init__() | |
self.data = data | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, index): | |
# Get data | |
d = self.data[index] | |
return d | |
class EarlyStopping(): | |
def __init__(self, tolerance=10, min_delta=0): | |
self.tolerance = tolerance | |
self.min_delta = min_delta | |
self.counter = 0 | |
self.early_stop = False | |
def __call__(self, train_loss, min_loss): | |
if (train_loss-min_loss) > self.min_delta: | |
self.counter +=1 | |
if self.counter >= self.tolerance: | |
self.early_stop = True | |
# def gen_text_from_center(args,plugin_vae, vae_model, decoder_tokenizer,label,epoch,pos): | |
# gen_text = [] | |
# latent_z = gen_latent_center(plugin_vae,pos).to(args.device).repeat((1,1)) | |
# print("latent_z",latent_z.shape) | |
# text_analogy = text_from_latent_code_batch(latent_z, vae_model, args, decoder_tokenizer) | |
# print("label",label) | |
# print(text_analogy) | |
# gen_text.extend([(label,y,epoch) for y in text_analogy]) | |
# text2out(gen_text, '/cognitive_comp/liangyuxin/projects/cond_vae/outputs/test.json') |