seungheondoh commited on
Commit
e48ca55
·
1 Parent(s): 7ccf3fd
Files changed (4) hide show
  1. app.py +81 -4
  2. model/bart.py +151 -0
  3. model/modules.py +95 -0
  4. utils/audio_utils.py +247 -0
app.py CHANGED
@@ -1,7 +1,84 @@
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!"
 
 
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
  import gradio as gr
4
+ from timeit import default_timer as timer
5
+ import torch
6
+ import numpy as np
7
+ import pandas as pd
8
+ from huggingface_hub import hf_hub_download
9
+ from model.bart import BartCaptionModel
10
+ from utils.audio_utils import load_audio, STR_CH_FIRST
11
 
12
+ if os.path.isfile("transfer.pth") == False:
13
+ torch.hub.download_url_to_file('https://huggingface.co/seungheondoh/lp-music-caps/resolve/main/transfer.pth', 'transfer.pth')
14
+ torch.hub.download_url_to_file('https://huggingface.co/seungheondoh/lp-music-caps/resolve/main/electronic.mp3', 'electronic.mp3')
15
+ torch.hub.download_url_to_file('https://huggingface.co/seungheondoh/lp-music-caps/resolve/main/orchestra.wav', 'orchestra.wav')
16
 
17
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
18
+
19
+ example_list = ['electronic.mp3', 'orchestra.wav']
20
+ model = BartCaptionModel(max_length = 128)
21
+ pretrained_object = torch.load('./transfer.pth', map_location='cpu')
22
+ state_dict = pretrained_object['state_dict']
23
+ model.load_state_dict(state_dict)
24
+ torch.cuda.set_device(device)
25
+ model = model.cuda(device)
26
+ model.eval()
27
+
28
+ def get_audio(audio_path, duration=10, target_sr=16000):
29
+ n_samples = int(duration * target_sr)
30
+ audio, sr = load_audio(
31
+ path= audio_path,
32
+ ch_format= STR_CH_FIRST,
33
+ sample_rate= target_sr,
34
+ downmix_to_mono= True,
35
+ )
36
+ if len(audio.shape) == 2:
37
+ audio = audio.mean(0, False) # to mono
38
+ input_size = int(n_samples)
39
+ if audio.shape[-1] < input_size: # pad sequence
40
+ pad = np.zeros(input_size)
41
+ pad[: audio.shape[-1]] = audio
42
+ audio = pad
43
+ ceil = int(audio.shape[-1] // n_samples)
44
+ audio = torch.from_numpy(np.stack(np.split(audio[:ceil * n_samples], ceil)).astype('float32'))
45
+ return audio
46
+
47
+ def captioning(audio_path):
48
+ audio_tensor = get_audio(audio_path = audio_path)
49
+ if device is not None:
50
+ audio_tensor = audio_tensor.to(device)
51
+ with torch.no_grad():
52
+ output = model.generate(
53
+ samples=audio_tensor,
54
+ num_beams=5,
55
+ )
56
+ inference = ""
57
+ number_of_chunks = range(audio_tensor.shape[0])
58
+ for chunk, text in zip(number_of_chunks, output):
59
+ time = f"[{chunk * 10}:00-{(chunk + 1) * 10}:00]"
60
+ inference += f"{time}\n{text} \n \n"
61
+ return inference
62
+
63
+ title = "Interactive demo: Music Captioning 🤖🎵"
64
+ description = """
65
+ <p style='text-align: center'> LP-MusicCaps: LLM-Based Pseudo Music Captioning</p>
66
+ <p style='text-align: center'> SeungHeon Doh, Keunwoo Choi, Jongpil Lee, Juhan Nam, ISMIR 2023</p>
67
+ <p style='text-align: center'> <a href='#' target='_blank'>ArXiv</a> | <a href='https://github.com/seungheondoh/lp-music-caps' target='_blank'>Github</a> | <a href='https://github.com/seungheondoh/lp-music-caps' target='_blank'>LP-MusicCaps-Dataset</a> </p>
68
+ <p style='text-align: center'> To use it, simply upload your audio and click 'submit', or click one of the examples to load them. Read more at the links below. </p>
69
+ """
70
+ article = "<p style='text-align: center'><a href='https://github.com/seungheondoh/lp-music-caps' target='_blank'>LP-MusicCaps Github</a> | <a href='#' target='_blank'>LP-MusicCaps Paper</a></p>"
71
+
72
+
73
+ demo = gr.Interface(fn=captioning,
74
+ inputs=gr.Audio(type="filepath"),
75
+ outputs=[
76
+ gr.Textbox(label="Caption generated by LP-MusicCaps Transfer Model"),
77
+ ],
78
+ examples=example_list,
79
+ title=title,
80
+ description=description,
81
+ article=article,
82
+ cache_examples=False
83
+ )
84
+ demo.launch()
model/bart.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from .modules import AudioEncoder
6
+ from transformers import BartForConditionalGeneration, BartTokenizer, BartConfig
7
+
8
+ class BartCaptionModel(nn.Module):
9
+ def __init__(self, n_mels=128, num_of_conv=6, sr=16000, duration=10, max_length=128, label_smoothing=0.1, bart_type="facebook/bart-base", audio_dim=768):
10
+ super(BartCaptionModel, self).__init__()
11
+ # non-finetunning case
12
+ bart_config = BartConfig.from_pretrained(bart_type)
13
+ self.tokenizer = BartTokenizer.from_pretrained(bart_type)
14
+ self.bart = BartForConditionalGeneration(bart_config)
15
+
16
+ self.n_sample = sr * duration
17
+ self.hop_length = int(0.01 * sr) # hard coding hop_size
18
+ self.n_frames = int(self.n_sample // self.hop_length)
19
+ self.num_of_stride_conv = num_of_conv - 1
20
+ self.n_ctx = int(self.n_frames // 2**self.num_of_stride_conv) + 1
21
+ self.audio_encoder = AudioEncoder(
22
+ n_mels = n_mels, # hard coding n_mel
23
+ n_ctx = self.n_ctx,
24
+ audio_dim = audio_dim,
25
+ text_dim = self.bart.config.hidden_size,
26
+ num_of_stride_conv = self.num_of_stride_conv
27
+ )
28
+
29
+ self.max_length = max_length
30
+ self.loss_fct = nn.CrossEntropyLoss(label_smoothing= label_smoothing, ignore_index=-100)
31
+
32
+ @property
33
+ def device(self):
34
+ return list(self.parameters())[0].device
35
+
36
+ def shift_tokens_right(self, input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
37
+ """
38
+ Shift input ids one token to the right.ls
39
+ """
40
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
41
+ shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
42
+ shifted_input_ids[:, 0] = decoder_start_token_id
43
+
44
+ if pad_token_id is None:
45
+ raise ValueError("self.model.config.pad_token_id has to be defined.")
46
+ # replace possible -100 values in labels by `pad_token_id`
47
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
48
+ return shifted_input_ids
49
+
50
+ def forward_encoder(self, audio):
51
+ audio_embs = self.audio_encoder(audio)
52
+ encoder_outputs = self.bart.model.encoder(
53
+ input_ids=None,
54
+ inputs_embeds=audio_embs,
55
+ return_dict=True
56
+ )["last_hidden_state"]
57
+ return encoder_outputs, audio_embs
58
+
59
+ def forward_decoder(self, text, encoder_outputs):
60
+ text = self.tokenizer(text,
61
+ padding='longest',
62
+ truncation=True,
63
+ max_length=self.max_length,
64
+ return_tensors="pt")
65
+ input_ids = text["input_ids"].to(self.device)
66
+ attention_mask = text["attention_mask"].to(self.device)
67
+
68
+ decoder_targets = input_ids.masked_fill(
69
+ input_ids == self.tokenizer.pad_token_id, -100
70
+ )
71
+
72
+ decoder_input_ids = self.shift_tokens_right(
73
+ decoder_targets, self.bart.config.pad_token_id, self.bart.config.decoder_start_token_id
74
+ )
75
+
76
+ decoder_outputs = self.bart(
77
+ input_ids=None,
78
+ attention_mask=None,
79
+ decoder_input_ids=decoder_input_ids,
80
+ decoder_attention_mask=attention_mask,
81
+ inputs_embeds=None,
82
+ labels=None,
83
+ encoder_outputs=(encoder_outputs,),
84
+ return_dict=True
85
+ )
86
+ lm_logits = decoder_outputs["logits"]
87
+ loss = self.loss_fct(lm_logits.view(-1, self.tokenizer.vocab_size), decoder_targets.view(-1))
88
+ return loss
89
+
90
+ def forward(self, audio, text):
91
+ encoder_outputs, _ = self.forward_encoder(audio)
92
+ loss = self.forward_decoder(text, encoder_outputs)
93
+ return loss
94
+
95
+ def generate(self,
96
+ samples,
97
+ use_nucleus_sampling=False,
98
+ num_beams=5,
99
+ max_length=128,
100
+ min_length=2,
101
+ top_p=0.9,
102
+ repetition_penalty=1.0,
103
+ ):
104
+
105
+ # self.bart.force_bos_token_to_be_generated = True
106
+ audio_embs = self.audio_encoder(samples)
107
+ encoder_outputs = self.bart.model.encoder(
108
+ input_ids=None,
109
+ attention_mask=None,
110
+ head_mask=None,
111
+ inputs_embeds=audio_embs,
112
+ output_attentions=None,
113
+ output_hidden_states=None,
114
+ return_dict=True)
115
+
116
+ input_ids = torch.zeros((encoder_outputs['last_hidden_state'].size(0), 1)).long().to(self.device)
117
+ input_ids[:, 0] = self.bart.config.decoder_start_token_id
118
+ decoder_attention_mask = torch.ones((encoder_outputs['last_hidden_state'].size(0), 1)).long().to(self.device)
119
+ if use_nucleus_sampling:
120
+ outputs = self.bart.generate(
121
+ input_ids=None,
122
+ attention_mask=None,
123
+ decoder_input_ids=input_ids,
124
+ decoder_attention_mask=decoder_attention_mask,
125
+ encoder_outputs=encoder_outputs,
126
+ max_length=max_length,
127
+ min_length=min_length,
128
+ do_sample=True,
129
+ top_p=top_p,
130
+ num_return_sequences=1,
131
+ repetition_penalty=1.1)
132
+ else:
133
+ outputs = self.bart.generate(input_ids=None,
134
+ attention_mask=None,
135
+ decoder_input_ids=input_ids,
136
+ decoder_attention_mask=decoder_attention_mask,
137
+ encoder_outputs=encoder_outputs,
138
+ head_mask=None,
139
+ decoder_head_mask=None,
140
+ inputs_embeds=None,
141
+ decoder_inputs_embeds=None,
142
+ use_cache=None,
143
+ output_attentions=None,
144
+ output_hidden_states=None,
145
+ max_length=max_length,
146
+ min_length=min_length,
147
+ num_beams=num_beams,
148
+ repetition_penalty=repetition_penalty)
149
+
150
+ captions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
151
+ return captions
model/modules.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### code reference: https://github.com/openai/whisper/blob/main/whisper/audio.py
2
+
3
+ import os
4
+ import torch
5
+ import torchaudio
6
+ import numpy as np
7
+ import torch.nn.functional as F
8
+ from torch import Tensor, nn
9
+ from typing import Dict, Iterable, Optional
10
+
11
+ # hard-coded audio hyperparameters
12
+ SAMPLE_RATE = 16000
13
+ N_FFT = 1024
14
+ N_MELS = 128
15
+ HOP_LENGTH = int(0.01 * SAMPLE_RATE)
16
+ DURATION = 10
17
+ N_SAMPLES = int(DURATION * SAMPLE_RATE)
18
+ N_FRAMES = N_SAMPLES // HOP_LENGTH + 1
19
+
20
+ def sinusoids(length, channels, max_timescale=10000):
21
+ """Returns sinusoids for positional embedding"""
22
+ log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
23
+ inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
24
+ scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
25
+ return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
26
+
27
+ class MelEncoder(nn.Module):
28
+ """
29
+ time-frequency represntation
30
+ """
31
+ def __init__(self,
32
+ sample_rate= 16000,
33
+ f_min=0,
34
+ f_max=8000,
35
+ n_fft=1024,
36
+ win_length=1024,
37
+ hop_length = int(0.01 * 16000),
38
+ n_mels = 128,
39
+ power = None,
40
+ pad= 0,
41
+ normalized= False,
42
+ center= True,
43
+ pad_mode= "reflect"
44
+ ):
45
+ super(MelEncoder, self).__init__()
46
+ self.window = torch.hann_window(win_length)
47
+ self.spec_fn = torchaudio.transforms.Spectrogram(
48
+ n_fft = n_fft,
49
+ win_length = win_length,
50
+ hop_length = hop_length,
51
+ power = power
52
+ )
53
+ self.mel_scale = torchaudio.transforms.MelScale(
54
+ n_mels,
55
+ sample_rate,
56
+ f_min,
57
+ f_max,
58
+ n_fft // 2 + 1)
59
+
60
+ self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()
61
+
62
+ def forward(self, wav):
63
+ spec = self.spec_fn(wav)
64
+ power_spec = spec.real.abs().pow(2)
65
+ mel_spec = self.mel_scale(power_spec)
66
+ mel_spec = self.amplitude_to_db(mel_spec) # Log10(max(reference value and amin))
67
+ return mel_spec
68
+
69
+ class AudioEncoder(nn.Module):
70
+ def __init__(
71
+ self, n_mels: int, n_ctx: int, audio_dim: int, text_dim: int, num_of_stride_conv: int,
72
+ ):
73
+ super().__init__()
74
+ self.mel_encoder = MelEncoder(n_mels=n_mels)
75
+ self.conv1 = nn.Conv1d(n_mels, audio_dim, kernel_size=3, padding=1)
76
+ self.conv_stack = nn.ModuleList([])
77
+ for _ in range(num_of_stride_conv):
78
+ self.conv_stack.append(
79
+ nn.Conv1d(audio_dim, audio_dim, kernel_size=3, stride=2, padding=1)
80
+ )
81
+ # self.proj = nn.Linear(audio_dim, text_dim, bias=False)
82
+ self.register_buffer("positional_embedding", sinusoids(n_ctx, text_dim))
83
+
84
+ def forward(self, x: Tensor):
85
+ """
86
+ x : torch.Tensor, shape = (batch_size, waveform)
87
+ single channel wavform
88
+ """
89
+ x = self.mel_encoder(x) # (batch_size, n_mels, n_ctx)
90
+ x = F.gelu(self.conv1(x))
91
+ for conv in self.conv_stack:
92
+ x = F.gelu(conv(x))
93
+ x = x.permute(0, 2, 1)
94
+ x = (x + self.positional_embedding).to(x.dtype)
95
+ return x
utils/audio_utils.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ STR_CLIP_ID = 'clip_id'
2
+ STR_AUDIO_SIGNAL = 'audio_signal'
3
+ STR_TARGET_VECTOR = 'target_vector'
4
+
5
+
6
+ STR_CH_FIRST = 'channels_first'
7
+ STR_CH_LAST = 'channels_last'
8
+
9
+ import io
10
+ import os
11
+ import tqdm
12
+ import logging
13
+ import subprocess
14
+ from typing import Tuple
15
+ from pathlib import Path
16
+
17
+ # import librosa
18
+ import numpy as np
19
+ import soundfile as sf
20
+
21
+ import itertools
22
+ from numpy.fft import irfft
23
+
24
+ def _resample_load_ffmpeg(path: str, sample_rate: int, downmix_to_mono: bool) -> Tuple[np.ndarray, int]:
25
+ """
26
+ Decoding, downmixing, and downsampling by librosa.
27
+ Returns a channel-first audio signal.
28
+
29
+ Args:
30
+ path:
31
+ sample_rate:
32
+ downmix_to_mono:
33
+
34
+ Returns:
35
+ (audio signal, sample rate)
36
+ """
37
+
38
+ def _decode_resample_by_ffmpeg(filename, sr):
39
+ """decode, downmix, and resample audio file"""
40
+ channel_cmd = '-ac 1 ' if downmix_to_mono else '' # downmixing option
41
+ resampling_cmd = f'-ar {str(sr)}' if sr else '' # downsampling option
42
+ cmd = f"ffmpeg -i \"{filename}\" {channel_cmd} {resampling_cmd} -f wav -"
43
+ p = subprocess.Popen(cmd, shell=True, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
44
+ out, err = p.communicate()
45
+ return out
46
+
47
+ src, sr = sf.read(io.BytesIO(_decode_resample_by_ffmpeg(path, sr=sample_rate)))
48
+ return src.T, sr
49
+
50
+
51
+ def _resample_load_librosa(path: str, sample_rate: int, downmix_to_mono: bool, **kwargs) -> Tuple[np.ndarray, int]:
52
+ """
53
+ Decoding, downmixing, and downsampling by librosa.
54
+ Returns a channel-first audio signal.
55
+ """
56
+ src, sr = librosa.load(path, sr=sample_rate, mono=downmix_to_mono, **kwargs)
57
+ return src, sr
58
+
59
+
60
+ def load_audio(
61
+ path: str or Path,
62
+ ch_format: str,
63
+ sample_rate: int = None,
64
+ downmix_to_mono: bool = False,
65
+ resample_by: str = 'ffmpeg',
66
+ **kwargs,
67
+ ) -> Tuple[np.ndarray, int]:
68
+ """A wrapper of librosa.load that:
69
+ - forces the returned audio to be 2-dim,
70
+ - defaults to sr=None, and
71
+ - defaults to downmix_to_mono=False.
72
+
73
+ The audio decoding is done by `audioread` or `soundfile` package and ultimately, often by ffmpeg.
74
+ The resampling is done by `librosa`'s child package `resampy`.
75
+
76
+ Args:
77
+ path: audio file path
78
+ ch_format: one of 'channels_first' or 'channels_last'
79
+ sample_rate: target sampling rate. if None, use the rate of the audio file
80
+ downmix_to_mono:
81
+ resample_by (str): 'librosa' or 'ffmpeg'. it decides backend for audio decoding and resampling.
82
+ **kwargs: keyword args for librosa.load - offset, duration, dtype, res_type.
83
+
84
+ Returns:
85
+ (audio, sr) tuple
86
+ """
87
+ if ch_format not in (STR_CH_FIRST, STR_CH_LAST):
88
+ raise ValueError(f'ch_format is wrong here -> {ch_format}')
89
+
90
+ if os.stat(path).st_size > 8000:
91
+ if resample_by == 'librosa':
92
+ src, sr = _resample_load_librosa(path, sample_rate, downmix_to_mono, **kwargs)
93
+ elif resample_by == 'ffmpeg':
94
+ src, sr = _resample_load_ffmpeg(path, sample_rate, downmix_to_mono)
95
+ else:
96
+ raise NotImplementedError(f'resample_by: "{resample_by}" is not supposred yet')
97
+ else:
98
+ raise ValueError('Given audio is too short!')
99
+ return src, sr
100
+
101
+ # if src.ndim == 1:
102
+ # src = np.expand_dims(src, axis=0)
103
+ # # now always 2d and channels_first
104
+
105
+ # if ch_format == STR_CH_FIRST:
106
+ # return src, sr
107
+ # else:
108
+ # return src.T, sr
109
+
110
+ def ms(x):
111
+ """Mean value of signal `x` squared.
112
+ :param x: Dynamic quantity.
113
+ :returns: Mean squared of `x`.
114
+ """
115
+ return (np.abs(x)**2.0).mean()
116
+
117
+ def normalize(y, x=None):
118
+ """normalize power in y to a (standard normal) white noise signal.
119
+ Optionally normalize to power in signal `x`.
120
+ #The mean power of a Gaussian with :math:`\\mu=0` and :math:`\\sigma=1` is 1.
121
+ """
122
+ if x is not None:
123
+ x = ms(x)
124
+ else:
125
+ x = 1.0
126
+ return y * np.sqrt(x / ms(y))
127
+
128
+ def noise(N, color='white', state=None):
129
+ """Noise generator.
130
+ :param N: Amount of samples.
131
+ :param color: Color of noise.
132
+ :param state: State of PRNG.
133
+ :type state: :class:`np.random.RandomState`
134
+ """
135
+ try:
136
+ return _noise_generators[color](N, state)
137
+ except KeyError:
138
+ raise ValueError("Incorrect color.")
139
+
140
+ def white(N, state=None):
141
+ """
142
+ White noise.
143
+ :param N: Amount of samples.
144
+ :param state: State of PRNG.
145
+ :type state: :class:`np.random.RandomState`
146
+ White noise has a constant power density. It's narrowband spectrum is therefore flat.
147
+ The power in white noise will increase by a factor of two for each octave band,
148
+ and therefore increases with 3 dB per octave.
149
+ """
150
+ state = np.random.RandomState() if state is None else state
151
+ return state.randn(N)
152
+
153
+ def pink(N, state=None):
154
+ """
155
+ Pink noise.
156
+ :param N: Amount of samples.
157
+ :param state: State of PRNG.
158
+ :type state: :class:`np.random.RandomState`
159
+ Pink noise has equal power in bands that are proportionally wide.
160
+ Power density decreases with 3 dB per octave.
161
+ """
162
+ state = np.random.RandomState() if state is None else state
163
+ uneven = N % 2
164
+ X = state.randn(N // 2 + 1 + uneven) + 1j * state.randn(N // 2 + 1 + uneven)
165
+ S = np.sqrt(np.arange(len(X)) + 1.) # +1 to avoid divide by zero
166
+ y = (irfft(X / S)).real
167
+ if uneven:
168
+ y = y[:-1]
169
+ return normalize(y)
170
+
171
+ def blue(N, state=None):
172
+ """
173
+ Blue noise.
174
+ :param N: Amount of samples.
175
+ :param state: State of PRNG.
176
+ :type state: :class:`np.random.RandomState`
177
+ Power increases with 6 dB per octave.
178
+ Power density increases with 3 dB per octave.
179
+ """
180
+ state = np.random.RandomState() if state is None else state
181
+ uneven = N % 2
182
+ X = state.randn(N // 2 + 1 + uneven) + 1j * state.randn(N // 2 + 1 + uneven)
183
+ S = np.sqrt(np.arange(len(X))) # Filter
184
+ y = (irfft(X * S)).real
185
+ if uneven:
186
+ y = y[:-1]
187
+ return normalize(y)
188
+
189
+ def brown(N, state=None):
190
+ """
191
+ Violet noise.
192
+ :param N: Amount of samples.
193
+ :param state: State of PRNG.
194
+ :type state: :class:`np.random.RandomState`
195
+ Power decreases with -3 dB per octave.
196
+ Power density decreases with 6 dB per octave.
197
+ """
198
+ state = np.random.RandomState() if state is None else state
199
+ uneven = N % 2
200
+ X = state.randn(N // 2 + 1 + uneven) + 1j * state.randn(N // 2 + 1 + uneven)
201
+ S = (np.arange(len(X)) + 1) # Filter
202
+ y = (irfft(X / S)).real
203
+ if uneven:
204
+ y = y[:-1]
205
+ return normalize(y)
206
+
207
+ def violet(N, state=None):
208
+ """
209
+ Violet noise. Power increases with 6 dB per octave.
210
+ :param N: Amount of samples.
211
+ :param state: State of PRNG.
212
+ :type state: :class:`np.random.RandomState`
213
+ Power increases with +9 dB per octave.
214
+ Power density increases with +6 dB per octave.
215
+ """
216
+ state = np.random.RandomState() if state is None else state
217
+ uneven = N % 2
218
+ X = state.randn(N // 2 + 1 + uneven) + 1j * state.randn(N // 2 + 1 + uneven)
219
+ S = (np.arange(len(X))) # Filter
220
+ y = (irfft(X * S)).real
221
+ if uneven:
222
+ y = y[:-1]
223
+ return normalize(y)
224
+
225
+ _noise_generators = {
226
+ 'white': white,
227
+ 'pink': pink,
228
+ 'blue': blue,
229
+ 'brown': brown,
230
+ 'violet': violet,
231
+ }
232
+
233
+ def noise_generator(N=44100, color='white', state=None):
234
+ """Noise generator.
235
+ :param N: Amount of unique samples to generate.
236
+ :param color: Color of noise.
237
+ Generate `N` amount of unique samples and cycle over these samples.
238
+ """
239
+ #yield from itertools.cycle(noise(N, color)) # Python 3.3
240
+ for sample in itertools.cycle(noise(N, color, state)):
241
+ yield sample
242
+
243
+ def heaviside(N):
244
+ """Heaviside.
245
+ Returns the value 0 for `x < 0`, 1 for `x > 0`, and 1/2 for `x = 0`.
246
+ """
247
+ return 0.5 * (np.sign(N) + 1)