import torch.nn from torch.utils.data import DataLoader from utils.data import FFTDataset, SplitDataset, AudioINRDataset from datasets import load_dataset from utils.train import Trainer, INRTrainer from utils.models import MultiGraph, ImplicitEncoder from omegaconf import OmegaConf # from .utils.models import CNNKan, KanEncoder from utils.inr import INR from utils.data_utils import * from huggingface_hub import login import yaml import datetime import json import numpy as np from tqdm import tqdm import matplotlib.pyplot as plt from scipy.signal import savgol_filter as savgol from utils.kan import FasterKAN from utils.relational_transformer import RelationalTransformer from collections import OrderedDict def plot_results(dims, i, data, losses, pred_values): data = savgol(data.cpu().detach().numpy(), window_length=250, polyorder=1) pred_values = pred_values.transpose(-1, -2).unflatten(-1, data.shape[-2:]).squeeze(0).cpu().detach().numpy() pred_values = (pred_values - np.min(pred_values)) / (np.max(pred_values) - np.min(pred_values)) data = (data - np.min(data)) / (np.max(data) - np.min(data)) plt.plot(data.squeeze()) plt.plot(pred_values.squeeze()) # axes[0].set_title('Original') # axes[1].set_title('Reconstruction') plt.show() # plt.plot(np.arange(len(losses)), losses) # plt.xlabel('Iteration') # plt.ylabel('Reconstruction MSE Error') # plt.show() # local_rank = torch.device('cuda' if torch.cuda.is_available() else 'cpu') current_date = datetime.date.today().strftime("%Y-%m-%d") datetime_dir = f"frugal_{current_date}" args_dir = 'utils/config.yaml' data_args = Container(**yaml.safe_load(open(args_dir, 'r'))['Data']) exp_num = data_args.exp_num model_name = data_args.model_name rt_args = Container(**yaml.safe_load(open(args_dir, 'r'))['RelationalTransformer']) cnn_args = Container(**yaml.safe_load(open(args_dir, 'r'))['CNNEncoder_f']) conformer_args = Container(**yaml.safe_load(open(args_dir, 'r'))['Conformer']) kan_args = Container(**yaml.safe_load(open(args_dir, 'r'))['KAN_INR']) inr_args = Container(**yaml.safe_load(open(args_dir, 'r'))['INR']) if not os.path.exists(f"{data_args.log_dir}/{datetime_dir}"): os.makedirs(f"{data_args.log_dir}/{datetime_dir}") with open("../../logs/token.txt", "r") as f: api_key = f.read() # local_rank, world_size, gpus_per_node = setup() local_rank = torch.device('cuda' if torch.cuda.is_available() else 'cpu') login(api_key) dataset = load_dataset("rfcx/frugalai", streaming=True) train_ds = SplitDataset(AudioINRDataset(FFTDataset(dataset["train"])), is_train=True) train_dl = DataLoader(train_ds, batch_size=data_args.batch_size) val_ds = SplitDataset(AudioINRDataset(FFTDataset(dataset["train"])), is_train=False) val_dl = DataLoader(val_ds, batch_size=data_args.batch_size) test_ds = AudioINRDataset(FFTDataset(dataset["test"])) test_dl = DataLoader(test_ds, batch_size=data_args.batch_size) # for i, batch in enumerate(train_ds): # fft_phase, fft_mag, audio = batch['audio']['fft_phase'], batch['audio']['fft_mag'], batch['audio']['array'] # label = batch['label'] # fig, axes = plt.subplots(nrows=1, ncols=3) # axes = axes.flatten() # axes[0].plot(fft_phase) # axes[1].plot(fft_mag) # axes[2].plot(audio) # fig.suptitle(label) # plt.tight_layout() # plt.show() # if i > 20: # break # model = DualEncoder(model_args, model_args_f, conformer_args) # model = FasterKAN([18000,64,64,16,1]) # model = INR(in_features=1) # model.kan.speed() # model = KanEncoder(kan_args.get_dict()) # model = model.to(local_rank) # state_dict = torch.load(data_args.checkpoint_path, map_location=torch.device('cpu')) # new_state_dict = OrderedDict() # for key, value in state_dict.items(): # if key.startswith('module.'): # key = key[7:] # new_state_dict[key] = value # missing, unexpected = model.load_state_dict(new_state_dict) # model = DDP(model, device_ids=[local_rank], output_device=local_rank) # num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) # print(f"Number of parameters: {num_params}") # # optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) # total_steps = int(data_args.num_epochs) * 1000 # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, # T_max=total_steps, # eta_min=float((5e-4) / 10)) loss_fn = torch.nn.BCEWithLogitsLoss() inr_criterion = torch.nn.MSELoss() for i, batch in enumerate(train_ds): coords, fft, audio = batch['audio']['coords'], batch['audio']['fft_mag'], batch['audio']['array'] coords = coords.to(local_rank) fft = fft.to(local_rank) audio = audio.to(local_rank) # values = torch.cat((audio.unsqueeze(-1), fft.unsqueeze(-1)), dim=-1) model = INR(hidden_features=128, n_layers=4, in_features=1, out_features=1).to(local_rank) # model = FasterKAN(layers_hidden=[1,16,16,1]).to(local_rank) optimizer = torch.optim.Adam([{'params': model.parameters()}], lr=1e-3) pbar = tqdm(range(200)) losses = [] print(coords.shape) for t in pbar: optimizer.zero_grad() pred_values = model(coords.to(local_rank)).float() loss = inr_criterion(pred_values, fft) loss.backward() optimizer.step() pbar.set_description(f'loss: {loss.item()}') losses.append(loss.item()) state_dict = model.state_dict() torch.save(state_dict, 'test') # print(f'Sample {i+offset} label {label} saved in {inr_path}') plot_results(1, i, fft, losses, pred_values) # # exit() # missing, unexpected = model.load_state_dict(torch.load(model_args.checkpoint_path)) # print(f"Missing keys: {missing}") # print(f"Unexpected keys: {unexpected}") layer_layout = [inr_args.in_features] + [inr_args.hidden_features for _ in range(inr_args.n_layers)] + [inr_args.out_features] graph_constructor = OmegaConf.create( { "_target_": "utils.graph_constructor.GraphConstructor", "_recursive_": False, "_convert_": "all", "d_in": 1, "d_edge_in": 1, "zero_out_bias": False, "zero_out_weights": False, "sin_emb": True, "sin_emb_dim": rt_args.d_node, "use_pos_embed": False, "input_layers": 1, "inp_factor": 1, "num_probe_features": 0, "inr_model": None, "stats": None, "sparsify": False, 'sym_edges': False, } ) rt_model = RelationalTransformer(layer_layout=layer_layout, graph_constructor=graph_constructor, **rt_args.get_dict()).to(local_rank) rt_model.proj_out= torch.nn.Identity() multi_graph = MultiGraph(rt_model, cnn_args) implicit_net = INR(**inr_args.get_dict()) model = ImplicitEncoder(implicit_net, multi_graph).to(local_rank) num_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"Number of parameters: {num_parameters}") optimizer = torch.optim.Adam([{'params': model.parameters()}], lr=1e-3) trainer = Trainer(model=model, optimizer=optimizer, criterion=loss_fn, output_dim=1, scaler=None, scheduler=None, train_dataloader=train_dl, val_dataloader=val_dl, device=local_rank, exp_num=datetime_dir, log_path=data_args.log_dir, range_update=None, accumulation_step=1, max_iter=100, exp_name=f"frugal_kan_{exp_num}") fit_res = trainer.fit(num_epochs=100, device=local_rank, early_stopping=10, only_p=False, best='loss', conf=True) output_filename = f'{data_args.log_dir}/{datetime_dir}/{model_name}_frugal_{exp_num}.json' with open(output_filename, "w") as f: json.dump(fit_res, f, indent=2) preds, acc = trainer.predict(test_dl, local_rank) print(f"Accuracy: {acc}")