Spaces:
Running
Running
File size: 2,174 Bytes
19759e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
from laion_clap import create_model
from laion_clap.training.data import get_data
from laion_clap.training import parse_args
import torch
import os
from tqdm import tqdm
from laion_clap.training.distributed import is_master, world_info_from_env
from laion_clap.utils import dataset_split
def run_dataloader():
for i, batch in enumerate(tqdm(dataloader, total=data["train"].dataloader.num_samples // args.batch_size)):
pass
if __name__ == '__main__':
args = parse_args()
# sanitize model name for filesystem / uri use, easier if we don't use / in name as a rule?
args.amodel = args.amodel.replace("/", "-")
device = torch.device('cpu')
# discover initial world args early so we can log properly
args.distributed = False
args.local_rank, args.rank, args.world_size = world_info_from_env()
if args.remotedata and is_master(args):
for dataset_name in args.datasetnames:
for split in dataset_split[dataset_name]:
if not os.path.exists(f"./json_files/{dataset_name}/{split}"):
os.makedirs(f"./json_files/{dataset_name}/{split}")
os.system(
f"aws s3 cp s3://s-laion-audio/webdataset_tar/{dataset_name}/{split}/sizes.json ./json_files/{dataset_name}/{split}/sizes.json"
)
model, model_cfg = create_model(
args.amodel,
args.tmodel,
args.pretrained,
precision=args.precision,
device=device,
jit=args.torchscript,
force_quick_gelu=args.force_quick_gelu,
openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir),
skip_params=True,
pretrained_audio=args.pretrained_audio,
pretrained_text=args.pretrained_text,
enable_fusion=args.enable_fusion,
fusion_type=args.fusion_type
)
data = get_data(args, model_cfg)
dataloader, sampler = data["train"].dataloader, data["train"].sampler
print('dataset size:', data["train"].dataloader.num_samples)
print('batch size:', args.batch_size)
print('num batches:', data["train"].dataloader.num_samples // args.batch_size)
run_dataloader()
|