|
import os |
|
import numpy as np |
|
import random |
|
import pandas as pd |
|
import math |
|
from tqdm import tqdm |
|
import time |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.utils.data import DataLoader |
|
|
|
import Stage3_source.preprocess as prep |
|
import Stage3_source.cond_diff_transformer_layer as mod |
|
import Stage3_source.transformer_training_helper as train_helper |
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
def cond_autocomplete_real_samples( |
|
model: nn.Module, |
|
args: any, |
|
realization: torch.Tensor, |
|
y_c: torch.Tensor, |
|
idx: torch.Tensor |
|
) -> ( |
|
any, |
|
torch.Tensor, |
|
torch.Tensor, |
|
torch.Tensor, |
|
torch.Tensor |
|
): |
|
|
|
model.eval() |
|
bs, channel, seq_length = realization.size() |
|
|
|
sampled_random_path = train_helper.sample_random_path(bs, seq_length, device=args.device) |
|
|
|
random_path_mask = train_helper.create_mask_at_random_path_index(sampled_random_path, idx, bs, seq_length) |
|
|
|
real_tokens, bs, seq_length= train_helper.create_token_labels(args, realization) |
|
|
|
|
|
|
|
real_token_masked = train_helper.mask_realizations(real_tokens, random_path_mask) |
|
|
|
conditional_prob, probs = train_helper.cond_predict_conditional_prob(model, real_token_masked, y_c, idx, args) |
|
|
|
log_prob = train_helper.log_prob_of_realization(args, conditional_prob, real_tokens) |
|
|
|
return ( |
|
conditional_prob, |
|
probs.cpu(), |
|
real_token_masked.cpu(), |
|
real_tokens.cpu(), |
|
log_prob.cpu(), |
|
sampled_random_path.cpu(), |
|
random_path_mask.cpu() |
|
) |
|
|
|
|
|
|
|
def extract_samples_with_labels( |
|
dataloader: DataLoader, |
|
target_labels: int, |
|
total_num: int, |
|
pad_included: bool=False |
|
) -> dict: |
|
|
|
extracted_sampled = { |
|
'sample': [], |
|
'label': [] |
|
} |
|
|
|
for data, labels in dataloader: |
|
for i, label in enumerate(labels): |
|
|
|
if label.item() == target_labels: |
|
|
|
if pad_included: |
|
pass |
|
else: |
|
data[i] += 1 |
|
|
|
extracted_sampled['sample'].append(data[i]) |
|
extracted_sampled['label'].append(label) |
|
if len(extracted_sampled['label']) == total_num: |
|
return extracted_sampled |
|
|
|
return extracted_sampled |
|
|
|
|
|
|
|
def corrupt_samples( |
|
args: any, |
|
realization: torch.Tensor, |
|
perc: float |
|
) -> torch.Tensor: |
|
|
|
bs, channels, seq_length = realization.size() |
|
|
|
|
|
idx = (args.diffusion_steps * torch.Tensor([perc])).to(int).to(args.device) |
|
|
|
sampled_random_path = train_helper.sample_random_path(bs, seq_length, device=args.device) |
|
|
|
random_path_mask = train_helper.create_mask_at_random_path_index(sampled_random_path, idx, bs, seq_length) |
|
|
|
real_tokens, bs, seq_length= train_helper.create_token_labels(args, realization) |
|
|
|
real_token_masked = train_helper.mask_realizations(real_tokens, random_path_mask) |
|
|
|
return ( |
|
real_token_masked, |
|
sampled_random_path, |
|
idx |
|
) |
|
|
|
|
|
@torch.no_grad() |
|
def predict_next_index( |
|
model: nn.Module, |
|
args: any, |
|
mask_realization: torch.Tensor, |
|
y_c: torch.Tensor, |
|
idx: torch.Tensor |
|
) -> ( |
|
any, |
|
torch.Tensor, |
|
torch.Tensor, |
|
torch.Tensor, |
|
torch.Tensor, |
|
torch.Tensor |
|
): |
|
|
|
model.eval() |
|
bs, channel, seq_length = mask_realization.size() |
|
|
|
|
|
conditional_prob, probs = train_helper.cond_predict_conditional_prob(model, mask_realization.squeeze(1), y_c, idx, args) |
|
|
|
return ( |
|
conditional_prob, |
|
probs.cpu(), |
|
) |
|
|
|
|
|
|
|
|
|
def generate_denoised_sampled( |
|
args: any, |
|
model: nn.Module, |
|
extract_digit_samples: torch.Tensor, |
|
extract_time: torch.Tensor, |
|
extract_digit_label: torch.Tensor, |
|
sampling_path: torch.Tensor |
|
) -> ( |
|
list, |
|
list |
|
): |
|
|
|
mask_realization_list, time_idx_list = [], [] |
|
|
|
|
|
temp_y_c = extract_digit_label.to(args.device) |
|
temp_mask_realization = extract_digit_samples.unsqueeze(1).long().to(args.device) |
|
temp_idx = torch.Tensor([extract_time]).to(args.device).squeeze(0) |
|
temp_sampling_path = sampling_path.to(args.device) |
|
|
|
for ii in tqdm(range(int(temp_idx.item()), args.diffusion_steps)): |
|
|
|
|
|
current_location = temp_sampling_path == temp_idx |
|
print(current_location.shape) |
|
|
|
|
|
conditional_prob, prob = predict_next_index( |
|
model=model, |
|
args=args, |
|
mask_realization=temp_mask_realization, |
|
y_c=temp_y_c, |
|
idx=temp_idx |
|
) |
|
|
|
|
|
next_temp_realization = torch.argmax( |
|
conditional_prob.sample(), dim=-1 |
|
) |
|
|
|
temp_mask_realization[0, current_location] = next_temp_realization[current_location] |
|
mask_realization_list.append(temp_mask_realization.cpu().numpy()) |
|
time_idx_list.append(temp_idx.cpu().numpy()) |
|
temp_idx+=1 |
|
|
|
|
|
return ( |
|
mask_realization_list, |
|
time_idx_list |
|
) |
|
|
|
|
|
def batch_generate_denoised_sampled( |
|
args: any, |
|
model: nn.Module, |
|
extract_digit_samples: torch.Tensor, |
|
extract_time: torch.Tensor, |
|
extract_digit_label: torch.Tensor, |
|
sampling_path: torch.Tensor |
|
) -> (list, list): |
|
|
|
|
|
assert extract_digit_samples.size(0) == extract_digit_label.size(0) == sampling_path.size(0) == extract_time.size(0), "Mismatched batch dimensions" |
|
|
|
batch_size = extract_digit_samples.size(0) |
|
mask_realization_list, time_idx_list = [], [] |
|
print('batch_size:', batch_size) |
|
|
|
|
|
temp_y_c = extract_digit_label.to(args.device) |
|
temp_mask_realization = extract_digit_samples.unsqueeze(1).long().to(args.device) |
|
temp_idx = extract_time.unsqueeze(-1).to(args.device) |
|
temp_sampling_path = sampling_path.to(args.device) |
|
print(f"Starting temp_idx: {temp_idx[0].item()}") |
|
|
|
start_time_index = temp_idx[0].item() |
|
max_diffusion_step = args.diffusion_steps |
|
|
|
|
|
for ii in tqdm(range(start_time_index, max_diffusion_step), initial=start_time_index, total=max_diffusion_step): |
|
|
|
|
|
if torch.any(temp_idx >= args.diffusion_steps): |
|
break |
|
|
|
|
|
current_ii = torch.full((batch_size,), ii, dtype=torch.long, device=args.device) |
|
|
|
|
|
conditional_prob, prob = predict_next_index( |
|
model=model, |
|
args=args, |
|
mask_realization=temp_mask_realization, |
|
y_c=temp_y_c, |
|
idx=temp_idx |
|
) |
|
|
|
|
|
|
|
next_temp_realization = torch.argmax(conditional_prob.sample(), dim=-1) |
|
|
|
|
|
current_location = temp_sampling_path == temp_idx |
|
current_location = torch.argmax(current_location.detach().cpu()*1, dim=-1) |
|
temp_mask_realization[:, 0, current_location] = next_temp_realization[:,current_location] |
|
|
|
|
|
mask_realization_list.append(temp_mask_realization.cpu().numpy()) |
|
time_idx_list.append(temp_idx.cpu().numpy()) |
|
|
|
|
|
temp_idx += 1 |
|
|
|
return mask_realization_list, time_idx_list |
|
|
|
|
|
|
|
|
|
def convert_num_to_chars( |
|
tokenizer: any, |
|
num_seq: list |
|
) -> list: |
|
|
|
char_seq = [tokenizer[num] for num in num_seq] |
|
return "".join(char_seq) |
|
|