Spaces:
Runtime error
Runtime error
import os | |
from glob import glob | |
import joblib | |
import numpy as np | |
import torch | |
from sklearn.cluster import MiniBatchKMeans | |
from torch.utils.data import DataLoader | |
from tqdm import tqdm | |
from train import instantiate_from_config | |
class FeatClusterStage(object): | |
def __init__(self, num_clusters=None, cached_kmeans_path=None, feats_dataset_config=None, num_workers=None): | |
if cached_kmeans_path is not None and os.path.exists(cached_kmeans_path): | |
print(f'Precalculated Clusterer already exists, loading from {cached_kmeans_path}') | |
self.clusterer = joblib.load(cached_kmeans_path) | |
elif feats_dataset_config is not None: | |
self.clusterer = self.load_or_precalculate_kmeans(num_clusters, feats_dataset_config, num_workers) | |
else: | |
raise Exception('Neither `feats_dataset_config` nor `cached_kmeans_path` are defined') | |
def eval(self): | |
return self | |
def encode(self, c): | |
# c_quant: cluster centers, c_ind: cluster index | |
B, D, T = c.shape | |
# (B*T, D) <- (B, T, D) <- (B, D, T) | |
c_flat = c.permute(0, 2, 1).view(B*T, D).cpu().numpy() | |
c_ind = self.clusterer.predict(c_flat) | |
c_quant = self.clusterer.cluster_centers_[c_ind] | |
c_ind = torch.from_numpy(c_ind).to(c.device) | |
c_quant = torch.from_numpy(c_quant).to(c.device) | |
c_ind = c_ind.long().unsqueeze(-1) | |
c_quant = c_quant.view(B, T, D).permute(0, 2, 1) | |
info = None, None, c_ind | |
# (B, D, T), (), ((), (768, 1024), (768, 1)) | |
return c_quant, None, info | |
def decode(self, c): | |
return c | |
def get_input(self, batch, k): | |
x = batch[k] | |
x = x.permute(0, 2, 1).to(memory_format=torch.contiguous_format) | |
return x.float() | |
def load_or_precalculate_kmeans(self, num_clusters, dataset_cfg, num_workers): | |
print(f'Calculating clustering K={num_clusters}') | |
batch_size = 64 | |
dataset_name = dataset_cfg.target.split('.')[-1] | |
cached_path = os.path.join('./specvqgan/modules/misc/', f'kmeans_K{num_clusters}_{dataset_name}.sklearn') | |
feat_depth = dataset_cfg.params.condition_dataset_cfg.feat_depth | |
feat_crop_len = dataset_cfg.params.condition_dataset_cfg.feat_crop_len | |
feat_loading_dset = instantiate_from_config(dataset_cfg) | |
feat_loading_dset = DataLoader(feat_loading_dset, batch_size, num_workers=num_workers, shuffle=True) | |
clusterer = MiniBatchKMeans(num_clusters, batch_size=batch_size*feat_crop_len, random_state=0) | |
for item in tqdm(feat_loading_dset): | |
batch = item['feature'].reshape(-1, feat_depth).float().numpy() | |
clusterer.partial_fit(batch) | |
joblib.dump(clusterer, cached_path) | |
print(f'Saved the calculated Clusterer @ {cached_path}') | |
return clusterer | |
if __name__ == '__main__': | |
from omegaconf import OmegaConf | |
config = OmegaConf.load('./configs/vggsound_featcluster_transformer.yaml') | |
config.model.params.first_stage_config.params.ckpt_path = './logs/2021-05-19T22-16-54_vggsound_specs_vqgan/checkpoints/epoch_39.ckpt' | |
model = instantiate_from_config(config.model.params.cond_stage_config) | |
print(model) | |