IlayMalinyak
remove pwvt
72a8e1c
raw
history blame
8.01 kB
import torch.nn
from torch.utils.data import DataLoader
from utils.data import FFTDataset, SplitDataset, AudioINRDataset
from datasets import load_dataset
from utils.train import Trainer, INRTrainer
from utils.models import MultiGraph, ImplicitEncoder
from omegaconf import OmegaConf
# from .utils.models import CNNKan, KanEncoder
from utils.inr import INR
from utils.data_utils import *
from huggingface_hub import login
import yaml
import datetime
import json
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from scipy.signal import savgol_filter as savgol
from utils.kan import FasterKAN
from utils.relational_transformer import RelationalTransformer
from collections import OrderedDict
def plot_results(dims, i, data, losses, pred_values):
data = savgol(data.cpu().detach().numpy(), window_length=250, polyorder=1)
pred_values = pred_values.transpose(-1, -2).unflatten(-1, data.shape[-2:]).squeeze(0).cpu().detach().numpy()
pred_values = (pred_values - np.min(pred_values)) / (np.max(pred_values) - np.min(pred_values))
data = (data - np.min(data)) / (np.max(data) - np.min(data))
plt.plot(data.squeeze())
plt.plot(pred_values.squeeze())
# axes[0].set_title('Original')
# axes[1].set_title('Reconstruction')
plt.show()
# plt.plot(np.arange(len(losses)), losses)
# plt.xlabel('Iteration')
# plt.ylabel('Reconstruction MSE Error')
# plt.show()
# local_rank = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
current_date = datetime.date.today().strftime("%Y-%m-%d")
datetime_dir = f"frugal_{current_date}"
args_dir = 'utils/config.yaml'
data_args = Container(**yaml.safe_load(open(args_dir, 'r'))['Data'])
exp_num = data_args.exp_num
model_name = data_args.model_name
rt_args = Container(**yaml.safe_load(open(args_dir, 'r'))['RelationalTransformer'])
cnn_args = Container(**yaml.safe_load(open(args_dir, 'r'))['CNNEncoder_f'])
conformer_args = Container(**yaml.safe_load(open(args_dir, 'r'))['Conformer'])
kan_args = Container(**yaml.safe_load(open(args_dir, 'r'))['KAN_INR'])
inr_args = Container(**yaml.safe_load(open(args_dir, 'r'))['INR'])
if not os.path.exists(f"{data_args.log_dir}/{datetime_dir}"):
os.makedirs(f"{data_args.log_dir}/{datetime_dir}")
with open("../../logs/token.txt", "r") as f:
api_key = f.read()
# local_rank, world_size, gpus_per_node = setup()
local_rank = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
login(api_key)
dataset = load_dataset("rfcx/frugalai", streaming=True)
train_ds = SplitDataset(AudioINRDataset(FFTDataset(dataset["train"])), is_train=True)
train_dl = DataLoader(train_ds, batch_size=data_args.batch_size)
val_ds = SplitDataset(AudioINRDataset(FFTDataset(dataset["train"])), is_train=False)
val_dl = DataLoader(val_ds, batch_size=data_args.batch_size)
test_ds = AudioINRDataset(FFTDataset(dataset["test"]))
test_dl = DataLoader(test_ds, batch_size=data_args.batch_size)
# for i, batch in enumerate(train_ds):
# fft_phase, fft_mag, audio = batch['audio']['fft_phase'], batch['audio']['fft_mag'], batch['audio']['array']
# label = batch['label']
# fig, axes = plt.subplots(nrows=1, ncols=3)
# axes = axes.flatten()
# axes[0].plot(fft_phase)
# axes[1].plot(fft_mag)
# axes[2].plot(audio)
# fig.suptitle(label)
# plt.tight_layout()
# plt.show()
# if i > 20:
# break
# model = DualEncoder(model_args, model_args_f, conformer_args)
# model = FasterKAN([18000,64,64,16,1])
# model = INR(in_features=1)
# model.kan.speed()
# model = KanEncoder(kan_args.get_dict())
# model = model.to(local_rank)
# state_dict = torch.load(data_args.checkpoint_path, map_location=torch.device('cpu'))
# new_state_dict = OrderedDict()
# for key, value in state_dict.items():
# if key.startswith('module.'):
# key = key[7:]
# new_state_dict[key] = value
# missing, unexpected = model.load_state_dict(new_state_dict)
# model = DDP(model, device_ids=[local_rank], output_device=local_rank)
# num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
# print(f"Number of parameters: {num_params}")
#
# optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# total_steps = int(data_args.num_epochs) * 1000
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
# T_max=total_steps,
# eta_min=float((5e-4) / 10))
loss_fn = torch.nn.BCEWithLogitsLoss()
inr_criterion = torch.nn.MSELoss()
for i, batch in enumerate(train_ds):
coords, fft, audio = batch['audio']['coords'], batch['audio']['fft_mag'], batch['audio']['array']
coords = coords.to(local_rank)
fft = fft.to(local_rank)
audio = audio.to(local_rank)
# values = torch.cat((audio.unsqueeze(-1), fft.unsqueeze(-1)), dim=-1)
model = INR(hidden_features=128, n_layers=4,
in_features=1,
out_features=1).to(local_rank)
# model = FasterKAN(layers_hidden=[1,16,16,1]).to(local_rank)
optimizer = torch.optim.Adam([{'params': model.parameters()}], lr=1e-3)
pbar = tqdm(range(200))
losses = []
print(coords.shape)
for t in pbar:
optimizer.zero_grad()
pred_values = model(coords.to(local_rank)).float()
loss = inr_criterion(pred_values, fft)
loss.backward()
optimizer.step()
pbar.set_description(f'loss: {loss.item()}')
losses.append(loss.item())
state_dict = model.state_dict()
torch.save(state_dict, 'test')
# print(f'Sample {i+offset} label {label} saved in {inr_path}')
plot_results(1, i, fft, losses, pred_values)
# #
exit()
# missing, unexpected = model.load_state_dict(torch.load(model_args.checkpoint_path))
# print(f"Missing keys: {missing}")
# print(f"Unexpected keys: {unexpected}")
layer_layout = [inr_args.in_features] + [inr_args.hidden_features for _ in range(inr_args.n_layers)] + [inr_args.out_features]
graph_constructor = OmegaConf.create(
{
"_target_": "utils.graph_constructor.GraphConstructor",
"_recursive_": False,
"_convert_": "all",
"d_in": 1,
"d_edge_in": 1,
"zero_out_bias": False,
"zero_out_weights": False,
"sin_emb": True,
"sin_emb_dim": rt_args.d_node,
"use_pos_embed": False,
"input_layers": 1,
"inp_factor": 1,
"num_probe_features": 0,
"inr_model": None,
"stats": None,
"sparsify": False,
'sym_edges': False,
}
)
rt_model = RelationalTransformer(layer_layout=layer_layout, graph_constructor=graph_constructor,
**rt_args.get_dict()).to(local_rank)
rt_model.proj_out= torch.nn.Identity()
multi_graph = MultiGraph(rt_model, cnn_args)
implicit_net = INR(**inr_args.get_dict())
model = ImplicitEncoder(implicit_net, multi_graph).to(local_rank)
num_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Number of parameters: {num_parameters}")
optimizer = torch.optim.Adam([{'params': model.parameters()}], lr=1e-3)
trainer = Trainer(model=model, optimizer=optimizer,
criterion=loss_fn, output_dim=1, scaler=None,
scheduler=None, train_dataloader=train_dl,
val_dataloader=val_dl, device=local_rank,
exp_num=datetime_dir, log_path=data_args.log_dir,
range_update=None,
accumulation_step=1, max_iter=100,
exp_name=f"frugal_kan_{exp_num}")
fit_res = trainer.fit(num_epochs=100, device=local_rank,
early_stopping=10, only_p=False, best='loss', conf=True)
output_filename = f'{data_args.log_dir}/{datetime_dir}/{model_name}_frugal_{exp_num}.json'
with open(output_filename, "w") as f:
json.dump(fit_res, f, indent=2)
preds, acc = trainer.predict(test_dl, local_rank)
print(f"Accuracy: {acc}")