diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..1ba0d9d6906589dd605ea0f949f8b302029cfd71 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,27 @@ +# 使用 Python 3.9 作为基础镜像 +FROM python:3.9 + +# 添加用户 +RUN useradd -m -u 1000 user + +# 设置工作目录 +WORKDIR /app + +# 切换到 root 用户以安装系统依赖 +USER root +RUN apt-get update && apt-get install -y rubberband-cli + +# 切回到普通用户 +USER user +ENV PATH="/home/user/.local/bin:$PATH" + +# 复制并安装 Python 依赖 +COPY --chown=user ./requirements.txt requirements.txt +RUN pip install --no-cache-dir --upgrade -r requirements.txt + +# 复制应用文件 +COPY --chown=user . /app + +# 启动 Gradio 应用,假设应用入口文件为 app.py +# CMD ["python", "app.py"] +CMD ["python", "app_chat.py"] diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..82e60e026839ea73cb3699af82e2805d2ef7ff2c --- /dev/null +++ b/app.py @@ -0,0 +1,107 @@ +import torch +import numpy as np +from tqdm import tqdm +from model.DiffSynthSampler import DiffSynthSampler +import soundfile as sf +# import pyrubberband as pyrb +from tqdm import tqdm +from model.VQGAN import get_VQGAN +from model.diffusion import get_diffusion_model +from transformers import AutoTokenizer, ClapModel +from model.diffusion_components import linear_beta_schedule +from model.timbre_encoder_pretrain import get_timbre_encoder +from model.multimodal_model import get_multi_modal_model + + + +import gradio as gr +from webUI.natural_language_guided.gradio_webUI import GradioWebUI +from webUI.natural_language_guided.text2sound import get_text2sound_module +from webUI.natural_language_guided.sound2sound_with_text import get_sound2sound_with_text_module +from webUI.natural_language_guided.inpaint_with_text import get_inpaint_with_text_module +from webUI.natural_language_guided.build_instrument import get_build_instrument_module +from webUI.natural_language_guided.README import get_readme_module + + + +device = "cuda" if torch.cuda.is_available() else "cpu" +use_pretrained_CLAP = False + +# load VQ-GAN +VAE_model_name = "24_1_2024-52_4x_L_D" +modelConfig = {"in_channels": 3, "hidden_channels": [80, 160], "embedding_dim": 4, "out_channels": 3, "block_depth": 2, + "attn_pos": [80, 160], "attn_with_skip": True, + "num_embeddings": 8192, "commitment_cost": 0.25, "decay": 0.99, + "norm_type": "groupnorm", "act_type": "swish", "num_groups": 16} +VAE = get_VQGAN(modelConfig, load_pretrain=True, model_name=VAE_model_name, device=device) + +# load U-Net +UNet_model_name = "history/28_1_2024_CLAP_STFT_180000" if use_pretrained_CLAP else "history/28_1_2024_TE_STFT_300000" +unetConfig = {"in_dim": 4, "down_dims": [96, 96, 192, 384], "up_dims": [384, 384, 192, 96], "attn_type": "linear_add", "condition_type": "natural_language_prompt", "label_emb_dim": 512} +uNet = get_diffusion_model(unetConfig, load_pretrain=True, model_name=UNet_model_name, device=device) + +# load LM +CLAP_temp = ClapModel.from_pretrained("laion/clap-htsat-unfused") # 153,492,890 +CLAP_tokenizer = AutoTokenizer.from_pretrained("laion/clap-htsat-unfused") + +timbre_encoder_name = "24_1_2024_STFT" +timbre_encoder_Config = {"input_dim": 512, "feature_dim": 512, "hidden_dim": 1024, "num_instrument_classes": 1006, "num_instrument_family_classes": 11, "num_velocity_classes": 128, "num_qualities": 10, "num_layers": 3} +timbre_encoder = get_timbre_encoder(timbre_encoder_Config, load_pretrain=True, model_name=timbre_encoder_name, device=device) + +if use_pretrained_CLAP: + text_encoder = CLAP_temp +else: + multimodalmodel_name = "24_1_2024" + multimodalmodel_config = {"text_feature_dim": 512, "spectrogram_feature_dim": 1024, "multi_modal_emb_dim": 512, "num_projection_layers": 2, + "temperature": 1.0, "dropout": 0.1, "freeze_text_encoder": False, "freeze_spectrogram_encoder": False} + mmm = get_multi_modal_model(timbre_encoder, CLAP_temp, multimodalmodel_config, load_pretrain=True, model_name=multimodalmodel_name, device=device) + + text_encoder = mmm.to("cpu") + + + + + + +gradioWebUI = GradioWebUI(device, VAE, uNet, text_encoder, CLAP_tokenizer, freq_resolution=512, time_resolution=256, channels=4, timesteps=1000, squared=False, + VAE_scale=4, flexible_duration=True, noise_strategy="repeat", GAN_generator=None) + +with gr.Blocks(theme=gr.themes.Soft(), mode="dark") as demo: +# with gr.Blocks(theme='WeixuanYuan/Soft_dark', mode="dark") as demo: + gr.Markdown("DiffuSynth v0.2") + + reconstruction_state = gr.State(value={}) + text2sound_state = gr.State(value={}) + sound2sound_state = gr.State(value={}) + inpaint_state = gr.State(value={}) + super_resolution_state = gr.State(value={}) + virtual_instruments_state = gr.State(value={"virtual_instruments": {}}) + + get_text2sound_module(gradioWebUI, text2sound_state, virtual_instruments_state) + get_sound2sound_with_text_module(gradioWebUI, sound2sound_state, virtual_instruments_state) + get_inpaint_with_text_module(gradioWebUI, inpaint_state, virtual_instruments_state) + get_build_instrument_module(gradioWebUI, virtual_instruments_state) + get_readme_module() + +demo.launch(debug=True, share=True) + + + + + + + + + + + + + + + + + + + + + diff --git a/app_chat.py b/app_chat.py new file mode 100644 index 0000000000000000000000000000000000000000..04cc31aa8d0e06aeaac3b59bb361ed71d831e43f --- /dev/null +++ b/app_chat.py @@ -0,0 +1,7 @@ +import gradio as gr + +def greet(name): + return "Hello " + name + "!!" + +demo = gr.Interface(fn=greet, inputs="text", outputs="text") +demo.launch() diff --git a/metrics/FD.py b/metrics/FD.py new file mode 100644 index 0000000000000000000000000000000000000000..34072570cf64e78527d46c22de099fa52fdd9a09 --- /dev/null +++ b/metrics/FD.py @@ -0,0 +1,293 @@ +import json +import os + +import librosa +import numpy as np +import torch +from tqdm import tqdm +from scipy.linalg import sqrtm + +from metrics.pipelines import sample_pipeline, sample_pipeline_GAN +from metrics.pipelines_STFT import sample_pipeline_STFT, sample_pipeline_GAN_STFT +from tools import rms_normalize + + +def ASTaudio2feature(device, signal, processor, AST, sampling_rate): + # audio file is decoded on the fly + inputs = processor(signal, sampling_rate=sampling_rate, return_tensors="pt").to(device) + with torch.no_grad(): + outputs = AST(**inputs) + + last_hidden_states = outputs.last_hidden_state[:, 0, :].to("cpu").detach().numpy() + return last_hidden_states + + +# 计算两个numpy数组的均值和协方差矩阵 +def calculate_statistics(features): + mu = np.mean(features, axis=0) + sigma = np.cov(features, rowvar=False) + return mu, sigma + + +# 计算FID +def calculate_fid(mu1, sigma1, mu2, sigma2, eps=1e-6): + # 在协方差矩阵对角线上添加一个小的正值 + sigma1 += np.eye(sigma1.shape[0]) * eps + sigma2 += np.eye(sigma2.shape[0]) * eps + + ssdiff = np.sum((mu1 - mu2) ** 2.0) + covmean = sqrtm(sigma1.dot(sigma2)) + + # 由于数值问题,有时可能会得到复数,只取实部 + if np.iscomplexobj(covmean): + covmean = covmean.real + + fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean) + return fid + + +# 计算FID +def calculate_fid_dict(dict1, dict2, eps=1e-6): + # 在协方差矩阵对角线上添加一个小的正值 + mu1, sigma1 = dict1["mu"], dict1["sigma"] + mu2, sigma2 = dict2["mu"], dict2["sigma"] + sigma1 += np.eye(sigma1.shape[0]) * eps + sigma2 += np.eye(sigma2.shape[0]) * eps + + ssdiff = np.sum((mu1 - mu2) ** 2.0) + covmean = sqrtm(sigma1.dot(sigma2)) + + # 由于数值问题,有时可能会得到复数,只取实部 + if np.iscomplexobj(covmean): + covmean = covmean.real + + fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean) + return fid + + +# Todo: AudioLDM +# def generate_features_with_AudioLDM_and_AST(device, processor, AST, AudioLDM_signals_directory_path, return_feature=False): + +# diffuSynth_features = [] + +# # Step 1: Load all wav files in AudioLDM_signals_directory_path +# AudioLDM_signals = [] +# signal_lengths = set() + +# for file_name in os.listdir(AudioLDM_signals_directory_path): +# if file_name.endswith('.wav'): +# file_path = os.path.join(AudioLDM_signals_directory_path, file_name) +# signal, sr = librosa.load(file_path, sr=16000) # Load audio file with sampling rate 16000 +# # Normalize +# AudioLDM_signals.append(rms_normalize(signal)) +# signal_lengths.add(len(signal)) + +# # Step 2: Check if all signals have the same length +# if len(signal_lengths) != 1: +# raise ValueError("Not all signals have the same length. Please ensure all audio files are of the same length.") + +# # Step 3: Reshape to signal_batches [number_batches, batch_size=8, signal_length] +# batch_size = 8 +# signal_length = signal_lengths.pop() # All lengths are the same, get one of them + +# # Create batches +# signal_batches = [AudioLDM_signals[i:i + batch_size] for i in range(0, len(AudioLDM_signals), batch_size)] + +# for signal_batch in tqdm(signal_batches): + +# features = ASTaudio2feature(device, signal_batch, processor, AST, sampling_rate=16000) +# diffuSynth_features.extend(features) + +# if return_feature: +# return diffuSynth_features +# else: +# mu, sigma = calculate_statistics(diffuSynth_features) +# return {"mu": mu, "sigma": sigma} + +def generate_features_with_AudioLDM_and_AST(device, processor, AST, AudioLDM_signals_directory_path, return_feature=False): + + diffuSynth_features = [] + + # Step 1: Load all wav files in AudioLDM_signals_directory_path + AudioLDM_signals = [] + signal_lengths = set() + target_length = 4 * 16000 # 4 seconds * 16000 samples per second + + for file_name in os.listdir(AudioLDM_signals_directory_path): + if file_name.endswith('.wav') and not file_name.startswith('._'): + file_path = os.path.join(AudioLDM_signals_directory_path, file_name) + try: + signal, sr = librosa.load(file_path, sr=16000) # Load audio file with sampling rate 16000 + if len(signal) >= target_length: + signal = signal[:target_length] # Take only the first 4 seconds + else: + raise ValueError(f"The file {file_name} is shorter than 4 seconds.") + # Normalize + AudioLDM_signals.append(rms_normalize(signal)) + signal_lengths.add(len(signal)) + except Exception as e: + print(f"Error loading {file_name}: {e}") + + # Step 2: Check if all signals have the same length + if len(signal_lengths) != 1: + raise ValueError("Not all signals have the same length. Please ensure all audio files are of the same length.") + + # Step 3: Reshape to signal_batches [number_batches, batch_size=8, signal_length] + batch_size = 8 + signal_length = signal_lengths.pop() # All lengths are the same, get one of them + + # Create batches + signal_batches = [AudioLDM_signals[i:i + batch_size] for i in range(0, len(AudioLDM_signals), batch_size)] + + for signal_batch in tqdm(signal_batches): + features = ASTaudio2feature(device, signal_batch, processor, AST, sampling_rate=16000) + diffuSynth_features.extend(features) + + if return_feature: + return diffuSynth_features + else: + mu, sigma = calculate_statistics(diffuSynth_features) + return {"mu": mu, "sigma": sigma} + + + + +def generate_features_with_diffuSynth_and_AST(device, uNet, VAE, mmm, CLAP_tokenizer, processor, AST, num_batches, + positive_prompts, negative_prompts="", CFG=1, sample_steps=10, task="spectrograms", return_feature=False): + diffuSynth_features = [] + + if task == "spectrograms": + pipe = sample_pipeline + elif task == "STFT": + pipe = sample_pipeline_STFT + else: + raise NotImplementedError + + for _ in tqdm(range(num_batches)): + quantized_latent_representations, reconstruction_batch, signals = pipe(device, uNet, VAE, mmm, + CLAP_tokenizer, + positive_prompts=positive_prompts, + negative_prompts=negative_prompts, + batchsize=8, + sample_steps=sample_steps, + CFG=CFG, seed=None, + return_latent=False) + + features = ASTaudio2feature(device, signals, processor, AST, sampling_rate=16000) + diffuSynth_features.extend(features) + + if return_feature: + return diffuSynth_features + else: + mu, sigma = calculate_statistics(diffuSynth_features) + return {"mu": mu, "sigma": sigma} + + +def generate_features_with_GAN_and_AST(device, gan_generator, VAE, mmm, CLAP_tokenizer, processor, AST, num_batches, + positive_prompts, negative_prompts="", CFG=1, sample_steps=10, task="spectrograms", return_feature=False): + diffuSynth_features = [] + + if task == "spectrograms": + pipe = sample_pipeline_GAN + elif task == "STFT": + pipe = sample_pipeline_GAN_STFT + else: + raise NotImplementedError + + for _ in tqdm(range(num_batches)): + quantized_latent_representations, reconstruction_batch, signals = pipe(device, gan_generator, VAE, mmm, + CLAP_tokenizer, + positive_prompts=positive_prompts, + negative_prompts=negative_prompts, + batchsize=8, + sample_steps=sample_steps, + CFG=CFG, seed=None, + return_latent=False) + + features = ASTaudio2feature(device, signals, processor, AST, sampling_rate=16000) + diffuSynth_features.extend(features) + + if return_feature: + return diffuSynth_features + else: + mu, sigma = calculate_statistics(diffuSynth_features) + return {"mu": mu, "sigma": sigma} + + +def get_FD(train_features, device, uNet, VAE, mmm, CLAP_tokenizer, processor, AST, num_batches, positive_prompts, + negative_prompts="", CFG=1, sample_steps=10): + diffuSynth_features = generate_features_with_diffuSynth_and_AST(device, uNet, VAE, mmm, CLAP_tokenizer, processor, + AST, num_batches, positive_prompts, + negative_prompts=negative_prompts, CFG=CFG, + sample_steps=sample_steps) + + mu_real, sigma_real = calculate_statistics(train_features) + mu_gen, sigma_gen = calculate_statistics(diffuSynth_features) + + fid_score = calculate_fid(mu_real, sigma_real, mu_gen, sigma_gen) + print('FID score:', fid_score) + + +def get_fid_score(feature1, features2): + mu_real, sigma_real = calculate_statistics(feature1) + mu_gen, sigma_gen = calculate_statistics(features2) + + fid_score = calculate_fid(mu_real, sigma_real, mu_gen, sigma_gen) + # print('FID score:', fid_score) + return fid_score + + +def calculate_fid_matrix(features_list_1, features_list_2, get_fid_score): + # 初始化一个矩阵来存储FID分数 + # 矩阵的大小为 len(features_list_1) x len(features_list_2) + fid_scores = [[0 for _ in range(len(features_list_2))] for _ in range(len(features_list_1))] + + # 遍历两个列表,并计算每一对特征集合的FID分数 + for i, feature1 in enumerate(features_list_1): + for j, feature2 in enumerate(features_list_2): + fid_scores[i][j] = get_fid_score(feature1, feature2) + + return fid_scores + + +def save_AST_feature(key, mu, sigma, path='results/AST_metric/pre_calculated_features/AST_features.json'): + # 尝试打开并读取现有的JSON文件 + try: + with open(path, 'r') as file: + data = json.load(file) + except FileNotFoundError: + # 如果文件不存在,创建一个新的字典 + data = {} + + if isinstance(mu, np.ndarray): + mu = mu.tolist() + if isinstance(sigma, np.ndarray): + sigma = sigma.tolist() + + # 添加新数据 + data[key] = {"mu": mu, "sigma": sigma} + + # 将更新后的数据写回文件 + with open(path, 'w') as file: + json.dump(data, file, indent=4) + + +def read_AST_features(path='results/AST_metric/pre_calculated_features/AST_features.json'): + try: + # 尝试打开并读取JSON文件 + with open(path, 'r') as file: + AST_features = json.load(file) + + for AST_feature_name in AST_features.keys(): + AST_features[AST_feature_name]["mu"] = np.array(AST_features[AST_feature_name]["mu"]) + AST_features[AST_feature_name]["sigma"] = np.array(AST_features[AST_feature_name]["sigma"]) + + return AST_features + except FileNotFoundError: + # 如果文件不存在,返回一个空字典 + print(f"文件 {path} 未找到.") + return {} + except json.JSONDecodeError: + # 如果文件不是有效的JSON,返回一个空字典 + print(f"文件 {path} 不是有效的JSON格式.") + return {} \ No newline at end of file diff --git a/metrics/IS.py b/metrics/IS.py new file mode 100644 index 0000000000000000000000000000000000000000..ab5dfea4e855b93f0f8856cd7e7f44cfaf1e6a7b --- /dev/null +++ b/metrics/IS.py @@ -0,0 +1,218 @@ +import os + +import librosa +import numpy as np +import torch +from tqdm import tqdm + +from metrics.pipelines import sample_pipeline, inpaint_pipeline, sample_pipeline_GAN +from metrics.pipelines_STFT import sample_pipeline_STFT, sample_pipeline_GAN_STFT +from tools import rms_normalize, pad_STFT, encode_stft +from webUI.natural_language_guided.utils import InputBatch2Encode_STFT + +def get_inception_score_for_AudioLDM(device, timbre_encoder, VAE, AudioLDM_signals_directory_path): + VAE_encoder, VAE_quantizer, VAE_decoder = VAE._encoder, VAE._vq_vae, VAE._decoder + + diffuSynth_probabilities = [] + + # Step 1: Load all wav files in AudioLDM_signals_directory_path + AudioLDM_signals = [] + signal_lengths = set() + target_length = 4 * 16000 # 4 seconds * 16000 samples per second + + for file_name in os.listdir(AudioLDM_signals_directory_path): + if file_name.endswith('.wav') and not file_name.startswith('._'): + file_path = os.path.join(AudioLDM_signals_directory_path, file_name) + signal, sr = librosa.load(file_path, sr=16000) # Load audio file with sampling rate 16000 + if len(signal) >= target_length: + signal = signal[:target_length] # Take only the first 4 seconds + else: + raise ValueError(f"The file {file_name} is shorter than 4 seconds.") + # Normalize + AudioLDM_signals.append(rms_normalize(signal)) + signal_lengths.add(len(signal)) + + # Step 2: Check if all signals have the same length + if len(signal_lengths) != 1: + raise ValueError("Not all signals have the same length. Please ensure all audio files are of the same length.") + + encoded_audios = [] + for origin_audio in AudioLDM_signals: + D = librosa.stft(origin_audio, n_fft=1024, hop_length=256, win_length=1024) + padded_D = pad_STFT(D) + encoded_D = encode_stft(padded_D) + encoded_audios.append(encoded_D) + encoded_audios_np = np.array(encoded_audios) + origin_spectrogram_batch_tensor = torch.from_numpy(encoded_audios_np).float().to(device) + + # Step 3: Reshape to signal_batches [number_batches, batch_size=8, signal_length] + batch_size = 8 + num_batches = int(np.ceil(origin_spectrogram_batch_tensor.shape[0] / batch_size)) + spectrogram_batches = [] + for i in range(num_batches): + batch = origin_spectrogram_batch_tensor[i * batch_size:(i + 1) * batch_size] + spectrogram_batches.append(batch) + + for spectrogram_batch in tqdm(spectrogram_batches): + spectrogram_batch = spectrogram_batch.to(device) + _, _, _, _, quantized_latent_representations = InputBatch2Encode_STFT(VAE_encoder, spectrogram_batch, quantizer=VAE_quantizer, squared=False) + quantized_latent_representations = quantized_latent_representations + feature, instrument_logits, instrument_family_logits, velocity_logits, qualities = timbre_encoder(quantized_latent_representations) + probabilities = torch.nn.functional.softmax(instrument_logits, dim=1) + + diffuSynth_probabilities.extend(probabilities.to("cpu").detach().numpy()) + + return inception_score(np.array(diffuSynth_probabilities)) + + +# def get_inception_score_for_AudioLDM(device, timbre_encoder, VAE, AudioLDM_signals_directory_path): +# VAE_encoder, VAE_quantizer, VAE_decoder = VAE._encoder, VAE._vq_vae, VAE._decoder +# +# diffuSynth_probabilities = [] +# +# # Step 1: Load all wav files in AudioLDM_signals_directory_path +# AudioLDM_signals = [] +# signal_lengths = set() +# +# for file_name in os.listdir(AudioLDM_signals_directory_path): +# if file_name.endswith('.wav'): +# file_path = os.path.join(AudioLDM_signals_directory_path, file_name) +# signal, sr = librosa.load(file_path, sr=16000) # Load audio file with sampling rate 16000 +# # Normalize +# AudioLDM_signals.append(rms_normalize(signal)) +# signal_lengths.add(len(signal)) +# +# # Step 2: Check if all signals have the same length +# if len(signal_lengths) != 1: +# raise ValueError("Not all signals have the same length. Please ensure all audio files are of the same length.") +# +# encoded_audios = [] +# for origin_audio in AudioLDM_signals: +# D = librosa.stft(origin_audio, n_fft=1024, hop_length=256, win_length=1024) +# padded_D = pad_STFT(D) +# encoded_D = encode_stft(padded_D) +# encoded_audios.append(encoded_D) +# encoded_audios_np = np.array(encoded_audios) +# origin_spectrogram_batch_tensor = torch.from_numpy(encoded_audios_np).float().to(device) +# +# +# # Step 3: Reshape to signal_batches [number_batches, batch_size=8, signal_length] +# batch_size = 8 +# num_batches = int(np.ceil(origin_spectrogram_batch_tensor.shape[0] / batch_size)) +# spectrogram_batches = [] +# for i in range(num_batches): +# batch = origin_spectrogram_batch_tensor[i * batch_size:(i + 1) * batch_size] +# spectrogram_batches.append(batch) +# +# +# for spectrogram_batch in tqdm(spectrogram_batches): +# spectrogram_batch = spectrogram_batch.to(device) +# _, _, _, _, quantized_latent_representations = InputBatch2Encode_STFT(VAE_encoder, spectrogram_batch, quantizer=VAE_quantizer,squared=False) +# quantized_latent_representations = quantized_latent_representations +# feature, instrument_logits, instrument_family_logits, velocity_logits, qualities = timbre_encoder(quantized_latent_representations) +# probabilities = torch.nn.functional.softmax(instrument_logits, dim=1) +# +# diffuSynth_probabilities.extend(probabilities.to("cpu").detach().numpy()) +# +# return inception_score(np.array(diffuSynth_probabilities)) + + +def get_inception_score(device, uNet, VAE, MMM, CLAP_tokenizer, timbre_encoder, num_batches, positive_prompts, negative_prompts="", CFG=1, sample_steps=10, task="spectrograms"): + diffuSynth_probabilities = [] + + if task == "spectrograms": + pipe = sample_pipeline + elif task == "STFT": + pipe = sample_pipeline_STFT + else: + raise NotImplementedError + + for _ in tqdm(range(num_batches)): + quantized_latent_representations = pipe(device, uNet, VAE, MMM, CLAP_tokenizer, + positive_prompts=positive_prompts, negative_prompts=negative_prompts, + batchsize=8, sample_steps=sample_steps, CFG=CFG, seed=None) + + quantized_latent_representations = quantized_latent_representations.to(device) + feature, instrument_logits, instrument_family_logits, velocity_logits, qualities = timbre_encoder(quantized_latent_representations) + probabilities = torch.nn.functional.softmax(instrument_logits, dim=1) + + diffuSynth_probabilities.extend(probabilities.to("cpu").detach().numpy()) + + return inception_score(np.array(diffuSynth_probabilities)) + + +def get_inception_score_GAN(device, gan_generator, VAE, MMM, CLAP_tokenizer, timbre_encoder, num_batches, positive_prompts, negative_prompts="", CFG=1, sample_steps=10, task="spectrograms"): + diffuSynth_probabilities = [] + + if task == "spectrograms": + pipe = sample_pipeline_GAN + elif task == "STFT": + pipe = sample_pipeline_GAN_STFT + else: + raise NotImplementedError + + for _ in tqdm(range(num_batches)): + quantized_latent_representations = pipe(device, gan_generator, VAE, MMM, CLAP_tokenizer, + positive_prompts=positive_prompts, negative_prompts=negative_prompts, + batchsize=8, sample_steps=sample_steps, CFG=CFG, seed=None) + + quantized_latent_representations = quantized_latent_representations.to(device) + feature, instrument_logits, instrument_family_logits, velocity_logits, qualities = timbre_encoder(quantized_latent_representations) + probabilities = torch.nn.functional.softmax(instrument_logits, dim=1) + + diffuSynth_probabilities.extend(probabilities.to("cpu").detach().numpy()) + + return inception_score(np.array(diffuSynth_probabilities)) + + +def predict_qualities_with_diffuSynth_sample(device, uNet, VAE, MMM, CLAP_tokenizer, timbre_encoder, num_batches, positive_prompts, negative_prompts="", CFG=6, sample_steps=10): + diffuSynth_qualities = [] + for _ in tqdm(range(num_batches)): + quantized_latent_representations = sample_pipeline(device, uNet, VAE, MMM, CLAP_tokenizer, + positive_prompts=positive_prompts, negative_prompts=negative_prompts, + batchsize=8, sample_steps=sample_steps, CFG=CFG, seed=None) + + quantized_latent_representations = quantized_latent_representations.to(device) + feature, instrument_logits, instrument_family_logits, velocity_logits, qualities = timbre_encoder(quantized_latent_representations) + qualities = qualities.to("cpu").detach().numpy() + # qualities = np.where(qualities > 0.5, 1, 0) + + diffuSynth_qualities.extend(qualities) + + return np.mean(diffuSynth_qualities, axis=0) + + +def generate_probabilities_with_diffuSynth_inpaint(device, uNet, VAE, MMM, CLAP_tokenizer, timbre_encoder, num_batches, guidance, duration, use_dynamic_mask, noising_strength, positive_prompts, negative_prompts="", CFG=6, sample_steps=10): + + inpaint_probabilities, signals = [], [] + for _ in tqdm(range(num_batches)): + quantized_latent_representations, _, rec_signals = inpaint_pipeline(device, uNet, VAE, MMM, CLAP_tokenizer, + use_dynamic_mask=use_dynamic_mask, noising_strength=noising_strength, guidance=guidance, + positive_prompts=positive_prompts, negative_prompts=negative_prompts, batchsize=8, sample_steps=sample_steps, CFG=CFG, seed=None, duration=duration, mask_flexivity=0.999, + return_latent=False) + + quantized_latent_representations = quantized_latent_representations.to(device) + feature, instrument_logits, instrument_family_logits, velocity_logits, qualities = timbre_encoder(quantized_latent_representations) + probabilities = torch.nn.functional.softmax(instrument_logits, dim=1) + + inpaint_probabilities.extend(probabilities.to("cpu").detach().numpy()) + signals.extend(rec_signals) + + return np.array(inpaint_probabilities), signals + + +def inception_score(pred): + + # 计算每个图像的条件概率分布 P(y|x) + pyx = pred / np.sum(pred, axis=1, keepdims=True) + + # 计算整个数据集的边缘概率分布 P(y) + py = np.mean(pyx, axis=0, keepdims=True) + + # 计算KL散度 + kl_div = pyx * (np.log(pyx + 1e-11) - np.log(py + 1e-11)) + + # 对所有图像求和并平均 + kl_div_sum = np.sum(kl_div, axis=1) + score = np.exp(np.mean(kl_div_sum)) + return score \ No newline at end of file diff --git a/metrics/P_C_T.py b/metrics/P_C_T.py new file mode 100644 index 0000000000000000000000000000000000000000..d11b4691dd9b98da6db1c5989c134fd57d55c899 --- /dev/null +++ b/metrics/P_C_T.py @@ -0,0 +1,12 @@ +import numpy as np +from metrics.precision_recall import knn_precision_recall_features + + +# 生成样本 +real_features = np.random.normal(0, 1, size=(1600, 512)) +generated_features = np.random.normal(0, 1, size=(1600, 512)) + +state = knn_precision_recall_features(real_features, generated_features, nhood_sizes=[1, 2, 3, 4, 5, 10], + row_batch_size=16, col_batch_size=16) + +print(state) \ No newline at end of file diff --git a/metrics/get_reference_AST_features.py b/metrics/get_reference_AST_features.py new file mode 100644 index 0000000000000000000000000000000000000000..bb0b648b70d8af3cf738cd9573a3121fbbc5d446 --- /dev/null +++ b/metrics/get_reference_AST_features.py @@ -0,0 +1,63 @@ +import json +import librosa +import numpy as np +from tqdm import tqdm +from metrics.FD import ASTaudio2feature, calculate_statistics, save_AST_feature +from tools import rms_normalize +from transformers import AutoProcessor, ASTModel + +device = "cpu" +processor = AutoProcessor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593") +AST = ASTModel.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593").to(device) + + +data_split = "train" +with open(f'data/NSynth/{data_split}_examples.json') as f: + data = json.load(f) + +def read_signal(note_str): + y, sr = librosa.load(f"data/NSynth/nsynth-{data_split}-52/audio/{note_str}.wav", sr=16000) + if len(y) >= 64000: + y = y[:64000] + else: + y_extend = [0.0] * 64000 + y_extend[:len(y)] = y + y = y_extend + + return rms_normalize(y) + +for quality in ["bright", "dark", "distortion", "fast_decay", "long_release", "multiphonic", "nonlinear_env", "percussive", "reverb", "tempo-synced"]: + features = [] + for i, (note_str, attributes) in tqdm(enumerate(data.items())): + if not attributes["pitch"] == 52: + continue + if not (quality in attributes['qualities_str']): + continue + + signal = read_signal(note_str) + feature_for_one_signal = ASTaudio2feature(device, [signal], processor, AST, sampling_rate=16000)[0] + features.append(feature_for_one_signal) + + mu, sigma = calculate_statistics(features) + print(np.shape(mu)) + print(np.shape(sigma)) + + save_AST_feature(f'{data_split}_{quality}', mu.tolist(), sigma.tolist()) + +for instrument_name in ["bass", "brass", "flute", "guitar", "keyboard", "mallet", "organ", "reed", "string", "synth_lead", "vocal"]: + features = [] + for i, (note_str, attributes) in tqdm(enumerate(data.items())): + if not attributes["pitch"] == 52: + continue + if not (attributes["instrument_family_str"] == instrument_name): + continue + + signal = read_signal(note_str) + feature_for_one_signal = ASTaudio2feature(device, [signal], processor, AST, sampling_rate=16000)[0] + features.append(feature_for_one_signal) + + mu, sigma = calculate_statistics(features) + print(np.shape(mu)) + print(np.shape(sigma)) + + save_AST_feature(f'{data_split}_{instrument_name}', mu.tolist(), sigma.tolist()) \ No newline at end of file diff --git a/metrics/pipelines.py b/metrics/pipelines.py new file mode 100644 index 0000000000000000000000000000000000000000..cfc6afa03fb5f92a7c9d99825ee3ef4ff3488087 --- /dev/null +++ b/metrics/pipelines.py @@ -0,0 +1,144 @@ +import librosa +import numpy as np +import torch +from tqdm import tqdm + +from tools import VAE_out_put_to_spc, rms_normalize, nnData2Audio +from model.DiffSynthSampler import DiffSynthSampler + +def sample_pipeline(device, uNet, VAE, MMM, CLAP_tokenizer, + positive_prompts, negative_prompts, batchsize, sample_steps, CFG, seed=None, duration=3.0, + freq_resolution=512, time_resolution=256, channels=4, VAE_scale=4, timesteps=1000, noise_strategy="repeat", sampler="ddim", return_latent=True): + + height = int(freq_resolution/VAE_scale) + width = int(time_resolution/VAE_scale) + VAE_encoder, VAE_quantizer, VAE_decoder = VAE._encoder, VAE._vq_vae, VAE._decoder + + text2sound_embedding = \ + MMM.get_text_features(**CLAP_tokenizer([positive_prompts], padding=True, return_tensors="pt"))[0].to(device) + negative_condition = \ + MMM.get_text_features(**CLAP_tokenizer([negative_prompts], padding=True, return_tensors="pt"))[0].to(device) + + mySampler = DiffSynthSampler(timesteps, height=height, channels=channels, noise_strategy=noise_strategy, mute=True) + mySampler.activate_classifier_free_guidance(CFG, negative_condition) + + mySampler.respace(list(np.linspace(0, timesteps - 1, sample_steps, dtype=np.int32))) + + condition = text2sound_embedding.repeat(batchsize, 1) + + latent_representations, initial_noise = \ + mySampler.sample(model=uNet, shape=(batchsize, channels, height, width), seed=seed, + return_tensor=True, condition=condition, sampler=sampler) + + latent_representations = latent_representations[-1] + + quantized_latent_representations, _, (_, _, _) = VAE_quantizer(latent_representations) + + if return_latent: + return quantized_latent_representations.detach() + reconstruction_batch = VAE_decoder(quantized_latent_representations).to("cpu").detach().numpy() + time_resolution = int(time_resolution * ((duration+1) / 4)) + + rec_signals = nnData2Audio(reconstruction_batch, resolution=(freq_resolution, time_resolution)) + rec_signals = [rms_normalize(rec_signal) for rec_signal in rec_signals] + + return quantized_latent_representations.detach(), reconstruction_batch, rec_signals + +def sample_pipeline_GAN(device, gan_generator, VAE, MMM, CLAP_tokenizer, + positive_prompts, negative_prompts, batchsize, sample_steps, CFG, seed=None, duration=3.0, + freq_resolution=512, time_resolution=256, channels=4, VAE_scale=4, timesteps=1000, noise_strategy="repeat", sampler="ddim", return_latent=True): + + height = int(freq_resolution/VAE_scale) + width = int(time_resolution/VAE_scale) + VAE_encoder, VAE_quantizer, VAE_decoder = VAE._encoder, VAE._vq_vae, VAE._decoder + + text2sound_embedding = \ + MMM.get_text_features(**CLAP_tokenizer([positive_prompts], padding=True, return_tensors="pt"))[0].to(device) + + condition = text2sound_embedding.repeat(batchsize, 1) + + noise = torch.randn(batchsize, channels, height, width).to(device) + latent_representations = gan_generator(noise, condition) + + quantized_latent_representations, _, (_, _, _) = VAE_quantizer(latent_representations) + + if return_latent: + return quantized_latent_representations.detach() + reconstruction_batch = VAE_decoder(quantized_latent_representations).to("cpu").detach().numpy() + time_resolution = int(time_resolution * ((duration+1) / 4)) + + rec_signals = nnData2Audio(reconstruction_batch, resolution=(freq_resolution, time_resolution)) + rec_signals = [rms_normalize(rec_signal) for rec_signal in rec_signals] + + return quantized_latent_representations.detach(), reconstruction_batch, rec_signals + +def inpaint_pipeline(device, uNet, VAE, MMM, CLAP_tokenizer, use_dynamic_mask, noising_strength, guidance, + positive_prompts, negative_prompts, batchsize, sample_steps, CFG, seed=None, duration=3.0, mask_flexivity=0.99, + freq_resolution=512, time_resolution=256, channels=4, VAE_scale=4, timesteps=1000, noise_strategy="repeat", sampler="ddim", return_latent=True): + + height = int(freq_resolution/VAE_scale) + width = int(time_resolution * ((duration + 1) / 4) / VAE_scale) + VAE_encoder, VAE_quantizer, VAE_decoder = VAE._encoder, VAE._vq_vae, VAE._decoder + + + text2sound_embedding = \ + MMM.get_text_features(**CLAP_tokenizer([positive_prompts], padding=True, return_tensors="pt"))[0] + negative_condition = \ + MMM.get_text_features(**CLAP_tokenizer([negative_prompts], padding=True, return_tensors="pt"))[0] + + + mySampler = DiffSynthSampler(timesteps, height=height, channels=channels, noise_strategy=noise_strategy, mute=True) + mySampler.activate_classifier_free_guidance(CFG, negative_condition) + mySampler.respace(list(np.linspace(0, timesteps - 1, sample_steps, dtype=np.int32))) + + condition = text2sound_embedding.repeat(batchsize, 1) + guidance = guidance.repeat(batchsize, 1, 1, 1).to(device) + + # mask = 1, freeze + latent_mask = torch.zeros((batchsize, 1, height, width), dtype=torch.float32).to(device) + latent_mask[:, :, :, -int(time_resolution * (1 / 4) / VAE_scale):] = 1.0 + + latent_representations, initial_noise = \ + mySampler.inpaint_sample(model=uNet, shape=(batchsize, channels, height, width), + noising_strength=noising_strength, + guide_img=guidance, mask=latent_mask, return_tensor=True, + condition=condition, sampler=sampler, + use_dynamic_mask=use_dynamic_mask, + end_noise_level_ratio=0.0, + mask_flexivity=mask_flexivity) + + latent_representations = latent_representations[-1] + + quantized_latent_representations, _, (_, _, _) = VAE_quantizer(latent_representations) + + if return_latent: + return quantized_latent_representations.detach() + reconstruction_batch = VAE_decoder(quantized_latent_representations).to("cpu").detach().numpy() + time_resolution = int(time_resolution * ((duration+1) / 4)) + + rec_signals = nnData2Audio(reconstruction_batch, resolution=(freq_resolution, time_resolution)) + rec_signals = [rms_normalize(rec_signal) for rec_signal in rec_signals] + + return quantized_latent_representations.detach(), reconstruction_batch, rec_signals + + +def generate_audios_with_diffuSynth_sample(device, uNet, VAE, MMM, CLAP_tokenizer, num_batches, positive_prompts, negative_prompts="", CFG=6, sample_steps=10): + diffuSynth_signals = [] + for _ in tqdm(range(num_batches)): + _, _, signals = sample_pipeline(device, uNet, VAE, MMM, CLAP_tokenizer, + positive_prompts=positive_prompts, negative_prompts=negative_prompts, + batchsize=16, sample_steps=sample_steps, CFG=CFG, seed=None, return_latent=False) + diffuSynth_signals.extend(signals) + return np.array(diffuSynth_signals) + + +def generate_audios_with_diffuSynth_inpaint(device, uNet, VAE, MMM, CLAP_tokenizer, num_batches, guidance, duration, use_dynamic_mask, noising_strength, positive_prompts, negative_prompts="", CFG=6, sample_steps=10): + + diffuSynth_signals = [] + for _ in tqdm(range(num_batches)): + _, _, signals = inpaint_pipeline(device, uNet, VAE, MMM, CLAP_tokenizer, + use_dynamic_mask=use_dynamic_mask, noising_strength=noising_strength, guidance=guidance, + positive_prompts=positive_prompts, negative_prompts=negative_prompts, batchsize=16, sample_steps=sample_steps, CFG=CFG, seed=None, duration=duration, mask_flexivity=0.999, + return_latent=False) + diffuSynth_signals.extend(signals) + return np.array(diffuSynth_signals) \ No newline at end of file diff --git a/metrics/pipelines_STFT.py b/metrics/pipelines_STFT.py new file mode 100644 index 0000000000000000000000000000000000000000..d866f258175ce27211866e9bef4213a030b6de13 --- /dev/null +++ b/metrics/pipelines_STFT.py @@ -0,0 +1,100 @@ +import librosa +import numpy as np +import torch +from tqdm import tqdm + +from tools import rms_normalize, decode_stft, depad_STFT +from model.DiffSynthSampler import DiffSynthSampler + +def sample_pipeline_STFT(device, uNet, VAE, MMM, CLAP_tokenizer, + positive_prompts, negative_prompts, batchsize, sample_steps, CFG, seed=None, + freq_resolution=512, time_resolution=256, channels=4, VAE_scale=4, timesteps=1000, noise_strategy="repeat", sampler="ddim", return_latent=True): + "Sample a fix-length audio using a diffusion model, including 'ISTFT+' post-processing." + + height = int(freq_resolution/VAE_scale) + width = int(time_resolution/VAE_scale) + VAE_encoder, VAE_quantizer, VAE_decoder = VAE._encoder, VAE._vq_vae, VAE._decoder + + text2sound_embedding = \ + MMM.get_text_features(**CLAP_tokenizer([positive_prompts], padding=True, return_tensors="pt"))[0].to(device) + negative_condition = \ + MMM.get_text_features(**CLAP_tokenizer([negative_prompts], padding=True, return_tensors="pt"))[ + 0].to(device) + + mySampler = DiffSynthSampler(timesteps, height=height, channels=channels, noise_strategy=noise_strategy, mute=True) + mySampler.activate_classifier_free_guidance(CFG, negative_condition) + + mySampler.respace(list(np.linspace(0, timesteps - 1, sample_steps, dtype=np.int32))) + + condition = text2sound_embedding.repeat(batchsize, 1) + + latent_representations, initial_noise = \ + mySampler.sample(model=uNet, shape=(batchsize, channels, height, width), seed=seed, + return_tensor=True, condition=condition, sampler=sampler) + + latent_representations = latent_representations[-1] + + quantized_latent_representations, _, (_, _, _) = VAE_quantizer(latent_representations) + + if return_latent: + return quantized_latent_representations.detach() + + reconstruction_batch = VAE_decoder(quantized_latent_representations).to("cpu").detach().numpy() + + rec_signals = [] + + for index, STFT in enumerate(reconstruction_batch): + padded_D_rec = decode_stft(STFT) + D_rec = depad_STFT(padded_D_rec) + # get_audio + rec_signal = librosa.istft(D_rec, hop_length=256, win_length=1024) + rec_signals.append(rms_normalize(rec_signal)) + + return quantized_latent_representations.detach(), reconstruction_batch, rec_signals + +def sample_pipeline_GAN_STFT(device, gan_generator, VAE, MMM, CLAP_tokenizer, + positive_prompts, negative_prompts, batchsize, sample_steps, CFG, seed=None, + freq_resolution=512, time_resolution=256, channels=4, VAE_scale=4, timesteps=1000, noise_strategy="repeat", sampler="ddim", return_latent=True): + "Sample fix-length audio using a GAN, including 'ISTFT+' post-processing." + + height = int(freq_resolution/VAE_scale) + width = int(time_resolution/VAE_scale) + VAE_encoder, VAE_quantizer, VAE_decoder = VAE._encoder, VAE._vq_vae, VAE._decoder + + text2sound_embedding = \ + MMM.get_text_features(**CLAP_tokenizer([positive_prompts], padding=True, return_tensors="pt"))[0].to(device) + + condition = text2sound_embedding.repeat(batchsize, 1) + + noise = torch.randn(batchsize, channels, height, width).to(device) + latent_representations = gan_generator(noise, condition) + + quantized_latent_representations, _, (_, _, _) = VAE_quantizer(latent_representations) + + if return_latent: + return quantized_latent_representations.detach() + reconstruction_batch = VAE_decoder(quantized_latent_representations).to("cpu").detach().numpy() + + rec_signals = [] + + for index, STFT in enumerate(reconstruction_batch): + padded_D_rec = decode_stft(STFT) + D_rec = depad_STFT(padded_D_rec) + # get_audio + rec_signal = librosa.istft(D_rec, hop_length=256, win_length=1024) + rec_signals.append(rms_normalize(rec_signal)) + + return quantized_latent_representations.detach(), reconstruction_batch, rec_signals + + +def generate_audios_with_diffuSynth_sample(device, uNet, VAE, MMM, CLAP_tokenizer, num_batches, positive_prompts, negative_prompts="", CFG=6, sample_steps=10): + "Sample audios using a diffusion model, including 'ISTFT+' post-processing." + + diffuSynth_signals = [] + for _ in tqdm(range(num_batches)): + _, _, signals = sample_pipeline_STFT(device, uNet, VAE, MMM, CLAP_tokenizer, + positive_prompts=positive_prompts, negative_prompts=negative_prompts, + batchsize=8, sample_steps=sample_steps, CFG=CFG, seed=None, return_latent=False) + diffuSynth_signals.extend(signals) + return np.array(diffuSynth_signals) + diff --git a/metrics/precision_recall.py b/metrics/precision_recall.py new file mode 100644 index 0000000000000000000000000000000000000000..b2c281b0dbccae83bd194cfd6d84e06f4fa22666 --- /dev/null +++ b/metrics/precision_recall.py @@ -0,0 +1,204 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +"""k-NN precision and recall.""" + +from time import time + + +# ---------------------------------------------------------------------------- + +import numpy as np +from tqdm import tqdm + + +def batch_pairwise_distances(U, V): + """Compute pair-wise distance in a batch of feature.""" + + norm_u = np.sum(np.square(U), axis=1) + norm_v = np.sum(np.square(V), axis=1) + + norm_u = np.reshape(norm_u, [-1, 1]) + norm_v = np.reshape(norm_v, [1, -1]) + + D = np.maximum(norm_u - 2 * np.dot(U, V.T) + norm_v, 0.0) + return D + + +# ---------------------------------------------------------------------------- + +class DistanceBlock(): + """Compute pair-wise distance in a batch of feature.""" + + def __init__(self, num_features): + self.num_features = num_features + + def pairwise_distances(self, U, V): + return batch_pairwise_distances(U, V) + + + +# ---------------------------------------------------------------------------- + +class ManifoldEstimator(): + """Estimates the manifold of given feature vectors.""" + + def __init__(self, distance_block, features, row_batch_size=16, col_batch_size=16, + nhood_sizes=[3], clamp_to_percentile=None, eps=1e-5, mute=False): + """Estimate the manifold of given feature vectors. + + Args: + distance_block: DistanceBlock object that distributes pairwise distance + calculation to multiple GPUs. + features (np.array/tf.Tensor): Matrix of feature vectors to estimate their manifold. + row_batch_size (int): Row batch size to compute pairwise distances + (parameter to trade-off between memory usage and performance). + col_batch_size (int): Column batch size to compute pairwise distances. + nhood_sizes (list): Number of neighbors used to estimate the manifold. + clamp_to_percentile (float): Prune hyperspheres that have radius larger than + the given percentile. + eps (float): Small number for numerical stability. + """ + num_images = features.shape[0] + self.nhood_sizes = nhood_sizes + self.num_nhoods = len(nhood_sizes) + self.eps = eps + self.row_batch_size = row_batch_size + self.col_batch_size = col_batch_size + self._ref_features = features + self._distance_block = distance_block + self.mute = mute + + # Estimate manifold of features by calculating distances to k-NN of each sample. + self.D = np.zeros([num_images, self.num_nhoods], dtype=np.float32) + distance_batch = np.zeros([row_batch_size, num_images], dtype=np.float32) + seq = np.arange(max(self.nhood_sizes) + 1, dtype=np.int32) + + if mute: + for begin1 in range(0, num_images, row_batch_size): + end1 = min(begin1 + row_batch_size, num_images) + row_batch = features[begin1:end1] + + for begin2 in range(0, num_images, col_batch_size): + end2 = min(begin2 + col_batch_size, num_images) + col_batch = features[begin2:end2] + + # Compute distances between batches. + distance_batch[0:end1 - begin1, begin2:end2] = self._distance_block.pairwise_distances(row_batch, + col_batch) + + # Find the k-nearest neighbor from the current batch. + self.D[begin1:end1, :] = np.partition(distance_batch[0:end1 - begin1, :], seq, axis=1)[:, self.nhood_sizes] + else: + for begin1 in tqdm(range(0, num_images, row_batch_size)): + end1 = min(begin1 + row_batch_size, num_images) + row_batch = features[begin1:end1] + + for begin2 in range(0, num_images, col_batch_size): + end2 = min(begin2 + col_batch_size, num_images) + col_batch = features[begin2:end2] + + # Compute distances between batches. + distance_batch[0:end1 - begin1, begin2:end2] = self._distance_block.pairwise_distances(row_batch, + col_batch) + + # Find the k-nearest neighbor from the current batch. + self.D[begin1:end1, :] = np.partition(distance_batch[0:end1 - begin1, :], seq, axis=1)[:, self.nhood_sizes] + + if clamp_to_percentile is not None: + max_distances = np.percentile(self.D, clamp_to_percentile, axis=0) + self.D[self.D > max_distances] = 0 + + def evaluate(self, eval_features, return_realism=False, return_neighbors=False): + """Evaluate if new feature vectors are at the manifold.""" + num_eval_images = eval_features.shape[0] + num_ref_images = self.D.shape[0] + distance_batch = np.zeros([self.row_batch_size, num_ref_images], dtype=np.float32) + batch_predictions = np.zeros([num_eval_images, self.num_nhoods], dtype=np.int32) + max_realism_score = np.zeros([num_eval_images, ], dtype=np.float32) + nearest_indices = np.zeros([num_eval_images, ], dtype=np.int32) + + for begin1 in range(0, num_eval_images, self.row_batch_size): + end1 = min(begin1 + self.row_batch_size, num_eval_images) + feature_batch = eval_features[begin1:end1] + + for begin2 in range(0, num_ref_images, self.col_batch_size): + end2 = min(begin2 + self.col_batch_size, num_ref_images) + ref_batch = self._ref_features[begin2:end2] + + distance_batch[0:end1 - begin1, begin2:end2] = self._distance_block.pairwise_distances(feature_batch, + ref_batch) + + # From the minibatch of new feature vectors, determine if they are in the estimated manifold. + # If a feature vector is inside a hypersphere of some reference sample, then + # the new sample lies at the estimated manifold. + # The radii of the hyperspheres are determined from distances of neighborhood size k. + samples_in_manifold = distance_batch[0:end1 - begin1, :, None] <= self.D + batch_predictions[begin1:end1] = np.any(samples_in_manifold, axis=1).astype(np.int32) + + max_realism_score[begin1:end1] = np.max(self.D[:, 0] / (distance_batch[0:end1 - begin1, :] + self.eps), + axis=1) + nearest_indices[begin1:end1] = np.argmin(distance_batch[0:end1 - begin1, :], axis=1) + + if return_realism and return_neighbors: + return batch_predictions, max_realism_score, nearest_indices + elif return_realism: + return batch_predictions, max_realism_score + elif return_neighbors: + return batch_predictions, nearest_indices + + return batch_predictions + + +# ---------------------------------------------------------------------------- + +def knn_precision_recall_features(ref_features, eval_features, nhood_sizes=[3], + row_batch_size=10000, col_batch_size=50000, mute=False): + """Calculates k-NN precision and recall for two sets of feature vectors. + + Args: + ref_features (np.array/tf.Tensor): Feature vectors of reference images. + eval_features (np.array/tf.Tensor): Feature vectors of generated images. + nhood_sizes (list): Number of neighbors used to estimate the manifold. + row_batch_size (int): Row batch size to compute pairwise distances + (parameter to trade-off between memory usage and performance). + col_batch_size (int): Column batch size to compute pairwise distances. + num_gpus (int): Number of GPUs used to evaluate precision and recall. + + Returns: + State (dict): Dict that contains precision and recall calculated from + ref_features and eval_features. + """ + state = dict() + num_images = ref_features.shape[0] + num_features = ref_features.shape[1] + + # Initialize DistanceBlock and ManifoldEstimators. + distance_block = DistanceBlock(num_features) + ref_manifold = ManifoldEstimator(distance_block, ref_features, row_batch_size, col_batch_size, nhood_sizes, mute=mute) + eval_manifold = ManifoldEstimator(distance_block, eval_features, row_batch_size, col_batch_size, nhood_sizes, mute=mute) + + # Evaluate precision and recall using k-nearest neighbors. + if not mute: + print('Evaluating k-NN precision and recall with %i samples...' % num_images) + start = time() + + # Precision: How many points from eval_features are in ref_features manifold. + precision = ref_manifold.evaluate(eval_features) + state['precision'] = precision.mean(axis=0) + + # Recall: How many points from ref_features are in eval_features manifold. + recall = eval_manifold.evaluate(ref_features) + state['recall'] = recall.mean(axis=0) + + if not mute: + print('Evaluated k-NN precision and recall in: %gs' % (time() - start)) + + return state + +# ---------------------------------------------------------------------------- + diff --git a/metrics/visualizations.py b/metrics/visualizations.py new file mode 100644 index 0000000000000000000000000000000000000000..486e8c640c1617e1c1f326b79ee84e126c4b4557 --- /dev/null +++ b/metrics/visualizations.py @@ -0,0 +1,123 @@ +import numpy as np +from matplotlib import pyplot as plt +from scipy.fft import fft +from scipy.signal import savgol_filter +from tools import rms_normalize + +colors = [ + # (0, 0, 0), # Black + # (86, 180, 233), # Sky blue + # (240, 228, 66), # Yellow + # (204, 121, 167), # Reddish purple + (213, 94, 0), # Vermilion + (0, 114, 178), # Blue + (230, 159, 0), # Orange + (0, 158, 115), # Bluish green +] + + +def plot_psd_multiple_signals(signals_list, labels_list, sample_rate=16000, window_size=500, + figsize=(10, 6), save_path=None, normalize=False): + """ + 在同一张图上绘制多组音频信号的功率谱密度比较图,使用对数刻度的响度轴(以2为底),并应用平滑处理。 + + 参数: + signals_list: 包含多组音频信号的列表,每组信号形状为 [sample_number, sample_length] 的numpy array + labels_list: 每组音频信号对应的标签字符串列表 + sample_rate: 音频的采样率 + """ + + # 确保传入的signals_list和labels_list长度相同 + assert len(signals_list) == len(labels_list), "每组信号必须有一个对应的标签。" + + signals_list = [np.array([rms_normalize(signal) for signal in signals]) for signals in signals_list] + + # 绘图准备 + plt.figure(figsize=figsize) + + # 遍历所有的音频信号 + i = 0 + for signal, label in zip(signals_list, labels_list): + # 计算FFT + fft_signal = fft(signal, axis=1) + + # 计算平均功率谱密度 + psd_signal = np.mean(np.abs(fft_signal)**2, axis=0) + + # 计算频率轴 + freqs = np.fft.fftfreq(signal.shape[1], 1/sample_rate) + + # 应用Savitzky-Golay滤波器进行平滑 + psd_smoothed = savgol_filter(np.log2(psd_signal[:signal.shape[1] // 2] + 1), window_size, 3) # 窗口大小51, 多项式阶数3 + + # Normalize each curve if normalize is True + if normalize: + psd_smoothed /= np.mean(psd_smoothed) + + # 绘制每组信号的功率谱密度 + plt.plot(freqs[:signal.shape[1] // 2], psd_smoothed, label=label, color=[x/255.0 for x in colors[i % len(colors)]], linewidth=1) + i += 1 + + # 设置图表元素 + plt.xlabel('Frequency (Hz)') + plt.ylabel('Mean Log-Amplitude') + plt.legend() + + # 根据save_path参数决定保存图像还是直接显示 + if save_path: + plt.savefig(save_path) + else: + plt.show() + + +def plot_amplitude_over_time(signals_list, labels_list, sample_rate=16000, window_size=500, + figsize=(10, 6), save_path=None, normalize=False, start_time=0): + """ + Plot the loudness of multiple sets of audio signals over time on the same graph, + using a logarithmic scale for the loudness axis (base 2), with smoothing applied. + + Parameters: + signals_list: List of sets of audio signals, each set is a numpy array with shape [sample_number, sample_length] + labels_list: List of labels corresponding to each set of audio signals + sample_rate: Sampling rate of the audio + window_size: Window size for the Savitzky-Golay filter + figsize: Figure size + save_path: Path to save the figure, if None, the figure will be displayed + normalize: Whether to normalize each curve so that the sum of each curve is the same + start_time: Time (in seconds) to start plotting, only data after this time will be retained + """ + assert len(signals_list) == len(labels_list), f"len(signals_list) != len(labels_list) for " \ + f"len(signals_list) = {len(signals_list)} and len(labels_list) = {len(labels_list)}" + + # Compute starting sample index + start_sample = int(start_time * sample_rate) + + # Normalize signals and truncate data + signals_list = [np.array([rms_normalize(signal)[start_sample:] for signal in signals]) for signals in signals_list] + time_axis = np.arange(start_sample, start_sample + signals_list[0].shape[1]) / sample_rate + + plt.figure(figsize=figsize) + + i = 0 + for signal, label in zip(signals_list, labels_list): + amplitude_mean = np.mean(np.abs(signal), axis=0) + + amplitude_smoothed = savgol_filter(np.log2(amplitude_mean + 1), window_size, 3) + + # Normalize each curve if normalize is True + if normalize: + amplitude_smoothed /= np.mean(amplitude_smoothed) + + plt.plot(time_axis, amplitude_smoothed, label=label, color=[x/255.0 for x in colors[i % len(colors)]], linewidth=1) + i += 1 + + plt.xlabel('Time (seconds)') + plt.ylabel('Mean Log-Amplitude') + plt.legend() + + # Save or show the figure based on save_path parameter + if save_path: + plt.savefig(save_path) + else: + plt.show() + diff --git a/model/DiffSynthSampler.py b/model/DiffSynthSampler.py new file mode 100644 index 0000000000000000000000000000000000000000..8af761e6d832cf0d8e9a2f4ea202d1e5ffbff97d --- /dev/null +++ b/model/DiffSynthSampler.py @@ -0,0 +1,425 @@ +import numpy as np +import torch +from tqdm import tqdm + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """ + Extract values from a 1-D numpy array for a batch of indices. + + :param arr: the 1-D numpy array. + :param timesteps: a tensor of indices into the array to extract. + :param broadcast_shape: a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res.expand(broadcast_shape) + + +class DiffSynthSampler: + + def __init__(self, timesteps, beta_start=0.0001, beta_end=0.02, device=None, mute=False, + height=128, max_batchsize=16, max_width=256, channels=4, train_width=64, noise_strategy="repeat"): + if device is None: + self.device = "cuda" if torch.cuda.is_available() else "cpu" + else: + self.device = device + self.height = height + self.train_width = train_width + self.max_batchsize = max_batchsize + self.max_width = max_width + self.channels = channels + self.num_timesteps = timesteps + self.timestep_map = list(range(self.num_timesteps)) + self.betas = np.array(np.linspace(beta_start, beta_end, self.num_timesteps), dtype=np.float64) + self.respaced = False + self.define_beta_schedule() + self.CFG = 1.0 + self.mute = mute + self.noise_strategy = noise_strategy + + def get_deterministic_noise_tensor_non_repeat(self, batchsize, width, reference_noise=None): + if reference_noise is None: + large_noise_tensor = torch.randn((self.max_batchsize, self.channels, self.height, self.max_width), device=self.device) + else: + assert reference_noise.shape == (batchsize, self.channels, self.height, self.max_width), "reference_noise shape mismatch" + large_noise_tensor = reference_noise + return large_noise_tensor[:batchsize, :, :, :width], None + + def get_deterministic_noise_tensor(self, batchsize, width, reference_noise=None): + if self.noise_strategy == "repeat": + noise, concat_points = self.get_deterministic_noise_tensor_repeat(batchsize, width, reference_noise=reference_noise) + return noise, concat_points + else: + noise, concat_points = self.get_deterministic_noise_tensor_non_repeat(batchsize, width, reference_noise=reference_noise) + return noise, concat_points + + + def get_deterministic_noise_tensor_repeat(self, batchsize, width, reference_noise=None): + # 生成与训练数据长度相等的噪音 + if reference_noise is None: + train_noise_tensor = torch.randn((self.max_batchsize, self.channels, self.height, self.train_width), device=self.device) + else: + assert reference_noise.shape == (batchsize, self.channels, self.height, self.train_width), "reference_noise shape mismatch" + train_noise_tensor = reference_noise + + release_width = int(self.train_width * 1.0 / 4) + first_part_width = self.train_width - release_width + + first_part = train_noise_tensor[:batchsize, :, :, :first_part_width] + release_part = train_noise_tensor[:batchsize, :, :, -release_width:] + + # 如果所需 length 小于等于 origin length,去掉 first_part 的中间部分 + if width <= self.train_width: + _first_part_head_width = int((width - release_width) / 2) + _first_part_tail_width = width - release_width - _first_part_head_width + all_parts = [first_part[:, :, :, :_first_part_head_width], first_part[:, :, :, -_first_part_tail_width:], release_part] + + # 沿第四维度拼接张量 + noise_tensor = torch.cat(all_parts, dim=3) + + # 记录拼接点的位置 + concat_points = [0] + for part in all_parts[:-1]: + next_point = concat_points[-1] + part.size(3) + concat_points.append(next_point) + + return noise_tensor, concat_points + + # 如果所需 length 大于 origin length,不断地从中间插入 first_part 的中间部分 + else: + # 计算需要重复front_width的次数 + repeats = (width - release_width) // first_part_width + extra = (width - release_width) % first_part_width + + _repeat_first_part_head_width = int(first_part_width / 2) + _repeat_first_part_tail_width = first_part_width - _repeat_first_part_head_width + + repeated_first_head_parts = [first_part[:, :, :, :_repeat_first_part_head_width] for _ in range(repeats)] + repeated_first_tail_parts = [first_part[:, :, :, -_repeat_first_part_tail_width:] for _ in range(repeats)] + + # 计算起始索引 + _middle_part_start_index = (first_part_width - extra) // 2 + # 切片张量以获取中间部分 + middle_part = first_part[:, :, :, _middle_part_start_index: _middle_part_start_index + extra] + + all_parts = repeated_first_head_parts + [middle_part] + repeated_first_tail_parts + [release_part] + + # 沿第四维度拼接张量 + noise_tensor = torch.cat(all_parts, dim=3) + + # 记录拼接点的位置 + concat_points = [0] + for part in all_parts[:-1]: + next_point = concat_points[-1] + part.size(3) + concat_points.append(next_point) + + return noise_tensor, concat_points + + def define_beta_schedule(self): + assert self.respaced == False, "This schedule has already been respaced!" + # define alphas + self.alphas = 1.0 - self.betas + self.alphas_cumprod = np.cumprod(self.alphas, axis=0) + self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) + self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recip_alphas = np.sqrt(1.0 / self.alphas) + self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + self.posterior_variance = (self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)) + + def activate_classifier_free_guidance(self, CFG, unconditional_condition): + assert ( + not unconditional_condition is None) or CFG == 1.0, "For CFG != 1.0, unconditional_condition must be available" + self.CFG = CFG + self.unconditional_condition = unconditional_condition + + def respace(self, use_timesteps=None): + if not use_timesteps is None: + last_alpha_cumprod = 1.0 + new_betas = [] + self.timestep_map = [] + for i, _alpha_cumprod in enumerate(self.alphas_cumprod): + if i in use_timesteps: + new_betas.append(1 - _alpha_cumprod / last_alpha_cumprod) + last_alpha_cumprod = _alpha_cumprod + self.timestep_map.append(i) + self.num_timesteps = len(use_timesteps) + self.betas = np.array(new_betas) + self.define_beta_schedule() + self.respaced = True + + def generate_linear_noise(self, shape, variance=1.0, first_endpoint=None, second_endpoint=None): + assert shape[1] == self.channels, "shape[1] != self.channels" + assert shape[2] == self.height, "shape[2] != self.height" + noise = torch.empty(*shape, device=self.device) + + # 第三种情况:两个端点都不是None,进行线性插值 + if first_endpoint is not None and second_endpoint is not None: + for i in range(shape[0]): + alpha = i / (shape[0] - 1) # 插值系数 + noise[i] = alpha * second_endpoint + (1 - alpha) * first_endpoint + return noise # 返回插值后的结果,不需要进行后续的均值和方差调整 + else: + # 第一个端点不是None + if first_endpoint is not None: + noise[0] = first_endpoint + if shape[0] > 1: + noise[1], _ = self.get_deterministic_noise_tensor(1, shape[3])[0] + else: + noise[0], _ = self.get_deterministic_noise_tensor(1, shape[3])[0] + if shape[0] > 1: + noise[1], _ = self.get_deterministic_noise_tensor(1, shape[3])[0] + + # 生成其他的噪声点 + for i in range(2, shape[0]): + noise[i] = 2 * noise[i - 1] - noise[i - 2] + + # 当只有一个端点被指定时 + current_var = noise.var() + stddev_ratio = torch.sqrt(variance / current_var) + noise = noise * stddev_ratio + + # 如果第一个端点被指定,进行平移调整 + if first_endpoint is not None: + shift = first_endpoint - noise[0] + noise += shift + + return noise + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data for a given number of diffusion steps. + + In other words, sample from q(x_t | x_0). + + :param x_start: the initial data batch. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :param noise: if specified, the split-out normal noise. + :return: A noisy version of x_start. + """ + assert x_start.shape[1] == self.channels, "shape[1] != self.channels" + assert x_start.shape[2] == self.height, "shape[2] != self.height" + + if noise is None: + # noise = torch.randn_like(x_start) + noise, _ = self.get_deterministic_noise_tensor(x_start.shape[0], x_start.shape[3]) + + assert noise.shape == x_start.shape + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) + * noise + ) + + @torch.no_grad() + def ddim_sample(self, model, x, t, condition=None, ddim_eta=0.0): + map_tensor = torch.tensor(self.timestep_map, device=t.device, dtype=t.dtype) + mapped_t = map_tensor[t] + + # Todo: add CFG + + if self.CFG == 1.0: + pred_noise = model(x, mapped_t, condition) + else: + unconditional_condition = self.unconditional_condition.unsqueeze(0).repeat( + *([x.shape[0]] + [1] * len(self.unconditional_condition.shape))) + x_in = torch.cat([x] * 2) + t_in = torch.cat([mapped_t] * 2) + c_in = torch.cat([unconditional_condition, condition]) + noise_uncond, noise = model(x_in, t_in, c_in).chunk(2) + pred_noise = noise_uncond + self.CFG * (noise - noise_uncond) + + # Todo: END + + alpha_cumprod_t = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + alpha_cumprod_t_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) + + pred_x0 = (x - torch.sqrt((1. - alpha_cumprod_t)) * pred_noise) / torch.sqrt(alpha_cumprod_t) + + sigmas_t = ( + ddim_eta + * torch.sqrt((1 - alpha_cumprod_t_prev) / (1 - alpha_cumprod_t)) + * torch.sqrt(1 - alpha_cumprod_t / alpha_cumprod_t_prev) + ) + + pred_dir_xt = torch.sqrt(1 - alpha_cumprod_t_prev - sigmas_t ** 2) * pred_noise + + + step_noise, _ = self.get_deterministic_noise_tensor(x.shape[0], x.shape[3]) + + + x_prev = torch.sqrt(alpha_cumprod_t_prev) * pred_x0 + pred_dir_xt + sigmas_t * step_noise + + return x_prev + + def p_sample(self, model, x, t, condition=None, sampler="ddim"): + if sampler == "ddim": + return self.ddim_sample(model, x, t, condition=condition, ddim_eta=0.0) + elif sampler == "ddpm": + return self.ddim_sample(model, x, t, condition=condition, ddim_eta=1.0) + else: + raise NotImplementedError() + + def get_dynamic_masks(self, n_masks, shape, concat_points, mask_flexivity=0.8): + release_length = int(self.train_width / 4) + assert shape[3] == (concat_points[-1] + release_length), "shape[3] != (concat_points[-1] + release_length)" + + fraction_lengths = [concat_points[i + 1] - concat_points[i] for i in range(len(concat_points) - 1)] + + # Todo: remove hard-coding + n_guidance_steps = int(n_masks * mask_flexivity) + n_free_steps = n_masks - n_guidance_steps + + masks = [] + # Todo: 在一半的 steps 内收缩 mask。也就是说,在后程对 release 以外的区域不做inpaint,而是 img2img + for i in range(n_guidance_steps): + # mask = 1, freeze + step_i_mask = torch.zeros((shape[0], 1, shape[2], shape[3]), dtype=torch.float32).to(self.device) + step_i_mask[:, :, :, -release_length:] = 1.0 + + for fraction_index in range(len(fraction_lengths)): + + _fraction_mask_length = int((n_guidance_steps - 1 - i) / (n_guidance_steps - 1) * fraction_lengths[fraction_index]) + + if fraction_index == 0: + step_i_mask[:, :, :, :_fraction_mask_length] = 1.0 + elif fraction_index == len(fraction_lengths) - 1: + if not _fraction_mask_length == 0: + step_i_mask[:, :, :, -_fraction_mask_length - release_length:] = 1.0 + else: + fraction_mask_start_position = int((fraction_lengths[fraction_index] - _fraction_mask_length) / 2) + + step_i_mask[:, :, :, + concat_points[fraction_index] + fraction_mask_start_position:concat_points[ + fraction_index] + fraction_mask_start_position + _fraction_mask_length] = 1.0 + masks.append(step_i_mask) + + for i in range(n_free_steps): + step_i_mask = torch.zeros((shape[0], 1, shape[2], shape[3]), dtype=torch.float32).to(self.device) + step_i_mask[:, :, :, -release_length:] = 1.0 + masks.append(step_i_mask) + + masks.reverse() + return masks + + @torch.no_grad() + def p_sample_loop(self, model, shape, initial_noise=None, start_noise_level_ratio=1.0, end_noise_level_ratio=0.0, + return_tensor=False, condition=None, guide_img=None, + mask=None, sampler="ddim", inpaint=False, use_dynamic_mask=False, mask_flexivity=0.8): + + assert shape[1] == self.channels, "shape[1] != self.channels" + assert shape[2] == self.height, "shape[2] != self.height" + + initial_noise, _ = self.get_deterministic_noise_tensor(shape[0], shape[3], reference_noise=initial_noise) + assert initial_noise.shape == shape, "initial_noise.shape != shape" + + start_noise_level_index = int(self.num_timesteps * start_noise_level_ratio) # not included!!! + end_noise_level_index = int(self.num_timesteps * end_noise_level_ratio) + + timesteps = reversed(range(end_noise_level_index, start_noise_level_index)) + + # configure initial img + assert (start_noise_level_ratio == 1.0) or ( + not guide_img is None), "A guide_img must be given to sample from a non-pure-noise." + + if guide_img is None: + img = initial_noise + else: + guide_img, concat_points = self.get_deterministic_noise_tensor_repeat(shape[0], shape[3], reference_noise=guide_img) + assert guide_img.shape == shape, "guide_img.shape != shape" + + if start_noise_level_index > 0: + t = torch.full((shape[0],), start_noise_level_index-1, device=self.device).long() # -1 for start_noise_level_index not included + img = self.q_sample(guide_img, t, noise=initial_noise) + else: + print("Zero noise added to the guidance latent representation.") + img = guide_img + + # get masks + n_masks = start_noise_level_index - end_noise_level_index + if use_dynamic_mask: + masks = self.get_dynamic_masks(n_masks, shape, concat_points, mask_flexivity) + else: + masks = [mask for _ in range(n_masks)] + + imgs = [img] + current_mask = None + + + for i in tqdm(timesteps, total=start_noise_level_index - end_noise_level_index, disable=self.mute): + + # if i == 3: + # return [img], initial_noise # 第1排,第1列 + + img = self.p_sample(model, img, torch.full((shape[0],), i, device=self.device, dtype=torch.long), + condition=condition, + sampler=sampler) + # if i == 3: + # return [img], initial_noise # 第1排,第2列 + + if inpaint: + if i > 0: + t = torch.full((shape[0],), int(i-1), device=self.device).long() + img_noise_t = self.q_sample(guide_img, t, noise=initial_noise) + # if i == 3: + # return [img_noise_t], initial_noise # 第2排,第2列 + current_mask = masks.pop() + img = current_mask * img_noise_t + (1 - current_mask) * img + # if i == 3: + # return [img], initial_noise # 第1.5排,最后1列 + else: + img = current_mask * guide_img + (1 - current_mask) * img + + if return_tensor: + imgs.append(img) + else: + imgs.append(img.cpu().numpy()) + + return imgs, initial_noise + + + def sample(self, model, shape, return_tensor=False, condition=None, sampler="ddim", initial_noise=None, seed=None): + if not seed is None: + torch.manual_seed(seed) + return self.p_sample_loop(model, shape, initial_noise=initial_noise, start_noise_level_ratio=1.0, end_noise_level_ratio=0.0, + return_tensor=return_tensor, condition=condition, sampler=sampler) + + def interpolate(self, model, shape, variance, first_endpoint=None, second_endpoint=None, return_tensor=False, + condition=None, sampler="ddim", seed=None): + if not seed is None: + torch.manual_seed(seed) + linear_noise = self.generate_linear_noise(shape, variance, first_endpoint=first_endpoint, + second_endpoint=second_endpoint) + return self.p_sample_loop(model, shape, initial_noise=linear_noise, start_noise_level_ratio=1.0, + end_noise_level_ratio=0.0, + return_tensor=return_tensor, condition=condition, sampler=sampler) + + def img_guided_sample(self, model, shape, noising_strength, guide_img, return_tensor=False, condition=None, + sampler="ddim", initial_noise=None, seed=None): + if not seed is None: + torch.manual_seed(seed) + assert guide_img.shape[-1] == shape[-1], "guide_img.shape[:-1] != shape[:-1]" + return self.p_sample_loop(model, shape, start_noise_level_ratio=noising_strength, end_noise_level_ratio=0.0, + return_tensor=return_tensor, condition=condition, sampler=sampler, + guide_img=guide_img, initial_noise=initial_noise) + + def inpaint_sample(self, model, shape, noising_strength, guide_img, mask, return_tensor=False, condition=None, + sampler="ddim", initial_noise=None, use_dynamic_mask=False, end_noise_level_ratio=0.0, seed=None, + mask_flexivity=0.8): + if not seed is None: + torch.manual_seed(seed) + return self.p_sample_loop(model, shape, start_noise_level_ratio=noising_strength, end_noise_level_ratio=end_noise_level_ratio, + return_tensor=return_tensor, condition=condition, guide_img=guide_img, mask=mask, + sampler=sampler, inpaint=True, initial_noise=initial_noise, use_dynamic_mask=use_dynamic_mask, + mask_flexivity=mask_flexivity) \ No newline at end of file diff --git a/model/GAN.py b/model/GAN.py new file mode 100644 index 0000000000000000000000000000000000000000..b7cfc5d88176cb79726c2647d5237609d028cb36 --- /dev/null +++ b/model/GAN.py @@ -0,0 +1,262 @@ +import json +import numpy as np +import torch +from torch import nn +from six.moves import xrange +from torch.utils.tensorboard import SummaryWriter +import random + +from model.diffusion import ConditionedUnet +from tools import create_key + +class Discriminator(nn.Module): + def __init__(self, label_emb_dim): + super(Discriminator, self).__init__() + # 特征图卷积层 + self.conv_layers = nn.Sequential( + nn.Conv2d(4, 64, kernel_size=4, stride=2, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(128), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(256), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(512), + nn.LeakyReLU(0.2, inplace=True), + nn.AdaptiveAvgPool2d(1), # 添加适应性池化层 + nn.Flatten() + ) + + # 文本嵌入处理 + self.text_embedding = nn.Sequential( + nn.Linear(label_emb_dim, 512), + nn.LeakyReLU(0.2, inplace=True) + ) + + # 判别器最后的全连接层 + self.fc = nn.Linear(512 + 512, 1) # 两个512分别来自特征图和文本嵌入 + + def forward(self, x, text_emb): + x = self.conv_layers(x) + text_emb = self.text_embedding(text_emb) + combined = torch.cat((x, text_emb), dim=1) + output = self.fc(combined) + return output + + + +def evaluate_GAN(device, generator, discriminator, iterator, encodes2embeddings_mapping): + generator.to(device) + discriminator.to(device) + generator.eval() + discriminator.eval() + + real_accs = [] + fake_accs = [] + + with torch.no_grad(): + for i in range(100): + data, attributes = next(iter(iterator)) + data = data.to(device) + + conditions = [encodes2embeddings_mapping[create_key(attribute)] for attribute in attributes] + selected_conditions = [random.choice(conditions_of_one_sample) for conditions_of_one_sample in conditions] + selected_conditions = torch.stack(selected_conditions).float().to(device) + + # 将数据和标签移至设备 + real_images = data.to(device) + labels = selected_conditions.to(device) + + # 生成噪声和假图像 + noise = torch.randn_like(real_images).to(device) + fake_images = generator(noise) + + # 评估鉴别器的性能 + real_preds = discriminator(real_images, labels).reshape(-1) + fake_preds = discriminator(fake_images, labels).reshape(-1) + real_acc = (real_preds > 0.5).float().mean().item() # 真实图像的准确率 + fake_acc = (fake_preds < 0.5).float().mean().item() # 生成图像的准确率 + + real_accs.append(real_acc) + fake_accs.append(fake_acc) + + + # 计算平均准确率 + average_real_acc = sum(real_accs) / len(real_accs) + average_fake_acc = sum(fake_accs) / len(fake_accs) + + return average_real_acc, average_fake_acc + + +def get_Generator(model_Config, load_pretrain=False, model_name=None, device="cpu"): + generator = ConditionedUnet(**model_Config) + print(f"Model intialized, size: {sum(p.numel() for p in generator.parameters() if p.requires_grad)}") + generator.to(device) + + if load_pretrain: + print(f"Loading weights from models/{model_name}_generator.pth") + checkpoint = torch.load(f'models/{model_name}_generator.pth', map_location=device) + generator.load_state_dict(checkpoint['model_state_dict']) + generator.eval() + return generator + + +def get_Discriminator(model_Config, load_pretrain=False, model_name=None, device="cpu"): + discriminator = Discriminator(**model_Config) + print(f"Model intialized, size: {sum(p.numel() for p in discriminator.parameters() if p.requires_grad)}") + discriminator.to(device) + + if load_pretrain: + print(f"Loading weights from models/{model_name}_discriminator.pth") + checkpoint = torch.load(f'models/{model_name}_discriminator.pth', map_location=device) + discriminator.load_state_dict(checkpoint['model_state_dict']) + discriminator.eval() + return discriminator + + +def train_GAN(device, init_model_name, unetConfig, BATCH_SIZE, lr_G, lr_D, max_iter, iterator, load_pretrain, + encodes2embeddings_mapping, save_steps, unconditional_condition, uncondition_rate, save_model_name=None): + + if save_model_name is None: + save_model_name = init_model_name + + def save_model_hyperparameter(model_name, unetConfig, BATCH_SIZE, model_size, current_iter, current_loss): + model_hyperparameter = unetConfig + model_hyperparameter["BATCH_SIZE"] = BATCH_SIZE + model_hyperparameter["lr_G"] = lr_G + model_hyperparameter["lr_D"] = lr_D + model_hyperparameter["model_size"] = model_size + model_hyperparameter["current_iter"] = current_iter + model_hyperparameter["current_loss"] = current_loss + with open(f"models/hyperparameters/{model_name}_GAN.json", "w") as json_file: + json.dump(model_hyperparameter, json_file, ensure_ascii=False, indent=4) + + generator = ConditionedUnet(**unetConfig) + discriminator = Discriminator(unetConfig["label_emb_dim"]) + generator_size = sum(p.numel() for p in generator.parameters() if p.requires_grad) + discriminator_size = sum(p.numel() for p in discriminator.parameters() if p.requires_grad) + + print(f"Generator trainable parameters: {generator_size}, discriminator trainable parameters: {discriminator_size}") + generator.to(device) + discriminator.to(device) + optimizer_G = torch.optim.Adam(filter(lambda p: p.requires_grad, generator.parameters()), lr=lr_G, amsgrad=False) + optimizer_D = torch.optim.Adam(filter(lambda p: p.requires_grad, discriminator.parameters()), lr=lr_D, amsgrad=False) + + if load_pretrain: + print(f"Loading weights from models/{init_model_name}_generator.pt") + checkpoint = torch.load(f'models/{init_model_name}_generator.pth') + generator.load_state_dict(checkpoint['model_state_dict']) + optimizer_G.load_state_dict(checkpoint['optimizer_state_dict']) + print(f"Loading weights from models/{init_model_name}_discriminator.pt") + checkpoint = torch.load(f'models/{init_model_name}_discriminator.pth') + discriminator.load_state_dict(checkpoint['model_state_dict']) + optimizer_D.load_state_dict(checkpoint['optimizer_state_dict']) + else: + print("Model initialized.") + if max_iter == 0: + print("Return model directly.") + return generator, discriminator, optimizer_G, optimizer_D + + + train_loss_G, train_loss_D = [], [] + writer = SummaryWriter(f'runs/{save_model_name}_GAN') + + # average_real_acc, average_fake_acc = evaluate_GAN(device, generator, discriminator, iterator, encodes2embeddings_mapping) + # print(f"average_real_acc, average_fake_acc: {average_real_acc, average_fake_acc}") + + criterion = nn.BCEWithLogitsLoss() + generator.train() + for i in xrange(max_iter): + data, attributes = next(iter(iterator)) + data = data.to(device) + + conditions = [encodes2embeddings_mapping[create_key(attribute)] for attribute in attributes] + unconditional_condition_copy = torch.tensor(unconditional_condition, dtype=torch.float32).to(device).detach() + selected_conditions = [unconditional_condition_copy if random.random() < uncondition_rate else random.choice( + conditions_of_one_sample) for conditions_of_one_sample in conditions] + batch_size = len(selected_conditions) + selected_conditions = torch.stack(selected_conditions).float().to(device) + + # 将数据和标签移至设备 + real_images = data.to(device) + labels = selected_conditions.to(device) + + # 真实和假的标签 + real_labels = torch.ones(batch_size, 1).to(device) + fake_labels = torch.zeros(batch_size, 1).to(device) + + # ========== 训练鉴别器 ========== + optimizer_D.zero_grad() + + # 计算鉴别器对真实图像的损失 + outputs_real = discriminator(real_images, labels) + loss_D_real = criterion(outputs_real, real_labels) + + # 生成假图像 + noise = torch.randn_like(real_images).to(device) + fake_images = generator(noise, labels) + + # 计算鉴别器对假图像的损失 + outputs_fake = discriminator(fake_images.detach(), labels) + loss_D_fake = criterion(outputs_fake, fake_labels) + + # 反向传播和优化 + loss_D = loss_D_real + loss_D_fake + loss_D.backward() + optimizer_D.step() + + # ========== 训练生成器 ========== + optimizer_G.zero_grad() + + # 计算生成器的损失 + outputs_fake = discriminator(fake_images, labels) + loss_G = criterion(outputs_fake, real_labels) + + # 反向传播和优化 + loss_G.backward() + optimizer_G.step() + + + train_loss_G.append(loss_G.item()) + train_loss_D.append(loss_D.item()) + step = int(optimizer_G.state_dict()['state'][list(optimizer_G.state_dict()['state'].keys())[0]]['step'].numpy()) + + if (i + 1) % 100 == 0: + print('%d step' % (step)) + + if (i + 1) % save_steps == 0: + current_loss_D = np.mean(train_loss_D[-save_steps:]) + current_loss_G = np.mean(train_loss_G[-save_steps:]) + print('current_loss_G: %.5f' % current_loss_G) + print('current_loss_D: %.5f' % current_loss_D) + + writer.add_scalar(f"current_loss_G", current_loss_G, step) + writer.add_scalar(f"current_loss_D", current_loss_D, step) + + + torch.save({ + 'model_state_dict': generator.state_dict(), + 'optimizer_state_dict': optimizer_G.state_dict(), + }, f'models/{save_model_name}_generator.pth') + save_model_hyperparameter(save_model_name, unetConfig, BATCH_SIZE, generator_size, step, current_loss_G) + torch.save({ + 'model_state_dict': discriminator.state_dict(), + 'optimizer_state_dict': optimizer_D.state_dict(), + }, f'models/{save_model_name}_discriminator.pth') + save_model_hyperparameter(save_model_name, unetConfig, BATCH_SIZE, discriminator_size, step, current_loss_D) + + if step % 10000 == 0: + torch.save({ + 'model_state_dict': generator.state_dict(), + 'optimizer_state_dict': optimizer_G.state_dict(), + }, f'models/history/{save_model_name}_{step}_generator.pth') + torch.save({ + 'model_state_dict': discriminator.state_dict(), + 'optimizer_state_dict': optimizer_D.state_dict(), + }, f'models/history/{save_model_name}_{step}_discriminator.pth') + + return generator, discriminator, optimizer_G, optimizer_D + + diff --git a/model/VQGAN.py b/model/VQGAN.py new file mode 100644 index 0000000000000000000000000000000000000000..5e4ca0928ced28018e35822c1351a83950ea1990 --- /dev/null +++ b/model/VQGAN.py @@ -0,0 +1,684 @@ +import json +from torch.utils.tensorboard import SummaryWriter +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from six.moves import xrange +from einops import rearrange +from torchvision import models + + +def Normalize(in_channels, num_groups=32, norm_type="groupnorm"): + """Normalization layer""" + + if norm_type == "batchnorm": + return torch.nn.BatchNorm2d(in_channels) + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +def nonlinearity(x, act_type="relu"): + """Nonlinear activation function""" + + if act_type == "relu": + return F.relu(x) + else: + # swish + return x * torch.sigmoid(x) + + +class VectorQuantizer(nn.Module): + """Vector quantization layer""" + + def __init__(self, num_embeddings, embedding_dim, commitment_cost): + super(VectorQuantizer, self).__init__() + + self._embedding_dim = embedding_dim + self._num_embeddings = num_embeddings + + self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim) + self._embedding.weight.data.uniform_(-1 / self._num_embeddings, 1 / self._num_embeddings) + self._commitment_cost = commitment_cost + + def forward(self, inputs): + # convert inputs from BCHW -> BHWC + inputs = inputs.permute(0, 2, 3, 1).contiguous() + input_shape = inputs.shape + + # Flatten input BCHW -> (BHW)C + flat_input = inputs.view(-1, self._embedding_dim) + + # Calculate distances (input-embedding)^2 + distances = (torch.sum(flat_input ** 2, dim=1, keepdim=True) + + torch.sum(self._embedding.weight ** 2, dim=1) + - 2 * torch.matmul(flat_input, self._embedding.weight.t())) + + # Encoding (one-hot-encoding matrix) + encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1) + encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device) + encodings.scatter_(1, encoding_indices, 1) + + # Quantize and unflatten + quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape) + + # Loss + e_latent_loss = F.mse_loss(quantized.detach(), inputs) + q_latent_loss = F.mse_loss(quantized, inputs.detach()) + loss = q_latent_loss + self._commitment_cost * e_latent_loss + + quantized = inputs + (quantized - inputs).detach() + avg_probs = torch.mean(encodings, dim=0) + perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) + + # convert quantized from BHWC -> BCHW + min_encodings, min_encoding_indices = None, None + return quantized.permute(0, 3, 1, 2).contiguous(), loss, (perplexity, min_encodings, min_encoding_indices) + + +class VectorQuantizerEMA(nn.Module): + """Vector quantization layer based on exponential moving average""" + + def __init__(self, num_embeddings, embedding_dim, commitment_cost, decay, epsilon=1e-5): + super(VectorQuantizerEMA, self).__init__() + + self._embedding_dim = embedding_dim + self._num_embeddings = num_embeddings + + self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim) + self._embedding.weight.data.normal_() + self._commitment_cost = commitment_cost + + self.register_buffer('_ema_cluster_size', torch.zeros(num_embeddings)) + self._ema_w = nn.Parameter(torch.Tensor(num_embeddings, self._embedding_dim)) + self._ema_w.data.normal_() + + self._decay = decay + self._epsilon = epsilon + + def forward(self, inputs): + # convert inputs from BCHW -> BHWC + inputs = inputs.permute(0, 2, 3, 1).contiguous() + input_shape = inputs.shape + + # Flatten input + flat_input = inputs.view(-1, self._embedding_dim) + + # Calculate distances + distances = (torch.sum(flat_input ** 2, dim=1, keepdim=True) + + torch.sum(self._embedding.weight ** 2, dim=1) + - 2 * torch.matmul(flat_input, self._embedding.weight.t())) + + # Encoding + encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1) + encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device) + encodings.scatter_(1, encoding_indices, 1) + + # Quantize and unflatten + quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape) + + # Use EMA to update the embedding vectors + if self.training: + self._ema_cluster_size = self._ema_cluster_size * self._decay + \ + (1 - self._decay) * torch.sum(encodings, 0) + + # Laplace smoothing of the cluster size + n = torch.sum(self._ema_cluster_size.data) + self._ema_cluster_size = ( + (self._ema_cluster_size + self._epsilon) + / (n + self._num_embeddings * self._epsilon) * n) + + dw = torch.matmul(encodings.t(), flat_input) + self._ema_w = nn.Parameter(self._ema_w * self._decay + (1 - self._decay) * dw) + + self._embedding.weight = nn.Parameter(self._ema_w / self._ema_cluster_size.unsqueeze(1)) + + # Loss + e_latent_loss = F.mse_loss(quantized.detach(), inputs) + loss = self._commitment_cost * e_latent_loss + + # Straight Through Estimator + quantized = inputs + (quantized - inputs).detach() + avg_probs = torch.mean(encodings, dim=0) + perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) + + # convert quantized from BHWC -> BCHW + min_encodings, min_encoding_indices = None, None + return quantized.permute(0, 3, 1, 2).contiguous(), loss, (perplexity, min_encodings, min_encoding_indices) + + +class DownSample(nn.Module): + """DownSample layer""" + + def __init__(self, in_channels, out_channels): + super(DownSample, self).__init__() + self._conv2d = nn.Conv2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=4, + stride=2, padding=1) + + def forward(self, x): + return self._conv2d(x) + + +class UpSample(nn.Module): + """UpSample layer""" + + def __init__(self, in_channels, out_channels): + super(UpSample, self).__init__() + self._conv2d = nn.ConvTranspose2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=4, + stride=2, padding=1) + + def forward(self, x): + return self._conv2d(x) + + +class ResnetBlock(nn.Module): + """ResnetBlock is a combination of non-linearity, convolution, and normalization""" + + def __init__(self, *, in_channels, out_channels=None, double_conv=False, conv_shortcut=False, + dropout=0.0, temb_channels=512, norm_type="groupnorm", act_type="relu", num_groups=32): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.act_type = act_type + + self.norm1 = Normalize(in_channels, norm_type=norm_type, num_groups=num_groups) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + + self.double_conv = double_conv + if self.double_conv: + self.norm2 = Normalize(out_channels, norm_type=norm_type, num_groups=num_groups) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb=None): + h = x + h = self.norm1(h) + h = nonlinearity(h, act_type=self.act_type) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb, act_type=self.act_type))[:, :, None, None] + + if self.double_conv: + h = self.norm2(h) + h = nonlinearity(h, act_type=self.act_type) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class LinearAttention(nn.Module): + """Efficient attention block based on """ + + def __init__(self, dim, heads=4, dim_head=32, with_skip=True): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + self.with_skip = with_skip + if self.with_skip: + self.nin_shortcut = torch.nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads=self.heads, qkv=3) + k = k.softmax(dim=-1) + context = torch.einsum('bhdn,bhen->bhde', k, v) + out = torch.einsum('bhde,bhdn->bhen', context, q) + out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) + + if self.with_skip: + return self.to_out(out) + self.nin_shortcut(x) + return self.to_out(out) + + +class Encoder(nn.Module): + """The encoder, consisting of alternating stacks of ResNet blocks, efficient attention modules, and downsampling layers.""" + + def __init__(self, in_channels, hidden_channels, embedding_dim, block_depth=2, + attn_pos=None, attn_with_skip=True, norm_type="groupnorm", act_type="relu", num_groups=32): + super(Encoder, self).__init__() + + if attn_pos is None: + attn_pos = [] + self._layers = nn.ModuleList([DownSample(in_channels, hidden_channels[0])]) + current_channel = hidden_channels[0] + + for i in range(1, len(hidden_channels)): + for _ in range(block_depth - 1): + self._layers.append(ResnetBlock(in_channels=current_channel, + out_channels=current_channel, + double_conv=False, + conv_shortcut=False, + norm_type=norm_type, + act_type=act_type, + num_groups=num_groups)) + if current_channel in attn_pos: + self._layers.append(LinearAttention(current_channel, 1, 32, attn_with_skip)) + + self._layers.append(Normalize(current_channel, norm_type=norm_type, num_groups=num_groups)) + self._layers.append(nn.ReLU()) + self._layers.append(DownSample(current_channel, hidden_channels[i])) + current_channel = hidden_channels[i] + + for _ in range(block_depth - 1): + self._layers.append(ResnetBlock(in_channels=current_channel, + out_channels=current_channel, + double_conv=False, + conv_shortcut=False, + norm_type=norm_type, + act_type=act_type, + num_groups=num_groups)) + if current_channel in attn_pos: + self._layers.append(LinearAttention(current_channel, 1, 32, attn_with_skip)) + + # Conv1x1: hidden_channels[-1] -> embedding_dim + self._layers.append(Normalize(current_channel, norm_type=norm_type, num_groups=num_groups)) + self._layers.append(nn.ReLU()) + self._layers.append(nn.Conv2d(in_channels=current_channel, + out_channels=embedding_dim, + kernel_size=1, + stride=1)) + + def forward(self, x): + for layer in self._layers: + x = layer(x) + return x + + +class Decoder(nn.Module): + """The decoder, consisting of alternating stacks of ResNet blocks, efficient attention modules, and upsampling layers.""" + + def __init__(self, embedding_dim, hidden_channels, out_channels, block_depth=2, + attn_pos=None, attn_with_skip=True, norm_type="groupnorm", act_type="relu", + num_groups=32): + super(Decoder, self).__init__() + + if attn_pos is None: + attn_pos = [] + reversed_hidden_channels = list(reversed(hidden_channels)) + + # Conv1x1: hidden_channels[-1] -> embedding_dim + self._layers = nn.ModuleList([nn.Conv2d(in_channels=embedding_dim, + out_channels=reversed_hidden_channels[0], + kernel_size=1, stride=1, bias=False)]) + + current_channel = reversed_hidden_channels[0] + + for _ in range(block_depth - 1): + if current_channel in attn_pos: + self._layers.append(LinearAttention(current_channel, 1, 32, attn_with_skip)) + self._layers.append(ResnetBlock(in_channels=current_channel, + out_channels=current_channel, + double_conv=False, + conv_shortcut=False, + norm_type=norm_type, + act_type=act_type, + num_groups=num_groups)) + + for i in range(1, len(reversed_hidden_channels)): + self._layers.append(Normalize(current_channel, norm_type=norm_type, num_groups=num_groups)) + self._layers.append(nn.ReLU()) + self._layers.append(UpSample(current_channel, reversed_hidden_channels[i])) + current_channel = reversed_hidden_channels[i] + + for _ in range(block_depth - 1): + if current_channel in attn_pos: + self._layers.append(LinearAttention(current_channel, 1, 32, attn_with_skip)) + self._layers.append(ResnetBlock(in_channels=current_channel, + out_channels=current_channel, + double_conv=False, + conv_shortcut=False, + norm_type=norm_type, + act_type=act_type, + num_groups=num_groups)) + + self._layers.append(Normalize(current_channel, norm_type=norm_type, num_groups=num_groups)) + self._layers.append(nn.ReLU()) + self._layers.append(UpSample(current_channel, current_channel)) + + # final layers + self._layers.append(ResnetBlock(in_channels=current_channel, + out_channels=out_channels, + double_conv=False, + conv_shortcut=False, + norm_type=norm_type, + act_type=act_type, + num_groups=num_groups)) + + + def forward(self, x): + for layer in self._layers: + x = layer(x) + + log_magnitude = torch.nn.functional.softplus(x[:, 0, :, :]) + + cos_phase = torch.tanh(x[:, 1, :, :]) + sin_phase = torch.tanh(x[:, 2, :, :]) + x = torch.stack([log_magnitude, cos_phase, sin_phase], dim=1) + + return x + + +class VQGAN_Discriminator(nn.Module): + """The discriminator employs an 18-layer-ResNet architecture , with the first layer replaced by a 2D convolutional + layer that accommodates spectral representation inputs and the last two layers replaced by a binary classifier + layer.""" + + def __init__(self, in_channels=1): + super(VQGAN_Discriminator, self).__init__() + resnet = models.resnet18(pretrained=True) + + # 修改第一层以接受单通道(黑白)图像 + resnet.conv1 = nn.Conv2d(in_channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) + + # 使用ResNet的特征提取部分 + self.features = nn.Sequential(*list(resnet.children())[:-2]) + + # 添加判别器的额外层 + self.classifier = nn.Sequential( + nn.Linear(512, 1), + nn.Sigmoid() + ) + + def forward(self, x): + x = self.features(x) + x = nn.functional.adaptive_avg_pool2d(x, (1, 1)) + x = torch.flatten(x, 1) + x = self.classifier(x) + return x + + +class VQGAN(nn.Module): + """The VQ-GAN model. """ + + def __init__(self, in_channels, hidden_channels, embedding_dim, out_channels, block_depth=2, + attn_pos=None, attn_with_skip=True, norm_type="groupnorm", act_type="relu", + num_embeddings=1024, commitment_cost=0.25, decay=0.99, num_groups=32): + super(VQGAN, self).__init__() + + self._encoder = Encoder(in_channels, hidden_channels, embedding_dim, block_depth=block_depth, + attn_pos=attn_pos, attn_with_skip=attn_with_skip, norm_type=norm_type, act_type="act_type", num_groups=num_groups) + + if decay > 0.0: + self._vq_vae = VectorQuantizerEMA(num_embeddings, embedding_dim, + commitment_cost, decay) + else: + self._vq_vae = VectorQuantizer(num_embeddings, embedding_dim, + commitment_cost) + self._decoder = Decoder(embedding_dim, hidden_channels, out_channels, block_depth=block_depth, + attn_pos=attn_pos, attn_with_skip=attn_with_skip, norm_type=norm_type, + act_type=act_type, num_groups=num_groups) + + def forward(self, x): + z = self._encoder(x) + quantized, vq_loss, (perplexity, _, _) = self._vq_vae(z) + x_recon = self._decoder(quantized) + + return vq_loss, x_recon, perplexity + + +class ReconstructionLoss(nn.Module): + def __init__(self, w1, w2, epsilon=1e-3): + super(ReconstructionLoss, self).__init__() + self.w1 = w1 + self.w2 = w2 + self.epsilon = epsilon + + def weighted_mae_loss(self, y_true, y_pred): + # avoid divide by zero + y_true_safe = torch.clamp(y_true, min=self.epsilon) + + # compute weighted MAE + loss = torch.mean(torch.abs(y_pred - y_true) / y_true_safe) + return loss + + def mae_loss(self, y_true, y_pred): + loss = torch.mean(torch.abs(y_pred - y_true)) + return loss + + def forward(self, y_pred, y_true): + # loss for magnitude channel + log_magnitude_loss = self.w1 * self.weighted_mae_loss(y_pred[:, 0, :, :], y_true[:, 0, :, :]) + + # loss for phase channels + phase_loss = self.w2 * self.mae_loss(y_pred[:, 1:, :, :], y_true[:, 1:, :, :]) + + # sum up + rec_loss = log_magnitude_loss + phase_loss + return log_magnitude_loss, phase_loss, rec_loss + + +def evaluate_VQGAN(model, discriminator, iterator, reconstructionLoss, adversarial_loss, trainingConfig): + model.to(trainingConfig["device"]) + model.eval() + train_res_error = [] + for i in xrange(100): + data = next(iter(iterator)) + data = data.to(trainingConfig["device"]) + + # true/fake labels + real_labels = torch.ones(data.size(0), 1).to(trainingConfig["device"]) + + vq_loss, data_recon, perplexity = model(data) + + + fake_preds = discriminator(data_recon) + adver_loss = adversarial_loss(fake_preds, real_labels) + + log_magnitude_loss, phase_loss, rec_loss = reconstructionLoss(data_recon, data) + loss = rec_loss + trainingConfig["vq_weight"] * vq_loss + trainingConfig["adver_weight"] * adver_loss + + train_res_error.append(loss.item()) + initial_loss = np.mean(train_res_error) + return initial_loss + + +def get_VQGAN(model_Config, load_pretrain=False, model_name=None, device="cpu"): + VQVAE = VQGAN(**model_Config) + print(f"Model intialized, size: {sum(p.numel() for p in VQVAE.parameters() if p.requires_grad)}") + VQVAE.to(device) + + if load_pretrain: + print(f"Loading weights from models/{model_name}_imageVQVAE.pth") + checkpoint = torch.load(f'models/{model_name}_imageVQVAE.pth', map_location=device) + VQVAE.load_state_dict(checkpoint['model_state_dict']) + VQVAE.eval() + return VQVAE + + +def train_VQGAN(model_Config, trainingConfig, iterator): + + def save_model_hyperparameter(model_Config, trainingConfig, current_iter, + log_magnitude_loss, phase_loss, current_perplexity, current_vq_loss, + current_loss): + model_name = trainingConfig["model_name"] + model_hyperparameter = model_Config + model_hyperparameter.update(trainingConfig) + model_hyperparameter["current_iter"] = current_iter + model_hyperparameter["log_magnitude_loss"] = log_magnitude_loss + model_hyperparameter["phase_loss"] = phase_loss + model_hyperparameter["erplexity"] = current_perplexity + model_hyperparameter["vq_loss"] = current_vq_loss + model_hyperparameter["total_loss"] = current_loss + + with open(f"models/hyperparameters/{model_name}_VQGAN_STFT.json", "w") as json_file: + json.dump(model_hyperparameter, json_file, ensure_ascii=False, indent=4) + + # initialize VAE + model = VQGAN(**model_Config) + print(f"VQ_VAE size: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") + model.to(trainingConfig["device"]) + + VAE_optimizer = torch.optim.Adam(model.parameters(), lr=trainingConfig["lr"], amsgrad=False) + model_name = trainingConfig["model_name"] + + if trainingConfig["load_pretrain"]: + print(f"Loading weights from models/{model_name}_imageVQVAE.pth") + checkpoint = torch.load(f'models/{model_name}_imageVQVAE.pth', map_location=trainingConfig["device"]) + model.load_state_dict(checkpoint['model_state_dict']) + VAE_optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + else: + print("VAE initialized.") + if trainingConfig["max_iter"] == 0: + print("Return VAE directly.") + return model + + # initialize discriminator + discriminator = VQGAN_Discriminator(model_Config["in_channels"]) + print(f"Discriminator size: {sum(p.numel() for p in discriminator.parameters() if p.requires_grad)}") + discriminator.to(trainingConfig["device"]) + + discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=trainingConfig["d_lr"], amsgrad=False) + + if trainingConfig["load_pretrain"]: + print(f"Loading weights from models/{model_name}_imageVQVAE_discriminator.pth") + checkpoint = torch.load(f'models/{model_name}_imageVQVAE_discriminator.pth', map_location=trainingConfig["device"]) + discriminator.load_state_dict(checkpoint['model_state_dict']) + discriminator_optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + else: + print("Discriminator initialized.") + + # Training + + train_res_phase_loss, train_res_perplexity, train_res_log_magnitude_loss, train_res_vq_loss, train_res_loss = [], [], [], [], [] + train_discriminator_loss, train_adverserial_loss = [], [] + + reconstructionLoss = ReconstructionLoss(w1=trainingConfig["w1"], w2=trainingConfig["w2"], epsilon=trainingConfig["threshold"]) + + adversarial_loss = nn.BCEWithLogitsLoss() + writer = SummaryWriter(f'runs/{model_name}_VQVAE_lr=1e-4') + + previous_lowest_loss = evaluate_VQGAN(model, discriminator, iterator, + reconstructionLoss, adversarial_loss, trainingConfig) + print(f"initial_loss: {previous_lowest_loss}") + + model.train() + for i in xrange(trainingConfig["max_iter"]): + data = next(iter(iterator)) + data = data.to(trainingConfig["device"]) + + # true/fake labels + real_labels = torch.ones(data.size(0), 1).to(trainingConfig["device"]) + fake_labels = torch.zeros(data.size(0), 1).to(trainingConfig["device"]) + + # update discriminator + discriminator_optimizer.zero_grad() + + vq_loss, data_recon, perplexity = model(data) + + real_preds = discriminator(data) + fake_preds = discriminator(data_recon.detach()) + + loss_real = adversarial_loss(real_preds, real_labels) + loss_fake = adversarial_loss(fake_preds, fake_labels) + + loss_D = loss_real + loss_fake + loss_D.backward() + discriminator_optimizer.step() + + + # update VQVAE + VAE_optimizer.zero_grad() + + fake_preds = discriminator(data_recon) + adver_loss = adversarial_loss(fake_preds, real_labels) + + log_magnitude_loss, phase_loss, rec_loss = reconstructionLoss(data_recon, data) + + loss = rec_loss + trainingConfig["vq_weight"] * vq_loss + trainingConfig["adver_weight"] * adver_loss + loss.backward() + VAE_optimizer.step() + + train_discriminator_loss.append(loss_D.item()) + train_adverserial_loss.append(trainingConfig["adver_weight"] * adver_loss.item()) + train_res_log_magnitude_loss.append(log_magnitude_loss.item()) + train_res_phase_loss.append(phase_loss.item()) + train_res_perplexity.append(perplexity.item()) + train_res_vq_loss.append(trainingConfig["vq_weight"] * vq_loss.item()) + train_res_loss.append(loss.item()) + step = int(VAE_optimizer.state_dict()['state'][list(VAE_optimizer.state_dict()['state'].keys())[0]]['step'].cpu().numpy()) + + save_steps = trainingConfig["save_steps"] + if (i + 1) % 100 == 0: + print('%d step' % (step)) + + if (i + 1) % save_steps == 0: + current_discriminator_loss = np.mean(train_discriminator_loss[-save_steps:]) + current_adverserial_loss = np.mean(train_adverserial_loss[-save_steps:]) + current_log_magnitude_loss = np.mean(train_res_log_magnitude_loss[-save_steps:]) + current_phase_loss = np.mean(train_res_phase_loss[-save_steps:]) + current_perplexity = np.mean(train_res_perplexity[-save_steps:]) + current_vq_loss = np.mean(train_res_vq_loss[-save_steps:]) + current_loss = np.mean(train_res_loss[-save_steps:]) + + print('discriminator_loss: %.3f' % current_discriminator_loss) + print('adverserial_loss: %.3f' % current_adverserial_loss) + print('log_magnitude_loss: %.3f' % current_log_magnitude_loss) + print('phase_loss: %.3f' % current_phase_loss) + print('perplexity: %.3f' % current_perplexity) + print('vq_loss: %.3f' % current_vq_loss) + print('total_loss: %.3f' % current_loss) + writer.add_scalar(f"log_magnitude_loss", current_log_magnitude_loss, step) + writer.add_scalar(f"phase_loss", current_phase_loss, step) + writer.add_scalar(f"perplexity", current_perplexity, step) + writer.add_scalar(f"vq_loss", current_vq_loss, step) + writer.add_scalar(f"total_loss", current_loss, step) + if current_loss < previous_lowest_loss: + previous_lowest_loss = current_loss + + torch.save({ + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': VAE_optimizer.state_dict(), + }, f'models/{model_name}_imageVQVAE.pth') + + torch.save({ + 'model_state_dict': discriminator.state_dict(), + 'optimizer_state_dict': discriminator_optimizer.state_dict(), + }, f'models/{model_name}_imageVQVAE_discriminator.pth') + + save_model_hyperparameter(model_Config, trainingConfig, step, + current_log_magnitude_loss, current_phase_loss, current_perplexity, current_vq_loss, + current_loss) + + return model \ No newline at end of file diff --git a/model/__pycache__/DiffSynthSampler.cpython-310.pyc b/model/__pycache__/DiffSynthSampler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f26091b0ae12c3b3411ba15f96e90963d3d38f02 Binary files /dev/null and b/model/__pycache__/DiffSynthSampler.cpython-310.pyc differ diff --git a/model/__pycache__/GAN.cpython-310.pyc b/model/__pycache__/GAN.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..476c402c904cdf6b24bef41c168865242c509194 Binary files /dev/null and b/model/__pycache__/GAN.cpython-310.pyc differ diff --git a/model/__pycache__/VQGAN.cpython-310.pyc b/model/__pycache__/VQGAN.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c98118b240f1a46cc50d828e0efdcbc7c812c0a Binary files /dev/null and b/model/__pycache__/VQGAN.cpython-310.pyc differ diff --git a/model/__pycache__/diffusion.cpython-310.pyc b/model/__pycache__/diffusion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..549244a5c5bc5fd769219a2a8570f401f5dae3f5 Binary files /dev/null and b/model/__pycache__/diffusion.cpython-310.pyc differ diff --git a/model/__pycache__/diffusion_components.cpython-310.pyc b/model/__pycache__/diffusion_components.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3bb6470f7349664bc53da79e3327e17062460e94 Binary files /dev/null and b/model/__pycache__/diffusion_components.cpython-310.pyc differ diff --git a/model/__pycache__/multimodal_model.cpython-310.pyc b/model/__pycache__/multimodal_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37360da9ad8ffd0c599fc101b16bef17bd58aaad Binary files /dev/null and b/model/__pycache__/multimodal_model.cpython-310.pyc differ diff --git a/model/__pycache__/perceptual_label_predictor.cpython-37.pyc b/model/__pycache__/perceptual_label_predictor.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0fe0f4fce0445e7cd54cae4507a13e3b1b093da Binary files /dev/null and b/model/__pycache__/perceptual_label_predictor.cpython-37.pyc differ diff --git a/model/__pycache__/timbre_encoder_pretrain.cpython-310.pyc b/model/__pycache__/timbre_encoder_pretrain.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..543bd5451b9e86189200de1a162c5622ce575984 Binary files /dev/null and b/model/__pycache__/timbre_encoder_pretrain.cpython-310.pyc differ diff --git a/model/diffusion.py b/model/diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..4cd7f68e90a2a365d6fd7d51383f318bfb3e5bf5 --- /dev/null +++ b/model/diffusion.py @@ -0,0 +1,371 @@ +import json +from functools import partial + +import numpy as np +import torch +from torch import nn +import torch.nn.functional as F +from six.moves import xrange +from torch.utils.tensorboard import SummaryWriter +import random + +from metrics.IS import get_inception_score +from tools import create_key + +from model.diffusion_components import default, ConvNextBlock, ResnetBlock, SinusoidalPositionEmbeddings, Residual, \ + PreNorm, \ + Downsample, Upsample, exists, q_sample, get_beta_schedule, pad_and_concat, ConditionalEmbedding, \ + LinearCrossAttention, LinearCrossAttentionAdd + + +class ConditionedUnet(nn.Module): + def __init__( + self, + in_dim, + out_dim=None, + down_dims=None, + up_dims=None, + mid_depth=3, + with_time_emb=True, + time_dim=None, + resnet_block_groups=8, + use_convnext=True, + convnext_mult=2, + attn_type="linear_cat", + n_label_class=11, + condition_type="instrument_family", + label_emb_dim=128, + ): + super().__init__() + + self.label_embedding = ConditionalEmbedding(int(n_label_class + 1), int(label_emb_dim), condition_type) + + if up_dims is None: + up_dims = [128, 128, 64, 32] + if down_dims is None: + down_dims = [32, 32, 64, 128] + + out_dim = default(out_dim, in_dim) + assert len(down_dims) == len(up_dims), "len(down_dims) != len(up_dims)" + assert down_dims[0] == up_dims[-1], "down_dims[0] != up_dims[-1]" + assert up_dims[0] == down_dims[-1], "up_dims[0] != down_dims[-1]" + down_in_out = list(zip(down_dims[:-1], down_dims[1:])) + up_in_out = list(zip(up_dims[:-1], up_dims[1:])) + print(f"down_in_out: {down_in_out}") + print(f"up_in_out: {up_in_out}") + time_dim = default(time_dim, int(down_dims[0] * 4)) + + self.init_conv = nn.Conv2d(in_dim, down_dims[0], 7, padding=3) + + if use_convnext: + block_klass = partial(ConvNextBlock, mult=convnext_mult) + else: + block_klass = partial(ResnetBlock, groups=resnet_block_groups) + + if attn_type == "linear_cat": + attn_klass = partial(LinearCrossAttention) + elif attn_type == "linear_add": + attn_klass = partial(LinearCrossAttentionAdd) + else: + raise NotImplementedError() + + # time embeddings + if with_time_emb: + self.time_mlp = nn.Sequential( + SinusoidalPositionEmbeddings(down_dims[0]), + nn.Linear(down_dims[0], time_dim), + nn.GELU(), + nn.Linear(time_dim, time_dim), + ) + else: + time_dim = None + self.time_mlp = None + + # left layers + self.downs = nn.ModuleList([]) + self.ups = nn.ModuleList([]) + skip_dims = [] + + for down_dim_in, down_dim_out in down_in_out: + self.downs.append( + nn.ModuleList( + [ + block_klass(down_dim_in, down_dim_out, time_emb_dim=time_dim), + + Residual(PreNorm(down_dim_out, attn_klass(down_dim_out, label_emb_dim=label_emb_dim, ))), + block_klass(down_dim_out, down_dim_out, time_emb_dim=time_dim), + Residual(PreNorm(down_dim_out, attn_klass(down_dim_out, label_emb_dim=label_emb_dim, ))), + Downsample(down_dim_out), + ] + ) + ) + skip_dims.append(down_dim_out) + + # bottleneck + mid_dim = down_dims[-1] + self.mid_left = nn.ModuleList([]) + self.mid_right = nn.ModuleList([]) + for _ in range(mid_depth - 1): + self.mid_left.append(block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)) + self.mid_right.append(block_klass(mid_dim * 2, mid_dim, time_emb_dim=time_dim)) + self.mid_mid = nn.ModuleList( + [ + block_klass(mid_dim, mid_dim, time_emb_dim=time_dim), + Residual(PreNorm(mid_dim, attn_klass(mid_dim, label_emb_dim=label_emb_dim, ))), + block_klass(mid_dim, mid_dim, time_emb_dim=time_dim), + ] + ) + + # right layers + for ind, (up_dim_in, up_dim_out) in enumerate(up_in_out): + skip_dim = skip_dims.pop() # down_dim_out + self.ups.append( + nn.ModuleList( + [ + # pop&cat (h/2, w/2, down_dim_out) + block_klass(up_dim_in + skip_dim, up_dim_in, time_emb_dim=time_dim), + Residual(PreNorm(up_dim_in, attn_klass(up_dim_in, label_emb_dim=label_emb_dim, ))), + Upsample(up_dim_in), + # pop&cat (h, w, down_dim_out) + block_klass(up_dim_in + skip_dim, up_dim_out, time_emb_dim=time_dim), + Residual(PreNorm(up_dim_out, attn_klass(up_dim_out, label_emb_dim=label_emb_dim, ))), + # pop&cat (h, w, down_dim_out) + block_klass(up_dim_out + skip_dim, up_dim_out, time_emb_dim=time_dim), + Residual(PreNorm(up_dim_out, attn_klass(up_dim_out, label_emb_dim=label_emb_dim, ))), + ] + ) + ) + + self.final_conv = nn.Sequential( + block_klass(down_dims[0] + up_dims[-1], up_dims[-1]), nn.Conv2d(up_dims[-1], out_dim, 3, padding=1) + ) + + def size(self): + total_params = sum(p.numel() for p in self.parameters()) + trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) + print(f"Total parameters: {total_params}") + print(f"Trainable parameters: {trainable_params}") + + + def forward(self, x, time, condition=None): + + if condition is not None: + condition_emb = self.label_embedding(condition) + else: + condition_emb = None + + h = [] + + x = self.init_conv(x) + h.append(x) + + time_emb = self.time_mlp(time) if exists(self.time_mlp) else None + + # downsample + for block1, attn1, block2, attn2, downsample in self.downs: + x = block1(x, time_emb) + x = attn1(x, condition_emb) + h.append(x) + x = block2(x, time_emb) + x = attn2(x, condition_emb) + h.append(x) + x = downsample(x) + h.append(x) + + # bottleneck + + for block in self.mid_left: + x = block(x, time_emb) + h.append(x) + + (block1, attn, block2) = self.mid_mid + x = block1(x, time_emb) + x = attn(x, condition_emb) + x = block2(x, time_emb) + + for block in self.mid_right: + # This is U-Net!!! + x = pad_and_concat(h.pop(), x) + x = block(x, time_emb) + + # upsample + for block1, attn1, upsample, block2, attn2, block3, attn3 in self.ups: + x = pad_and_concat(h.pop(), x) + x = block1(x, time_emb) + x = attn1(x, condition_emb) + x = upsample(x) + + x = pad_and_concat(h.pop(), x) + x = block2(x, time_emb) + x = attn2(x, condition_emb) + + x = pad_and_concat(h.pop(), x) + x = block3(x, time_emb) + x = attn3(x, condition_emb) + + x = pad_and_concat(h.pop(), x) + x = self.final_conv(x) + return x + + +def conditional_p_losses(denoise_model, x_start, t, condition, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, + noise=None, loss_type="l1"): + if noise is None: + noise = torch.randn_like(x_start) + + x_noisy = q_sample(x_start=x_start, t=t, sqrt_alphas_cumprod=sqrt_alphas_cumprod, + sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod, noise=noise) + predicted_noise = denoise_model(x_noisy, t, condition) + + if loss_type == 'l1': + loss = F.l1_loss(noise, predicted_noise) + elif loss_type == 'l2': + loss = F.mse_loss(noise, predicted_noise) + elif loss_type == "huber": + loss = F.smooth_l1_loss(noise, predicted_noise) + else: + raise NotImplementedError() + + return loss + + +def evaluate_diffusion_model(device, model, iterator, BATCH_SIZE, timesteps, unetConfig, encodes2embeddings_mapping, + uncondition_rate, unconditional_condition): + model.to(device) + model.eval() + eva_loss = [] + sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, _, _ = get_beta_schedule(timesteps) + for i in xrange(500): + data, attributes = next(iter(iterator)) + data = data.to(device) + + conditions = [encodes2embeddings_mapping[create_key(attribute)] for attribute in attributes] + selected_conditions = [ + unconditional_condition if random.random() < uncondition_rate else random.choice(conditions_of_one_sample) + for conditions_of_one_sample in conditions] + + selected_conditions = torch.stack(selected_conditions).float().to(device) + + t = torch.randint(0, timesteps, (BATCH_SIZE,), device=device).long() + loss = conditional_p_losses(model, data, t, selected_conditions, loss_type="huber", + sqrt_alphas_cumprod=sqrt_alphas_cumprod, + sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod) + + eva_loss.append(loss.item()) + initial_loss = np.mean(eva_loss) + return initial_loss + + +def get_diffusion_model(model_Config, load_pretrain=False, model_name=None, device="cpu"): + UNet = ConditionedUnet(**model_Config) + print(f"Model intialized, size: {sum(p.numel() for p in UNet.parameters() if p.requires_grad)}") + UNet.to(device) + + if load_pretrain: + print(f"Loading weights from models/{model_name}_UNet.pth") + checkpoint = torch.load(f'models/{model_name}_UNet.pth', map_location=device) + UNet.load_state_dict(checkpoint['model_state_dict']) + UNet.eval() + return UNet + + +def train_diffusion_model(VAE, text_encoder, CLAP_tokenizer, timbre_encoder, device, init_model_name, unetConfig, BATCH_SIZE, timesteps, lr, max_iter, iterator, load_pretrain, + encodes2embeddings_mapping, uncondition_rate, unconditional_condition, save_steps=5000, init_loss=None, save_model_name=None, + n_IS_batches=50): + + if save_model_name is None: + save_model_name = init_model_name + + def save_model_hyperparameter(model_name, unetConfig, BATCH_SIZE, lr, model_size, current_iter, current_loss): + model_hyperparameter = unetConfig + model_hyperparameter["BATCH_SIZE"] = BATCH_SIZE + model_hyperparameter["lr"] = lr + model_hyperparameter["model_size"] = model_size + model_hyperparameter["current_iter"] = current_iter + model_hyperparameter["current_loss"] = current_loss + with open(f"models/hyperparameters/{model_name}_UNet.json", "w") as json_file: + json.dump(model_hyperparameter, json_file, ensure_ascii=False, indent=4) + + model = ConditionedUnet(**unetConfig) + model_size = sum(p.numel() for p in model.parameters() if p.requires_grad) + print(f"Trainable parameters: {model_size}") + model.to(device) + optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, amsgrad=False) + + if load_pretrain: + print(f"Loading weights from models/{init_model_name}_UNet.pt") + checkpoint = torch.load(f'models/{init_model_name}_UNet.pth') + model.load_state_dict(checkpoint['model_state_dict']) + optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + else: + print("Model initialized.") + if max_iter == 0: + print("Return model directly.") + return model, optimizer + + + train_loss = [] + writer = SummaryWriter(f'runs/{save_model_name}_UNet') + if init_loss is None: + previous_loss = evaluate_diffusion_model(device, model, iterator, BATCH_SIZE, timesteps, unetConfig, encodes2embeddings_mapping, + uncondition_rate, unconditional_condition) + else: + previous_loss = init_loss + print(f"initial_IS: {previous_loss}") + sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, _, _ = get_beta_schedule(timesteps) + + model.train() + for i in xrange(max_iter): + data, attributes = next(iter(iterator)) + data = data.to(device) + + conditions = [encodes2embeddings_mapping[create_key(attribute)] for attribute in attributes] + unconditional_condition_copy = torch.tensor(unconditional_condition, dtype=torch.float32).to(device).detach() + selected_conditions = [unconditional_condition_copy if random.random() < uncondition_rate else random.choice( + conditions_of_one_sample) for conditions_of_one_sample in conditions] + + selected_conditions = torch.stack(selected_conditions).float().to(device) + + optimizer.zero_grad() + + t = torch.randint(0, timesteps, (BATCH_SIZE,), device=device).long() + loss = conditional_p_losses(model, data, t, selected_conditions, loss_type="huber", + sqrt_alphas_cumprod=sqrt_alphas_cumprod, + sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod) + + loss.backward() + optimizer.step() + + train_loss.append(loss.item()) + step = int(optimizer.state_dict()['state'][list(optimizer.state_dict()['state'].keys())[0]]['step'].numpy()) + + if step % 100 == 0: + print('%d step' % (step)) + + if step % save_steps == 0: + current_loss = np.mean(train_loss[-save_steps:]) + print(f"current_loss = {current_loss}") + torch.save({ + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + }, f'models/{save_model_name}_UNet.pth') + save_model_hyperparameter(save_model_name, unetConfig, BATCH_SIZE, lr, model_size, step, current_loss) + + + if step % 20000 == 0: + current_IS = get_inception_score(device, model, VAE, text_encoder, CLAP_tokenizer, timbre_encoder, n_IS_batches, + positive_prompts="", negative_prompts="", CFG=1, sample_steps=20, task="STFT") + print('current_IS: %.5f' % current_IS) + current_loss = np.mean(train_loss[-save_steps:]) + + writer.add_scalar(f"current_IS", current_IS, step) + + torch.save({ + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + }, f'models/history/{save_model_name}_{step}_UNet.pth') + save_model_hyperparameter(save_model_name, unetConfig, BATCH_SIZE, lr, model_size, step, current_loss) + + return model, optimizer + + diff --git a/model/diffusion_components.py b/model/diffusion_components.py new file mode 100644 index 0000000000000000000000000000000000000000..e0ee12505856b7e153045b92c2414a183aecb59f --- /dev/null +++ b/model/diffusion_components.py @@ -0,0 +1,351 @@ +import torch.nn.functional as F +import torch +from torch import nn +from einops import rearrange +from inspect import isfunction +import math +from tqdm import tqdm + + +def exists(x): + """Return true for x is not None.""" + return x is not None + + +def default(val, d): + """Helper function""" + if exists(val): + return val + return d() if isfunction(d) else d + + +class Residual(nn.Module): + """Skip connection""" + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x, *args, **kwargs): + return self.fn(x, *args, **kwargs) + x + + +def Upsample(dim): + """Upsample layer, a transposed convolution layer with stride=2""" + return nn.ConvTranspose2d(dim, dim, 4, 2, 1) + + +def Downsample(dim): + """Downsample layer, a convolution layer with stride=2""" + return nn.Conv2d(dim, dim, 4, 2, 1) + + +class SinusoidalPositionEmbeddings(nn.Module): + """Return sinusoidal embedding for integer time step.""" + + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, time): + device = time.device + half_dim = self.dim // 2 + embeddings = math.log(10000) / (half_dim - 1) + embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) + embeddings = time[:, None] * embeddings[None, :] + embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) + return embeddings + + +class Block(nn.Module): + """Stack of convolution, normalization, and non-linear activation""" + + def __init__(self, dim, dim_out, groups=8): + super().__init__() + self.proj = nn.Conv2d(dim, dim_out, 3, padding=1) + self.norm = nn.GroupNorm(groups, dim_out) + self.act = nn.SiLU() + + def forward(self, x, scale_shift=None): + x = self.proj(x) + x = self.norm(x) + + if exists(scale_shift): + scale, shift = scale_shift + x = x * (scale + 1) + shift + + x = self.act(x) + return x + + +class ResnetBlock(nn.Module): + """Stack of [conv + norm + act (+ scale&shift)], with positional embedding inserted """ + + def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8): + super().__init__() + self.mlp = ( + nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out)) + if exists(time_emb_dim) + else None + ) + + self.block1 = Block(dim, dim_out, groups=groups) + self.block2 = Block(dim_out, dim_out, groups=groups) + self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() + + def forward(self, x, time_emb=None): + h = self.block1(x) + + if exists(self.mlp) and exists(time_emb): + time_emb = self.mlp(time_emb) + # Adding positional embedding to intermediate layer (by broadcasting along spatial dimension) + h = rearrange(time_emb, "b c -> b c 1 1") + h + + h = self.block2(h) + return h + self.res_conv(x) + + +class ConvNextBlock(nn.Module): + """Stack of [conv7x7 (+ condition(pos)) + norm + conv3x3 + act + norm + conv3x3 + res1x1],with positional embedding inserted""" + + def __init__(self, dim, dim_out, *, time_emb_dim=None, mult=2, norm=True): + super().__init__() + self.mlp = ( + nn.Sequential(nn.GELU(), nn.Linear(time_emb_dim, dim)) + if exists(time_emb_dim) + else None + ) + + self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, groups=dim) + + self.net = nn.Sequential( + nn.GroupNorm(1, dim) if norm else nn.Identity(), + nn.Conv2d(dim, dim_out * mult, 3, padding=1), + nn.GELU(), + nn.GroupNorm(1, dim_out * mult), + nn.Conv2d(dim_out * mult, dim_out, 3, padding=1), + ) + + self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() + + def forward(self, x, time_emb=None): + h = self.ds_conv(x) + + if exists(self.mlp) and exists(time_emb): + assert exists(time_emb), "time embedding must be passed in" + condition = self.mlp(time_emb) + h = h + rearrange(condition, "b c -> b c 1 1") + + h = self.net(h) + return h + self.res_conv(x) + + +class PreNorm(nn.Module): + """Apply normalization before 'fn'""" + + def __init__(self, dim, fn): + super().__init__() + self.fn = fn + self.norm = nn.GroupNorm(1, dim) + + def forward(self, x, *args, **kwargs): + x = self.norm(x) + return self.fn(x, *args, **kwargs) + + +class ConditionalEmbedding(nn.Module): + """Return embedding for label and projection for text embedding""" + + def __init__(self, num_labels, embedding_dim, condition_type="instrument_family"): + super(ConditionalEmbedding, self).__init__() + if condition_type == "instrument_family": + self.embedding = nn.Embedding(num_labels, embedding_dim) + elif condition_type == "natural_language_prompt": + self.embedding = nn.Linear(embedding_dim, embedding_dim, bias=True) + else: + raise NotImplementedError() + + def forward(self, labels): + return self.embedding(labels) + + +class LinearCrossAttention(nn.Module): + """Combination of efficient attention and cross attention.""" + + def __init__(self, dim, heads=4, label_emb_dim=128, dim_head=32): + super().__init__() + self.dim_head = dim_head + self.scale = dim_head ** -0.5 + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), nn.GroupNorm(1, dim)) + + # embedding for key and value + self.label_key = nn.Linear(label_emb_dim, hidden_dim) + self.label_value = nn.Linear(label_emb_dim, hidden_dim) + + def forward(self, x, label_embedding=None): + b, c, h, w = x.shape + qkv = self.to_qkv(x).chunk(3, dim=1) + q, k, v = map( + lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv + ) + + if label_embedding is not None: + label_k = self.label_key(label_embedding).view(b, self.heads, self.dim_head, 1) + label_v = self.label_value(label_embedding).view(b, self.heads, self.dim_head, 1) + + k = torch.cat([k, label_k], dim=-1) + v = torch.cat([v, label_v], dim=-1) + + q = q.softmax(dim=-2) + k = k.softmax(dim=-1) + q = q * self.scale + context = torch.einsum("b h d n, b h e n -> b h d e", k, v) + out = torch.einsum("b h d e, b h d n -> b h e n", context, q) + out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w) + return self.to_out(out) + + +def pad_to_match(encoder_tensor, decoder_tensor): + """ + Pads the decoder_tensor to match the spatial dimensions of encoder_tensor. + + :param encoder_tensor: The feature map from the encoder. + :param decoder_tensor: The feature map from the decoder that needs to be upsampled. + :return: Padded decoder_tensor with the same spatial dimensions as encoder_tensor. + """ + + enc_shape = encoder_tensor.shape[2:] # spatial dimensions are at index 2 and 3 + dec_shape = decoder_tensor.shape[2:] + + # assume enc_shape >= dec_shape + delta_w = enc_shape[1] - dec_shape[1] + delta_h = enc_shape[0] - dec_shape[0] + + # padding + padding_left = delta_w // 2 + padding_right = delta_w - padding_left + padding_top = delta_h // 2 + padding_bottom = delta_h - padding_top + decoder_tensor_padded = F.pad(decoder_tensor, (padding_left, padding_right, padding_top, padding_bottom)) + + return decoder_tensor_padded + + +def pad_and_concat(encoder_tensor, decoder_tensor): + """ + Pads the decoder_tensor and concatenates it with the encoder_tensor along the channel dimension. + + :param encoder_tensor: The feature map from the encoder. + :param decoder_tensor: The feature map from the decoder that needs to be concatenated with encoder_tensor. + :return: Concatenated tensor. + """ + + # pad decoder_tensor + decoder_tensor_padded = pad_to_match(encoder_tensor, decoder_tensor) + # concat encoder_tensor and decoder_tensor_padded + concatenated_tensor = torch.cat((encoder_tensor, decoder_tensor_padded), dim=1) + return concatenated_tensor + + +class LinearCrossAttentionAdd(nn.Module): + def __init__(self, dim, heads=4, label_emb_dim=128, dim_head=32): + super().__init__() + self.dim = dim + self.dim_head = dim_head + self.scale = dim_head ** -0.5 + self.heads = heads + self.label_emb_dim = label_emb_dim + self.dim_head = dim_head + + self.hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(self.dim, self.hidden_dim * 3, 1, bias=False) + self.to_out = nn.Sequential(nn.Conv2d(self.hidden_dim, dim, 1), nn.GroupNorm(1, dim)) + + # embedding for key and value + self.label_key = nn.Linear(label_emb_dim, self.hidden_dim) + self.label_query = nn.Linear(label_emb_dim, self.hidden_dim) + + + def forward(self, x, condition=None): + b, c, h, w = x.shape + + qkv = self.to_qkv(x).chunk(3, dim=1) + + q, k, v = map( + lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv + ) + + # if condition exists,concat its key and value with origin + if condition is not None: + label_k = self.label_key(condition).view(b, self.heads, self.dim_head, 1) + label_q = self.label_query(condition).view(b, self.heads, self.dim_head, 1) + k = k + label_k + q = q + label_q + + q = q.softmax(dim=-2) + k = k.softmax(dim=-1) + q = q * self.scale + context = torch.einsum("b h d n, b h e n -> b h d e", k, v) + out = torch.einsum("b h d e, b h d n -> b h e n", context, q) + out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w) + return self.to_out(out) + + + +def linear_beta_schedule(timesteps): + beta_start = 0.0001 + beta_end = 0.02 + return torch.linspace(beta_start, beta_end, timesteps) + + +def get_beta_schedule(timesteps): + betas = linear_beta_schedule(timesteps=timesteps) + + # define alphas + alphas = 1. - betas + alphas_cumprod = torch.cumprod(alphas, axis=0) + alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) + sqrt_recip_alphas = torch.sqrt(1.0 / alphas) + + # calculations for diffusion q(x_t | x_{t-1}) and others + sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) + sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + return sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, posterior_variance, sqrt_recip_alphas + + +def extract(a, t, x_shape): + batch_size = t.shape[0] + out = a.gather(-1, t.cpu()) + return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device) + + +# forward diffusion +def q_sample(x_start, t, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, noise=None): + if noise is None: + noise = torch.randn_like(x_start) + + sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape) + sqrt_one_minus_alphas_cumprod_t = extract( + sqrt_one_minus_alphas_cumprod, t, x_start.shape + ) + + return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise + + + + + + + + + + + + + + diff --git a/model/multimodal_model.py b/model/multimodal_model.py new file mode 100644 index 0000000000000000000000000000000000000000..01a543b4085c375195cb149762393d944433aef6 --- /dev/null +++ b/model/multimodal_model.py @@ -0,0 +1,274 @@ +import itertools +import json +import random + +import numpy as np +import torch +from torch import nn +import torch.nn.functional as F + +from tools import create_key +from model.timbre_encoder_pretrain import get_timbre_encoder + + +class ProjectionLayer(nn.Module): + """Single-layer Linear projection with dropout, layer norm, and Gelu activation""" + + def __init__(self, input_dim, output_dim, dropout): + super(ProjectionLayer, self).__init__() + self.projection = nn.Linear(input_dim, output_dim) + self.gelu = nn.GELU() + self.fc = nn.Linear(output_dim, output_dim) + self.dropout = nn.Dropout(dropout) + self.layer_norm = nn.LayerNorm(output_dim) + + def forward(self, x): + projected = self.projection(x) + x = self.gelu(projected) + x = self.fc(x) + x = self.dropout(x) + x = x + projected + x = self.layer_norm(x) + return x + + +class ProjectionHead(nn.Module): + """Stack of 'ProjectionLayer'""" + + def __init__(self, embedding_dim, projection_dim, dropout, num_layers=2): + super(ProjectionHead, self).__init__() + self.layers = nn.ModuleList([ProjectionLayer(embedding_dim if i == 0 else projection_dim, + projection_dim, + dropout) for i in range(num_layers)]) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x + + +class multi_modal_model(nn.Module): + """The multi-modal model for contrastive learning""" + + def __init__( + self, + timbre_encoder, + text_encoder, + spectrogram_feature_dim, + text_feature_dim, + multi_modal_emb_dim, + temperature, + dropout, + num_projection_layers=1, + freeze_spectrogram_encoder=True, + freeze_text_encoder=True, + ): + super().__init__() + self.timbre_encoder = timbre_encoder + self.text_encoder = text_encoder + + self.multi_modal_emb_dim = multi_modal_emb_dim + + self.text_projection = ProjectionHead(embedding_dim=text_feature_dim, + projection_dim=self.multi_modal_emb_dim, dropout=dropout, + num_layers=num_projection_layers) + + self.spectrogram_projection = ProjectionHead(embedding_dim=spectrogram_feature_dim, + projection_dim=self.multi_modal_emb_dim, dropout=dropout, + num_layers=num_projection_layers) + + self.temperature = temperature + + # Make spectrogram_encoder parameters non-trainable + for param in self.timbre_encoder.parameters(): + param.requires_grad = not freeze_spectrogram_encoder + + # Make text_encoder parameters non-trainable + for param in self.text_encoder.parameters(): + param.requires_grad = not freeze_text_encoder + + def forward(self, spectrogram_batch, tokenized_text_batch): + # Getting Image and Text Embeddings (with same dimension) + spectrogram_features, _, _, _, _ = self.timbre_encoder(spectrogram_batch) + text_features = self.text_encoder.get_text_features(**tokenized_text_batch) + + # Concat and apply projection + spectrogram_embeddings = self.spectrogram_projection(spectrogram_features) + text_embeddings = self.text_projection(text_features) + + # Calculating the Loss + logits = (text_embeddings @ spectrogram_embeddings.T) / self.temperature + images_similarity = spectrogram_embeddings @ spectrogram_embeddings.T + texts_similarity = text_embeddings @ text_embeddings.T + targets = F.softmax( + (images_similarity + texts_similarity) / 2 * self.temperature, dim=-1 + ) + texts_loss = cross_entropy(logits, targets, reduction='none') + images_loss = cross_entropy(logits.T, targets.T, reduction='none') + contrastive_loss = (images_loss + texts_loss) / 2.0 # shape: (batch_size) + contrastive_loss = contrastive_loss.mean() + + return contrastive_loss + + + def get_text_features(self, input_ids, attention_mask): + text_features = self.text_encoder.get_text_features(input_ids=input_ids, attention_mask=attention_mask) + return self.text_projection(text_features) + + + def get_timbre_features(self, spectrogram_batch): + spectrogram_features, _, _, _, _ = self.timbre_encoder(spectrogram_batch) + return self.spectrogram_projection(spectrogram_features) + + +def cross_entropy(preds, targets, reduction='none'): + log_softmax = nn.LogSoftmax(dim=-1) + loss = (-targets * log_softmax(preds)).sum(1) + if reduction == "none": + return loss + elif reduction == "mean": + return loss.mean() + + +def get_multi_modal_model(timbre_encoder, text_encoder, model_Config, load_pretrain=False, model_name=None, device="cpu"): + mmm = multi_modal_model(timbre_encoder, text_encoder, **model_Config) + print(f"Model intialized, size: {sum(p.numel() for p in mmm.parameters() if p.requires_grad)}") + mmm.to(device) + + if load_pretrain: + print(f"Loading weights from models/{model_name}_MMM.pth") + checkpoint = torch.load(f'models/{model_name}_MMM.pth', map_location=device) + mmm.load_state_dict(checkpoint['model_state_dict']) + mmm.eval() + return mmm + + +def train_epoch(text_tokenizer, model, train_loader, labels_mapping, optimizer, device): + (data, attributes) = next(iter(train_loader)) + keys = [create_key(attribute) for attribute in attributes] + + while(len(set(keys)) != len(keys)): + (data, attributes) = next(iter(train_loader)) + keys = [create_key(attribute) for attribute in attributes] + + data = data.to(device) + + texts = [labels_mapping[create_key(attribute)] for attribute in attributes] + selected_texts = [l[random.randint(0, len(l) - 1)] for l in texts] + + tokenized_text = text_tokenizer(selected_texts, padding=True, return_tensors="pt").to(device) + + loss = model(data, tokenized_text) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + return loss.item() + + +def valid_epoch(text_tokenizer, model, valid_loader, labels_mapping, device): + (data, attributes) = next(iter(valid_loader)) + keys = [create_key(attribute) for attribute in attributes] + + while(len(set(keys)) != len(keys)): + (data, attributes) = next(iter(valid_loader)) + keys = [create_key(attribute) for attribute in attributes] + + data = data.to(device) + texts = [labels_mapping[create_key(attribute)] for attribute in attributes] + selected_texts = [l[random.randint(0, len(l) - 1)] for l in texts] + + tokenized_text = text_tokenizer(selected_texts, padding=True, return_tensors="pt").to(device) + + loss = model(data, tokenized_text) + return loss.item() + + +def train_multi_modal_model(device, training_dataloader, labels_mapping, text_tokenizer, text_encoder, + timbre_encoder_Config, MMM_config, MMM_training_config, + mmm_name, BATCH_SIZE, max_iter=0, load_pretrain=True, + timbre_encoder_name=None, init_loss=None, save_steps=2000): + + def save_model_hyperparameter(model_name, MMM_config, MMM_training_config, BATCH_SIZE, model_size, current_iter, + current_loss): + + model_hyperparameter = MMM_config + model_hyperparameter.update(MMM_training_config) + model_hyperparameter["BATCH_SIZE"] = BATCH_SIZE + model_hyperparameter["model_size"] = model_size + model_hyperparameter["current_iter"] = current_iter + model_hyperparameter["current_loss"] = current_loss + with open(f"models/hyperparameters/{model_name}_MMM.json", "w") as json_file: + json.dump(model_hyperparameter, json_file, ensure_ascii=False, indent=4) + + timbreEncoder = get_timbre_encoder(timbre_encoder_Config, load_pretrain=True, model_name=timbre_encoder_name, + device=device) + + mmm = multi_modal_model(timbreEncoder, text_encoder, **MMM_config).to(device) + + print(f"spectrogram_encoder parameter: {sum(p.numel() for p in mmm.timbre_encoder.parameters())}") + print(f"text_encoder parameter: {sum(p.numel() for p in mmm.text_encoder.parameters())}") + print(f"spectrogram_projection parameter: {sum(p.numel() for p in mmm.spectrogram_projection.parameters())}") + print(f"text_projection parameter: {sum(p.numel() for p in mmm.text_projection.parameters())}") + total_parameters = sum(p.numel() for p in mmm.parameters()) + trainable_parameters = sum(p.numel() for p in mmm.parameters() if p.requires_grad) + print(f"Trainable/Total parameter: {trainable_parameters}/{total_parameters}") + + params = [ + {"params": itertools.chain( + mmm.spectrogram_projection.parameters(), + mmm.text_projection.parameters(), + ), "lr": MMM_training_config["head_lr"], "weight_decay": MMM_training_config["head_weight_decay"]}, + ] + if not MMM_config["freeze_text_encoder"]: + params.append({"params": mmm.text_encoder.parameters(), "lr": MMM_training_config["text_encoder_lr"], + "weight_decay": MMM_training_config["text_encoder_weight_decay"]}) + if not MMM_config["freeze_spectrogram_encoder"]: + params.append({"params": mmm.timbre_encoder.parameters(), "lr": MMM_training_config["spectrogram_encoder_lr"], + "weight_decay": MMM_training_config["timbre_encoder_weight_decay"]}) + + optimizer = torch.optim.AdamW(params, weight_decay=0.) + + if load_pretrain: + print(f"Loading weights from models/{mmm_name}_MMM.pt") + checkpoint = torch.load(f'models/{mmm_name}_MMM.pth') + mmm.load_state_dict(checkpoint['model_state_dict']) + optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + else: + print("Model initialized.") + + if max_iter == 0: + print("Return model directly.") + return mmm, optimizer + + if init_loss is None: + previous_lowest_loss = valid_epoch(text_tokenizer, mmm, training_dataloader, labels_mapping, device) + else: + previous_lowest_loss = init_loss + print(f"Initial total loss: {previous_lowest_loss}") + + train_loss_list = [] + for i in range(max_iter): + + mmm.train() + train_loss = train_epoch(text_tokenizer, mmm, training_dataloader, labels_mapping, optimizer, device) + train_loss_list.append(train_loss) + + step = int( + optimizer.state_dict()['state'][list(optimizer.state_dict()['state'].keys())[0]]['step'].cpu().numpy()) + if (i + 1) % 100 == 0: + print('%d step' % (step)) + + if (i + 1) % save_steps == 0: + current_loss = np.mean(train_loss_list[-save_steps:]) + print(f"train_total_loss: {current_loss}") + if current_loss < previous_lowest_loss: + previous_lowest_loss = current_loss + torch.save({ + 'model_state_dict': mmm.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + }, f'models/{mmm_name}_MMM.pth') + save_model_hyperparameter(mmm_name, MMM_config, MMM_training_config, BATCH_SIZE, total_parameters, step, + current_loss) + + return mmm, optimizer \ No newline at end of file diff --git a/model/timbre_encoder_pretrain.py b/model/timbre_encoder_pretrain.py new file mode 100644 index 0000000000000000000000000000000000000000..0e22057bca595b08937a18c9b0f1e74f093da14c --- /dev/null +++ b/model/timbre_encoder_pretrain.py @@ -0,0 +1,220 @@ +import json +import numpy as np +import torch +from torch import nn +from torch.utils.tensorboard import SummaryWriter +from tools import create_key + + +class TimbreEncoder(nn.Module): + def __init__(self, input_dim, feature_dim, hidden_dim, num_instrument_classes, num_instrument_family_classes, num_velocity_classes, num_qualities, num_layers=1): + super(TimbreEncoder, self).__init__() + + # Input layer + self.input_layer = nn.Linear(input_dim, feature_dim) + + # LSTM Layer + self.lstm = nn.LSTM(feature_dim, hidden_dim, num_layers=num_layers, batch_first=True) + + # Fully Connected Layers for classification + self.instrument_classifier_layer = nn.Linear(hidden_dim, num_instrument_classes) + self.instrument_family_classifier_layer = nn.Linear(hidden_dim, num_instrument_family_classes) + self.velocity_classifier_layer = nn.Linear(hidden_dim, num_velocity_classes) + self.qualities_classifier_layer = nn.Linear(hidden_dim, num_qualities) + + # Softmax for converting output to probabilities + self.softmax = nn.LogSoftmax(dim=1) + + def forward(self, x): + # # Merge first two dimensions + batch_size, _, _, seq_len = x.shape + x = x.view(batch_size, -1, seq_len) # [batch_size, input_dim, seq_len] + + # Forward propagate LSTM + x = x.permute(0, 2, 1) + x = self.input_layer(x) + feature, _ = self.lstm(x) + feature = feature[:, -1, :] + + # Apply classification layers + instrument_logits = self.instrument_classifier_layer(feature) + instrument_family_logits = self.instrument_family_classifier_layer(feature) + velocity_logits = self.velocity_classifier_layer(feature) + qualities = self.qualities_classifier_layer(feature) + + # Apply Softmax + instrument_logits = self.softmax(instrument_logits) + instrument_family_logits= self.softmax(instrument_family_logits) + velocity_logits = self.softmax(velocity_logits) + qualities = torch.sigmoid(qualities) + + return feature, instrument_logits, instrument_family_logits, velocity_logits, qualities + + +def get_multiclass_acc(outputs, ground_truth): + _, predicted = torch.max(outputs.data, 1) + total = ground_truth.size(0) + correct = (predicted == ground_truth).sum().item() + accuracy = 100 * correct / total + return accuracy + +def get_binary_accuracy(y_pred, y_true): + predictions = (y_pred > 0.5).int() + + correct_predictions = (predictions == y_true).float() + + accuracy = correct_predictions.mean() + + return accuracy.item() * 100.0 + + +def get_timbre_encoder(model_Config, load_pretrain=False, model_name=None, device="cpu"): + timbreEncoder = TimbreEncoder(**model_Config) + print(f"Model intialized, size: {sum(p.numel() for p in timbreEncoder.parameters() if p.requires_grad)}") + timbreEncoder.to(device) + + if load_pretrain: + print(f"Loading weights from models/{model_name}_timbre_encoder.pth") + checkpoint = torch.load(f'models/{model_name}_timbre_encoder.pth', map_location=device) + timbreEncoder.load_state_dict(checkpoint['model_state_dict']) + timbreEncoder.eval() + return timbreEncoder + + +def evaluate_timbre_encoder(device, model, iterator, nll_Loss, bce_Loss, n_sample=100): + model.to(device) + model.eval() + + eva_loss = [] + for i in range(n_sample): + representation, attributes = next(iter(iterator)) + + instrument = torch.tensor([s["instrument"] for s in attributes], dtype=torch.long).to(device) + instrument_family = torch.tensor([s["instrument_family"] for s in attributes], dtype=torch.long).to(device) + velocity = torch.tensor([s["velocity"] for s in attributes], dtype=torch.long).to(device) + qualities = torch.tensor([[int(char) for char in create_key(attribute)[-10:]] for attribute in attributes], dtype=torch.float32).to(device) + + _, instrument_logits, instrument_family_logits, velocity_logits, qualities_pred = model(representation.to(device)) + + # compute loss + instrument_loss = nll_Loss(instrument_logits, instrument) + instrument_family_loss = nll_Loss(instrument_family_logits, instrument_family) + velocity_loss = nll_Loss(velocity_logits, velocity) + qualities_loss = bce_Loss(qualities_pred, qualities) + + loss = instrument_loss + instrument_family_loss + velocity_loss + qualities_loss + + eva_loss.append(loss.item()) + + eva_loss = np.mean(eva_loss) + return eva_loss + + +def train_timbre_encoder(device, model_name, timbre_encoder_Config, BATCH_SIZE, lr, max_iter, training_iterator, load_pretrain): + def save_model_hyperparameter(model_name, timbre_encoder_Config, BATCH_SIZE, lr, model_size, current_iter, + current_loss): + model_hyperparameter = timbre_encoder_Config + model_hyperparameter["BATCH_SIZE"] = BATCH_SIZE + model_hyperparameter["lr"] = lr + model_hyperparameter["model_size"] = model_size + model_hyperparameter["current_iter"] = current_iter + model_hyperparameter["current_loss"] = current_loss + with open(f"models/hyperparameters/{model_name}_timbre_encoder.json", "w") as json_file: + json.dump(model_hyperparameter, json_file, ensure_ascii=False, indent=4) + + model = TimbreEncoder(**timbre_encoder_Config) + model_size = sum(p.numel() for p in model.parameters() if p.requires_grad) + print(f"Model size: {model_size}") + model.to(device) + nll_Loss = torch.nn.NLLLoss() + bce_Loss = torch.nn.BCELoss() + + optimizer = torch.optim.Adam(model.parameters(), lr=lr, amsgrad=False) + + if load_pretrain: + print(f"Loading weights from models/{model_name}_timbre_encoder.pt") + checkpoint = torch.load(f'models/{model_name}_timbre_encoder.pth') + model.load_state_dict(checkpoint['model_state_dict']) + optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + else: + print("Model initialized.") + if max_iter == 0: + print("Return model directly.") + return model, model + + train_loss, training_instrument_acc, training_instrument_family_acc, training_velocity_acc, training_qualities_acc = [], [], [], [], [] + writer = SummaryWriter(f'runs/{model_name}_timbre_encoder') + current_best_model = model + previous_lowest_loss = 100.0 + print(f"initial__loss: {previous_lowest_loss}") + + for i in range(max_iter): + model.train() + + representation, attributes = next(iter(training_iterator)) + + instrument = torch.tensor([s["instrument"] for s in attributes], dtype=torch.long).to(device) + instrument_family = torch.tensor([s["instrument_family"] for s in attributes], dtype=torch.long).to(device) + velocity = torch.tensor([s["velocity"] for s in attributes], dtype=torch.long).to(device) + qualities = torch.tensor([[int(char) for char in create_key(attribute)[-10:]] for attribute in attributes], dtype=torch.float32).to(device) + + optimizer.zero_grad() + + _, instrument_logits, instrument_family_logits, velocity_logits, qualities_pred = model(representation.to(device)) + + # compute loss + instrument_loss = nll_Loss(instrument_logits, instrument) + instrument_family_loss = nll_Loss(instrument_family_logits, instrument_family) + velocity_loss = nll_Loss(velocity_logits, velocity) + qualities_loss = bce_Loss(qualities_pred, qualities) + + loss = instrument_loss + instrument_family_loss + velocity_loss + qualities_loss + + loss.backward() + optimizer.step() + instrument_acc = get_multiclass_acc(instrument_logits, instrument) + instrument_family_acc = get_multiclass_acc(instrument_family_logits, instrument_family) + velocity_acc = get_multiclass_acc(velocity_logits, velocity) + qualities_acc = get_binary_accuracy(qualities_pred, qualities) + + train_loss.append(loss.item()) + training_instrument_acc.append(instrument_acc) + training_instrument_family_acc.append(instrument_family_acc) + training_velocity_acc.append(velocity_acc) + training_qualities_acc.append(qualities_acc) + step = int(optimizer.state_dict()['state'][list(optimizer.state_dict()['state'].keys())[0]]['step'].numpy()) + + if (i + 1) % 100 == 0: + print('%d step' % (step)) + + save_steps = 500 + if (i + 1) % save_steps == 0: + current_loss = np.mean(train_loss[-save_steps:]) + current_instrument_acc = np.mean(training_instrument_acc[-save_steps:]) + current_instrument_family_acc = np.mean(training_instrument_family_acc[-save_steps:]) + current_velocity_acc = np.mean(training_velocity_acc[-save_steps:]) + current_qualities_acc = np.mean(training_qualities_acc[-save_steps:]) + print('train_loss: %.5f' % current_loss) + print('current_instrument_acc: %.5f' % current_instrument_acc) + print('current_instrument_family_acc: %.5f' % current_instrument_family_acc) + print('current_velocity_acc: %.5f' % current_velocity_acc) + print('current_qualities_acc: %.5f' % current_qualities_acc) + writer.add_scalar(f"train_loss", current_loss, step) + writer.add_scalar(f"current_instrument_acc", current_instrument_acc, step) + writer.add_scalar(f"current_instrument_family_acc", current_instrument_family_acc, step) + writer.add_scalar(f"current_velocity_acc", current_velocity_acc, step) + writer.add_scalar(f"current_qualities_acc", current_qualities_acc, step) + + if current_loss < previous_lowest_loss: + previous_lowest_loss = current_loss + current_best_model = model + torch.save({ + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + }, f'models/{model_name}_timbre_encoder.pth') + save_model_hyperparameter(model_name, timbre_encoder_Config, BATCH_SIZE, lr, model_size, step, + current_loss) + + return model, current_best_model + + diff --git a/models/24_1_2024-52_4x_L_D_imageVQVAE.pth b/models/24_1_2024-52_4x_L_D_imageVQVAE.pth new file mode 100644 index 0000000000000000000000000000000000000000..7747a92f9389133e157ea8f6059e235f8f9b1eac --- /dev/null +++ b/models/24_1_2024-52_4x_L_D_imageVQVAE.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e5feec46219e25e6f95bfa453a4bddb3ec7bc26d29f2e01748defa4901762c9f +size 16069859 diff --git a/models/24_1_2024-52_4x_L_D_imageVQVAE_discriminator.pth b/models/24_1_2024-52_4x_L_D_imageVQVAE_discriminator.pth new file mode 100644 index 0000000000000000000000000000000000000000..892380ebe38b50f4abaab84aab6c88b1ce78347f --- /dev/null +++ b/models/24_1_2024-52_4x_L_D_imageVQVAE_discriminator.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:69080ff5094f82bba6e69e4310cda27864420dfea9b08d3a001dafe46bbf6808 +size 134268962 diff --git a/models/24_1_2024_MMM.pth b/models/24_1_2024_MMM.pth new file mode 100644 index 0000000000000000000000000000000000000000..70d4c25be12f2301bfc25f56dcab92d8d58aca2c --- /dev/null +++ b/models/24_1_2024_MMM.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:494f7fa1f9874ebd2cd870b6da993cb796e907e857a28550e62be8c335fb9f5a +size 1930637291 diff --git a/models/24_1_2024_STFT_timbre_encoder.pth b/models/24_1_2024_STFT_timbre_encoder.pth new file mode 100644 index 0000000000000000000000000000000000000000..603f99c8c5516d7efbf75ea804be4ba00271b78d --- /dev/null +++ b/models/24_1_2024_STFT_timbre_encoder.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9b9288a7fc4a8cdc4b0db5803fd0a5e88e6bcf07bdde22f49ebb8dec12fb33e6 +size 294502949 diff --git a/models/history/28_1_2024_TE_STFT_300000_UNet.pth b/models/history/28_1_2024_TE_STFT_300000_UNet.pth new file mode 100644 index 0000000000000000000000000000000000000000..f0ebf01b6f85cb86ae145196dd4046f075827921 --- /dev/null +++ b/models/history/28_1_2024_TE_STFT_300000_UNet.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:edcdd3c65275badad3233eabdcb7a3ffa7adbf95c673172ca2e35338097d1a1c +size 1284015362 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..eda8db81bd7be59d3a9118a8c908e78ba6b1b384 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,15 @@ +torchmetrics==0.7.0 +torchsynth==1.0.2 +torchaudio +soundfile +einops +pytorch-ssim +piqa +torchinfo +mido +tensorboard +librosa +transformers +matplotlib +gradio==3.50.2 + diff --git a/tools.py b/tools.py new file mode 100644 index 0000000000000000000000000000000000000000..559c09a03b8b48a3a5d53778ef38ead15be6df33 --- /dev/null +++ b/tools.py @@ -0,0 +1,344 @@ +import numpy as np +import matplotlib.pyplot as plt +import matplotlib +import librosa +from scipy.io.wavfile import write +import torch + +k = 1e-16 + +def np_log10(x): + """Safe log function with base 10.""" + numerator = np.log(x + 1e-16) + denominator = np.log(10) + return numerator / denominator + + +def sigmoid(x): + """Safe log function with base 10.""" + s = 1 / (1 + np.exp(-x)) + return s + + +def inv_sigmoid(s): + """Safe inverse sigmoid function.""" + x = np.log((s / (1 - s)) + 1e-16) + return x + + +def spc_to_VAE_input(spc): + """Restrict value range from [0, infinite] to [0, 1]. (deprecated )""" + return spc / (1 + spc) + + +def VAE_out_put_to_spc(o): + """Inverse transform of function 'spc_to_VAE_input'. (deprecated )""" + return o / (1 - o + k) + + + +def np_power_to_db(S, amin=1e-16, top_db=80.0): + """Helper method for numpy data scaling. (deprecated )""" + ref = S.max() + + log_spec = 10.0 * np_log10(np.maximum(amin, S)) + log_spec -= 10.0 * np_log10(np.maximum(amin, ref)) + + log_spec = np.maximum(log_spec, log_spec.max() - top_db) + + return log_spec + + +def show_spc(spc): + """Show a spectrogram. (deprecated )""" + s = np.shape(spc) + spc = np.reshape(spc, (s[0], s[1])) + magnitude_spectrum = np.abs(spc) + log_spectrum = np_power_to_db(magnitude_spectrum) + plt.imshow(np.flipud(log_spectrum)) + plt.show() + + +def save_results(spectrogram, spectrogram_image_path, waveform_path): + """Save the input 'spectrogram' and its waveform (reconstructed by Griffin Lim) + to path provided by 'spectrogram_image_path' and 'waveform_path'.""" + magnitude_spectrum = np.abs(spectrogram) + log_spc = np_power_to_db(magnitude_spectrum) + log_spc = np.reshape(log_spc, (512, 256)) + matplotlib.pyplot.imsave(spectrogram_image_path, log_spc, vmin=-100, vmax=0, + origin='lower') + + # save waveform + abs_spec = np.zeros((513, 256)) + abs_spec[:512, :] = abs_spec[:512, :] + np.sqrt(np.reshape(spectrogram, (512, 256))) + rec_signal = librosa.griffinlim(abs_spec, n_iter=32, hop_length=256, win_length=1024) + write(waveform_path, 16000, rec_signal) + + +def plot_log_spectrogram(signal: np.ndarray, + path: str, + n_fft=2048, + frame_length=1024, + frame_step=256): + """Save spectrogram.""" + stft = librosa.stft(signal, n_fft=n_fft, hop_length=frame_step, win_length=frame_length) + amp = np.square(np.real(stft)) + np.square(np.imag(stft)) + magnitude_spectrum = np.abs(amp) + log_mel = np_power_to_db(magnitude_spectrum) + matplotlib.pyplot.imsave(path, log_mel, vmin=-100, vmax=0, origin='lower') + + +def visualize_feature_maps(device, model, inputs, channel_indices=[0, 3,]): + """ + Visualize feature maps before and after quantization for given input. + + Parameters: + - model: Your VQ-VAE model. + - inputs: A batch of input data. + - channel_indices: Indices of feature map channels to visualize. + """ + model.eval() + inputs = inputs.to(device) + + with torch.no_grad(): + z_e = model._encoder(inputs) + z_q, loss, (perplexity, min_encodings, min_encoding_indices) = model._vq_vae(z_e) + + # Assuming inputs have shape [batch_size, channels, height, width] + batch_size = z_e.size(0) + + for idx in range(batch_size): + fig, axs = plt.subplots(1, len(channel_indices)*2, figsize=(15, 5)) + + for i, channel_idx in enumerate(channel_indices): + # Plot encoder output + axs[2*i].imshow(z_e[idx][channel_idx].cpu().numpy(), cmap='viridis') + axs[2*i].set_title(f"Encoder Output - Channel {channel_idx}") + + # Plot quantized output + axs[2*i+1].imshow(z_q[idx][channel_idx].cpu().numpy(), cmap='viridis') + axs[2*i+1].set_title(f"Quantized Output - Channel {channel_idx}") + + plt.show() + + +def adjust_audio_length(audio, desired_length, original_sample_rate, target_sample_rate): + """ + Adjust the audio length to the desired length and resample to target sample rate. + + Parameters: + - audio (np.array): The input audio signal + - desired_length (int): The desired length of the output audio + - original_sample_rate (int): The original sample rate of the audio + - target_sample_rate (int): The target sample rate for the output audio + + Returns: + - np.array: The adjusted and resampled audio + """ + + if not (original_sample_rate == target_sample_rate): + audio = librosa.core.resample(audio, orig_sr=original_sample_rate, target_sr=target_sample_rate) + + if len(audio) > desired_length: + return audio[:desired_length] + + elif len(audio) < desired_length: + padded_audio = np.zeros(desired_length) + padded_audio[:len(audio)] = audio + return padded_audio + else: + return audio + + +def safe_int(s, default=0): + try: + return int(s) + except ValueError: + return default + + +def pad_spectrogram(D): + """Resize spectrogram to (512, 256). (deprecated )""" + D = D[1:, :] + + padding_length = 256 - D.shape[1] + D_padded = np.pad(D, ((0, 0), (0, padding_length)), 'constant') + return D_padded + + +def pad_STFT(D, time_resolution=256): + """Resize spectral matrix by padding and cropping""" + D = D[1:, :] + + if time_resolution is None: + return D + + padding_length = time_resolution - D.shape[1] + if padding_length > 0: + D_padded = np.pad(D, ((0, 0), (0, padding_length)), 'constant') + return D_padded + else: + return D + + +def depad_STFT(D_padded): + """Inverse function of 'pad_STFT'""" + zero_row = np.zeros((1, D_padded.shape[1])) + + D_restored = np.concatenate([zero_row, D_padded], axis=0) + + return D_restored + + +def nnData2Audio(spectrogram_batch, resolution=(512, 256), squared=False): + """Transform batch of numpy spectrogram into signals and encodings.""" + # Todo: remove resolution hard-coding + frequency_resolution, time_resolution = resolution + + if isinstance(spectrogram_batch, torch.Tensor): + spectrogram_batch = spectrogram_batch.to("cpu").detach().numpy() + + origin_signals = [] + for spectrogram in spectrogram_batch: + spc = VAE_out_put_to_spc(spectrogram) + + # get_audio + abs_spec = np.zeros((frequency_resolution+1, time_resolution)) + + if squared: + abs_spec[1:, :] = abs_spec[1:, :] + np.sqrt(np.reshape(spc, (frequency_resolution, time_resolution))) + else: + abs_spec[1:, :] = abs_spec[1:, :] + np.reshape(spc, (frequency_resolution, time_resolution)) + + origin_signal = librosa.griffinlim(abs_spec, n_iter=32, hop_length=256, win_length=1024) + origin_signals.append(origin_signal) + + return origin_signals + + +def amp_to_audio(amp, n_iter=50): + """The Griffin-Lim algorithm.""" + y_reconstructed = librosa.griffinlim(amp, n_iter=n_iter, hop_length=256, win_length=1024) + return y_reconstructed + + +def rescale(amp, method="log1p"): + """Rescale function.""" + if method == "log1p": + return np.log1p(amp) + elif method == "NormalizedLogisticCompression": + return amp / (1.0 + amp) + else: + raise NotImplementedError() + + +def unrescale(scaled_amp, method="NormalizedLogisticCompression"): + """Inverse function of 'rescale'""" + if method == "log1p": + return np.expm1(scaled_amp) + elif method == "NormalizedLogisticCompression": + return scaled_amp / (1.0 - scaled_amp + 1e-10) + else: + raise NotImplementedError() + + +def create_key(attributes): + """Create unique key for each multi-label.""" + qualities_str = ''.join(map(str, attributes["qualities"])) + instrument_source_str = attributes["instrument_source_str"] + instrument_family = attributes["instrument_family_str"] + key = f"{instrument_source_str}_{instrument_family}_{qualities_str}" + return key + + +def merge_dictionaries(dicts): + """Merge dictionaries.""" + merged_dict = {} + for dictionary in dicts: + for key, value in dictionary.items(): + if key in merged_dict: + merged_dict[key] += value + else: + merged_dict[key] = value + return merged_dict + + +def adsr_envelope(signal, sample_rate, duration, attack_time, decay_time, sustain_level, release_time): + """ + Apply an ADSR envelope to an audio signal. + + :param signal: The original audio signal (numpy array). + :param sample_rate: The sample rate of the audio signal. + :param attack_time: Attack time in seconds. + :param decay_time: Decay time in seconds. + :param sustain_level: Sustain level as a fraction of the peak (0 to 1). + :param release_time: Release time in seconds. + :return: The audio signal with the ADSR envelope applied. + """ + # Calculate the number of samples for each ADSR phase + duration_samples = int(duration * sample_rate) + + # assert (duration_samples + int(1.0 * sample_rate)) <= len(signal), "(duration_samples + sample_rate) > len(signal)" + assert release_time <= 1.0, "release_time > 1.0" + + attack_samples = int(attack_time * sample_rate) + decay_samples = int(decay_time * sample_rate) + release_samples = int(release_time * sample_rate) + sustain_samples = max(0, duration_samples - attack_samples - decay_samples) + + # Create ADSR envelope + attack_env = np.linspace(0, 1, attack_samples) + decay_env = np.linspace(1, sustain_level, decay_samples) + sustain_env = np.full(sustain_samples, sustain_level) + release_env = np.linspace(sustain_level, 0, release_samples) + release_env_expand = np.zeros(int(1.0 * sample_rate)) + release_env_expand[:len(release_env)] = release_env + + # Concatenate all phases to create the complete envelope + envelope = np.concatenate([attack_env, decay_env, sustain_env, release_env_expand]) + + # Apply the envelope to the signal + if len(envelope) <= len(signal): + applied_signal = signal[:len(envelope)] * envelope + else: + signal_expanded = np.zeros(len(envelope)) + signal_expanded[:len(signal)] = signal + applied_signal = signal_expanded * envelope + + return applied_signal + + +def rms_normalize(audio, target_rms=0.1): + """Normalize the RMS value.""" + current_rms = np.sqrt(np.mean(audio**2)) + scaling_factor = target_rms / current_rms + normalized_audio = audio * scaling_factor + return normalized_audio + + +def encode_stft(D): + """'STFT+' function that transform spectral matrix into spectral representation.""" + magnitude = np.abs(D) + phase = np.angle(D) + + log_magnitude = np.log1p(magnitude) + + cos_phase = np.cos(phase) + sin_phase = np.sin(phase) + + encoded_D = np.stack([log_magnitude, cos_phase, sin_phase], axis=0) + return encoded_D + + +def decode_stft(encoded_D): + """'ISTFT+' function that reconstructs spectral matrix from spectral representation.""" + log_magnitude = encoded_D[0, ...] + cos_phase = encoded_D[1, ...] + sin_phase = encoded_D[2, ...] + + magnitude = np.expm1(log_magnitude) + + phase = np.arctan2(sin_phase, cos_phase) + + D = magnitude * (np.cos(phase) + 1j * np.sin(phase)) + return D diff --git a/webUI/__pycache__/app.cpython-310.pyc b/webUI/__pycache__/app.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ffdcc3fb90facf255ef313f9f5662551788e948b Binary files /dev/null and b/webUI/__pycache__/app.cpython-310.pyc differ diff --git a/webUI/deprecated/interpolationWithCondition.py b/webUI/deprecated/interpolationWithCondition.py new file mode 100644 index 0000000000000000000000000000000000000000..6947c4895bcbcd86a4831bcbaf70a38afea1e2b1 --- /dev/null +++ b/webUI/deprecated/interpolationWithCondition.py @@ -0,0 +1,178 @@ +import gradio as gr +import numpy as np +import torch + +from model.DiffSynthSampler import DiffSynthSampler +from tools import safe_int +from webUI.natural_language_guided_STFT.utils import encodeBatch2GradioOutput, latent_representation_to_Gradio_image + + +def get_interpolation_with_condition_module(gradioWebUI, interpolation_with_text_state): + # Load configurations + uNet = gradioWebUI.uNet + freq_resolution, time_resolution = gradioWebUI.freq_resolution, gradioWebUI.time_resolution + VAE_scale = gradioWebUI.VAE_scale + height, width, channels = int(freq_resolution/VAE_scale), int(time_resolution/VAE_scale), gradioWebUI.channels + timesteps = gradioWebUI.timesteps + VAE_quantizer = gradioWebUI.VAE_quantizer + VAE_decoder = gradioWebUI.VAE_decoder + CLAP = gradioWebUI.CLAP + CLAP_tokenizer = gradioWebUI.CLAP_tokenizer + device = gradioWebUI.device + squared = gradioWebUI.squared + sample_rate = gradioWebUI.sample_rate + noise_strategy = gradioWebUI.noise_strategy + + def diffusion_random_sample(text2sound_prompts_1, text2sound_prompts_2, text2sound_negative_prompts, text2sound_batchsize, + text2sound_duration, + text2sound_guidance_scale, text2sound_sampler, + text2sound_sample_steps, text2sound_seed, + interpolation_with_text_dict): + text2sound_sample_steps = int(text2sound_sample_steps) + text2sound_seed = safe_int(text2sound_seed, 12345678) + # Todo: take care of text2sound_time_resolution/width + width = int(time_resolution*((text2sound_duration+1)/4) / VAE_scale) + text2sound_batchsize = int(text2sound_batchsize) + + text2sound_embedding_1 = \ + CLAP.get_text_features(**CLAP_tokenizer([text2sound_prompts_1], padding=True, return_tensors="pt"))[0].to(device) + text2sound_embedding_2 = \ + CLAP.get_text_features(**CLAP_tokenizer([text2sound_prompts_2], padding=True, return_tensors="pt"))[0].to(device) + + CFG = int(text2sound_guidance_scale) + + mySampler = DiffSynthSampler(timesteps, height=height, channels=channels, noise_strategy=noise_strategy) + unconditional_condition = \ + CLAP.get_text_features(**CLAP_tokenizer([text2sound_negative_prompts], padding=True, return_tensors="pt"))[0] + mySampler.activate_classifier_free_guidance(CFG, unconditional_condition.to(device)) + + mySampler.respace(list(np.linspace(0, timesteps - 1, text2sound_sample_steps, dtype=np.int32))) + + condition = torch.linspace(1, 0, steps=text2sound_batchsize).unsqueeze(1).to(device) * text2sound_embedding_1 + \ + torch.linspace(0, 1, steps=text2sound_batchsize).unsqueeze(1).to(device) * text2sound_embedding_2 + + # Todo: move this code + torch.manual_seed(text2sound_seed) + initial_noise = torch.randn(text2sound_batchsize, channels, height, width).to(device) + + latent_representations, initial_noise = \ + mySampler.sample(model=uNet, shape=(text2sound_batchsize, channels, height, width), seed=text2sound_seed, + return_tensor=True, condition=condition, sampler=text2sound_sampler, initial_noise=initial_noise) + + latent_representations = latent_representations[-1] + + interpolation_with_text_dict["latent_representations"] = latent_representations + + latent_representation_gradio_images = [] + quantized_latent_representation_gradio_images = [] + new_sound_spectrogram_gradio_images = [] + new_sound_rec_signals_gradio = [] + + quantized_latent_representations, loss, (_, _, _) = VAE_quantizer(latent_representations) + # Todo: remove hard-coding + flipped_log_spectrums, rec_signals = encodeBatch2GradioOutput(VAE_decoder, quantized_latent_representations, + resolution=(512, width * VAE_scale), centralized=False, + squared=squared) + + for i in range(text2sound_batchsize): + latent_representation_gradio_images.append(latent_representation_to_Gradio_image(latent_representations[i])) + quantized_latent_representation_gradio_images.append( + latent_representation_to_Gradio_image(quantized_latent_representations[i])) + new_sound_spectrogram_gradio_images.append(flipped_log_spectrums[i]) + new_sound_rec_signals_gradio.append((sample_rate, rec_signals[i])) + + def concatenate_arrays(arrays_list): + return np.concatenate(arrays_list, axis=1) + + concatenated_spectrogram_gradio_image = concatenate_arrays(new_sound_spectrogram_gradio_images) + + interpolation_with_text_dict["latent_representation_gradio_images"] = latent_representation_gradio_images + interpolation_with_text_dict["quantized_latent_representation_gradio_images"] = quantized_latent_representation_gradio_images + interpolation_with_text_dict["new_sound_spectrogram_gradio_images"] = new_sound_spectrogram_gradio_images + interpolation_with_text_dict["new_sound_rec_signals_gradio"] = new_sound_rec_signals_gradio + + return {text2sound_latent_representation_image: interpolation_with_text_dict["latent_representation_gradio_images"][0], + text2sound_quantized_latent_representation_image: + interpolation_with_text_dict["quantized_latent_representation_gradio_images"][0], + text2sound_sampled_concatenated_spectrogram_image: concatenated_spectrogram_gradio_image, + text2sound_sampled_spectrogram_image: interpolation_with_text_dict["new_sound_spectrogram_gradio_images"][0], + text2sound_sampled_audio: interpolation_with_text_dict["new_sound_rec_signals_gradio"][0], + text2sound_seed_textbox: text2sound_seed, + interpolation_with_text_state: interpolation_with_text_dict, + text2sound_sample_index_slider: gr.update(minimum=0, maximum=text2sound_batchsize - 1, value=0, step=1, + visible=True, + label="Sample index.", + info="Swipe to view other samples")} + + def show_random_sample(sample_index, text2sound_dict): + sample_index = int(sample_index) + return {text2sound_latent_representation_image: text2sound_dict["latent_representation_gradio_images"][ + sample_index], + text2sound_quantized_latent_representation_image: + text2sound_dict["quantized_latent_representation_gradio_images"][sample_index], + text2sound_sampled_spectrogram_image: text2sound_dict["new_sound_spectrogram_gradio_images"][sample_index], + text2sound_sampled_audio: text2sound_dict["new_sound_rec_signals_gradio"][sample_index]} + + with gr.Tab("InterpolationCond."): + gr.Markdown("Use interpolation to generate a gradient sound sequence.") + with gr.Row(variant="panel"): + with gr.Column(scale=3): + text2sound_prompts_1_textbox = gr.Textbox(label="Positive prompt 1", lines=2, value="organ") + text2sound_prompts_2_textbox = gr.Textbox(label="Positive prompt 2", lines=2, value="string") + text2sound_negative_prompts_textbox = gr.Textbox(label="Negative prompt", lines=2, value="") + + with gr.Column(scale=1): + text2sound_sampling_button = gr.Button(variant="primary", + value="Generate a batch of samples and show " + "the first one", + scale=1) + text2sound_sample_index_slider = gr.Slider(minimum=0, maximum=3, value=0, step=1.0, visible=False, + label="Sample index", + info="Swipe to view other samples") + with gr.Row(variant="panel"): + with gr.Column(scale=1, variant="panel"): + text2sound_sample_steps_slider = gradioWebUI.get_sample_steps_slider() + text2sound_sampler_radio = gradioWebUI.get_sampler_radio() + text2sound_batchsize_slider = gradioWebUI.get_batchsize_slider(cpu_batchsize=3) + text2sound_duration_slider = gradioWebUI.get_duration_slider() + text2sound_guidance_scale_slider = gradioWebUI.get_guidance_scale_slider() + text2sound_seed_textbox = gradioWebUI.get_seed_textbox() + + with gr.Column(scale=1): + with gr.Row(variant="panel"): + text2sound_sampled_concatenated_spectrogram_image = gr.Image(label="Interpolations", type="numpy", + height=420, scale=8) + text2sound_sampled_spectrogram_image = gr.Image(label="Selected spectrogram", type="numpy", + height=420, scale=1) + text2sound_sampled_audio = gr.Audio(type="numpy", label="Play") + + with gr.Row(variant="panel"): + text2sound_latent_representation_image = gr.Image(label="Sampled latent representation", type="numpy", + height=200, width=100) + text2sound_quantized_latent_representation_image = gr.Image(label="Quantized latent representation", + type="numpy", height=200, width=100) + + text2sound_sampling_button.click(diffusion_random_sample, + inputs=[text2sound_prompts_1_textbox, + text2sound_prompts_2_textbox, + text2sound_negative_prompts_textbox, + text2sound_batchsize_slider, + text2sound_duration_slider, + text2sound_guidance_scale_slider, text2sound_sampler_radio, + text2sound_sample_steps_slider, + text2sound_seed_textbox, + interpolation_with_text_state], + outputs=[text2sound_latent_representation_image, + text2sound_quantized_latent_representation_image, + text2sound_sampled_concatenated_spectrogram_image, + text2sound_sampled_spectrogram_image, + text2sound_sampled_audio, + text2sound_seed_textbox, + interpolation_with_text_state, + text2sound_sample_index_slider]) + text2sound_sample_index_slider.change(show_random_sample, + inputs=[text2sound_sample_index_slider, interpolation_with_text_state], + outputs=[text2sound_latent_representation_image, + text2sound_quantized_latent_representation_image, + text2sound_sampled_spectrogram_image, + text2sound_sampled_audio]) diff --git a/webUI/deprecated/interpolationWithXT.py b/webUI/deprecated/interpolationWithXT.py new file mode 100644 index 0000000000000000000000000000000000000000..fa826732f85fa53a37fd3787f045b97f50392e26 --- /dev/null +++ b/webUI/deprecated/interpolationWithXT.py @@ -0,0 +1,173 @@ +import gradio as gr +import numpy as np +import torch + +from model.DiffSynthSampler import DiffSynthSampler +from tools import safe_int +from webUI.natural_language_guided.utils import encodeBatch2GradioOutput, latent_representation_to_Gradio_image + + +def get_interpolation_with_xT_module(gradioWebUI, interpolation_with_text_state): + # Load configurations + uNet = gradioWebUI.uNet + freq_resolution, time_resolution = gradioWebUI.freq_resolution, gradioWebUI.time_resolution + VAE_scale = gradioWebUI.VAE_scale + height, width, channels = int(freq_resolution/VAE_scale), int(time_resolution/VAE_scale), gradioWebUI.channels + timesteps = gradioWebUI.timesteps + VAE_quantizer = gradioWebUI.VAE_quantizer + VAE_decoder = gradioWebUI.VAE_decoder + CLAP = gradioWebUI.CLAP + CLAP_tokenizer = gradioWebUI.CLAP_tokenizer + device = gradioWebUI.device + squared = gradioWebUI.squared + sample_rate = gradioWebUI.sample_rate + noise_strategy = gradioWebUI.noise_strategy + + def diffusion_random_sample(text2sound_prompts, text2sound_negative_prompts, text2sound_batchsize, + text2sound_duration, + text2sound_noise_variance, text2sound_guidance_scale, text2sound_sampler, + text2sound_sample_steps, text2sound_seed, + interpolation_with_text_dict): + text2sound_sample_steps = int(text2sound_sample_steps) + text2sound_seed = safe_int(text2sound_seed, 12345678) + # Todo: take care of text2sound_time_resolution/width + width = int(time_resolution*((text2sound_duration+1)/4) / VAE_scale) + text2sound_batchsize = int(text2sound_batchsize) + + text2sound_embedding = \ + CLAP.get_text_features(**CLAP_tokenizer([text2sound_prompts], padding=True, return_tensors="pt"))[0].to(device) + + CFG = int(text2sound_guidance_scale) + + mySampler = DiffSynthSampler(timesteps, height=height, channels=channels, noise_strategy=noise_strategy) + unconditional_condition = \ + CLAP.get_text_features(**CLAP_tokenizer([text2sound_negative_prompts], padding=True, return_tensors="pt"))[0] + mySampler.activate_classifier_free_guidance(CFG, unconditional_condition.to(device)) + + mySampler.respace(list(np.linspace(0, timesteps - 1, text2sound_sample_steps, dtype=np.int32))) + + condition = text2sound_embedding.repeat(text2sound_batchsize, 1) + latent_representations, initial_noise = \ + mySampler.interpolate(model=uNet, shape=(text2sound_batchsize, channels, height, width), + seed=text2sound_seed, + variance=text2sound_noise_variance, + return_tensor=True, condition=condition, sampler=text2sound_sampler) + + latent_representations = latent_representations[-1] + + interpolation_with_text_dict["latent_representations"] = latent_representations + + latent_representation_gradio_images = [] + quantized_latent_representation_gradio_images = [] + new_sound_spectrogram_gradio_images = [] + new_sound_rec_signals_gradio = [] + + quantized_latent_representations, loss, (_, _, _) = VAE_quantizer(latent_representations) + # Todo: remove hard-coding + flipped_log_spectrums, rec_signals = encodeBatch2GradioOutput(VAE_decoder, quantized_latent_representations, + resolution=(512, width * VAE_scale), centralized=False, + squared=squared) + + for i in range(text2sound_batchsize): + latent_representation_gradio_images.append(latent_representation_to_Gradio_image(latent_representations[i])) + quantized_latent_representation_gradio_images.append( + latent_representation_to_Gradio_image(quantized_latent_representations[i])) + new_sound_spectrogram_gradio_images.append(flipped_log_spectrums[i]) + new_sound_rec_signals_gradio.append((sample_rate, rec_signals[i])) + + def concatenate_arrays(arrays_list): + return np.concatenate(arrays_list, axis=1) + + concatenated_spectrogram_gradio_image = concatenate_arrays(new_sound_spectrogram_gradio_images) + + interpolation_with_text_dict["latent_representation_gradio_images"] = latent_representation_gradio_images + interpolation_with_text_dict["quantized_latent_representation_gradio_images"] = quantized_latent_representation_gradio_images + interpolation_with_text_dict["new_sound_spectrogram_gradio_images"] = new_sound_spectrogram_gradio_images + interpolation_with_text_dict["new_sound_rec_signals_gradio"] = new_sound_rec_signals_gradio + + return {text2sound_latent_representation_image: interpolation_with_text_dict["latent_representation_gradio_images"][0], + text2sound_quantized_latent_representation_image: + interpolation_with_text_dict["quantized_latent_representation_gradio_images"][0], + text2sound_sampled_concatenated_spectrogram_image: concatenated_spectrogram_gradio_image, + text2sound_sampled_spectrogram_image: interpolation_with_text_dict["new_sound_spectrogram_gradio_images"][0], + text2sound_sampled_audio: interpolation_with_text_dict["new_sound_rec_signals_gradio"][0], + text2sound_seed_textbox: text2sound_seed, + interpolation_with_text_state: interpolation_with_text_dict, + text2sound_sample_index_slider: gr.update(minimum=0, maximum=text2sound_batchsize - 1, value=0, step=1, + visible=True, + label="Sample index.", + info="Swipe to view other samples")} + + def show_random_sample(sample_index, text2sound_dict): + sample_index = int(sample_index) + return {text2sound_latent_representation_image: text2sound_dict["latent_representation_gradio_images"][ + sample_index], + text2sound_quantized_latent_representation_image: + text2sound_dict["quantized_latent_representation_gradio_images"][sample_index], + text2sound_sampled_spectrogram_image: text2sound_dict["new_sound_spectrogram_gradio_images"][sample_index], + text2sound_sampled_audio: text2sound_dict["new_sound_rec_signals_gradio"][sample_index]} + + with gr.Tab("InterpolationXT"): + gr.Markdown("Use interpolation to generate a gradient sound sequence.") + with gr.Row(variant="panel"): + with gr.Column(scale=3): + text2sound_prompts_textbox = gr.Textbox(label="Positive prompt", lines=2, value="organ") + text2sound_negative_prompts_textbox = gr.Textbox(label="Negative prompt", lines=2, value="") + + with gr.Column(scale=1): + text2sound_sampling_button = gr.Button(variant="primary", + value="Generate a batch of samples and show " + "the first one", + scale=1) + text2sound_sample_index_slider = gr.Slider(minimum=0, maximum=3, value=0, step=1.0, visible=False, + label="Sample index", + info="Swipe to view other samples") + with gr.Row(variant="panel"): + with gr.Column(scale=1, variant="panel"): + text2sound_sample_steps_slider = gradioWebUI.get_sample_steps_slider() + text2sound_sampler_radio = gradioWebUI.get_sampler_radio() + text2sound_batchsize_slider = gradioWebUI.get_batchsize_slider(cpu_batchsize=3) + text2sound_duration_slider = gradioWebUI.get_duration_slider() + text2sound_guidance_scale_slider = gradioWebUI.get_guidance_scale_slider() + text2sound_seed_textbox = gradioWebUI.get_seed_textbox() + text2sound_noise_variance_slider = gr.Slider(minimum=0., maximum=5., value=1., step=0.01, + label="Noise variance", + info="The larger this value, the more diversity the interpolation has.") + + with gr.Column(scale=1): + with gr.Row(variant="panel"): + text2sound_sampled_concatenated_spectrogram_image = gr.Image(label="Interpolations", type="numpy", + height=420, scale=8) + text2sound_sampled_spectrogram_image = gr.Image(label="Selected spectrogram", type="numpy", + height=420, scale=1) + text2sound_sampled_audio = gr.Audio(type="numpy", label="Play") + + with gr.Row(variant="panel"): + text2sound_latent_representation_image = gr.Image(label="Sampled latent representation", type="numpy", + height=200, width=100) + text2sound_quantized_latent_representation_image = gr.Image(label="Quantized latent representation", + type="numpy", height=200, width=100) + + text2sound_sampling_button.click(diffusion_random_sample, + inputs=[text2sound_prompts_textbox, text2sound_negative_prompts_textbox, + text2sound_batchsize_slider, + text2sound_duration_slider, + text2sound_noise_variance_slider, + text2sound_guidance_scale_slider, text2sound_sampler_radio, + text2sound_sample_steps_slider, + text2sound_seed_textbox, + interpolation_with_text_state], + outputs=[text2sound_latent_representation_image, + text2sound_quantized_latent_representation_image, + text2sound_sampled_concatenated_spectrogram_image, + text2sound_sampled_spectrogram_image, + text2sound_sampled_audio, + text2sound_seed_textbox, + interpolation_with_text_state, + text2sound_sample_index_slider]) + text2sound_sample_index_slider.change(show_random_sample, + inputs=[text2sound_sample_index_slider, interpolation_with_text_state], + outputs=[text2sound_latent_representation_image, + text2sound_quantized_latent_representation_image, + text2sound_sampled_spectrogram_image, + text2sound_sampled_audio]) diff --git a/webUI/natural_language_guided/GAN.py b/webUI/natural_language_guided/GAN.py new file mode 100644 index 0000000000000000000000000000000000000000..9aacef29ae5a07ae31f3a9b55d4c47938b471021 --- /dev/null +++ b/webUI/natural_language_guided/GAN.py @@ -0,0 +1,164 @@ +import gradio as gr +import numpy as np +import torch + +from tools import safe_int +from webUI.natural_language_guided_STFT.utils import encodeBatch2GradioOutput, latent_representation_to_Gradio_image, \ + add_instrument + + +def get_testGAN(gradioWebUI, text2sound_state, virtual_instruments_state): + # Load configurations + gan_generator = gradioWebUI.GAN_generator + freq_resolution, time_resolution = gradioWebUI.freq_resolution, gradioWebUI.time_resolution + VAE_scale = gradioWebUI.VAE_scale + height, width, channels = int(freq_resolution / VAE_scale), int(time_resolution / VAE_scale), gradioWebUI.channels + + timesteps = gradioWebUI.timesteps + VAE_quantizer = gradioWebUI.VAE_quantizer + VAE_decoder = gradioWebUI.VAE_decoder + CLAP = gradioWebUI.CLAP + CLAP_tokenizer = gradioWebUI.CLAP_tokenizer + device = gradioWebUI.device + squared = gradioWebUI.squared + sample_rate = gradioWebUI.sample_rate + noise_strategy = gradioWebUI.noise_strategy + + def gan_random_sample(text2sound_prompts, text2sound_negative_prompts, text2sound_batchsize, + text2sound_duration, + text2sound_guidance_scale, text2sound_sampler, + text2sound_sample_steps, text2sound_seed, + text2sound_dict): + text2sound_seed = safe_int(text2sound_seed, 12345678) + + width = int(time_resolution * ((text2sound_duration + 1) / 4) / VAE_scale) + + text2sound_batchsize = int(text2sound_batchsize) + + text2sound_embedding = \ + CLAP.get_text_features(**CLAP_tokenizer([text2sound_prompts], padding=True, return_tensors="pt"))[0].to( + device) + + CFG = int(text2sound_guidance_scale) + + condition = text2sound_embedding.repeat(text2sound_batchsize, 1) + + noise = torch.randn(text2sound_batchsize, channels, height, width).to(device) + latent_representations = gan_generator(noise, condition) + + print(latent_representations[0, 0, :3, :3]) + + latent_representation_gradio_images = [] + quantized_latent_representation_gradio_images = [] + new_sound_spectrogram_gradio_images = [] + new_sound_rec_signals_gradio = [] + + quantized_latent_representations, loss, (_, _, _) = VAE_quantizer(latent_representations) + # Todo: remove hard-coding + flipped_log_spectrums, rec_signals = encodeBatch2GradioOutput(VAE_decoder, quantized_latent_representations, + resolution=(512, width * VAE_scale), + centralized=False, + squared=squared) + + for i in range(text2sound_batchsize): + latent_representation_gradio_images.append(latent_representation_to_Gradio_image(latent_representations[i])) + quantized_latent_representation_gradio_images.append( + latent_representation_to_Gradio_image(quantized_latent_representations[i])) + new_sound_spectrogram_gradio_images.append(flipped_log_spectrums[i]) + new_sound_rec_signals_gradio.append((sample_rate, rec_signals[i])) + + text2sound_dict["latent_representations"] = latent_representations.to("cpu").detach().numpy() + text2sound_dict["quantized_latent_representations"] = quantized_latent_representations.to("cpu").detach().numpy() + text2sound_dict["latent_representation_gradio_images"] = latent_representation_gradio_images + text2sound_dict["quantized_latent_representation_gradio_images"] = quantized_latent_representation_gradio_images + text2sound_dict["new_sound_spectrogram_gradio_images"] = new_sound_spectrogram_gradio_images + text2sound_dict["new_sound_rec_signals_gradio"] = new_sound_rec_signals_gradio + + text2sound_dict["condition"] = condition.to("cpu").detach().numpy() + # text2sound_dict["negative_condition"] = negative_condition.to("cpu").detach().numpy() + text2sound_dict["guidance_scale"] = CFG + text2sound_dict["sampler"] = text2sound_sampler + + return {text2sound_latent_representation_image: text2sound_dict["latent_representation_gradio_images"][0], + text2sound_quantized_latent_representation_image: + text2sound_dict["quantized_latent_representation_gradio_images"][0], + text2sound_sampled_spectrogram_image: text2sound_dict["new_sound_spectrogram_gradio_images"][0], + text2sound_sampled_audio: text2sound_dict["new_sound_rec_signals_gradio"][0], + text2sound_seed_textbox: text2sound_seed, + text2sound_state: text2sound_dict, + text2sound_sample_index_slider: gr.update(minimum=0, maximum=text2sound_batchsize - 1, value=0, step=1, + visible=True, + label="Sample index.", + info="Swipe to view other samples")} + + def show_random_sample(sample_index, text2sound_dict): + sample_index = int(sample_index) + text2sound_dict["sample_index"] = sample_index + return {text2sound_latent_representation_image: text2sound_dict["latent_representation_gradio_images"][ + sample_index], + text2sound_quantized_latent_representation_image: + text2sound_dict["quantized_latent_representation_gradio_images"][sample_index], + text2sound_sampled_spectrogram_image: text2sound_dict["new_sound_spectrogram_gradio_images"][ + sample_index], + text2sound_sampled_audio: text2sound_dict["new_sound_rec_signals_gradio"][sample_index]} + + + with gr.Tab("Text2sound_GAN"): + gr.Markdown("Use neural networks to select random sounds using your favorite instrument!") + with gr.Row(variant="panel"): + with gr.Column(scale=3): + text2sound_prompts_textbox = gr.Textbox(label="Positive prompt", lines=2, value="organ") + text2sound_negative_prompts_textbox = gr.Textbox(label="Negative prompt", lines=2, value="") + + with gr.Column(scale=1): + text2sound_sampling_button = gr.Button(variant="primary", + value="Generate a batch of samples and show " + "the first one", + scale=1) + text2sound_sample_index_slider = gr.Slider(minimum=0, maximum=3, value=0, step=1.0, visible=False, + label="Sample index", + info="Swipe to view other samples") + with gr.Row(variant="panel"): + with gr.Column(scale=1, variant="panel"): + text2sound_sample_steps_slider = gradioWebUI.get_sample_steps_slider() + text2sound_sampler_radio = gradioWebUI.get_sampler_radio() + text2sound_batchsize_slider = gradioWebUI.get_batchsize_slider() + text2sound_duration_slider = gradioWebUI.get_duration_slider() + text2sound_guidance_scale_slider = gradioWebUI.get_guidance_scale_slider() + text2sound_seed_textbox = gradioWebUI.get_seed_textbox() + + with gr.Column(scale=1): + text2sound_sampled_spectrogram_image = gr.Image(label="Sampled spectrogram", type="numpy", height=420) + text2sound_sampled_audio = gr.Audio(type="numpy", label="Play") + + + with gr.Row(variant="panel"): + text2sound_latent_representation_image = gr.Image(label="Sampled latent representation", type="numpy", + height=200, width=100) + text2sound_quantized_latent_representation_image = gr.Image(label="Quantized latent representation", + type="numpy", height=200, width=100) + + text2sound_sampling_button.click(gan_random_sample, + inputs=[text2sound_prompts_textbox, + text2sound_negative_prompts_textbox, + text2sound_batchsize_slider, + text2sound_duration_slider, + text2sound_guidance_scale_slider, text2sound_sampler_radio, + text2sound_sample_steps_slider, + text2sound_seed_textbox, + text2sound_state], + outputs=[text2sound_latent_representation_image, + text2sound_quantized_latent_representation_image, + text2sound_sampled_spectrogram_image, + text2sound_sampled_audio, + text2sound_seed_textbox, + text2sound_state, + text2sound_sample_index_slider]) + + + text2sound_sample_index_slider.change(show_random_sample, + inputs=[text2sound_sample_index_slider, text2sound_state], + outputs=[text2sound_latent_representation_image, + text2sound_quantized_latent_representation_image, + text2sound_sampled_spectrogram_image, + text2sound_sampled_audio]) diff --git a/webUI/natural_language_guided/README.py b/webUI/natural_language_guided/README.py new file mode 100644 index 0000000000000000000000000000000000000000..fcbcba02b92fd83da9988a0fe6c223fcc841fc52 --- /dev/null +++ b/webUI/natural_language_guided/README.py @@ -0,0 +1,53 @@ +import gradio as gr + +readme_content = """## Stable Diffusion for Sound Generation + +This project applies stable diffusion[1] to sound generation. Inspired by the work of AUTOMATIC1111, 2022[2], we have implemented a preliminary version of text2sound, sound2sound, inpaint, as well as an additional interpolation feature, all accessible through a web UI. + +### Neural Network Training Data: +The neural network is trained using the filtered NSynth dataset[3], which is a large-scale and high-quality collection of annotated musical notes, comprising 305,979 musical notes. However, for this project, only samples with a pitch set to E3 were used, resulting in an actual training sample size of 4,096, making it a low-resource project. + +The training took place on an NVIDIA Tesla T4 GPU and spanned approximately 10 hours. + +### Natural Language Guidance: +Natural language guidance is derived from the multi-label annotations of the NSynth dataset. The labels included in the training are: + +- **Instrument Families**: bass, brass, flute, guitar, keyboard, mallet, organ, reed, string, synth lead, vocal. + +- **Instrument Sources**: acoustic, electronic, synthetic. + +- **Note Qualities**: bright, dark, distortion, fast decay, long release, multiphonic, nonlinear env, percussive, reverb, tempo-synced. + +### Usage Hints: + +1. **Prompt Format**: It's recommended to use the format “label1, label2, label3“, e.g., ”organ, dark, long release“. + +2. **Unique Sounds**: If you keep generating the same sound, try setting a different seed! + +3. **Sample Indexing**: Drag the "Sample index slider" to view other samples within the generated batch. + +4. **Running on CPU**: Be cautious with the settings for 'batchsize' and 'sample_steps' when running on CPU to avoid timeouts. Recommended settings are batchsize ≤ 4 and sample_steps = 15. + +5. **Editing Sounds**: Generated audio can be downloaded and then re-uploaded for further editing at the sound2sound/inpaint sections. + +6. **Guidance Scale**: A higher 'guidance_scale' intensifies the influence of natural language conditioning on the generation[4]. It's recommended to set it between 3 and 10. + +7. **Noising Strength**: A smaller 'noising_strength' value makes the generated sound closer to the input sound. + +References: + +[1] Rombach, R., Blattmann, A., Lorenz, D., Esser, P., & Ommer, B. (2022). High-resolution image synthesis with latent diffusion models. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition (pp. 10684-10695). + +[2] AUTOMATIC1111. (2022). Stable Diffusion Web UI [Computer software]. Retrieved from https://github.com/AUTOMATIC1111/stable-diffusion-webui + +[3] Engel, J., Resnick, C., Roberts, A., Dieleman, S., Eck, D., Simonyan, K., & Norouzi, M. (2017). Neural Audio Synthesis of Musical Notes with WaveNet Autoencoders. + +[4] Ho, J., & Salimans, T. (2022). Classifier-free diffusion guidance. arXiv preprint arXiv:2207.12598. +""" + +def get_readme_module(): + + with gr.Tab("README"): + # gr.Markdown("Use interpolation to generate a gradient sound sequence.") + with gr.Column(scale=3): + readme_textbox = gr.Textbox(label="readme", lines=40, value=readme_content, interactive=False) \ No newline at end of file diff --git a/webUI/natural_language_guided/__pycache__/README.cpython-310.pyc b/webUI/natural_language_guided/__pycache__/README.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c88a11359138a9cafec9f8b18b492f20c393502 Binary files /dev/null and b/webUI/natural_language_guided/__pycache__/README.cpython-310.pyc differ diff --git a/webUI/natural_language_guided/__pycache__/README_STFT.cpython-310.pyc b/webUI/natural_language_guided/__pycache__/README_STFT.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da3edcb31c80b382fd688370218f164b2c580f79 Binary files /dev/null and b/webUI/natural_language_guided/__pycache__/README_STFT.cpython-310.pyc differ diff --git a/webUI/natural_language_guided/__pycache__/buildInstrument_STFT.cpython-310.pyc b/webUI/natural_language_guided/__pycache__/buildInstrument_STFT.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb265c428a5661067127578187f79551dff5f781 Binary files /dev/null and b/webUI/natural_language_guided/__pycache__/buildInstrument_STFT.cpython-310.pyc differ diff --git a/webUI/natural_language_guided/__pycache__/build_instrument.cpython-310.pyc b/webUI/natural_language_guided/__pycache__/build_instrument.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aaedf3fd3e5a26fe7e800163c7ecf98c76449e67 Binary files /dev/null and b/webUI/natural_language_guided/__pycache__/build_instrument.cpython-310.pyc differ diff --git a/webUI/natural_language_guided/__pycache__/gradioWebUI.cpython-310.pyc b/webUI/natural_language_guided/__pycache__/gradioWebUI.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..568a07d46f35b6e6cc68964f74a1d7ca8ed79fa0 Binary files /dev/null and b/webUI/natural_language_guided/__pycache__/gradioWebUI.cpython-310.pyc differ diff --git a/webUI/natural_language_guided/__pycache__/gradioWebUI_STFT.cpython-310.pyc b/webUI/natural_language_guided/__pycache__/gradioWebUI_STFT.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f945df82173dcd8be66d6e80b7aece72bc0f768 Binary files /dev/null and b/webUI/natural_language_guided/__pycache__/gradioWebUI_STFT.cpython-310.pyc differ diff --git a/webUI/natural_language_guided/__pycache__/gradio_webUI.cpython-310.pyc b/webUI/natural_language_guided/__pycache__/gradio_webUI.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3aff3c3afb99e9f80540e92d55753006484a9a3 Binary files /dev/null and b/webUI/natural_language_guided/__pycache__/gradio_webUI.cpython-310.pyc differ diff --git a/webUI/natural_language_guided/__pycache__/inpaintWithText.cpython-310.pyc b/webUI/natural_language_guided/__pycache__/inpaintWithText.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24a2d3849357672fdf471691a978c1f340e4ad25 Binary files /dev/null and b/webUI/natural_language_guided/__pycache__/inpaintWithText.cpython-310.pyc differ diff --git a/webUI/natural_language_guided/__pycache__/inpaintWithText_STFT.cpython-310.pyc b/webUI/natural_language_guided/__pycache__/inpaintWithText_STFT.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0190af21f9037966b85e076a66effb15b7bcf104 Binary files /dev/null and b/webUI/natural_language_guided/__pycache__/inpaintWithText_STFT.cpython-310.pyc differ diff --git a/webUI/natural_language_guided/__pycache__/inpaint_with_text.cpython-310.pyc b/webUI/natural_language_guided/__pycache__/inpaint_with_text.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a842bb30c0eb41ce08734731886dcc19ee6262d3 Binary files /dev/null and b/webUI/natural_language_guided/__pycache__/inpaint_with_text.cpython-310.pyc differ diff --git a/webUI/natural_language_guided/__pycache__/rec.cpython-310.pyc b/webUI/natural_language_guided/__pycache__/rec.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..00e0f414a916f91b475bf1b858bbdae2c86927ac Binary files /dev/null and b/webUI/natural_language_guided/__pycache__/rec.cpython-310.pyc differ diff --git a/webUI/natural_language_guided/__pycache__/recSTFT.cpython-310.pyc b/webUI/natural_language_guided/__pycache__/recSTFT.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5897c3d8f88aad4c72d941ecf90cab61d97f883b Binary files /dev/null and b/webUI/natural_language_guided/__pycache__/recSTFT.cpython-310.pyc differ diff --git a/webUI/natural_language_guided/__pycache__/sound2soundWithText.cpython-310.pyc b/webUI/natural_language_guided/__pycache__/sound2soundWithText.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1027af298bb437ad327c81b56fcf3ee704b33509 Binary files /dev/null and b/webUI/natural_language_guided/__pycache__/sound2soundWithText.cpython-310.pyc differ diff --git a/webUI/natural_language_guided/__pycache__/sound2soundWithText_STFT.cpython-310.pyc b/webUI/natural_language_guided/__pycache__/sound2soundWithText_STFT.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3955de3b1176e62901f4dd78e34f870322ccda1a Binary files /dev/null and b/webUI/natural_language_guided/__pycache__/sound2soundWithText_STFT.cpython-310.pyc differ diff --git a/webUI/natural_language_guided/__pycache__/sound2sound_with_text.cpython-310.pyc b/webUI/natural_language_guided/__pycache__/sound2sound_with_text.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1a72557c0fe2c9a43471d310e5104ba4af602d6 Binary files /dev/null and b/webUI/natural_language_guided/__pycache__/sound2sound_with_text.cpython-310.pyc differ diff --git a/webUI/natural_language_guided/__pycache__/text2sound.cpython-310.pyc b/webUI/natural_language_guided/__pycache__/text2sound.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a442190b58d7f95800b299a8fa7780dde99e0e3 Binary files /dev/null and b/webUI/natural_language_guided/__pycache__/text2sound.cpython-310.pyc differ diff --git a/webUI/natural_language_guided/__pycache__/text2sound_STFT.cpython-310.pyc b/webUI/natural_language_guided/__pycache__/text2sound_STFT.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c6c3d8c96e70bc3f57e0c0603f0819fc283f50a Binary files /dev/null and b/webUI/natural_language_guided/__pycache__/text2sound_STFT.cpython-310.pyc differ diff --git a/webUI/natural_language_guided/__pycache__/track_maker.cpython-310.pyc b/webUI/natural_language_guided/__pycache__/track_maker.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5509f0428dddd9df2c62cf2a80b7f2a26246da15 Binary files /dev/null and b/webUI/natural_language_guided/__pycache__/track_maker.cpython-310.pyc differ diff --git a/webUI/natural_language_guided/__pycache__/utils.cpython-310.pyc b/webUI/natural_language_guided/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..963640111ac11e0beb6fd144589446fa8898c313 Binary files /dev/null and b/webUI/natural_language_guided/__pycache__/utils.cpython-310.pyc differ diff --git a/webUI/natural_language_guided/build_instrument.py b/webUI/natural_language_guided/build_instrument.py new file mode 100644 index 0000000000000000000000000000000000000000..25dd102f783fc7bfb400b4d67a0e1abe7aa226f8 --- /dev/null +++ b/webUI/natural_language_guided/build_instrument.py @@ -0,0 +1,274 @@ +import librosa +import numpy as np +import torch +import gradio as gr +import mido +from io import BytesIO +import pyrubberband as pyrb + +from model.DiffSynthSampler import DiffSynthSampler +from tools import adsr_envelope, adjust_audio_length +from webUI.natural_language_guided.track_maker import DiffSynth +from webUI.natural_language_guided.utils import encodeBatch2GradioOutput_STFT, phase_to_Gradio_image, \ + spectrogram_to_Gradio_image + + +def get_build_instrument_module(gradioWebUI, virtual_instruments_state): + # Load configurations + uNet = gradioWebUI.uNet + freq_resolution, time_resolution = gradioWebUI.freq_resolution, gradioWebUI.time_resolution + VAE_scale = gradioWebUI.VAE_scale + height, width, channels = int(freq_resolution / VAE_scale), int(time_resolution / VAE_scale), gradioWebUI.channels + + timesteps = gradioWebUI.timesteps + VAE_quantizer = gradioWebUI.VAE_quantizer + VAE_decoder = gradioWebUI.VAE_decoder + CLAP = gradioWebUI.CLAP + CLAP_tokenizer = gradioWebUI.CLAP_tokenizer + device = gradioWebUI.device + squared = gradioWebUI.squared + sample_rate = gradioWebUI.sample_rate + noise_strategy = gradioWebUI.noise_strategy + + def select_sound(virtual_instrument_name, virtual_instruments_dict): + virtual_instruments = virtual_instruments_dict["virtual_instruments"] + virtual_instrument = virtual_instruments[virtual_instrument_name] + + return {source_sound_spectrogram_image: virtual_instrument["spectrogram_gradio_image"], + source_sound_phase_image: virtual_instrument["phase_gradio_image"], + source_sound_audio: virtual_instrument["signal"]} + + def make_track(inpaint_steps, midi, noising_strength, attack, before_release, instrument_names, virtual_instruments_dict): + + if noising_strength < 1: + print(f"Warning: making track with noising_strength = {noising_strength} < 1") + virtual_instruments = virtual_instruments_dict["virtual_instruments"] + sample_steps = int(inpaint_steps) + + instrument_names = instrument_names.split("@") + instruments_configs = {} + for virtual_instrument_name in instrument_names: + virtual_instrument = virtual_instruments[virtual_instrument_name] + + latent_representation = torch.tensor(virtual_instrument["latent_representation"], dtype=torch.float32).to(device) + sampler = virtual_instrument["sampler"] + + batchsize = 1 + + latent_representation = latent_representation.repeat(batchsize, 1, 1, 1) + + mid = mido.MidiFile(file=BytesIO(midi)) + instruments_configs[virtual_instrument_name] = { + 'sample_steps': sample_steps, + 'sampler': sampler, + 'noising_strength': noising_strength, + 'latent_representation': latent_representation, + 'attack': attack, + 'before_release': before_release} + + diffSynth = DiffSynth(instruments_configs, uNet, VAE_quantizer, VAE_decoder, CLAP, CLAP_tokenizer, device) + + full_audio = diffSynth.get_music(mid, instrument_names) + + return {track_audio: (sample_rate, full_audio)} + + def test_duration_inpaint(virtual_instrument_name, inpaint_steps, duration, noising_strength, end_noise_level_ratio, attack, before_release, mask_flexivity, virtual_instruments_dict, use_dynamic_mask): + width = int(time_resolution * ((duration + 1) / 4) / VAE_scale) + + virtual_instruments = virtual_instruments_dict["virtual_instruments"] + virtual_instrument = virtual_instruments[virtual_instrument_name] + + latent_representation = torch.tensor(virtual_instrument["latent_representation"], dtype=torch.float32).to(device) + sample_steps = int(inpaint_steps) + sampler = virtual_instrument["sampler"] + batchsize = 1 + + mySampler = DiffSynthSampler(timesteps, height=height, channels=channels, noise_strategy=noise_strategy) + mySampler.respace(list(np.linspace(0, timesteps - 1, sample_steps, dtype=np.int32))) + + latent_representation = latent_representation.repeat(batchsize, 1, 1, 1) + + # mask = 1, freeze + latent_mask = torch.zeros((batchsize, 1, height, width), dtype=torch.float32).to(device) + + latent_mask[:, :, :, :int(time_resolution * (attack / 4) / VAE_scale)] = 1.0 + latent_mask[:, :, :, -int(time_resolution * ((before_release+1) / 4) / VAE_scale):] = 1.0 + + + text2sound_embedding = \ + CLAP.get_text_features(**CLAP_tokenizer([""], padding=True, return_tensors="pt"))[0].to( + device) + condition = text2sound_embedding.repeat(1, 1) + + + latent_representations, initial_noise = \ + mySampler.inpaint_sample(model=uNet, shape=(batchsize, channels, height, width), + noising_strength=noising_strength, + guide_img=latent_representation, mask=latent_mask, return_tensor=True, + condition=condition, sampler=sampler, + use_dynamic_mask=use_dynamic_mask, + end_noise_level_ratio=end_noise_level_ratio, + mask_flexivity=mask_flexivity) + + latent_representations = latent_representations[-1] + + quantized_latent_representations, loss, (_, _, _) = VAE_quantizer(latent_representations) + # Todo: remove hard-coding + flipped_log_spectrums, flipped_phases, rec_signals, _, _, _ = encodeBatch2GradioOutput_STFT(VAE_decoder, + quantized_latent_representations, + resolution=( + 512, + width * VAE_scale), + original_STFT_batch=None + ) + + + return {test_duration_spectrogram_image: flipped_log_spectrums[0], + test_duration_phase_image: flipped_phases[0], + test_duration_audio: (sample_rate, rec_signals[0])} + + def test_duration_envelope(virtual_instrument_name, duration, noising_strength, attack, before_release, release, virtual_instruments_dict): + + virtual_instruments = virtual_instruments_dict["virtual_instruments"] + virtual_instrument = virtual_instruments[virtual_instrument_name] + sample_rate, signal = virtual_instrument["signal"] + + applied_signal = adsr_envelope(signal=signal, sample_rate=sample_rate, duration=duration, + attack_time=0.0, decay_time=0.0, sustain_level=1.0, release_time=release) + + D = librosa.stft(applied_signal, n_fft=1024, hop_length=256, win_length=1024)[1:, :] + spc = np.abs(D) + phase = np.angle(D) + + flipped_log_spectrum = spectrogram_to_Gradio_image(spc) + flipped_phase = phase_to_Gradio_image(phase) + + return {test_duration_spectrogram_image: flipped_log_spectrum, + test_duration_phase_image: flipped_phase, + test_duration_audio: (sample_rate, applied_signal)} + + def test_duration_stretch(virtual_instrument_name, duration, noising_strength, attack, before_release, release, virtual_instruments_dict): + + virtual_instruments = virtual_instruments_dict["virtual_instruments"] + virtual_instrument = virtual_instruments[virtual_instrument_name] + sample_rate, signal = virtual_instrument["signal"] + + s = 3 / duration + applied_signal = pyrb.time_stretch(signal, sample_rate, s) + applied_signal = adjust_audio_length(applied_signal, int((duration+1) * sample_rate), sample_rate, sample_rate) + + D = librosa.stft(applied_signal, n_fft=1024, hop_length=256, win_length=1024)[1:, :] + spc = np.abs(D) + phase = np.angle(D) + + flipped_log_spectrum = spectrogram_to_Gradio_image(spc) + flipped_phase = phase_to_Gradio_image(phase) + + return {test_duration_spectrogram_image: flipped_log_spectrum, + test_duration_phase_image: flipped_phase, + test_duration_audio: (sample_rate, applied_signal)} + + + with gr.Tab("TestInTrack"): + gr.Markdown("Make music with generated sounds!") + with gr.Row(variant="panel"): + with gr.Column(scale=3): + instrument_name_textbox = gr.Textbox(label="Instrument name", lines=1, + placeholder="Name of your instrument", scale=1) + select_instrument_button = gr.Button(variant="primary", value="Select", scale=1) + with gr.Column(scale=3): + inpaint_steps_slider = gr.Slider(minimum=5.0, maximum=999.0, value=20.0, step=1.0, label="inpaint_steps") + noising_strength_slider = gradioWebUI.get_noising_strength_slider(default_noising_strength=1.) + end_noise_level_ratio_slider = gr.Slider(minimum=0.0, maximum=1., value=0.0, step=0.01, label="end_noise_level_ratio") + attack_slider = gr.Slider(minimum=0.0, maximum=1.5, value=0.5, step=0.01, label="attack in sec") + before_release_slider = gr.Slider(minimum=0.0, maximum=1.5, value=0.5, step=0.01, label="before_release in sec") + release_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.01, label="release in sec") + mask_flexivity_slider = gr.Slider(minimum=0.01, maximum=1.00, value=1., step=0.01, label="mask_flexivity") + with gr.Column(scale=3): + use_dynamic_mask_checkbox = gr.Checkbox(label="Use dynamic mask", value=True) + test_duration_envelope_button = gr.Button(variant="primary", value="Apply envelope", scale=1) + test_duration_stretch_button = gr.Button(variant="primary", value="Apply stretch", scale=1) + test_duration_inpaint_button = gr.Button(variant="primary", value="Inpaint different duration", scale=1) + duration_slider = gradioWebUI.get_duration_slider() + + with gr.Row(variant="panel"): + with gr.Column(scale=2): + with gr.Row(variant="panel"): + source_sound_spectrogram_image = gr.Image(label="New sound spectrogram", type="numpy", + height=600, scale=1) + source_sound_phase_image = gr.Image(label="New sound phase", type="numpy", + height=600, scale=1) + source_sound_audio = gr.Audio(type="numpy", label="Play new sound", interactive=False) + + with gr.Column(scale=3): + with gr.Row(variant="panel"): + test_duration_spectrogram_image = gr.Image(label="New sound spectrogram", type="numpy", + height=600, scale=1) + test_duration_phase_image = gr.Image(label="New sound phase", type="numpy", + height=600, scale=1) + test_duration_audio = gr.Audio(type="numpy", label="Play new sound", interactive=False) + + with gr.Row(variant="panel"): + with gr.Column(scale=1): + # track_spectrogram_image = gr.Image(label="New sound spectrogram", type="numpy", + # height=420, scale=1) + midi_file = gr.File(label="Upload midi file", type="binary") + instrument_names_textbox = gr.Textbox(label="Instrument names", lines=2, + placeholder="Names of your instrument used to play the midi", scale=1) + track_audio = gr.Audio(type="numpy", label="Play new sound", interactive=False) + make_track_button = gr.Button(variant="primary", value="Make track", scale=1) + + select_instrument_button.click(select_sound, + inputs=[instrument_name_textbox, virtual_instruments_state], + outputs=[source_sound_spectrogram_image, + source_sound_phase_image, + source_sound_audio]) + + test_duration_envelope_button.click(test_duration_envelope, + inputs=[instrument_name_textbox, duration_slider, + noising_strength_slider, + attack_slider, + before_release_slider, + release_slider, + virtual_instruments_state, + ], + outputs=[test_duration_spectrogram_image, + test_duration_phase_image, + test_duration_audio]) + + test_duration_stretch_button.click(test_duration_stretch, + inputs=[instrument_name_textbox, duration_slider, + noising_strength_slider, + attack_slider, + before_release_slider, + release_slider, + virtual_instruments_state, + ], + outputs=[test_duration_spectrogram_image, + test_duration_phase_image, + test_duration_audio]) + + test_duration_inpaint_button.click(test_duration_inpaint, + inputs=[instrument_name_textbox, + inpaint_steps_slider, + duration_slider, + noising_strength_slider, + end_noise_level_ratio_slider, + attack_slider, + before_release_slider, + mask_flexivity_slider, + virtual_instruments_state, + use_dynamic_mask_checkbox], + outputs=[test_duration_spectrogram_image, + test_duration_phase_image, + test_duration_audio]) + + make_track_button.click(make_track, + inputs=[inpaint_steps_slider, midi_file, + noising_strength_slider, + attack_slider, + before_release_slider, + instrument_names_textbox, + virtual_instruments_state], + outputs=[track_audio]) + diff --git a/webUI/natural_language_guided/gradio_webUI.py b/webUI/natural_language_guided/gradio_webUI.py new file mode 100644 index 0000000000000000000000000000000000000000..3f91b98e7559dd92baef496109025c80a559a124 --- /dev/null +++ b/webUI/natural_language_guided/gradio_webUI.py @@ -0,0 +1,68 @@ +import gradio as gr + + +class GradioWebUI(): + + def __init__(self, device, VAE, uNet, CLAP, CLAP_tokenizer, + freq_resolution=512, time_resolution=256, channels=4, timesteps=1000, + sample_rate=16000, squared=False, VAE_scale=4, + flexible_duration=False, noise_strategy="repeat", + GAN_generator = None): + self.device = device + self.VAE_encoder, self.VAE_quantizer, self.VAE_decoder = VAE._encoder, VAE._vq_vae, VAE._decoder + self.uNet = uNet + self.CLAP, self.CLAP_tokenizer = CLAP, CLAP_tokenizer + self.freq_resolution, self.time_resolution = freq_resolution, time_resolution + self.channels = channels + self.GAN_generator = GAN_generator + + self.timesteps = timesteps + self.sample_rate = sample_rate + self.squared = squared + self.VAE_scale = VAE_scale + self.flexible_duration = flexible_duration + self.noise_strategy = noise_strategy + + self.text2sound_state = gr.State(value={}) + self.interpolation_state = gr.State(value={}) + self.sound2sound_state = gr.State(value={}) + self.inpaint_state = gr.State(value={}) + + def get_sample_steps_slider(self): + default_steps = 10 if (self.device == "cpu") else 20 + return gr.Slider(minimum=10, maximum=100, value=default_steps, step=1, + label="Sample steps", + info="Sampling steps. The more sampling steps, the better the " + "theoretical result, but the time it consumes.") + + def get_sampler_radio(self): + # return gr.Radio(choices=["ddpm", "ddim", "dpmsolver++", "dpmsolver"], value="ddim", label="Sampler") + return gr.Radio(choices=["ddpm", "ddim"], value="ddim", label="Sampler") + + def get_batchsize_slider(self, cpu_batchsize=1): + return gr.Slider(minimum=1., maximum=16, value=cpu_batchsize if (self.device == "cpu") else 8, step=1, label="Batchsize") + + def get_time_resolution_slider(self): + return gr.Slider(minimum=16., maximum=int(1024/self.VAE_scale), value=int(256/self.VAE_scale), step=1, label="Time resolution", interactive=True) + + def get_duration_slider(self): + if self.flexible_duration: + return gr.Slider(minimum=0.25, maximum=8., value=3., step=0.01, label="duration in sec") + else: + return gr.Slider(minimum=1., maximum=8., value=3., step=1., label="duration in sec") + + def get_guidance_scale_slider(self): + return gr.Slider(minimum=0., maximum=20., value=6., step=1., + label="Guidance scale", + info="The larger this value, the more the generated sound is " + "influenced by the condition. Setting it to 0 is equivalent to " + "the negative case.") + + def get_noising_strength_slider(self, default_noising_strength=0.7): + return gr.Slider(minimum=0.0, maximum=1.00, value=default_noising_strength, step=0.01, + label="noising strength", + info="The smaller this value, the more the generated sound is " + "closed to the origin.") + + def get_seed_textbox(self): + return gr.Textbox(label="Seed", lines=1, placeholder="seed", value=0) diff --git a/webUI/natural_language_guided/inpaint_with_text.py b/webUI/natural_language_guided/inpaint_with_text.py new file mode 100644 index 0000000000000000000000000000000000000000..b0de29922ae0cff6f9c78efaf7c4839c78c5a2f4 --- /dev/null +++ b/webUI/natural_language_guided/inpaint_with_text.py @@ -0,0 +1,441 @@ +import librosa +import numpy as np +import torch +import gradio as gr +from scipy.ndimage import zoom + +from model.DiffSynthSampler import DiffSynthSampler +from tools import adjust_audio_length, safe_int, pad_STFT, encode_stft +from webUI.natural_language_guided.utils import latent_representation_to_Gradio_image, InputBatch2Encode_STFT, encodeBatch2GradioOutput_STFT, add_instrument + + +def get_triangle_mask(height, width): + mask = np.zeros((height, width)) + slope = 8 / 3 + for i in range(height): + for j in range(width): + if i < slope * j: + mask[i, j] = 1 + return mask + + +def get_inpaint_with_text_module(gradioWebUI, inpaintWithText_state, virtual_instruments_state): + # Load configurations + uNet = gradioWebUI.uNet + freq_resolution, time_resolution = gradioWebUI.freq_resolution, gradioWebUI.time_resolution + VAE_scale = gradioWebUI.VAE_scale + height, width, channels = int(freq_resolution/VAE_scale), int(time_resolution/VAE_scale), gradioWebUI.channels + timesteps = gradioWebUI.timesteps + VAE_encoder = gradioWebUI.VAE_encoder + VAE_quantizer = gradioWebUI.VAE_quantizer + VAE_decoder = gradioWebUI.VAE_decoder + CLAP = gradioWebUI.CLAP + CLAP_tokenizer = gradioWebUI.CLAP_tokenizer + device = gradioWebUI.device + squared = gradioWebUI.squared + sample_rate = gradioWebUI.sample_rate + noise_strategy = gradioWebUI.noise_strategy + + def receive_uopoad_origin_audio(sound2sound_duration, sound2sound_origin_source, sound2sound_origin_upload, sound2sound_origin_microphone, + inpaintWithText_dict): + + if sound2sound_origin_source == "upload": + origin_sr, origin_audio = sound2sound_origin_upload + else: + origin_sr, origin_audio = sound2sound_origin_microphone + + origin_audio = origin_audio / np.max(np.abs(origin_audio)) + + width = int(time_resolution*((sound2sound_duration+1)/4) / VAE_scale) + audio_length = 256 * (VAE_scale * width - 1) + origin_audio = adjust_audio_length(origin_audio, audio_length, origin_sr, sample_rate) + + D = librosa.stft(origin_audio, n_fft=1024, hop_length=256, win_length=1024) + padded_D = pad_STFT(D) + encoded_D = encode_stft(padded_D) + + # Todo: justify batchsize to 1 + origin_spectrogram_batch_tensor = torch.from_numpy( + np.repeat(encoded_D[np.newaxis, :, :, :], 1, axis=0)).float().to(device) + + # Todo: remove hard-coding + origin_flipped_log_spectrums, origin_flipped_phases, origin_signals, origin_latent_representations, quantized_origin_latent_representations = InputBatch2Encode_STFT( + VAE_encoder, origin_spectrogram_batch_tensor, resolution=(512, width * VAE_scale), quantizer=VAE_quantizer, squared=squared) + + if sound2sound_origin_source == "upload": + inpaintWithText_dict["origin_upload_latent_representations"] = origin_latent_representations.tolist() + inpaintWithText_dict[ + "sound2sound_origin_upload_latent_representation_image"] = latent_representation_to_Gradio_image( + origin_latent_representations[0]).tolist() + inpaintWithText_dict[ + "sound2sound_origin_upload_quantized_latent_representation_image"] = latent_representation_to_Gradio_image( + quantized_origin_latent_representations[0]).tolist() + return {sound2sound_origin_spectrogram_upload_image: origin_flipped_log_spectrums[0], + sound2sound_origin_phase_upload_image: origin_flipped_phases[0], + sound2sound_origin_spectrogram_microphone_image: gr.update(), + sound2sound_origin_phase_microphone_image: gr.update(), + sound2sound_origin_upload_latent_representation_image: latent_representation_to_Gradio_image( + origin_latent_representations[0]), + sound2sound_origin_upload_quantized_latent_representation_image: latent_representation_to_Gradio_image( + quantized_origin_latent_representations[0]), + sound2sound_origin_microphone_latent_representation_image: gr.update(), + sound2sound_origin_microphone_quantized_latent_representation_image: gr.update(), + inpaintWithText_state: inpaintWithText_dict} + else: + inpaintWithText_dict["origin_microphone_latent_representations"] = origin_latent_representations.tolist() + inpaintWithText_dict[ + "sound2sound_origin_microphone_latent_representation_image"] = latent_representation_to_Gradio_image( + origin_latent_representations[0]).tolist() + inpaintWithText_dict[ + "sound2sound_origin_microphone_quantized_latent_representation_image"] = latent_representation_to_Gradio_image( + quantized_origin_latent_representations[0]).tolist() + return {sound2sound_origin_spectrogram_upload_image: origin_flipped_log_spectrums[0], + sound2sound_origin_phase_upload_image: origin_flipped_phases[0], + sound2sound_origin_spectrogram_microphone_image: gr.update(), + sound2sound_origin_phase_microphone_image: gr.update(), + sound2sound_origin_upload_latent_representation_image: latent_representation_to_Gradio_image( + origin_latent_representations[0]), + sound2sound_origin_upload_quantized_latent_representation_image: latent_representation_to_Gradio_image( + quantized_origin_latent_representations[0]), + sound2sound_origin_microphone_latent_representation_image: gr.update(), + sound2sound_origin_microphone_quantized_latent_representation_image: gr.update(), + inpaintWithText_state: inpaintWithText_dict} + + def sound2sound_sample(sound2sound_origin_spectrogram_upload, sound2sound_origin_spectrogram_microphone, + text2sound_prompts, text2sound_negative_prompts, sound2sound_batchsize, + sound2sound_guidance_scale, sound2sound_sampler, + sound2sound_sample_steps, sound2sound_origin_source, + sound2sound_noising_strength, sound2sound_seed, sound2sound_inpaint_area, + mask_time_begin, mask_time_end, mask_frequency_begin, mask_frequency_end, inpaintWithText_dict + ): + + # input preprocessing + sound2sound_seed = safe_int(sound2sound_seed, 12345678) + sound2sound_batchsize = int(sound2sound_batchsize) + noising_strength = sound2sound_noising_strength + sound2sound_sample_steps = int(sound2sound_sample_steps) + CFG = int(sound2sound_guidance_scale) + + text2sound_embedding = \ + CLAP.get_text_features(**CLAP_tokenizer([text2sound_prompts], padding=True, return_tensors="pt"))[0].to(device) + + if sound2sound_origin_source == "upload": + origin_latent_representations = torch.tensor( + inpaintWithText_dict["origin_upload_latent_representations"]).repeat(sound2sound_batchsize, 1, 1, 1).to( + device) + mask = np.array(sound2sound_origin_spectrogram_upload["mask"]) + elif sound2sound_origin_source == "microphone": + origin_latent_representations = torch.tensor( + inpaintWithText_dict["origin_microphone_latent_representations"]).repeat(sound2sound_batchsize, 1, 1, 1).to( + device) + mask = np.array(sound2sound_origin_spectrogram_microphone["mask"]) + else: + print("Input source not in ['upload', 'microphone']!") + raise NotImplementedError() + + merged_mask = np.all(mask == 255, axis=2).astype(np.uint8) + latent_mask = zoom(merged_mask, (1 / VAE_scale, 1 / VAE_scale)) + latent_mask = np.clip(latent_mask, 0, 1) + print(f"latent_mask.avg = {np.mean(latent_mask)}") + latent_mask[int(mask_frequency_begin):int(mask_frequency_end), int(mask_time_begin*time_resolution/(VAE_scale*4)):int(mask_time_end*time_resolution/(VAE_scale*4))] = 1 + + + # latent_mask = get_triangle_mask(128, 64) + + print(f"latent_mask.avg = {np.mean(latent_mask)}") + if sound2sound_inpaint_area == "inpaint masked": + latent_mask = 1 - latent_mask + latent_mask = torch.from_numpy(latent_mask).unsqueeze(0).unsqueeze(1).repeat(sound2sound_batchsize, channels, 1, + 1).float().to(device) + latent_mask = torch.flip(latent_mask, [2]) + + mySampler = DiffSynthSampler(timesteps, height=height, channels=channels, noise_strategy=noise_strategy) + unconditional_condition = \ + CLAP.get_text_features(**CLAP_tokenizer([text2sound_negative_prompts], padding=True, return_tensors="pt"))[0] + mySampler.activate_classifier_free_guidance(CFG, unconditional_condition.to(device)) + + normalized_sample_steps = int(sound2sound_sample_steps / noising_strength) + + mySampler.respace(list(np.linspace(0, timesteps - 1, normalized_sample_steps, dtype=np.int32))) + + # Todo: remove hard-coding + width = origin_latent_representations.shape[-1] + condition = text2sound_embedding.repeat(sound2sound_batchsize, 1) + + new_sound_latent_representations, initial_noise = \ + mySampler.inpaint_sample(model=uNet, shape=(sound2sound_batchsize, channels, height, width), + seed=sound2sound_seed, + noising_strength=noising_strength, + guide_img=origin_latent_representations, mask=latent_mask, return_tensor=True, + condition=condition, sampler=sound2sound_sampler) + + new_sound_latent_representations = new_sound_latent_representations[-1] + + # Quantize new sound latent representations + quantized_new_sound_latent_representations, loss, (_, _, _) = VAE_quantizer(new_sound_latent_representations) + new_sound_flipped_log_spectrums, new_sound_flipped_phases, new_sound_signals, _, _, _ = encodeBatch2GradioOutput_STFT(VAE_decoder, + quantized_new_sound_latent_representations, + resolution=( + 512, + width * VAE_scale), + original_STFT_batch=None + ) + + new_sound_latent_representation_gradio_images = [] + new_sound_quantized_latent_representation_gradio_images = [] + new_sound_spectrogram_gradio_images = [] + new_sound_phase_gradio_images = [] + new_sound_rec_signals_gradio = [] + for i in range(sound2sound_batchsize): + new_sound_latent_representation_gradio_images.append( + latent_representation_to_Gradio_image(new_sound_latent_representations[i])) + new_sound_quantized_latent_representation_gradio_images.append( + latent_representation_to_Gradio_image(quantized_new_sound_latent_representations[i])) + new_sound_spectrogram_gradio_images.append(new_sound_flipped_log_spectrums[i]) + new_sound_phase_gradio_images.append(new_sound_flipped_phases[i]) + new_sound_rec_signals_gradio.append((sample_rate, new_sound_signals[i])) + + inpaintWithText_dict[ + "new_sound_latent_representation_gradio_images"] = new_sound_latent_representation_gradio_images + inpaintWithText_dict[ + "new_sound_quantized_latent_representation_gradio_images"] = new_sound_quantized_latent_representation_gradio_images + inpaintWithText_dict["new_sound_spectrogram_gradio_images"] = new_sound_spectrogram_gradio_images + inpaintWithText_dict["new_sound_phase_gradio_images"] = new_sound_phase_gradio_images + inpaintWithText_dict["new_sound_rec_signals_gradio"] = new_sound_rec_signals_gradio + + inpaintWithText_dict["latent_representations"] = new_sound_latent_representations.to("cpu").detach().numpy() + inpaintWithText_dict["quantized_latent_representations"] = quantized_new_sound_latent_representations.to("cpu").detach().numpy() + inpaintWithText_dict["sampler"] = sound2sound_sampler + + return {sound2sound_new_sound_latent_representation_image: latent_representation_to_Gradio_image( + new_sound_latent_representations[0]), + sound2sound_new_sound_quantized_latent_representation_image: latent_representation_to_Gradio_image( + quantized_new_sound_latent_representations[0]), + sound2sound_new_sound_spectrogram_image: new_sound_flipped_log_spectrums[0], + sound2sound_new_sound_phase_image: new_sound_flipped_phases[0], + sound2sound_new_sound_audio: (sample_rate, new_sound_signals[0]), + sound2sound_sample_index_slider: gr.update(minimum=0, maximum=sound2sound_batchsize - 1, value=0, + step=1.0, + visible=True, + label="Sample index", + info="Swipe to view other samples"), + sound2sound_seed_textbox: sound2sound_seed, + inpaintWithText_state: inpaintWithText_dict} + + def show_sound2sound_sample(sound2sound_sample_index, inpaintWithText_dict): + sample_index = int(sound2sound_sample_index) + return {sound2sound_new_sound_latent_representation_image: + inpaintWithText_dict["new_sound_latent_representation_gradio_images"][sample_index], + sound2sound_new_sound_quantized_latent_representation_image: + inpaintWithText_dict["new_sound_quantized_latent_representation_gradio_images"][sample_index], + sound2sound_new_sound_spectrogram_image: inpaintWithText_dict["new_sound_spectrogram_gradio_images"][ + sample_index], + sound2sound_new_sound_phase_image: inpaintWithText_dict["new_sound_phase_gradio_images"][ + sample_index], + sound2sound_new_sound_audio: inpaintWithText_dict["new_sound_rec_signals_gradio"][sample_index]} + + def sound2sound_switch_origin_source(sound2sound_origin_source): + + if sound2sound_origin_source == "upload": + return {sound2sound_origin_upload_audio: gr.update(visible=True), + sound2sound_origin_microphone_audio: gr.update(visible=False), + sound2sound_origin_spectrogram_upload_image: gr.update(visible=True), + sound2sound_origin_phase_upload_image: gr.update(visible=True), + sound2sound_origin_spectrogram_microphone_image: gr.update(visible=False), + sound2sound_origin_phase_microphone_image: gr.update(visible=False), + sound2sound_origin_upload_latent_representation_image: gr.update(visible=True), + sound2sound_origin_upload_quantized_latent_representation_image: gr.update(visible=True), + sound2sound_origin_microphone_latent_representation_image: gr.update(visible=False), + sound2sound_origin_microphone_quantized_latent_representation_image: gr.update(visible=False)} + elif sound2sound_origin_source == "microphone": + return {sound2sound_origin_upload_audio: gr.update(visible=False), + sound2sound_origin_microphone_audio: gr.update(visible=True), + sound2sound_origin_spectrogram_upload_image: gr.update(visible=False), + sound2sound_origin_phase_upload_image: gr.update(visible=False), + sound2sound_origin_spectrogram_microphone_image: gr.update(visible=True), + sound2sound_origin_phase_microphone_image: gr.update(visible=True), + sound2sound_origin_upload_latent_representation_image: gr.update(visible=False), + sound2sound_origin_upload_quantized_latent_representation_image: gr.update(visible=False), + sound2sound_origin_microphone_latent_representation_image: gr.update(visible=True), + sound2sound_origin_microphone_quantized_latent_representation_image: gr.update(visible=True)} + else: + print("Input source not in ['upload', 'microphone']!") + + def save_virtual_instrument(sample_index, virtual_instrument_name, sound2sound_dict, virtual_instruments_dict): + + virtual_instruments_dict = add_instrument(sound2sound_dict, virtual_instruments_dict, virtual_instrument_name, sample_index) + return {virtual_instruments_state: virtual_instruments_dict, + sound2sound_instrument_name_textbox: gr.Textbox(label="Instrument name", lines=1, + placeholder=f"Saved as {virtual_instrument_name}!")} + + with gr.Tab("Inpaint"): + gr.Markdown("Select the area to inpaint and use the prompt to guide the synthesis of a new sound!") + with gr.Row(variant="panel"): + with gr.Column(scale=3): + text2sound_prompts_textbox = gr.Textbox(label="Positive prompt", lines=2, value="organ") + text2sound_negative_prompts_textbox = gr.Textbox(label="Negative prompt", lines=2, value="") + + with gr.Column(scale=1): + sound2sound_sample_button = gr.Button(variant="primary", value="Generate", scale=1) + + sound2sound_sample_index_slider = gr.Slider(minimum=0, maximum=3, value=0, step=1.0, visible=False, + label="Sample index", + info="Swipe to view other samples") + + with gr.Row(variant="panel"): + with gr.Column(scale=1): + with gr.Tab("Origin sound"): + sound2sound_duration_slider = gradioWebUI.get_duration_slider() + sound2sound_origin_source_radio = gr.Radio(choices=["upload", "microphone"], value="upload", + label="Input source") + + sound2sound_origin_upload_audio = gr.Audio(type="numpy", label="Upload", source="upload", + interactive=True, visible=True) + sound2sound_origin_microphone_audio = gr.Audio(type="numpy", label="Record", source="microphone", + interactive=True, visible=False) + with gr.Row(variant="panel"): + sound2sound_origin_spectrogram_upload_image = gr.Image(label="Original upload spectrogram", + type="numpy", height=600, + visible=True, tool="sketch") + sound2sound_origin_phase_upload_image = gr.Image(label="Original upload phase", + type="numpy", height=600, + visible=True) + sound2sound_origin_spectrogram_microphone_image = gr.Image(label="Original microphone spectrogram", + type="numpy", height=600, + visible=False, tool="sketch") + sound2sound_origin_phase_microphone_image = gr.Image(label="Original microphone phase", + type="numpy", height=600, + visible=False) + sound2sound_inpaint_area_radio = gr.Radio(choices=["inpaint masked", "inpaint not masked"], + value="inpaint masked") + + with gr.Tab("Sound2sound settings"): + sound2sound_sample_steps_slider = gradioWebUI.get_sample_steps_slider() + sound2sound_sampler_radio = gradioWebUI.get_sampler_radio() + sound2sound_batchsize_slider = gradioWebUI.get_batchsize_slider() + sound2sound_noising_strength_slider = gradioWebUI.get_noising_strength_slider() + sound2sound_guidance_scale_slider = gradioWebUI.get_guidance_scale_slider() + sound2sound_seed_textbox = gradioWebUI.get_seed_textbox() + + with gr.Tab("Mask prototypes"): + with gr.Tab("Mask along time axis"): + mask_time_begin_slider = gr.Slider(minimum=0.0, maximum=4.00, value=0.0, step=0.01, label="Begin time") + mask_time_end_slider = gr.Slider(minimum=0.0, maximum=4.00, value=0.0, step=0.01, label="End time") + with gr.Tab("Mask along frequency axis"): + mask_frequency_begin_slider = gr.Slider(minimum=0, maximum=127, value=0, step=1, label="Begin freq pixel") + mask_frequency_end_slider = gr.Slider(minimum=0, maximum=127, value=0, step=1, label="End freq pixel") + + with gr.Column(scale=1): + sound2sound_new_sound_audio = gr.Audio(type="numpy", label="Play new sound", interactive=False) + with gr.Row(variant="panel"): + sound2sound_new_sound_spectrogram_image = gr.Image(label="New sound spectrogram", type="numpy", + height=600, scale=1) + sound2sound_new_sound_phase_image = gr.Image(label="New sound phase", type="numpy", + height=600, scale=1) + + + with gr.Row(variant="panel"): + sound2sound_instrument_name_textbox = gr.Textbox(label="Instrument name", lines=1, + placeholder="Name of your instrument") + sound2sound_save_instrument_button = gr.Button(variant="primary", + value="Save instrument", + scale=1) + + with gr.Row(variant="panel"): + sound2sound_origin_upload_latent_representation_image = gr.Image(label="Original latent representation", + type="numpy", height=800, + visible=True) + sound2sound_origin_upload_quantized_latent_representation_image = gr.Image( + label="Original quantized latent representation", type="numpy", height=800, visible=True) + + sound2sound_origin_microphone_latent_representation_image = gr.Image(label="Original latent representation", + type="numpy", height=800, + visible=False) + sound2sound_origin_microphone_quantized_latent_representation_image = gr.Image( + label="Original quantized latent representation", type="numpy", height=800, visible=False) + + sound2sound_new_sound_latent_representation_image = gr.Image(label="New latent representation", + type="numpy", height=800) + sound2sound_new_sound_quantized_latent_representation_image = gr.Image( + label="New sound quantized latent representation", type="numpy", height=800) + + sound2sound_origin_upload_audio.change(receive_uopoad_origin_audio, + inputs=[sound2sound_duration_slider, sound2sound_origin_source_radio, sound2sound_origin_upload_audio, + sound2sound_origin_microphone_audio, inpaintWithText_state], + outputs=[sound2sound_origin_spectrogram_upload_image, + sound2sound_origin_phase_upload_image, + sound2sound_origin_spectrogram_microphone_image, + sound2sound_origin_phase_microphone_image, + sound2sound_origin_upload_latent_representation_image, + sound2sound_origin_upload_quantized_latent_representation_image, + sound2sound_origin_microphone_latent_representation_image, + sound2sound_origin_microphone_quantized_latent_representation_image, + inpaintWithText_state]) + sound2sound_origin_microphone_audio.change(receive_uopoad_origin_audio, + inputs=[sound2sound_duration_slider, sound2sound_origin_source_radio, sound2sound_origin_upload_audio, + sound2sound_origin_microphone_audio, inpaintWithText_state], + outputs=[sound2sound_origin_spectrogram_upload_image, + sound2sound_origin_phase_upload_image, + sound2sound_origin_spectrogram_microphone_image, + sound2sound_origin_phase_microphone_image, + sound2sound_origin_upload_latent_representation_image, + sound2sound_origin_upload_quantized_latent_representation_image, + sound2sound_origin_microphone_latent_representation_image, + sound2sound_origin_microphone_quantized_latent_representation_image, + inpaintWithText_state]) + + sound2sound_sample_button.click(sound2sound_sample, + inputs=[sound2sound_origin_spectrogram_upload_image, + sound2sound_origin_spectrogram_microphone_image, + text2sound_prompts_textbox, + text2sound_negative_prompts_textbox, + sound2sound_batchsize_slider, + sound2sound_guidance_scale_slider, + sound2sound_sampler_radio, + sound2sound_sample_steps_slider, + sound2sound_origin_source_radio, + sound2sound_noising_strength_slider, + sound2sound_seed_textbox, + sound2sound_inpaint_area_radio, + mask_time_begin_slider, + mask_time_end_slider, + mask_frequency_begin_slider, + mask_frequency_end_slider, + inpaintWithText_state], + outputs=[sound2sound_new_sound_latent_representation_image, + sound2sound_new_sound_quantized_latent_representation_image, + sound2sound_new_sound_spectrogram_image, + sound2sound_new_sound_phase_image, + sound2sound_new_sound_audio, + sound2sound_sample_index_slider, + sound2sound_seed_textbox, + inpaintWithText_state]) + + sound2sound_sample_index_slider.change(show_sound2sound_sample, + inputs=[sound2sound_sample_index_slider, inpaintWithText_state], + outputs=[sound2sound_new_sound_latent_representation_image, + sound2sound_new_sound_quantized_latent_representation_image, + sound2sound_new_sound_spectrogram_image, + sound2sound_new_sound_phase_image, + sound2sound_new_sound_audio]) + + sound2sound_origin_source_radio.change(sound2sound_switch_origin_source, + inputs=[sound2sound_origin_source_radio], + outputs=[sound2sound_origin_upload_audio, + sound2sound_origin_microphone_audio, + sound2sound_origin_spectrogram_upload_image, + sound2sound_origin_phase_upload_image, + sound2sound_origin_spectrogram_microphone_image, + sound2sound_origin_phase_microphone_image, + sound2sound_origin_upload_latent_representation_image, + sound2sound_origin_upload_quantized_latent_representation_image, + sound2sound_origin_microphone_latent_representation_image, + sound2sound_origin_microphone_quantized_latent_representation_image]) + + sound2sound_save_instrument_button.click(save_virtual_instrument, + inputs=[sound2sound_sample_index_slider, + sound2sound_instrument_name_textbox, + inpaintWithText_state, + virtual_instruments_state], + outputs=[virtual_instruments_state, + sound2sound_instrument_name_textbox]) \ No newline at end of file diff --git a/webUI/natural_language_guided/rec.py b/webUI/natural_language_guided/rec.py new file mode 100644 index 0000000000000000000000000000000000000000..4e7ac17caa52bb0a5de18291dc696a47c90df4c5 --- /dev/null +++ b/webUI/natural_language_guided/rec.py @@ -0,0 +1,190 @@ +import gradio as gr + +from data_generation.nsynth import get_nsynth_dataloader +from webUI.natural_language_guided_STFT.utils import encodeBatch2GradioOutput_STFT, InputBatch2Encode_STFT, \ + latent_representation_to_Gradio_image + + +def get_recSTFT_module(gradioWebUI, reconstruction_state): + # Load configurations + uNet = gradioWebUI.uNet + freq_resolution, time_resolution = gradioWebUI.freq_resolution, gradioWebUI.time_resolution + VAE_scale = gradioWebUI.VAE_scale + height, width, channels = int(freq_resolution / VAE_scale), int(time_resolution / VAE_scale), gradioWebUI.channels + + timesteps = gradioWebUI.timesteps + VAE_quantizer = gradioWebUI.VAE_quantizer + VAE_encoder = gradioWebUI.VAE_encoder + VAE_decoder = gradioWebUI.VAE_decoder + CLAP = gradioWebUI.CLAP + CLAP_tokenizer = gradioWebUI.CLAP_tokenizer + device = gradioWebUI.device + squared = gradioWebUI.squared + sample_rate = gradioWebUI.sample_rate + noise_strategy = gradioWebUI.noise_strategy + + def generate_reconstruction_samples(sample_source, batchsize_slider, encodeCache, + reconstruction_samples): + + vae_batchsize = int(batchsize_slider) + + if sample_source == "text2sound_trainSTFT": + training_dataset_path = f'data/NSynth/nsynth-STFT-train-52.hdf5' # Make sure to use your actual path + iterator = get_nsynth_dataloader(training_dataset_path, batch_size=vae_batchsize, shuffle=True, + get_latent_representation=False, with_meta_data=False, + task="STFT") + elif sample_source == "text2sound_validSTFT": + training_dataset_path = f'data/NSynth/nsynth-STFT-valid-52.hdf5' # Make sure to use your actual path + iterator = get_nsynth_dataloader(training_dataset_path, batch_size=vae_batchsize, shuffle=True, + get_latent_representation=False, with_meta_data=False, + task="STFT") + elif sample_source == "text2sound_testSTFT": + training_dataset_path = f'data/NSynth/nsynth-STFT-test-52.hdf5' # Make sure to use your actual path + iterator = get_nsynth_dataloader(training_dataset_path, batch_size=vae_batchsize, shuffle=True, + get_latent_representation=False, with_meta_data=False, + task="STFT") + else: + raise NotImplementedError() + + spectrogram_batch = next(iter(iterator)) + + origin_flipped_log_spectrums, origin_flipped_phases, origin_signals, latent_representations, quantized_latent_representations = InputBatch2Encode_STFT( + VAE_encoder, spectrogram_batch, resolution=(512, width * VAE_scale), quantizer=VAE_quantizer, squared=squared) + + latent_representation_gradio_images, quantized_latent_representation_gradio_images = [], [] + for i in range(vae_batchsize): + latent_representation_gradio_images.append(latent_representation_to_Gradio_image(latent_representations[i])) + quantized_latent_representation_gradio_images.append( + latent_representation_to_Gradio_image(quantized_latent_representations[i])) + + if quantized_latent_representations is None: + quantized_latent_representations = latent_representations + reconstruction_flipped_log_spectrums, reconstruction_flipped_phases, reconstruction_signals, reconstruction_flipped_log_spectrums_WOA, reconstruction_flipped_phases_WOA, reconstruction_signals_WOA = encodeBatch2GradioOutput_STFT(VAE_decoder, + quantized_latent_representations, + resolution=( + 512, + width * VAE_scale), + original_STFT_batch=spectrogram_batch + ) + + reconstruction_samples["origin_flipped_log_spectrums"] = origin_flipped_log_spectrums + reconstruction_samples["origin_flipped_phases"] = origin_flipped_phases + reconstruction_samples["origin_signals"] = origin_signals + reconstruction_samples["latent_representation_gradio_images"] = latent_representation_gradio_images + reconstruction_samples[ + "quantized_latent_representation_gradio_images"] = quantized_latent_representation_gradio_images + reconstruction_samples[ + "reconstruction_flipped_log_spectrums"] = reconstruction_flipped_log_spectrums + reconstruction_samples[ + "reconstruction_flipped_phases"] = reconstruction_flipped_phases + reconstruction_samples["reconstruction_signals"] = reconstruction_signals + reconstruction_samples[ + "reconstruction_flipped_log_spectrums_WOA"] = reconstruction_flipped_log_spectrums_WOA + reconstruction_samples[ + "reconstruction_flipped_phases_WOA"] = reconstruction_flipped_phases_WOA + reconstruction_samples["reconstruction_signals_WOA"] = reconstruction_signals_WOA + reconstruction_samples["sampleRate"] = sample_rate + + latent_representation_gradio_image = reconstruction_samples["latent_representation_gradio_images"][0] + quantized_latent_representation_gradio_image = \ + reconstruction_samples["quantized_latent_representation_gradio_images"][0] + origin_flipped_log_spectrum = reconstruction_samples["origin_flipped_log_spectrums"][0] + origin_flipped_phase = reconstruction_samples["origin_flipped_phases"][0] + origin_signal = reconstruction_samples["origin_signals"][0] + reconstruction_flipped_log_spectrum = reconstruction_samples["reconstruction_flipped_log_spectrums"][0] + reconstruction_flipped_phase = reconstruction_samples["reconstruction_flipped_phases"][0] + reconstruction_signal = reconstruction_samples["reconstruction_signals"][0] + reconstruction_flipped_log_spectrum_WOA = reconstruction_samples["reconstruction_flipped_log_spectrums_WOA"][0] + reconstruction_flipped_phase_WOA = reconstruction_samples["reconstruction_flipped_phases_WOA"][0] + reconstruction_signal_WOA = reconstruction_samples["reconstruction_signals_WOA"][0] + + return {origin_amplitude_image_output: origin_flipped_log_spectrum, + origin_phase_image_output: origin_flipped_phase, + origin_audio_output: (sample_rate, origin_signal), + latent_representation_image_output: latent_representation_gradio_image, + quantized_latent_representation_image_output: quantized_latent_representation_gradio_image, + reconstruction_amplitude_image_output: reconstruction_flipped_log_spectrum, + reconstruction_phase_image_output: reconstruction_flipped_phase, + reconstruction_audio_output: (sample_rate, reconstruction_signal), + reconstruction_amplitude_image_output_WOA: reconstruction_flipped_log_spectrum_WOA, + reconstruction_phase_image_output_WOA: reconstruction_flipped_phase_WOA, + reconstruction_audio_output_WOA: (sample_rate, reconstruction_signal_WOA), + sample_index_slider: gr.update(minimum=0, maximum=vae_batchsize - 1, value=0, step=1.0, + label="Sample index.", + info="Slide to view other samples", scale=1, visible=True), + reconstruction_state: encodeCache, + reconstruction_samples_state: reconstruction_samples} + + def show_reconstruction_sample(sample_index, encodeCache_state, reconstruction_samples_state): + sample_index = int(sample_index) + sampleRate = reconstruction_samples_state["sampleRate"] + latent_representation_gradio_image = reconstruction_samples_state["latent_representation_gradio_images"][ + sample_index] + quantized_latent_representation_gradio_image = \ + reconstruction_samples_state["quantized_latent_representation_gradio_images"][sample_index] + origin_flipped_log_spectrum = reconstruction_samples_state["origin_flipped_log_spectrums"][sample_index] + origin_flipped_phase = reconstruction_samples_state["origin_flipped_phases"][sample_index] + origin_signal = reconstruction_samples_state["origin_signals"][sample_index] + reconstruction_flipped_log_spectrum = reconstruction_samples_state["reconstruction_flipped_log_spectrums"][ + sample_index] + reconstruction_flipped_phase = reconstruction_samples_state["reconstruction_flipped_phases"][ + sample_index] + reconstruction_signal = reconstruction_samples_state["reconstruction_signals"][sample_index] + reconstruction_flipped_log_spectrum_WOA = reconstruction_samples_state["reconstruction_flipped_log_spectrums_WOA"][ + sample_index] + reconstruction_flipped_phase_WOA = reconstruction_samples_state["reconstruction_flipped_phases_WOA"][ + sample_index] + reconstruction_signal_WOA = reconstruction_samples_state["reconstruction_signals_WOA"][sample_index] + return origin_flipped_log_spectrum, origin_flipped_phase, (sampleRate, origin_signal), \ + latent_representation_gradio_image, quantized_latent_representation_gradio_image, \ + reconstruction_flipped_log_spectrum, reconstruction_flipped_phase, (sampleRate, reconstruction_signal), \ + reconstruction_flipped_log_spectrum_WOA, reconstruction_flipped_phase_WOA, (sampleRate, reconstruction_signal_WOA), \ + encodeCache_state, reconstruction_samples_state + + with gr.Tab("Reconstruction"): + reconstruction_samples_state = gr.State(value={}) + gr.Markdown("Test reconstruction.") + with gr.Row(variant="panel"): + with gr.Column(): + sample_source_radio = gr.Radio( + choices=["synthetic", "external", "text2sound_trainSTFT", "text2sound_testSTFT", "text2sound_validSTFT"], + value="text2sound_trainf", info="Info placeholder", scale=2) + batchsize_slider = gr.Slider(minimum=1., maximum=16., value=4., step=1., + label="batchsize") + with gr.Column(): + generate_button = gr.Button(variant="primary", value="Generate reconstruction samples", scale=1) + sample_index_slider = gr.Slider(minimum=0, maximum=3, value=0, step=1.0, label="Sample index.", + info="Slide to view other samples", scale=1, visible=False) + with gr.Row(variant="panel"): + with gr.Column(): + origin_amplitude_image_output = gr.Image(label="Spectrogram", type="numpy", height=300, width=100, scale=1) + origin_phase_image_output = gr.Image(label="Phase", type="numpy", height=300, width=100, scale=1) + origin_audio_output = gr.Audio(type="numpy", label="Play the example!") + with gr.Column(): + reconstruction_amplitude_image_output = gr.Image(label="Spectrogram", type="numpy", height=300, width=100, scale=1) + reconstruction_phase_image_output = gr.Image(label="Phase", type="numpy", height=300, width=100, scale=1) + reconstruction_audio_output = gr.Audio(type="numpy", label="Play the example!") + with gr.Column(): + reconstruction_amplitude_image_output_WOA = gr.Image(label="Spectrogram", type="numpy", height=300, width=100, scale=1) + reconstruction_phase_image_output_WOA = gr.Image(label="Phase", type="numpy", height=300, width=100, scale=1) + reconstruction_audio_output_WOA = gr.Audio(type="numpy", label="Play the example!") + with gr.Row(variant="panel", equal_height=True): + latent_representation_image_output = gr.Image(label="latent_representation", type="numpy", height=300, width=100) + quantized_latent_representation_image_output = gr.Image(label="quantized", type="numpy", height=300, width=100) + + generate_button.click(generate_reconstruction_samples, + inputs=[sample_source_radio, batchsize_slider, reconstruction_state, + reconstruction_samples_state], + outputs=[origin_amplitude_image_output, origin_phase_image_output, origin_audio_output, + latent_representation_image_output, quantized_latent_representation_image_output, + reconstruction_amplitude_image_output, reconstruction_phase_image_output, reconstruction_audio_output, + reconstruction_amplitude_image_output_WOA, reconstruction_phase_image_output_WOA, reconstruction_audio_output_WOA, + sample_index_slider, reconstruction_state, reconstruction_samples_state]) + + sample_index_slider.change(show_reconstruction_sample, + inputs=[sample_index_slider, reconstruction_state, reconstruction_samples_state], + outputs=[origin_amplitude_image_output, origin_phase_image_output, origin_audio_output, + latent_representation_image_output, quantized_latent_representation_image_output, + reconstruction_amplitude_image_output, reconstruction_phase_image_output, reconstruction_audio_output, + reconstruction_amplitude_image_output_WOA, reconstruction_phase_image_output_WOA, reconstruction_audio_output_WOA, + reconstruction_state, reconstruction_samples_state]) \ No newline at end of file diff --git a/webUI/natural_language_guided/sound2sound_with_text.py b/webUI/natural_language_guided/sound2sound_with_text.py new file mode 100644 index 0000000000000000000000000000000000000000..525cfb6cd4c397daeda3f3694073c5b94eefbc5e --- /dev/null +++ b/webUI/natural_language_guided/sound2sound_with_text.py @@ -0,0 +1,416 @@ +import gradio as gr +import librosa +import numpy as np +import torch + +from model.DiffSynthSampler import DiffSynthSampler +from tools import pad_STFT, encode_stft +from tools import safe_int, adjust_audio_length +from webUI.natural_language_guided.utils import InputBatch2Encode_STFT, encodeBatch2GradioOutput_STFT, \ + latent_representation_to_Gradio_image + + +def get_sound2sound_with_text_module(gradioWebUI, sound2sound_with_text_state, virtual_instruments_state): + # Load configurations + uNet = gradioWebUI.uNet + freq_resolution, time_resolution = gradioWebUI.freq_resolution, gradioWebUI.time_resolution + VAE_scale = gradioWebUI.VAE_scale + height, width, channels = int(freq_resolution/VAE_scale), int(time_resolution/VAE_scale), gradioWebUI.channels + timesteps = gradioWebUI.timesteps + VAE_encoder = gradioWebUI.VAE_encoder + VAE_quantizer = gradioWebUI.VAE_quantizer + VAE_decoder = gradioWebUI.VAE_decoder + CLAP = gradioWebUI.CLAP + CLAP_tokenizer = gradioWebUI.CLAP_tokenizer + device = gradioWebUI.device + squared = gradioWebUI.squared + sample_rate = gradioWebUI.sample_rate + noise_strategy = gradioWebUI.noise_strategy + + def receive_upload_origin_audio(sound2sound_duration, sound2sound_origin_source, + sound2sound_origin_upload, sound2sound_origin_microphone, + sound2sound_with_text_dict, virtual_instruments_dict): + + if sound2sound_origin_source == "upload": + origin_sr, origin_audio = sound2sound_origin_upload + else: + origin_sr, origin_audio = sound2sound_origin_microphone + + origin_audio = origin_audio / np.max(np.abs(origin_audio)) + + width = int(time_resolution*((sound2sound_duration+1)/4) / VAE_scale) + audio_length = 256 * (VAE_scale * width - 1) + origin_audio = adjust_audio_length(origin_audio, audio_length, origin_sr, sample_rate) + + D = librosa.stft(origin_audio, n_fft=1024, hop_length=256, win_length=1024) + padded_D = pad_STFT(D) + encoded_D = encode_stft(padded_D) + + # Todo: justify batchsize to 1 + origin_spectrogram_batch_tensor = torch.from_numpy( + np.repeat(encoded_D[np.newaxis, :, :, :], 1, axis=0)).float().to(device) + + # Todo: remove hard-coding + origin_flipped_log_spectrums, origin_flipped_phases, origin_signals, origin_latent_representations, quantized_origin_latent_representations = InputBatch2Encode_STFT( + VAE_encoder, origin_spectrogram_batch_tensor, resolution=(512, width * VAE_scale), quantizer=VAE_quantizer, squared=squared) + + default_condition = CLAP.get_text_features(**CLAP_tokenizer([""], padding=True, return_tensors="pt"))[0].to("cpu").detach().numpy() + + if sound2sound_origin_source == "upload": + sound2sound_with_text_dict["origin_upload_latent_representations"] = origin_latent_representations.tolist() + sound2sound_with_text_dict[ + "sound2sound_origin_upload_latent_representation_image"] = latent_representation_to_Gradio_image( + origin_latent_representations[0]).tolist() + sound2sound_with_text_dict[ + "sound2sound_origin_upload_quantized_latent_representation_image"] = latent_representation_to_Gradio_image( + quantized_origin_latent_representations[0]).tolist() + + virtual_instruments = virtual_instruments_dict["virtual_instruments"] + virtual_instrument = {"condition": default_condition, + "negative_condition": default_condition, # care!!! + "CFG": 1, + "latent_representation": origin_latent_representations[0].to("cpu").detach().numpy(), + "quantized_latent_representation": quantized_origin_latent_representations[0].to("cpu").detach().numpy(), + "sampler": "ddim", + "signal": (sample_rate, origin_audio), + "spectrogram_gradio_image": origin_flipped_log_spectrums[0], + "phase_gradio_image": origin_flipped_phases[0]} + virtual_instruments["s2sup"] = virtual_instrument + virtual_instruments_dict["virtual_instruments"] = virtual_instruments + + return {sound2sound_origin_spectrogram_upload_image: origin_flipped_log_spectrums[0], + sound2sound_origin_phase_upload_image: origin_flipped_phases[0], + sound2sound_origin_spectrogram_microphone_image: gr.update(), + sound2sound_origin_phase_microphone_image: gr.update(), + sound2sound_origin_upload_latent_representation_image: latent_representation_to_Gradio_image( + origin_latent_representations[0]), + sound2sound_origin_upload_quantized_latent_representation_image: latent_representation_to_Gradio_image( + quantized_origin_latent_representations[0]), + sound2sound_origin_microphone_latent_representation_image: gr.update(), + sound2sound_origin_microphone_quantized_latent_representation_image: gr.update(), + sound2sound_with_text_state: sound2sound_with_text_dict, + virtual_instruments_state: virtual_instruments_dict} + else: + sound2sound_with_text_dict["origin_microphone_latent_representations"] = origin_latent_representations.tolist() + sound2sound_with_text_dict[ + "sound2sound_origin_microphone_latent_representation_image"] = latent_representation_to_Gradio_image( + origin_latent_representations[0]).tolist() + sound2sound_with_text_dict[ + "sound2sound_origin_microphone_quantized_latent_representation_image"] = latent_representation_to_Gradio_image( + quantized_origin_latent_representations[0]).tolist() + + virtual_instruments = virtual_instruments_dict["virtual_instruments"] + virtual_instrument = {"condition": default_condition, + "negative_condition": default_condition, # care!!! + "CFG": 1, + "latent_representation": origin_latent_representations[0], + "quantized_latent_representation": quantized_origin_latent_representations[0], + "sampler": "ddim", + "signal": origin_audio, + "spectrogram_gradio_image": origin_flipped_log_spectrums[0]} + virtual_instruments["s2sre"] = virtual_instrument + virtual_instruments_dict["virtual_instruments"] = virtual_instruments + + return {sound2sound_origin_spectrogram_upload_image: gr.update(), + sound2sound_origin_phase_upload_image: gr.update(), + sound2sound_origin_spectrogram_microphone_image: origin_flipped_log_spectrums[0], + sound2sound_origin_phase_microphone_image: origin_flipped_phases[0], + sound2sound_origin_upload_latent_representation_image: gr.update(), + sound2sound_origin_upload_quantized_latent_representation_image: gr.update(), + sound2sound_origin_microphone_latent_representation_image: latent_representation_to_Gradio_image( + origin_latent_representations[0]), + sound2sound_origin_microphone_quantized_latent_representation_image: latent_representation_to_Gradio_image( + quantized_origin_latent_representations[0]), + sound2sound_with_text_state: sound2sound_with_text_dict, + virtual_instruments_state: virtual_instruments_dict} + + def sound2sound_sample(sound2sound_prompts, sound2sound_negative_prompts, sound2sound_batchsize, + sound2sound_guidance_scale, sound2sound_sampler, + sound2sound_sample_steps, + sound2sound_origin_source, + sound2sound_noising_strength, sound2sound_seed, sound2sound_dict, virtual_instruments_dict): + + # input processing + sound2sound_seed = safe_int(sound2sound_seed, 12345678) + sound2sound_batchsize = int(sound2sound_batchsize) + noising_strength = sound2sound_noising_strength + sound2sound_sample_steps = int(sound2sound_sample_steps) + CFG = int(sound2sound_guidance_scale) + + if sound2sound_origin_source == "upload": + origin_latent_representations = torch.tensor( + sound2sound_dict["origin_upload_latent_representations"]).repeat(sound2sound_batchsize, 1, 1, 1).to( + device) + elif sound2sound_origin_source == "microphone": + origin_latent_representations = torch.tensor( + sound2sound_dict["origin_microphone_latent_representations"]).repeat(sound2sound_batchsize, 1, 1, 1).to( + device) + else: + print("Input source not in ['upload', 'microphone']!") + raise NotImplementedError() + + # sound2sound + text2sound_embedding = \ + CLAP.get_text_features(**CLAP_tokenizer([sound2sound_prompts], padding=True, return_tensors="pt"))[0].to( + device) + + mySampler = DiffSynthSampler(timesteps, height=height, channels=channels, noise_strategy=noise_strategy) + unconditional_condition = \ + CLAP.get_text_features(**CLAP_tokenizer([sound2sound_negative_prompts], padding=True, return_tensors="pt"))[ + 0] + mySampler.activate_classifier_free_guidance(CFG, unconditional_condition.to(device)) + + normalized_sample_steps = int(sound2sound_sample_steps / noising_strength) + mySampler.respace(list(np.linspace(0, timesteps - 1, normalized_sample_steps, dtype=np.int32))) + + condition = text2sound_embedding.repeat(sound2sound_batchsize, 1) + + # Todo: remove-hard coding + width = origin_latent_representations.shape[-1] + new_sound_latent_representations, initial_noise = \ + mySampler.img_guided_sample(model=uNet, shape=(sound2sound_batchsize, channels, height, width), + seed=sound2sound_seed, + noising_strength=noising_strength, + guide_img=origin_latent_representations, return_tensor=True, + condition=condition, + sampler=sound2sound_sampler) + + new_sound_latent_representations = new_sound_latent_representations[-1] + + # Quantize new sound latent representations + quantized_new_sound_latent_representations, loss, (_, _, _) = VAE_quantizer(new_sound_latent_representations) + + new_sound_flipped_log_spectrums, new_sound_flipped_phases, new_sound_signals, _, _, _ = encodeBatch2GradioOutput_STFT(VAE_decoder, + quantized_new_sound_latent_representations, + resolution=( + 512, + width * VAE_scale), + original_STFT_batch=None + ) + + + + new_sound_latent_representation_gradio_images = [] + new_sound_quantized_latent_representation_gradio_images = [] + new_sound_spectrogram_gradio_images = [] + new_sound_phase_gradio_images = [] + new_sound_rec_signals_gradio = [] + for i in range(sound2sound_batchsize): + new_sound_latent_representation_gradio_images.append( + latent_representation_to_Gradio_image(new_sound_latent_representations[i])) + new_sound_quantized_latent_representation_gradio_images.append( + latent_representation_to_Gradio_image(quantized_new_sound_latent_representations[i])) + new_sound_spectrogram_gradio_images.append(new_sound_flipped_log_spectrums[i]) + new_sound_phase_gradio_images.append(new_sound_flipped_phases[i]) + new_sound_rec_signals_gradio.append((sample_rate, new_sound_signals[i])) + + sound2sound_dict[ + "new_sound_latent_representation_gradio_images"] = new_sound_latent_representation_gradio_images + sound2sound_dict[ + "new_sound_quantized_latent_representation_gradio_images"] = new_sound_quantized_latent_representation_gradio_images + sound2sound_dict["new_sound_spectrogram_gradio_images"] = new_sound_spectrogram_gradio_images + sound2sound_dict["new_sound_phase_gradio_images"] = new_sound_phase_gradio_images + sound2sound_dict["new_sound_rec_signals_gradio"] = new_sound_rec_signals_gradio + + return {sound2sound_new_sound_latent_representation_image: latent_representation_to_Gradio_image( + new_sound_latent_representations[0]), + sound2sound_new_sound_quantized_latent_representation_image: latent_representation_to_Gradio_image( + quantized_new_sound_latent_representations[0]), + sound2sound_new_sound_spectrogram_image: new_sound_flipped_log_spectrums[0], + sound2sound_new_sound_phase_image: new_sound_flipped_phases[0], + sound2sound_new_sound_audio: (sample_rate, new_sound_signals[0]), + sound2sound_sample_index_slider: gr.update(minimum=0, maximum=sound2sound_batchsize - 1, value=0, + step=1.0, + visible=True, + label="Sample index", + info="Swipe to view other samples"), + sound2sound_seed_textbox: sound2sound_seed, + sound2sound_with_text_state: sound2sound_dict, + virtual_instruments_state: virtual_instruments_dict} + + def show_sound2sound_sample(sound2sound_sample_index, sound2sound_with_text_dict): + sample_index = int(sound2sound_sample_index) + return {sound2sound_new_sound_latent_representation_image: + sound2sound_with_text_dict["new_sound_latent_representation_gradio_images"][sample_index], + sound2sound_new_sound_quantized_latent_representation_image: + sound2sound_with_text_dict["new_sound_quantized_latent_representation_gradio_images"][sample_index], + sound2sound_new_sound_spectrogram_image: sound2sound_with_text_dict["new_sound_spectrogram_gradio_images"][ + sample_index], + sound2sound_new_sound_phase_image: sound2sound_with_text_dict["new_sound_phase_gradio_images"][ + sample_index], + sound2sound_new_sound_audio: sound2sound_with_text_dict["new_sound_rec_signals_gradio"][sample_index]} + + def sound2sound_switch_origin_source(sound2sound_origin_source): + + if sound2sound_origin_source == "upload": + return {sound2sound_origin_upload_audio: gr.update(visible=True), + sound2sound_origin_microphone_audio: gr.update(visible=False), + sound2sound_origin_spectrogram_upload_image: gr.update(visible=True), + sound2sound_origin_phase_upload_image: gr.update(visible=True), + sound2sound_origin_spectrogram_microphone_image: gr.update(visible=False), + sound2sound_origin_phase_microphone_image: gr.update(visible=False), + sound2sound_origin_upload_latent_representation_image: gr.update(visible=True), + sound2sound_origin_upload_quantized_latent_representation_image: gr.update(visible=True), + sound2sound_origin_microphone_latent_representation_image: gr.update(visible=False), + sound2sound_origin_microphone_quantized_latent_representation_image: gr.update(visible=False)} + elif sound2sound_origin_source == "microphone": + return {sound2sound_origin_upload_audio: gr.update(visible=False), + sound2sound_origin_microphone_audio: gr.update(visible=True), + sound2sound_origin_spectrogram_upload_image: gr.update(visible=False), + sound2sound_origin_phase_upload_image: gr.update(visible=False), + sound2sound_origin_spectrogram_microphone_image: gr.update(visible=True), + sound2sound_origin_phase_microphone_image: gr.update(visible=True), + sound2sound_origin_upload_latent_representation_image: gr.update(visible=False), + sound2sound_origin_upload_quantized_latent_representation_image: gr.update(visible=False), + sound2sound_origin_microphone_latent_representation_image: gr.update(visible=True), + sound2sound_origin_microphone_quantized_latent_representation_image: gr.update(visible=True)} + else: + print("Input source not in ['upload', 'microphone']!") + + with gr.Tab("Sound2Sound"): + gr.Markdown("Generate new sound based on a given sound!") + with gr.Row(variant="panel"): + with gr.Column(scale=3): + sound2sound_prompts_textbox = gr.Textbox(label="Positive prompt", lines=2, value="organ") + text2sound_negative_prompts_textbox = gr.Textbox(label="Negative prompt", lines=2, value="") + + with gr.Column(scale=1): + sound2sound_sample_button = gr.Button(variant="primary", value="Generate", scale=1) + + sound2sound_sample_index_slider = gr.Slider(minimum=0, maximum=3, value=0, step=1.0, visible=False, + label="Sample index", + info="Swipe to view other samples") + + with gr.Row(variant="panel"): + with gr.Column(scale=1): + with gr.Tab("Origin sound"): + sound2sound_duration_slider = gradioWebUI.get_duration_slider() + sound2sound_origin_source_radio = gr.Radio(choices=["upload", "microphone"], value="upload", + label="Input source") + + sound2sound_origin_upload_audio = gr.Audio(type="numpy", label="Upload", source="upload", + interactive=True, visible=True) + sound2sound_origin_microphone_audio = gr.Audio(type="numpy", label="Record", source="microphone", + interactive=True, visible=False) + with gr.Row(variant="panel"): + sound2sound_origin_spectrogram_upload_image = gr.Image(label="Original upload spectrogram", + type="numpy", height=600, + visible=True) + sound2sound_origin_phase_upload_image = gr.Image(label="Original upload phase", + type="numpy", height=600, + visible=True) + sound2sound_origin_spectrogram_microphone_image = gr.Image(label="Original microphone spectrogram", + type="numpy", height=600, + visible=False) + sound2sound_origin_phase_microphone_image = gr.Image(label="Original microphone phase", + type="numpy", height=600, + visible=False) + + with gr.Tab("Sound2sound settings"): + sound2sound_sample_steps_slider = gradioWebUI.get_sample_steps_slider() + sound2sound_sampler_radio = gradioWebUI.get_sampler_radio() + sound2sound_batchsize_slider = gradioWebUI.get_batchsize_slider() + sound2sound_noising_strength_slider = gradioWebUI.get_noising_strength_slider() + sound2sound_guidance_scale_slider = gradioWebUI.get_guidance_scale_slider() + sound2sound_seed_textbox = gradioWebUI.get_seed_textbox() + + with gr.Column(scale=1): + sound2sound_new_sound_audio = gr.Audio(type="numpy", label="Play new sound", interactive=False) + with gr.Row(variant="panel"): + sound2sound_new_sound_spectrogram_image = gr.Image(label="New sound spectrogram", type="numpy", + height=600, scale=1) + sound2sound_new_sound_phase_image = gr.Image(label="New sound phase", type="numpy", + height=600, scale=1) + + with gr.Row(variant="panel"): + sound2sound_origin_upload_latent_representation_image = gr.Image(label="Original latent representation", + type="numpy", height=800, + visible=True) + sound2sound_origin_upload_quantized_latent_representation_image = gr.Image( + label="Original quantized latent representation", type="numpy", height=800, visible=True) + + sound2sound_origin_microphone_latent_representation_image = gr.Image(label="Original latent representation", + type="numpy", height=800, + visible=False) + sound2sound_origin_microphone_quantized_latent_representation_image = gr.Image( + label="Original quantized latent representation", type="numpy", height=800, visible=False) + + sound2sound_new_sound_latent_representation_image = gr.Image(label="New latent representation", + type="numpy", height=800) + sound2sound_new_sound_quantized_latent_representation_image = gr.Image( + label="New sound quantized latent representation", type="numpy", height=800) + + sound2sound_origin_upload_audio.change(receive_upload_origin_audio, + inputs=[sound2sound_duration_slider, sound2sound_origin_source_radio, + sound2sound_origin_upload_audio, + sound2sound_origin_microphone_audio, sound2sound_with_text_state, + virtual_instruments_state], + outputs=[sound2sound_origin_spectrogram_upload_image, + sound2sound_origin_phase_upload_image, + sound2sound_origin_spectrogram_microphone_image, + sound2sound_origin_phase_microphone_image, + sound2sound_origin_upload_latent_representation_image, + sound2sound_origin_upload_quantized_latent_representation_image, + sound2sound_origin_microphone_latent_representation_image, + sound2sound_origin_microphone_quantized_latent_representation_image, + sound2sound_with_text_state, + virtual_instruments_state]) + + sound2sound_origin_microphone_audio.change(receive_upload_origin_audio, + inputs=[sound2sound_duration_slider, + sound2sound_origin_source_radio, sound2sound_origin_upload_audio, + sound2sound_origin_microphone_audio, sound2sound_with_text_state, + virtual_instruments_state], + outputs=[sound2sound_origin_spectrogram_upload_image, + sound2sound_origin_phase_upload_image, + sound2sound_origin_spectrogram_microphone_image, + sound2sound_origin_phase_microphone_image, + sound2sound_origin_upload_latent_representation_image, + sound2sound_origin_upload_quantized_latent_representation_image, + sound2sound_origin_microphone_latent_representation_image, + sound2sound_origin_microphone_quantized_latent_representation_image, + sound2sound_with_text_state, + virtual_instruments_state]) + + sound2sound_sample_button.click(sound2sound_sample, + inputs=[sound2sound_prompts_textbox, + text2sound_negative_prompts_textbox, + sound2sound_batchsize_slider, + sound2sound_guidance_scale_slider, + sound2sound_sampler_radio, + sound2sound_sample_steps_slider, + sound2sound_origin_source_radio, + sound2sound_noising_strength_slider, + sound2sound_seed_textbox, + sound2sound_with_text_state, + virtual_instruments_state], + outputs=[sound2sound_new_sound_latent_representation_image, + sound2sound_new_sound_quantized_latent_representation_image, + sound2sound_new_sound_spectrogram_image, + sound2sound_new_sound_phase_image, + sound2sound_new_sound_audio, + sound2sound_sample_index_slider, + sound2sound_seed_textbox, + sound2sound_with_text_state, + virtual_instruments_state]) + + sound2sound_sample_index_slider.change(show_sound2sound_sample, + inputs=[sound2sound_sample_index_slider, sound2sound_with_text_state], + outputs=[sound2sound_new_sound_latent_representation_image, + sound2sound_new_sound_quantized_latent_representation_image, + sound2sound_new_sound_spectrogram_image, + sound2sound_new_sound_phase_image, + sound2sound_new_sound_audio]) + + sound2sound_origin_source_radio.change(sound2sound_switch_origin_source, + inputs=[sound2sound_origin_source_radio], + outputs=[sound2sound_origin_upload_audio, + sound2sound_origin_microphone_audio, + sound2sound_origin_spectrogram_upload_image, + sound2sound_origin_phase_upload_image, + sound2sound_origin_spectrogram_microphone_image, + sound2sound_origin_phase_microphone_image, + sound2sound_origin_upload_latent_representation_image, + sound2sound_origin_upload_quantized_latent_representation_image, + sound2sound_origin_microphone_latent_representation_image, + sound2sound_origin_microphone_quantized_latent_representation_image]) diff --git a/webUI/natural_language_guided/super_resolution_with_text.py b/webUI/natural_language_guided/super_resolution_with_text.py new file mode 100644 index 0000000000000000000000000000000000000000..6faf246ae746ff106e66f7d08636210933fa1ab3 --- /dev/null +++ b/webUI/natural_language_guided/super_resolution_with_text.py @@ -0,0 +1,387 @@ +import librosa +import numpy as np +import torch +import gradio as gr +from scipy.ndimage import zoom + +from model.DiffSynthSampler import DiffSynthSampler +from tools import adjust_audio_length, rescale, safe_int, pad_STFT, encode_stft +from webUI.natural_language_guided_STFT.utils import latent_representation_to_Gradio_image +from webUI.natural_language_guided_STFT.utils import InputBatch2Encode_STFT, encodeBatch2GradioOutput_STFT + + +def get_super_resolution_with_text_module(gradioWebUI, inpaintWithText_state): + # Load configurations + uNet = gradioWebUI.uNet + freq_resolution, time_resolution = gradioWebUI.freq_resolution, gradioWebUI.time_resolution + VAE_scale = gradioWebUI.VAE_scale + height, width, channels = int(freq_resolution/VAE_scale), int(time_resolution/VAE_scale), gradioWebUI.channels + timesteps = gradioWebUI.timesteps + VAE_encoder = gradioWebUI.VAE_encoder + VAE_quantizer = gradioWebUI.VAE_quantizer + VAE_decoder = gradioWebUI.VAE_decoder + CLAP = gradioWebUI.CLAP + CLAP_tokenizer = gradioWebUI.CLAP_tokenizer + device = gradioWebUI.device + squared = gradioWebUI.squared + sample_rate = gradioWebUI.sample_rate + noise_strategy = gradioWebUI.noise_strategy + + def receive_uopoad_origin_audio(sound2sound_duration, sound2sound_origin_source, sound2sound_origin_upload, sound2sound_origin_microphone, + inpaintWithText_dict): + + if sound2sound_origin_source == "upload": + origin_sr, origin_audio = sound2sound_origin_upload + else: + origin_sr, origin_audio = sound2sound_origin_microphone + + origin_audio = origin_audio / np.max(np.abs(origin_audio)) + + width = int(time_resolution*((sound2sound_duration+1)/4) / VAE_scale) + audio_length = 256 * (VAE_scale * width - 1) + origin_audio = adjust_audio_length(origin_audio, audio_length, origin_sr, sample_rate) + + D = librosa.stft(origin_audio, n_fft=1024, hop_length=256, win_length=1024) + padded_D = pad_STFT(D) + encoded_D = encode_stft(padded_D) + + # Todo: justify batchsize to 1 + origin_spectrogram_batch_tensor = torch.from_numpy( + np.repeat(encoded_D[np.newaxis, :, :, :], 1, axis=0)).float().to(device) + + # Todo: remove hard-coding + origin_flipped_log_spectrums, origin_flipped_phases, origin_signals, origin_latent_representations, quantized_origin_latent_representations = InputBatch2Encode_STFT( + VAE_encoder, origin_spectrogram_batch_tensor, resolution=(512, width * VAE_scale), quantizer=VAE_quantizer, squared=squared) + + if sound2sound_origin_source == "upload": + inpaintWithText_dict["origin_upload_latent_representations"] = origin_latent_representations.tolist() + inpaintWithText_dict[ + "sound2sound_origin_upload_latent_representation_image"] = latent_representation_to_Gradio_image( + origin_latent_representations[0]).tolist() + inpaintWithText_dict[ + "sound2sound_origin_upload_quantized_latent_representation_image"] = latent_representation_to_Gradio_image( + quantized_origin_latent_representations[0]).tolist() + return {sound2sound_origin_spectrogram_upload_image: origin_flipped_log_spectrums[0], + sound2sound_origin_phase_upload_image: origin_flipped_phases[0], + sound2sound_origin_spectrogram_microphone_image: gr.update(), + sound2sound_origin_phase_microphone_image: gr.update(), + sound2sound_origin_upload_latent_representation_image: latent_representation_to_Gradio_image( + origin_latent_representations[0]), + sound2sound_origin_upload_quantized_latent_representation_image: latent_representation_to_Gradio_image( + quantized_origin_latent_representations[0]), + sound2sound_origin_microphone_latent_representation_image: gr.update(), + sound2sound_origin_microphone_quantized_latent_representation_image: gr.update(), + inpaintWithText_state: inpaintWithText_dict} + else: + inpaintWithText_dict["origin_microphone_latent_representations"] = origin_latent_representations.tolist() + inpaintWithText_dict[ + "sound2sound_origin_microphone_latent_representation_image"] = latent_representation_to_Gradio_image( + origin_latent_representations[0]).tolist() + inpaintWithText_dict[ + "sound2sound_origin_microphone_quantized_latent_representation_image"] = latent_representation_to_Gradio_image( + quantized_origin_latent_representations[0]).tolist() + return {sound2sound_origin_spectrogram_upload_image: origin_flipped_log_spectrums[0], + sound2sound_origin_phase_upload_image: origin_flipped_phases[0], + sound2sound_origin_spectrogram_microphone_image: gr.update(), + sound2sound_origin_phase_microphone_image: gr.update(), + sound2sound_origin_upload_latent_representation_image: latent_representation_to_Gradio_image( + origin_latent_representations[0]), + sound2sound_origin_upload_quantized_latent_representation_image: latent_representation_to_Gradio_image( + quantized_origin_latent_representations[0]), + sound2sound_origin_microphone_latent_representation_image: gr.update(), + sound2sound_origin_microphone_quantized_latent_representation_image: gr.update(), + inpaintWithText_state: inpaintWithText_dict} + + def sound2sound_sample(sound2sound_origin_spectrogram_upload, sound2sound_origin_spectrogram_microphone, + text2sound_prompts, text2sound_negative_prompts, sound2sound_batchsize, + sound2sound_guidance_scale, sound2sound_sampler, + sound2sound_sample_steps, sound2sound_origin_source, + sound2sound_noising_strength, sound2sound_seed, sound2sound_inpaint_area, inpaintWithText_dict + ): + + # input preprocessing + sound2sound_seed = safe_int(sound2sound_seed, 12345678) + sound2sound_batchsize = int(sound2sound_batchsize) + noising_strength = sound2sound_noising_strength + sound2sound_sample_steps = int(sound2sound_sample_steps) + CFG = int(sound2sound_guidance_scale) + + text2sound_embedding = \ + CLAP.get_text_features(**CLAP_tokenizer([text2sound_prompts], padding=True, return_tensors="pt"))[0].to(device) + + if sound2sound_origin_source == "upload": + origin_latent_representations = torch.tensor( + inpaintWithText_dict["origin_upload_latent_representations"]).repeat(sound2sound_batchsize, 1, 1, 1).to( + device) + elif sound2sound_origin_source == "microphone": + origin_latent_representations = torch.tensor( + inpaintWithText_dict["origin_microphone_latent_representations"]).repeat(sound2sound_batchsize, 1, 1, 1).to( + device) + else: + print("Input source not in ['upload', 'microphone']!") + raise NotImplementedError() + + high_resolution_latent_representations = torch.zeros((sound2sound_batchsize, channels, 256, 64)).to(device) + high_resolution_latent_representations[:, :, :128, :] = origin_latent_representations + latent_mask = np.ones((256, 64)) + latent_mask[192:, :] = 0.0 + print(f"latent_mask mean: {np.mean(latent_mask)}") + + if sound2sound_inpaint_area == "inpaint masked": + latent_mask = 1 - latent_mask + latent_mask = torch.from_numpy(latent_mask).unsqueeze(0).unsqueeze(1).repeat(sound2sound_batchsize, channels, 1, + 1).float().to(device) + latent_mask = torch.flip(latent_mask, [2]) + + mySampler = DiffSynthSampler(timesteps, height=height*2, channels=channels, noise_strategy=noise_strategy) + unconditional_condition = \ + CLAP.get_text_features(**CLAP_tokenizer([text2sound_negative_prompts], padding=True, return_tensors="pt"))[0] + mySampler.activate_classifier_free_guidance(CFG, unconditional_condition.to(device)) + + normalized_sample_steps = int(sound2sound_sample_steps / noising_strength) + + mySampler.respace(list(np.linspace(0, timesteps - 1, normalized_sample_steps, dtype=np.int32))) + + # Todo: remove hard-coding + width = high_resolution_latent_representations.shape[-1] + condition = text2sound_embedding.repeat(sound2sound_batchsize, 1) + + new_sound_latent_representations, initial_noise = \ + mySampler.inpaint_sample(model=uNet, shape=(sound2sound_batchsize, channels, height*2, width), + seed=sound2sound_seed, + noising_strength=noising_strength, + guide_img=high_resolution_latent_representations, mask=latent_mask, return_tensor=True, + condition=condition, sampler=sound2sound_sampler) + + new_sound_latent_representations = new_sound_latent_representations[-1] + + # Quantize new sound latent representations + quantized_new_sound_latent_representations, loss, (_, _, _) = VAE_quantizer(new_sound_latent_representations) + new_sound_flipped_log_spectrums, new_sound_flipped_phases, new_sound_signals, _, _, _ = encodeBatch2GradioOutput_STFT(VAE_decoder, + quantized_new_sound_latent_representations, + resolution=( + 1024, + width * VAE_scale), + original_STFT_batch=None + ) + + new_sound_latent_representation_gradio_images = [] + new_sound_quantized_latent_representation_gradio_images = [] + new_sound_spectrogram_gradio_images = [] + new_sound_phase_gradio_images = [] + new_sound_rec_signals_gradio = [] + for i in range(sound2sound_batchsize): + new_sound_latent_representation_gradio_images.append( + latent_representation_to_Gradio_image(new_sound_latent_representations[i])) + new_sound_quantized_latent_representation_gradio_images.append( + latent_representation_to_Gradio_image(quantized_new_sound_latent_representations[i])) + new_sound_spectrogram_gradio_images.append(new_sound_flipped_log_spectrums[i]) + new_sound_phase_gradio_images.append(new_sound_flipped_phases[i]) + new_sound_rec_signals_gradio.append((sample_rate, new_sound_signals[i])) + + inpaintWithText_dict[ + "new_sound_latent_representation_gradio_images"] = new_sound_latent_representation_gradio_images + inpaintWithText_dict[ + "new_sound_quantized_latent_representation_gradio_images"] = new_sound_quantized_latent_representation_gradio_images + inpaintWithText_dict["new_sound_spectrogram_gradio_images"] = new_sound_spectrogram_gradio_images + inpaintWithText_dict["new_sound_phase_gradio_images"] = new_sound_phase_gradio_images + inpaintWithText_dict["new_sound_rec_signals_gradio"] = new_sound_rec_signals_gradio + + return {sound2sound_new_sound_latent_representation_image: latent_representation_to_Gradio_image( + new_sound_latent_representations[0]), + sound2sound_new_sound_quantized_latent_representation_image: latent_representation_to_Gradio_image( + quantized_new_sound_latent_representations[0]), + sound2sound_new_sound_spectrogram_image: new_sound_flipped_log_spectrums[0], + sound2sound_new_sound_phase_image: new_sound_flipped_phases[0], + sound2sound_new_sound_audio: (sample_rate, new_sound_signals[0]), + sound2sound_sample_index_slider: gr.update(minimum=0, maximum=sound2sound_batchsize - 1, value=0, + step=1.0, + visible=True, + label="Sample index", + info="Swipe to view other samples"), + sound2sound_seed_textbox: sound2sound_seed, + inpaintWithText_state: inpaintWithText_dict} + + def show_sound2sound_sample(sound2sound_sample_index, inpaintWithText_dict): + sample_index = int(sound2sound_sample_index) + return {sound2sound_new_sound_latent_representation_image: + inpaintWithText_dict["new_sound_latent_representation_gradio_images"][sample_index], + sound2sound_new_sound_quantized_latent_representation_image: + inpaintWithText_dict["new_sound_quantized_latent_representation_gradio_images"][sample_index], + sound2sound_new_sound_spectrogram_image: inpaintWithText_dict["new_sound_spectrogram_gradio_images"][ + sample_index], + sound2sound_new_sound_phase_image: inpaintWithText_dict["new_sound_phase_gradio_images"][ + sample_index], + sound2sound_new_sound_audio: inpaintWithText_dict["new_sound_rec_signals_gradio"][sample_index]} + + def sound2sound_switch_origin_source(sound2sound_origin_source): + + if sound2sound_origin_source == "upload": + return {sound2sound_origin_upload_audio: gr.update(visible=True), + sound2sound_origin_microphone_audio: gr.update(visible=False), + sound2sound_origin_spectrogram_upload_image: gr.update(visible=True), + sound2sound_origin_phase_upload_image: gr.update(visible=True), + sound2sound_origin_spectrogram_microphone_image: gr.update(visible=False), + sound2sound_origin_phase_microphone_image: gr.update(visible=False), + sound2sound_origin_upload_latent_representation_image: gr.update(visible=True), + sound2sound_origin_upload_quantized_latent_representation_image: gr.update(visible=True), + sound2sound_origin_microphone_latent_representation_image: gr.update(visible=False), + sound2sound_origin_microphone_quantized_latent_representation_image: gr.update(visible=False)} + elif sound2sound_origin_source == "microphone": + return {sound2sound_origin_upload_audio: gr.update(visible=False), + sound2sound_origin_microphone_audio: gr.update(visible=True), + sound2sound_origin_spectrogram_upload_image: gr.update(visible=False), + sound2sound_origin_phase_upload_image: gr.update(visible=False), + sound2sound_origin_spectrogram_microphone_image: gr.update(visible=True), + sound2sound_origin_phase_microphone_image: gr.update(visible=True), + sound2sound_origin_upload_latent_representation_image: gr.update(visible=False), + sound2sound_origin_upload_quantized_latent_representation_image: gr.update(visible=False), + sound2sound_origin_microphone_latent_representation_image: gr.update(visible=True), + sound2sound_origin_microphone_quantized_latent_representation_image: gr.update(visible=True)} + else: + print("Input source not in ['upload', 'microphone']!") + + with gr.Tab("Super Resolution"): + gr.Markdown("Select the area to inpaint and use the prompt to guide the synthesis of a new sound!") + with gr.Row(variant="panel"): + with gr.Column(scale=3): + text2sound_prompts_textbox = gr.Textbox(label="Positive prompt", lines=2, value="organ") + text2sound_negative_prompts_textbox = gr.Textbox(label="Negative prompt", lines=2, value="") + + with gr.Column(scale=1): + sound2sound_sample_button = gr.Button(variant="primary", value="Generate", scale=1) + + sound2sound_sample_index_slider = gr.Slider(minimum=0, maximum=3, value=0, step=1.0, visible=False, + label="Sample index", + info="Swipe to view other samples") + + with gr.Row(variant="panel"): + with gr.Column(scale=1): + with gr.Tab("Origin sound"): + sound2sound_duration_slider = gradioWebUI.get_duration_slider() + sound2sound_origin_source_radio = gr.Radio(choices=["upload", "microphone"], value="upload", + label="Input source") + + sound2sound_origin_upload_audio = gr.Audio(type="numpy", label="Upload", source="upload", + interactive=True, visible=True) + sound2sound_origin_microphone_audio = gr.Audio(type="numpy", label="Record", source="microphone", + interactive=True, visible=False) + with gr.Row(variant="panel"): + sound2sound_origin_spectrogram_upload_image = gr.Image(label="Original upload spectrogram", + type="numpy", height=600, + visible=True, tool="sketch") + sound2sound_origin_phase_upload_image = gr.Image(label="Original upload phase", + type="numpy", height=600, + visible=True) + sound2sound_origin_spectrogram_microphone_image = gr.Image(label="Original microphone spectrogram", + type="numpy", height=600, + visible=False, tool="sketch") + sound2sound_origin_phase_microphone_image = gr.Image(label="Original microphone phase", + type="numpy", height=600, + visible=False) + sound2sound_inpaint_area_radio = gr.Radio(choices=["inpaint masked", "inpaint not masked"], + value="inpaint masked") + + with gr.Tab("Sound2sound settings"): + sound2sound_sample_steps_slider = gradioWebUI.get_sample_steps_slider() + sound2sound_sampler_radio = gradioWebUI.get_sampler_radio() + sound2sound_batchsize_slider = gradioWebUI.get_batchsize_slider() + sound2sound_noising_strength_slider = gradioWebUI.get_noising_strength_slider(default_noising_strength=1.0) + sound2sound_guidance_scale_slider = gradioWebUI.get_guidance_scale_slider() + sound2sound_seed_textbox = gradioWebUI.get_seed_textbox() + + + with gr.Column(scale=1): + sound2sound_new_sound_audio = gr.Audio(type="numpy", label="Play new sound", interactive=False) + with gr.Row(variant="panel"): + sound2sound_new_sound_spectrogram_image = gr.Image(label="New sound spectrogram", type="numpy", + height=1200, scale=1) + sound2sound_new_sound_phase_image = gr.Image(label="New sound phase", type="numpy", + height=1200, scale=1) + + with gr.Row(variant="panel"): + sound2sound_origin_upload_latent_representation_image = gr.Image(label="Original latent representation", + type="numpy", height=1200, + visible=True) + sound2sound_origin_upload_quantized_latent_representation_image = gr.Image( + label="Original quantized latent representation", type="numpy", height=1200, visible=True) + + sound2sound_origin_microphone_latent_representation_image = gr.Image(label="Original latent representation", + type="numpy", height=1200, + visible=False) + sound2sound_origin_microphone_quantized_latent_representation_image = gr.Image( + label="Original quantized latent representation", type="numpy", height=1200, visible=False) + + sound2sound_new_sound_latent_representation_image = gr.Image(label="New latent representation", + type="numpy", height=1200) + sound2sound_new_sound_quantized_latent_representation_image = gr.Image( + label="New sound quantized latent representation", type="numpy", height=1200) + + sound2sound_origin_upload_audio.change(receive_uopoad_origin_audio, + inputs=[sound2sound_duration_slider, sound2sound_origin_source_radio, sound2sound_origin_upload_audio, + sound2sound_origin_microphone_audio, inpaintWithText_state], + outputs=[sound2sound_origin_spectrogram_upload_image, + sound2sound_origin_phase_upload_image, + sound2sound_origin_spectrogram_microphone_image, + sound2sound_origin_phase_microphone_image, + sound2sound_origin_upload_latent_representation_image, + sound2sound_origin_upload_quantized_latent_representation_image, + sound2sound_origin_microphone_latent_representation_image, + sound2sound_origin_microphone_quantized_latent_representation_image, + inpaintWithText_state]) + sound2sound_origin_microphone_audio.change(receive_uopoad_origin_audio, + inputs=[sound2sound_duration_slider, sound2sound_origin_source_radio, sound2sound_origin_upload_audio, + sound2sound_origin_microphone_audio, inpaintWithText_state], + outputs=[sound2sound_origin_spectrogram_upload_image, + sound2sound_origin_phase_upload_image, + sound2sound_origin_spectrogram_microphone_image, + sound2sound_origin_phase_microphone_image, + sound2sound_origin_upload_latent_representation_image, + sound2sound_origin_upload_quantized_latent_representation_image, + sound2sound_origin_microphone_latent_representation_image, + sound2sound_origin_microphone_quantized_latent_representation_image, + inpaintWithText_state]) + + sound2sound_sample_button.click(sound2sound_sample, + inputs=[sound2sound_origin_spectrogram_upload_image, + sound2sound_origin_spectrogram_microphone_image, + text2sound_prompts_textbox, + text2sound_negative_prompts_textbox, + sound2sound_batchsize_slider, + sound2sound_guidance_scale_slider, + sound2sound_sampler_radio, + sound2sound_sample_steps_slider, + sound2sound_origin_source_radio, + sound2sound_noising_strength_slider, + sound2sound_seed_textbox, + sound2sound_inpaint_area_radio, + inpaintWithText_state], + outputs=[sound2sound_new_sound_latent_representation_image, + sound2sound_new_sound_quantized_latent_representation_image, + sound2sound_new_sound_spectrogram_image, + sound2sound_new_sound_phase_image, + sound2sound_new_sound_audio, + sound2sound_sample_index_slider, + sound2sound_seed_textbox, + inpaintWithText_state]) + + sound2sound_sample_index_slider.change(show_sound2sound_sample, + inputs=[sound2sound_sample_index_slider, inpaintWithText_state], + outputs=[sound2sound_new_sound_latent_representation_image, + sound2sound_new_sound_quantized_latent_representation_image, + sound2sound_new_sound_spectrogram_image, + sound2sound_new_sound_phase_image, + sound2sound_new_sound_audio]) + + sound2sound_origin_source_radio.change(sound2sound_switch_origin_source, + inputs=[sound2sound_origin_source_radio], + outputs=[sound2sound_origin_upload_audio, + sound2sound_origin_microphone_audio, + sound2sound_origin_spectrogram_upload_image, + sound2sound_origin_phase_upload_image, + sound2sound_origin_spectrogram_microphone_image, + sound2sound_origin_phase_microphone_image, + sound2sound_origin_upload_latent_representation_image, + sound2sound_origin_upload_quantized_latent_representation_image, + sound2sound_origin_microphone_latent_representation_image, + sound2sound_origin_microphone_quantized_latent_representation_image]) diff --git a/webUI/natural_language_guided/text2sound.py b/webUI/natural_language_guided/text2sound.py new file mode 100644 index 0000000000000000000000000000000000000000..0d6ae2471c60cdce611f1f9f373319fa90d1baa8 --- /dev/null +++ b/webUI/natural_language_guided/text2sound.py @@ -0,0 +1,212 @@ +import gradio as gr +import numpy as np +import torch +from matplotlib import pyplot as plt + +from model.DiffSynthSampler import DiffSynthSampler +from tools import safe_int +from webUI.natural_language_guided.utils import latent_representation_to_Gradio_image, \ + encodeBatch2GradioOutput_STFT, add_instrument + + +def get_text2sound_module(gradioWebUI, text2sound_state, virtual_instruments_state): + # Load configurations + uNet = gradioWebUI.uNet + freq_resolution, time_resolution = gradioWebUI.freq_resolution, gradioWebUI.time_resolution + VAE_scale = gradioWebUI.VAE_scale + height, width, channels = int(freq_resolution / VAE_scale), int(time_resolution / VAE_scale), gradioWebUI.channels + + timesteps = gradioWebUI.timesteps + VAE_quantizer = gradioWebUI.VAE_quantizer + VAE_decoder = gradioWebUI.VAE_decoder + CLAP = gradioWebUI.CLAP + CLAP_tokenizer = gradioWebUI.CLAP_tokenizer + device = gradioWebUI.device + squared = gradioWebUI.squared + sample_rate = gradioWebUI.sample_rate + noise_strategy = gradioWebUI.noise_strategy + + def diffusion_random_sample(text2sound_prompts, text2sound_negative_prompts, text2sound_batchsize, + text2sound_duration, + text2sound_guidance_scale, text2sound_sampler, + text2sound_sample_steps, text2sound_seed, + text2sound_dict): + text2sound_sample_steps = int(text2sound_sample_steps) + text2sound_seed = safe_int(text2sound_seed, 12345678) + + width = int(time_resolution * ((text2sound_duration + 1) / 4) / VAE_scale) + + text2sound_batchsize = int(text2sound_batchsize) + + text2sound_embedding = \ + CLAP.get_text_features(**CLAP_tokenizer([text2sound_prompts], padding=True, return_tensors="pt"))[0].to( + device) + + CFG = int(text2sound_guidance_scale) + + mySampler = DiffSynthSampler(timesteps, height=height, channels=channels, noise_strategy=noise_strategy) + negative_condition = \ + CLAP.get_text_features(**CLAP_tokenizer([text2sound_negative_prompts], padding=True, return_tensors="pt"))[ + 0] + mySampler.activate_classifier_free_guidance(CFG, negative_condition.to(device)) + + mySampler.respace(list(np.linspace(0, timesteps - 1, text2sound_sample_steps, dtype=np.int32))) + + condition = text2sound_embedding.repeat(text2sound_batchsize, 1) + + latent_representations, initial_noise = \ + mySampler.sample(model=uNet, shape=(text2sound_batchsize, channels, height, width), seed=text2sound_seed, + return_tensor=True, condition=condition, sampler=text2sound_sampler) + + latent_representations = latent_representations[-1] + print(latent_representations[0, 0, :3, :3]) + + latent_representation_gradio_images = [] + quantized_latent_representation_gradio_images = [] + new_sound_spectrogram_gradio_images = [] + new_sound_phase_gradio_images = [] + new_sound_rec_signals_gradio = [] + + quantized_latent_representations, loss, (_, _, _) = VAE_quantizer(latent_representations) + # Todo: remove hard-coding + flipped_log_spectrums, flipped_phases, rec_signals, _, _, _ = encodeBatch2GradioOutput_STFT(VAE_decoder, + quantized_latent_representations, + resolution=( + 512, + width * VAE_scale), + original_STFT_batch=None + ) + + for i in range(text2sound_batchsize): + latent_representation_gradio_images.append(latent_representation_to_Gradio_image(latent_representations[i])) + quantized_latent_representation_gradio_images.append( + latent_representation_to_Gradio_image(quantized_latent_representations[i])) + new_sound_spectrogram_gradio_images.append(flipped_log_spectrums[i]) + new_sound_phase_gradio_images.append(flipped_phases[i]) + new_sound_rec_signals_gradio.append((sample_rate, rec_signals[i])) + + text2sound_dict["latent_representations"] = latent_representations.to("cpu").detach().numpy() + text2sound_dict["quantized_latent_representations"] = quantized_latent_representations.to("cpu").detach().numpy() + text2sound_dict["latent_representation_gradio_images"] = latent_representation_gradio_images + text2sound_dict["quantized_latent_representation_gradio_images"] = quantized_latent_representation_gradio_images + text2sound_dict["new_sound_spectrogram_gradio_images"] = new_sound_spectrogram_gradio_images + text2sound_dict["new_sound_phase_gradio_images"] = new_sound_phase_gradio_images + text2sound_dict["new_sound_rec_signals_gradio"] = new_sound_rec_signals_gradio + + text2sound_dict["condition"] = condition.to("cpu").detach().numpy() + text2sound_dict["negative_condition"] = negative_condition.to("cpu").detach().numpy() + text2sound_dict["guidance_scale"] = CFG + text2sound_dict["sampler"] = text2sound_sampler + + return {text2sound_latent_representation_image: text2sound_dict["latent_representation_gradio_images"][0], + text2sound_quantized_latent_representation_image: + text2sound_dict["quantized_latent_representation_gradio_images"][0], + text2sound_sampled_spectrogram_image: text2sound_dict["new_sound_spectrogram_gradio_images"][0], + text2sound_sampled_phase_image: text2sound_dict["new_sound_phase_gradio_images"][0], + text2sound_sampled_audio: text2sound_dict["new_sound_rec_signals_gradio"][0], + text2sound_seed_textbox: text2sound_seed, + text2sound_state: text2sound_dict, + text2sound_sample_index_slider: gr.update(minimum=0, maximum=text2sound_batchsize - 1, value=0, step=1, + visible=True, + label="Sample index.", + info="Swipe to view other samples")} + + def show_random_sample(sample_index, text2sound_dict): + sample_index = int(sample_index) + text2sound_dict["sample_index"] = sample_index + return {text2sound_latent_representation_image: text2sound_dict["latent_representation_gradio_images"][ + sample_index], + text2sound_quantized_latent_representation_image: + text2sound_dict["quantized_latent_representation_gradio_images"][sample_index], + text2sound_sampled_spectrogram_image: text2sound_dict["new_sound_spectrogram_gradio_images"][ + sample_index], + text2sound_sampled_phase_image: text2sound_dict["new_sound_phase_gradio_images"][ + sample_index], + text2sound_sampled_audio: text2sound_dict["new_sound_rec_signals_gradio"][sample_index]} + + + + def save_virtual_instrument(sample_index, virtual_instrument_name, text2sound_dict, virtual_instruments_dict): + + virtual_instruments_dict = add_instrument(text2sound_dict, virtual_instruments_dict, virtual_instrument_name, sample_index) + + return {virtual_instruments_state: virtual_instruments_dict, + text2sound_instrument_name_textbox: gr.Textbox(label="Instrument name", lines=1, + placeholder=f"Saved as {virtual_instrument_name}!")} + + with gr.Tab("Text2sound"): + gr.Markdown("Use neural networks to select random sounds using your favorite instrument!") + with gr.Row(variant="panel"): + with gr.Column(scale=3): + text2sound_prompts_textbox = gr.Textbox(label="Positive prompt", lines=2, value="organ") + text2sound_negative_prompts_textbox = gr.Textbox(label="Negative prompt", lines=2, value="") + + with gr.Column(scale=1): + text2sound_sampling_button = gr.Button(variant="primary", + value="Generate a batch of samples and show " + "the first one", + scale=1) + text2sound_sample_index_slider = gr.Slider(minimum=0, maximum=3, value=0, step=1.0, visible=False, + label="Sample index", + info="Swipe to view other samples") + with gr.Row(variant="panel"): + with gr.Column(variant="panel"): + text2sound_sample_steps_slider = gradioWebUI.get_sample_steps_slider() + text2sound_sampler_radio = gradioWebUI.get_sampler_radio() + text2sound_batchsize_slider = gradioWebUI.get_batchsize_slider() + text2sound_duration_slider = gradioWebUI.get_duration_slider() + text2sound_guidance_scale_slider = gradioWebUI.get_guidance_scale_slider() + text2sound_seed_textbox = gradioWebUI.get_seed_textbox() + + with gr.Column(variant="panel"): + with gr.Row(variant="panel"): + text2sound_sampled_spectrogram_image = gr.Image(label="Sampled spectrogram", type="numpy", height=600) + text2sound_sampled_phase_image = gr.Image(label="Sampled phase", type="numpy", height=600) + text2sound_sampled_audio = gr.Audio(type="numpy", label="Play") + + with gr.Row(variant="panel"): + text2sound_instrument_name_textbox = gr.Textbox(label="Instrument name", lines=1, + placeholder="Name of your instrument") + text2sound_save_instrument_button = gr.Button(variant="primary", + value="Save instrument", + scale=1) + + with gr.Row(variant="panel"): + text2sound_latent_representation_image = gr.Image(label="Sampled latent representation", type="numpy", + height=200, width=100) + text2sound_quantized_latent_representation_image = gr.Image(label="Quantized latent representation", + type="numpy", height=200, width=100) + + text2sound_sampling_button.click(diffusion_random_sample, + inputs=[text2sound_prompts_textbox, + text2sound_negative_prompts_textbox, + text2sound_batchsize_slider, + text2sound_duration_slider, + text2sound_guidance_scale_slider, text2sound_sampler_radio, + text2sound_sample_steps_slider, + text2sound_seed_textbox, + text2sound_state], + outputs=[text2sound_latent_representation_image, + text2sound_quantized_latent_representation_image, + text2sound_sampled_spectrogram_image, + text2sound_sampled_phase_image, + text2sound_sampled_audio, + text2sound_seed_textbox, + text2sound_state, + text2sound_sample_index_slider]) + + text2sound_save_instrument_button.click(save_virtual_instrument, + inputs=[text2sound_sample_index_slider, + text2sound_instrument_name_textbox, + text2sound_state, + virtual_instruments_state], + outputs=[virtual_instruments_state, + text2sound_instrument_name_textbox]) + + text2sound_sample_index_slider.change(show_random_sample, + inputs=[text2sound_sample_index_slider, text2sound_state], + outputs=[text2sound_latent_representation_image, + text2sound_quantized_latent_representation_image, + text2sound_sampled_spectrogram_image, + text2sound_sampled_phase_image, + text2sound_sampled_audio]) diff --git a/webUI/natural_language_guided/track_maker.py b/webUI/natural_language_guided/track_maker.py new file mode 100644 index 0000000000000000000000000000000000000000..7e7515fcbf4a2a6918ec7e10a9aa33858efda9c0 --- /dev/null +++ b/webUI/natural_language_guided/track_maker.py @@ -0,0 +1,209 @@ +import numpy as np +import torch + +from model.DiffSynthSampler import DiffSynthSampler +from webUI.natural_language_guided.utils import encodeBatch2GradioOutput_STFT +import mido +import torchaudio.transforms as transforms +from tqdm import tqdm + + +def pitch_shift_audio(waveform, sample_rate, n_steps): + # 如果输入是 numpy 数组,则转换为 torch.Tensor + if isinstance(waveform, np.ndarray): + waveform = torch.from_numpy(waveform) + + # 确保 waveform 的类型为 torch.float32 + waveform = waveform.to(torch.float32) + + # 创建 pitch_shift 变换 + pitch_shift = transforms.PitchShift(sample_rate, n_steps) + + # 返回 pitch_shift 处理后的结果,转换为 numpy 数组 + return pitch_shift(waveform).detach().numpy() + + +class NoteEvent: + def __init__(self, note, velocity, start_time, duration): + self.note = note + self.velocity = velocity + self.start_time = start_time # In ticks + self.duration = duration # In ticks + + def __str__(self): + return f"Note {self.note}, velocity {self.velocity}, start_time {self.start_time}, duration {self.duration}" + + +class Track: + def __init__(self, track, ticks_per_beat): + self.tempo_events = self._parse_tempo_events(track) + self.events = self._parse_note_events(track) + self.ticks_per_beat = ticks_per_beat + + def _parse_tempo_events(self, track): + tempo_events = [] + current_tempo = 500000 # Default MIDI tempo is 120 BPM which is 500000 microseconds per beat + for msg in track: + if msg.type == 'set_tempo': + tempo_events.append((msg.time, msg.tempo)) + elif not msg.is_meta: + tempo_events.append((msg.time, current_tempo)) + return tempo_events + + def _parse_note_events(self, track): + events = [] + start_time = 0 + for msg in track: + if not msg.is_meta: + start_time += msg.time + if msg.type == 'note_on' and msg.velocity > 0: + note_on_time = start_time + elif msg.type == 'note_on' and msg.velocity == 0: + duration = start_time - note_on_time + events.append(NoteEvent(msg.note, msg.velocity, note_on_time, duration)) + return events + + def synthesize_track(self, diffSynthSampler, sample_rate=16000): + track_audio = np.zeros(int(self._get_total_time() * sample_rate), dtype=np.float32) + current_tempo = 500000 # Start with default MIDI tempo 120 BPM + duration_note_mapping = {} + + for event in tqdm(self.events[:10]): + current_tempo = self._get_tempo_at(event.start_time) + seconds_per_tick = mido.tick2second(1, self.ticks_per_beat, current_tempo) + start_time_sec = event.start_time * seconds_per_tick + # Todo: set a minimum duration + duration_sec = event.duration * seconds_per_tick + duration_sec = max(duration_sec, 0.75) + start_sample = int(start_time_sec * sample_rate) + if not (str(duration_sec) in duration_note_mapping): + note_sample = diffSynthSampler(event.velocity, duration_sec) + duration_note_mapping[str(duration_sec)] = note_sample / np.max(np.abs(note_sample)) + + # note_audio = pyrb.pitch_shift(duration_note_mapping[str(duration_sec)], sample_rate, event.note - 52) + note_audio = pitch_shift_audio(duration_note_mapping[str(duration_sec)], sample_rate, event.note - 52) + end_sample = start_sample + len(note_audio) + track_audio[start_sample:end_sample] += note_audio + + return track_audio + + def _get_tempo_at(self, time_tick): + current_tempo = 500000 # Start with default MIDI tempo 120 BPM + elapsed_ticks = 0 + + for tempo_change in self.tempo_events: + if elapsed_ticks + tempo_change[0] > time_tick: + return current_tempo + elapsed_ticks += tempo_change[0] + current_tempo = tempo_change[1] + + return current_tempo + + def _get_total_time(self): + total_time = 0 + current_tempo = 500000 # Start with default MIDI tempo 120 BPM + + for event in self.events: + current_tempo = self._get_tempo_at(event.start_time) + seconds_per_tick = mido.tick2second(1, self.ticks_per_beat, current_tempo) + total_time += event.duration * seconds_per_tick + + return total_time + + +class DiffSynth: + def __init__(self, instruments_configs, noise_prediction_model, VAE_quantizer, VAE_decoder, text_encoder, CLAP_tokenizer, device, + model_sample_rate=16000, timesteps=1000, channels=4, freq_resolution=512, time_resolution=256, VAE_scale=4, squared=False): + + self.noise_prediction_model = noise_prediction_model + self.VAE_quantizer = VAE_quantizer + self.VAE_decoder = VAE_decoder + self.device = device + self.model_sample_rate = model_sample_rate + self.timesteps = timesteps + self.channels = channels + self.freq_resolution = freq_resolution + self.time_resolution = time_resolution + self.height = int(freq_resolution/VAE_scale) + self.VAE_scale = VAE_scale + self.squared = squared + self.text_encoder = text_encoder + self.CLAP_tokenizer = CLAP_tokenizer + + # instruments_configs 是字典 string -> (condition, negative_condition, guidance_scale, sample_steps, seed, initial_noise, sampler) + self.instruments_configs = instruments_configs + self.diffSynthSamplers = {} + self._update_instruments() + + + def _update_instruments(self): + + def diffSynthSamplerWrapper(instruments_config): + + def diffSynthSampler(velocity, duration_sec, sample_rate=16000): + + condition = self.text_encoder.get_text_features(**self.CLAP_tokenizer([""], padding=True, return_tensors="pt")).to(self.device) + sample_steps = instruments_config['sample_steps'] + sampler = instruments_config['sampler'] + noising_strength = instruments_config['noising_strength'] + latent_representation = instruments_config['latent_representation'] + attack = instruments_config['attack'] + before_release = instruments_config['before_release'] + + assert sample_rate == self.model_sample_rate, "sample_rate != model_sample_rate" + + width = int(self.time_resolution * ((duration_sec + 1) / 4) / self.VAE_scale) + + mySampler = DiffSynthSampler(self.timesteps, height=128, channels=4, noise_strategy="repeat", mute=True) + mySampler.respace(list(np.linspace(0, self.timesteps - 1, sample_steps, dtype=np.int32))) + + # mask = 1, freeze + latent_mask = torch.zeros((1, 1, self.height, width), dtype=torch.float32).to(self.device) + latent_mask[:, :, :, :int(self.time_resolution * (attack / 4) / self.VAE_scale)] = 1.0 + latent_mask[:, :, :, -int(self.time_resolution * ((before_release+1) / 4) / self.VAE_scale):] = 1.0 + + latent_representations, _ = \ + mySampler.inpaint_sample(model=self.noise_prediction_model, shape=(1, self.channels, self.height, width), + noising_strength=noising_strength, condition=condition, + guide_img=latent_representation, mask=latent_mask, return_tensor=True, + sampler=sampler, + use_dynamic_mask=True, end_noise_level_ratio=0.0, + mask_flexivity=1.0) + + + latent_representations = latent_representations[-1] + + quantized_latent_representations, _, (_, _, _) = self.VAE_quantizer(latent_representations) + # Todo: remove hard-coding + + flipped_log_spectrums, flipped_phases, rec_signals, _, _, _ = encodeBatch2GradioOutput_STFT(self.VAE_decoder, + quantized_latent_representations, + resolution=( + 512, + width * self.VAE_scale), + original_STFT_batch=None, + ) + + + return rec_signals[0] + + return diffSynthSampler + + for key in self.instruments_configs.keys(): + self.diffSynthSamplers[key] = diffSynthSamplerWrapper(self.instruments_configs[key]) + + def get_music(self, mid, instrument_names, sample_rate=16000): + tracks = [Track(t, mid.ticks_per_beat) for t in mid.tracks] + assert len(tracks) == len(instrument_names), f"len(tracks) = {len(tracks)} != {len(instrument_names)} = len(instrument_names)" + + track_audios = [track.synthesize_track(self.diffSynthSamplers[instrument_names[i]], sample_rate=sample_rate) for i, track in enumerate(tracks)] + + # 将所有音轨填充至最长音轨的长度,以便它们可以被叠加 + max_length = max(len(audio) for audio in track_audios) + full_audio = np.zeros(max_length, dtype=np.float32) # 初始化全音频数组为零 + for audio in track_audios: + # 音轨可能不够长,需要填充零 + padded_audio = np.pad(audio, (0, max_length - len(audio)), 'constant') + full_audio += padded_audio # 叠加音轨 + + return full_audio \ No newline at end of file diff --git a/webUI/natural_language_guided/utils.py b/webUI/natural_language_guided/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..404515f2caeba68ff8696c00b7766b27c0adf94e --- /dev/null +++ b/webUI/natural_language_guided/utils.py @@ -0,0 +1,174 @@ +import librosa +import numpy as np +import torch + +from tools import np_power_to_db, decode_stft, depad_STFT + + +def spectrogram_to_Gradio_image(spc): + ### input: spc [np.ndarray] + frequency_resolution, time_resolution = spc.shape[-2], spc.shape[-1] + spc = np.reshape(spc, (frequency_resolution, time_resolution)) + + # Todo: + magnitude_spectrum = np.abs(spc) + log_spectrum = np_power_to_db(magnitude_spectrum) + flipped_log_spectrum = np.flipud(log_spectrum) + + colorful_spc = np.ones((frequency_resolution, time_resolution, 3)) * -80.0 + colorful_spc[:, :, 0] = flipped_log_spectrum + colorful_spc[:, :, 1] = flipped_log_spectrum + colorful_spc[:, :, 2] = np.ones((frequency_resolution, time_resolution)) * -60.0 + # Rescale to 0-255 and convert to uint8 + rescaled = (colorful_spc + 80.0) / 80.0 + rescaled = (255.0 * rescaled).astype(np.uint8) + return rescaled + + +def phase_to_Gradio_image(phase): + ### input: spc [np.ndarray] + frequency_resolution, time_resolution = phase.shape[-2], phase.shape[-1] + phase = np.reshape(phase, (frequency_resolution, time_resolution)) + + # Todo: + flipped_phase = np.flipud(phase) + flipped_phase = (flipped_phase + 1.0) / 2.0 + + colorful_spc = np.zeros((frequency_resolution, time_resolution, 3)) + colorful_spc[:, :, 0] = flipped_phase + colorful_spc[:, :, 1] = flipped_phase + colorful_spc[:, :, 2] = 0.2 + # Rescale to 0-255 and convert to uint8 + rescaled = (255.0 * colorful_spc).astype(np.uint8) + return rescaled + + +def latent_representation_to_Gradio_image(latent_representation): + # input: latent_representation [torch.tensor] + if not isinstance(latent_representation, np.ndarray): + latent_representation = latent_representation.to("cpu").detach().numpy() + image = latent_representation + + def normalize_image(img): + min_val = img.min() + max_val = img.max() + normalized_img = ((img - min_val) / (max_val - min_val) * 255) + return normalized_img + + image[0, :, :] = normalize_image(image[0, :, :]) + image[1, :, :] = normalize_image(image[1, :, :]) + image[2, :, :] = normalize_image(image[2, :, :]) + image[3, :, :] = normalize_image(image[3, :, :]) + image_transposed = np.transpose(image, (1, 2, 0)) + enlarged_image = np.repeat(image_transposed, 8, axis=0) + enlarged_image = np.repeat(enlarged_image, 8, axis=1) + return np.flipud(enlarged_image).astype(np.uint8) + + +def InputBatch2Encode_STFT(encoder, STFT_batch, resolution=(512, 256), quantizer=None, squared=True): + """Transform batch of numpy spectrogram's into signals and encodings.""" + # Todo: remove resolution hard-coding + frequency_resolution, time_resolution = resolution + + device = next(encoder.parameters()).device + if not (quantizer is None): + latent_representation_batch = encoder(STFT_batch.to(device)) + quantized_latent_representation_batch, loss, (_, _, _) = quantizer(latent_representation_batch) + else: + mu, logvar, latent_representation_batch = encoder(STFT_batch.to(device)) + quantized_latent_representation_batch = None + + STFT_batch = STFT_batch.to("cpu").detach().numpy() + + origin_flipped_log_spectrums, origin_flipped_phases, origin_signals = [], [], [] + for STFT in STFT_batch: + + padded_D_rec = decode_stft(STFT) + D_rec = depad_STFT(padded_D_rec) + spc = np.abs(D_rec) + phase = np.angle(D_rec) + + flipped_log_spectrum = spectrogram_to_Gradio_image(spc) + flipped_phase = phase_to_Gradio_image(phase) + + # get_audio + rec_signal = librosa.istft(D_rec, hop_length=256, win_length=1024) + + origin_flipped_log_spectrums.append(flipped_log_spectrum) + origin_flipped_phases.append(flipped_phase) + origin_signals.append(rec_signal) + + return origin_flipped_log_spectrums, origin_flipped_phases, origin_signals, \ + latent_representation_batch, quantized_latent_representation_batch + + +def encodeBatch2GradioOutput_STFT(decoder, latent_vector_batch, resolution=(512, 256), original_STFT_batch=None): + """Show a spectrogram.""" + # Todo: remove resolution hard-coding + frequency_resolution, time_resolution = resolution + + if isinstance(latent_vector_batch, np.ndarray): + latent_vector_batch = torch.from_numpy(latent_vector_batch).to(next(decoder.parameters()).device) + + reconstruction_batch = decoder(latent_vector_batch).to("cpu").detach().numpy() + + flipped_log_spectrums, flipped_phases, rec_signals = [], [], [] + flipped_log_spectrums_with_original_amp, flipped_phases_with_original_amp, rec_signals_with_original_amp = [], [], [] + + for index, STFT in enumerate(reconstruction_batch): + padded_D_rec = decode_stft(STFT) + D_rec = depad_STFT(padded_D_rec) + spc = np.abs(D_rec) + phase = np.angle(D_rec) + + flipped_log_spectrum = spectrogram_to_Gradio_image(spc) + flipped_phase = phase_to_Gradio_image(phase) + + # get_audio + rec_signal = librosa.istft(D_rec, hop_length=256, win_length=1024) + + flipped_log_spectrums.append(flipped_log_spectrum) + flipped_phases.append(flipped_phase) + rec_signals.append(rec_signal) + + ########################################## + + if original_STFT_batch is not None: + STFT[0, :, :] = original_STFT_batch[index, 0, :, :] + + padded_D_rec = decode_stft(STFT) + D_rec = depad_STFT(padded_D_rec) + spc = np.abs(D_rec) + phase = np.angle(D_rec) + + flipped_log_spectrum = spectrogram_to_Gradio_image(spc) + flipped_phase = phase_to_Gradio_image(phase) + + # get_audio + rec_signal = librosa.istft(D_rec, hop_length=256, win_length=1024) + + flipped_log_spectrums_with_original_amp.append(flipped_log_spectrum) + flipped_phases_with_original_amp.append(flipped_phase) + rec_signals_with_original_amp.append(rec_signal) + + + return flipped_log_spectrums, flipped_phases, rec_signals, \ + flipped_log_spectrums_with_original_amp, flipped_phases_with_original_amp, rec_signals_with_original_amp + + + +def add_instrument(source_dict, virtual_instruments_dict, virtual_instrument_name, sample_index): + + virtual_instruments = virtual_instruments_dict["virtual_instruments"] + virtual_instrument = { + "latent_representation": source_dict["latent_representations"][sample_index], + "quantized_latent_representation": source_dict["quantized_latent_representations"][sample_index], + "sampler": source_dict["sampler"], + "signal": source_dict["new_sound_rec_signals_gradio"][sample_index], + "spectrogram_gradio_image": source_dict["new_sound_spectrogram_gradio_images"][ + sample_index], + "phase_gradio_image": source_dict["new_sound_phase_gradio_images"][ + sample_index]} + virtual_instruments[virtual_instrument_name] = virtual_instrument + virtual_instruments_dict["virtual_instruments"] = virtual_instruments + return virtual_instruments_dict \ No newline at end of file