{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "d179a842", "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": 2, "id": "6b561bd5-2495-47ab-9cdf-786662610840", "metadata": {}, "outputs": [], "source": [ "import torch\n", "from audiocraft.models.loaders import load_compression_model, load_dit_model_melodyflow\n", "import audiocraft.models.builders as builders\n", "from omegaconf import OmegaConf, DictConfig" ] }, { "cell_type": "code", "execution_count": 3, "id": "2ef93dcd", "metadata": {}, "outputs": [], "source": [ "# T18 32k mono\n", "BOX_CHECKPOINT_PATH = \"/Users/glelan/projects/melodyflow/checkpoints/923877216v250.pt\"\n", "CODEC_CHECKPOINT_PATH = \"/Users/glelan/projects/melodyflow/checkpoints/epoch_200_train_step_401200_eval_step_24875.pt\"\n", "DEMO_PATH = \"/Users/glelan/projects/melodyflow/checkpoints/t18_32k_25hz_mono_10s\"" ] }, { "cell_type": "code", "execution_count": 4, "id": "95af0ec5", "metadata": {}, "outputs": [], "source": [ "# T24 48k stereo 25Hz 30s\n", "BOX_CHECKPOINT_PATH = \"/Users/glelan/projects/melodyflow/checkpoints/924369475v250.pt\"\n", "CODEC_CHECKPOINT_PATH = \"/Users/glelan/projects/melodyflow/checkpoints/vae/epoch_200_train_step_401200_eval_step_24875.pt\"\n", "DEMO_PATH = \"/Users/glelan/projects/melodyflow/checkpoints/t24_48k_25hz_stereo_30s\"" ] }, { "cell_type": "code", "execution_count": 5, "id": "f0027692", "metadata": {}, "outputs": [], "source": [ "with open(CODEC_CHECKPOINT_PATH, \"rb\") as fileobj:\n", " codec_checkpoint = torch.load(fileobj, map_location=torch.device('cpu'))" ] }, { "cell_type": "code", "execution_count": 6, "id": "570506c1", "metadata": {}, "outputs": [], "source": [ "del codec_checkpoint['hyper_parameters']['cfg']['disable_cuda_benchmarks']\n", "del codec_checkpoint['hyper_parameters']['cfg']['enable_profiler']\n", "del codec_checkpoint['hyper_parameters']['cfg']['manifold_bucket']\n", "del codec_checkpoint['hyper_parameters']['cfg']['manifold_path']\n", "del codec_checkpoint['hyper_parameters']['cfg']['ckpt_path']\n", "del codec_checkpoint['hyper_parameters']['cfg']['auto_retry_from_last_checkpoint']\n", "del codec_checkpoint['hyper_parameters']['cfg']['checkpoint_ttl']\n", "del codec_checkpoint['hyper_parameters']['cfg']['async_checkpoint']\n", "del codec_checkpoint['hyper_parameters']['cfg']['pg_timeout_in_minutes']\n", "del codec_checkpoint['hyper_parameters']['cfg']['check_preemption_every_n_train_steps']\n", "del codec_checkpoint['hyper_parameters']['cfg']['model_entity_id']\n", "del codec_checkpoint['hyper_parameters']['cfg']['max_eval_steps_per_epoch']\n", "del codec_checkpoint['hyper_parameters']['cfg']['prefetch_factor']\n", "del codec_checkpoint['hyper_parameters']['cfg']['max_audio_size']\n", "del codec_checkpoint['hyper_parameters']['cfg']['conditioners']\n", "del codec_checkpoint['hyper_parameters']['cfg']['data']\n", "del codec_checkpoint['hyper_parameters']['cfg']['trainer']\n", "del codec_checkpoint['hyper_parameters']['cfg']['lfq']\n", "del codec_checkpoint['hyper_parameters']['cfg']['residual_lfq']\n", "del codec_checkpoint['hyper_parameters']['cfg']['seed']\n", "del codec_checkpoint['hyper_parameters']['cfg']['monitor']\n", "del codec_checkpoint['hyper_parameters']['cfg']['generate_every']\n", "del codec_checkpoint['hyper_parameters']['cfg']['epochs']\n", "del codec_checkpoint['hyper_parameters']['cfg']['steps_per_epoch']\n", "del codec_checkpoint['hyper_parameters']['cfg']['batch_size_per_gpu']\n", "del codec_checkpoint['hyper_parameters']['cfg']['num_workers_per_gpu']\n", "del codec_checkpoint['hyper_parameters']['cfg']['save_every_n_epochs']\n", "del codec_checkpoint['hyper_parameters']['cfg']['unit']\n", "del codec_checkpoint['hyper_parameters']['cfg']['lr']\n", "del codec_checkpoint['hyper_parameters']['cfg']['optimizer']\n", "del codec_checkpoint['hyper_parameters']['cfg']['optim']\n", "del codec_checkpoint['hyper_parameters']['cfg']['msstftd']\n", "del codec_checkpoint['hyper_parameters']['cfg']['msd']\n", "del codec_checkpoint['hyper_parameters']['cfg']['mpd']\n", "del codec_checkpoint['hyper_parameters']['cfg']['l1']\n", "del codec_checkpoint['hyper_parameters']['cfg']['l2']\n", "del codec_checkpoint['hyper_parameters']['cfg']['mrstft']\n", "del codec_checkpoint['hyper_parameters']['cfg']['sdstft']\n", "del codec_checkpoint['hyper_parameters']['cfg']['mel']\n", "del codec_checkpoint['hyper_parameters']['cfg']['sisnr']\n", "del codec_checkpoint['hyper_parameters']['cfg']['sisdr']\n", "del codec_checkpoint['hyper_parameters']['cfg']['msspec']\n", "del codec_checkpoint['hyper_parameters']['cfg']['losses']\n", "del codec_checkpoint['hyper_parameters']['cfg']['balancer']\n", "del codec_checkpoint['hyper_parameters']['cfg']['adversarial']" ] }, { "cell_type": "code", "execution_count": 7, "id": "96d0ad5c", "metadata": {}, "outputs": [], "source": [ "with open(BOX_CHECKPOINT_PATH, \"rb\") as fileobj:\n", " dit_checkpoint = torch.load(fileobj, map_location=torch.device('cpu'))" ] }, { "cell_type": "code", "execution_count": 8, "id": "057a5f1d", "metadata": {}, "outputs": [], "source": [ "del dit_checkpoint['hyper_parameters']['cfg']['disable_cuda_benchmarks']\n", "del dit_checkpoint['hyper_parameters']['cfg']['enable_profiler']\n", "del dit_checkpoint['hyper_parameters']['cfg']['manifold_bucket']\n", "del dit_checkpoint['hyper_parameters']['cfg']['manifold_path']\n", "del dit_checkpoint['hyper_parameters']['cfg']['ckpt_path']\n", "del dit_checkpoint['hyper_parameters']['cfg']['auto_retry_from_last_checkpoint']\n", "del dit_checkpoint['hyper_parameters']['cfg']['checkpoint_ttl']\n", "del dit_checkpoint['hyper_parameters']['cfg']['async_checkpoint']\n", "del dit_checkpoint['hyper_parameters']['cfg']['pg_timeout_in_minutes']\n", "del dit_checkpoint['hyper_parameters']['cfg']['check_preemption_every_n_train_steps']\n", "del dit_checkpoint['hyper_parameters']['cfg']['model_entity_id']\n", "del dit_checkpoint['hyper_parameters']['cfg']['max_eval_steps_per_epoch']\n", "del dit_checkpoint['hyper_parameters']['cfg']['prefetch_factor']\n", "del dit_checkpoint['hyper_parameters']['cfg']['max_audio_size']\n", "del dit_checkpoint['hyper_parameters']['cfg']['sample_rate_video']\n", "del dit_checkpoint['hyper_parameters']['cfg']['video_crop']\n", "del dit_checkpoint['hyper_parameters']['cfg']['compression_model_checkpoint']\n", "del dit_checkpoint['hyper_parameters']['cfg']['public_compression']\n", "del dit_checkpoint['hyper_parameters']['cfg']['finetune_cp_path']\n", "del dit_checkpoint['hyper_parameters']['cfg']['clip_grad_norm']\n", "del dit_checkpoint['hyper_parameters']['cfg']['generate']\n", "del dit_checkpoint['hyper_parameters']['cfg']['logging']\n", "del dit_checkpoint['hyper_parameters']['cfg']['schedule']\n", "del dit_checkpoint['hyper_parameters']['cfg']['data']\n", "del dit_checkpoint['hyper_parameters']['cfg']['trainer']\n", "del dit_checkpoint['hyper_parameters']['cfg']['seed']\n", "del dit_checkpoint['hyper_parameters']['cfg']['monitor']\n", "del dit_checkpoint['hyper_parameters']['cfg']['generate_every']\n", "del dit_checkpoint['hyper_parameters']['cfg']['epochs']\n", "del dit_checkpoint['hyper_parameters']['cfg']['steps_per_epoch']\n", "del dit_checkpoint['hyper_parameters']['cfg']['batch_size_per_gpu']\n", "del dit_checkpoint['hyper_parameters']['cfg']['num_workers_per_gpu']\n", "del dit_checkpoint['hyper_parameters']['cfg']['save_every_n_epochs']\n", "del dit_checkpoint['hyper_parameters']['cfg']['unit']\n", "del dit_checkpoint['hyper_parameters']['cfg']['lr']\n", "del dit_checkpoint['hyper_parameters']['cfg']['optim']" ] }, { "cell_type": "code", "execution_count": 9, "id": "df52ded5", "metadata": {}, "outputs": [], "source": [ "\n", "del dit_checkpoint['hyper_parameters']['cfg']['transformer_lm']['spectral_norm_attn_iters']\n", "del dit_checkpoint['hyper_parameters']['cfg']['transformer_lm']['spectral_norm_ff_iters']\n", "del dit_checkpoint['hyper_parameters']['cfg']['transformer_lm']['residual_balancer_attn']\n", "del dit_checkpoint['hyper_parameters']['cfg']['transformer_lm']['residual_balancer_ff']\n", "del dit_checkpoint['hyper_parameters']['cfg']['transformer_lm']['mask_cross_attention']\n", "del dit_checkpoint['hyper_parameters']['cfg']['transformer_lm']['qk_norm']\n", "del dit_checkpoint['hyper_parameters']['cfg']['transformer_lm']['n_q']\n", "del dit_checkpoint['hyper_parameters']['cfg']['transformer_lm']['q_modeling']\n", "del dit_checkpoint['hyper_parameters']['cfg']['transformer_lm']['emb_lr']\n", "del dit_checkpoint['hyper_parameters']['cfg']['transformer_lm']['card']\n", "cfg = OmegaConf.create(dit_checkpoint['hyper_parameters']['cfg'])\n", "cfg.fuser.input_interpolate = []\n", "cfg.dataset = {\"segment_duration\" : 30.0}\n", "cleaned_dit_checkpoint = {'xp.cfg': cfg, \"best_state\": {}}\n", "for key, value in dit_checkpoint[\"state_dict\"].items():\n", " if key.startswith(\"ema.module.\"):\n", " cleaned_dit_checkpoint[\"best_state\"][key[11:]] = value\n", "cleaned_dit_checkpoint[\"best_state\"][\"latent_mean\"] = codec_checkpoint[\"ema_model\"][\"bottleneck_norm.running_mean\"]\n", "cleaned_dit_checkpoint[\"best_state\"][\"latent_std\"] = codec_checkpoint[\"ema_model\"][\"bottleneck_norm.running_std\"]\n", "with open(DEMO_PATH + \"/state_dict.bin\", \"wb\") as fileobj:\n", " torch.save(cleaned_dit_checkpoint, fileobj)\n", "flow_model = load_dit_model_melodyflow(DEMO_PATH)" ] }, { "cell_type": "code", "execution_count": 10, "id": "793188cf", "metadata": {}, "outputs": [], "source": [ "del codec_checkpoint[\"ema_model\"][\"bottleneck_norm.running_mean\"]\n", "del codec_checkpoint[\"ema_model\"][\"bottleneck_norm.running_std\"]\n", "codec_checkpoint['hyper_parameters']['cfg']['encodec']['quantizer'] = \"no_quant\"\n", "cleand_codec_checkpoint = {'xp.cfg': codec_checkpoint['hyper_parameters']['cfg'], 'best_state': codec_checkpoint[\"ema_model\"]}\n", "with open(DEMO_PATH + \"/compression_state_dict.bin\", \"wb\") as fileobj:\n", " torch.save(cleand_codec_checkpoint, fileobj)\n", "compression_model = load_compression_model(DEMO_PATH)" ] }, { "cell_type": "code", "execution_count": null, "id": "3d6d2305", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "d1915c5e", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "melodyflow", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.19" } }, "nbformat": 4, "nbformat_minor": 5 }