AnhP's picture
Upload 82 files
e4d8df5 verified
raw
history blame
10.4 kB
import os
import torch
import librosa
import functools
import scipy.stats
import numpy as np
CENTS_PER_BIN, MAX_FMAX, PITCH_BINS, SAMPLE_RATE, WINDOW_SIZE = 20, 2006, 360, 16000, 1024
class Crepe(torch.nn.Module):
def __init__(self, model='full'):
super().__init__()
if model == 'full':
in_channels = [1, 1024, 128, 128, 128, 256]
out_channels = [1024, 128, 128, 128, 256, 512]
self.in_features = 2048
elif model == 'large':
in_channels = [1, 768, 96, 96, 96, 192]
out_channels = [768, 96, 96, 96, 192, 384]
self.in_features = 1536
elif model == 'medium':
in_channels = [1, 512, 64, 64, 64, 128]
out_channels = [512, 64, 64, 64, 128, 256]
self.in_features = 1024
elif model == 'small':
in_channels = [1, 256, 32, 32, 32, 64]
out_channels = [256, 32, 32, 32, 64, 128]
self.in_features = 512
elif model == 'tiny':
in_channels = [1, 128, 16, 16, 16, 32]
out_channels = [128, 16, 16, 16, 32, 64]
self.in_features = 256
kernel_sizes = [(512, 1)] + 5 * [(64, 1)]
strides = [(4, 1)] + 5 * [(1, 1)]
batch_norm_fn = functools.partial(torch.nn.BatchNorm2d, eps=0.0010000000474974513, momentum=0.0)
self.conv1 = torch.nn.Conv2d(in_channels=in_channels[0], out_channels=out_channels[0], kernel_size=kernel_sizes[0], stride=strides[0])
self.conv1_BN = batch_norm_fn(num_features=out_channels[0])
self.conv2 = torch.nn.Conv2d(in_channels=in_channels[1], out_channels=out_channels[1], kernel_size=kernel_sizes[1], stride=strides[1])
self.conv2_BN = batch_norm_fn(num_features=out_channels[1])
self.conv3 = torch.nn.Conv2d(in_channels=in_channels[2], out_channels=out_channels[2], kernel_size=kernel_sizes[2], stride=strides[2])
self.conv3_BN = batch_norm_fn(num_features=out_channels[2])
self.conv4 = torch.nn.Conv2d(in_channels=in_channels[3], out_channels=out_channels[3], kernel_size=kernel_sizes[3], stride=strides[3])
self.conv4_BN = batch_norm_fn(num_features=out_channels[3])
self.conv5 = torch.nn.Conv2d(in_channels=in_channels[4], out_channels=out_channels[4], kernel_size=kernel_sizes[4], stride=strides[4])
self.conv5_BN = batch_norm_fn(num_features=out_channels[4])
self.conv6 = torch.nn.Conv2d(in_channels=in_channels[5], out_channels=out_channels[5], kernel_size=kernel_sizes[5], stride=strides[5])
self.conv6_BN = batch_norm_fn(num_features=out_channels[5])
self.classifier = torch.nn.Linear(in_features=self.in_features, out_features=PITCH_BINS)
def forward(self, x, embed=False):
x = self.embed(x)
if embed: return x
return torch.sigmoid(self.classifier(self.layer(x, self.conv6, self.conv6_BN).permute(0, 2, 1, 3).reshape(-1, self.in_features)))
def embed(self, x):
x = x[:, None, :, None]
return self.layer(self.layer(self.layer(self.layer(self.layer(x, self.conv1, self.conv1_BN, (0, 0, 254, 254)), self.conv2, self.conv2_BN), self.conv3, self.conv3_BN), self.conv4, self.conv4_BN), self.conv5, self.conv5_BN)
def layer(self, x, conv, batch_norm, padding=(0, 0, 31, 32)):
return torch.nn.functional.max_pool2d(batch_norm(torch.nn.functional.relu(conv(torch.nn.functional.pad(x, padding)))), (2, 1), (2, 1))
def viterbi(logits):
if not hasattr(viterbi, 'transition'):
xx, yy = np.meshgrid(range(360), range(360))
transition = np.maximum(12 - abs(xx - yy), 0)
viterbi.transition = transition / transition.sum(axis=1, keepdims=True)
with torch.no_grad():
probs = torch.nn.functional.softmax(logits, dim=1)
bins = torch.tensor(np.array([librosa.sequence.viterbi(sequence, viterbi.transition).astype(np.int64) for sequence in probs.cpu().numpy()]), device=probs.device)
return bins, bins_to_frequency(bins)
def predict(audio, sample_rate, hop_length=None, fmin=50, fmax=MAX_FMAX, model='full', return_periodicity=False, batch_size=None, device='cpu', pad=True, providers=None, onnx=False):
results = []
if onnx:
import onnxruntime as ort
sess_options = ort.SessionOptions()
sess_options.log_severity_level = 3
session = ort.InferenceSession(os.path.join("assets", "models", "predictors", f"crepe_{model}.onnx"), sess_options=sess_options, providers=providers)
for frames in preprocess(audio, sample_rate, hop_length, batch_size, device, pad):
result = postprocess(torch.tensor(session.run([session.get_outputs()[0].name], {session.get_inputs()[0].name: frames.cpu().numpy()})[0].transpose(1, 0)[None]), fmin, fmax, return_periodicity)
results.append((result[0], result[1]) if isinstance(result, tuple) else result)
del session
if return_periodicity:
pitch, periodicity = zip(*results)
return torch.cat(pitch, 1), torch.cat(periodicity, 1)
return torch.cat(results, 1)
else:
with torch.no_grad():
for frames in preprocess(audio, sample_rate, hop_length, batch_size, device, pad):
result = postprocess(infer(frames, model, device, embed=False).reshape(audio.size(0), -1, PITCH_BINS).transpose(1, 2), fmin, fmax, return_periodicity)
results.append((result[0].to(audio.device), result[1].to(audio.device)) if isinstance(result, tuple) else result.to(audio.device))
if return_periodicity:
pitch, periodicity = zip(*results)
return torch.cat(pitch, 1), torch.cat(periodicity, 1)
return torch.cat(results, 1)
def bins_to_frequency(bins):
cents = CENTS_PER_BIN * bins + 1997.3794084376191
return 10 * 2 ** ((cents + cents.new_tensor(scipy.stats.triang.rvs(c=0.5, loc=-CENTS_PER_BIN, scale=2 * CENTS_PER_BIN, size=cents.size()))) / 1200)
def frequency_to_bins(frequency, quantize_fn=torch.floor):
return quantize_fn(((1200 * torch.log2(frequency / 10)) - 1997.3794084376191) / CENTS_PER_BIN).int()
def infer(frames, model='full', device='cpu', embed=False):
if not hasattr(infer, 'model') or not hasattr(infer, 'capacity') or (hasattr(infer, 'capacity') and infer.capacity != model): load_model(device, model)
infer.model = infer.model.to(device)
return infer.model(frames, embed=embed)
def load_model(device, capacity='full'):
infer.capacity = capacity
infer.model = Crepe(capacity)
infer.model.load_state_dict(torch.load(os.path.join("assets", "models", "predictors", f"crepe_{capacity}.pth"), map_location=device))
infer.model = infer.model.to(torch.device(device))
infer.model.eval()
def postprocess(probabilities, fmin=0, fmax=MAX_FMAX, return_periodicity=False):
probabilities = probabilities.detach()
probabilities[:, :frequency_to_bins(torch.tensor(fmin))] = -float('inf')
probabilities[:, frequency_to_bins(torch.tensor(fmax), torch.ceil):] = -float('inf')
bins, pitch = viterbi(probabilities)
if not return_periodicity: return pitch
return pitch, periodicity(probabilities, bins)
def preprocess(audio, sample_rate, hop_length=None, batch_size=None, device='cpu', pad=True):
hop_length = sample_rate // 100 if hop_length is None else hop_length
if sample_rate != SAMPLE_RATE:
audio = torch.tensor(librosa.resample(audio.detach().cpu().numpy().squeeze(0), orig_sr=sample_rate, target_sr=SAMPLE_RATE, res_type="soxr_vhq"), device=audio.device).unsqueeze(0)
hop_length = int(hop_length * SAMPLE_RATE / sample_rate)
if pad:
total_frames = 1 + int(audio.size(1) // hop_length)
audio = torch.nn.functional.pad(audio, (WINDOW_SIZE // 2, WINDOW_SIZE // 2))
else: total_frames = 1 + int((audio.size(1) - WINDOW_SIZE) // hop_length)
batch_size = total_frames if batch_size is None else batch_size
for i in range(0, total_frames, batch_size):
frames = torch.nn.functional.unfold(audio[:, None, None, max(0, i * hop_length):min(audio.size(1), (i + batch_size - 1) * hop_length + WINDOW_SIZE)], kernel_size=(1, WINDOW_SIZE), stride=(1, hop_length))
frames = frames.transpose(1, 2).reshape(-1, WINDOW_SIZE).to(device)
frames -= frames.mean(dim=1, keepdim=True)
frames /= torch.max(torch.tensor(1e-10, device=frames.device), frames.std(dim=1, keepdim=True))
yield frames
def periodicity(probabilities, bins):
probs_stacked = probabilities.transpose(1, 2).reshape(-1, PITCH_BINS)
periodicity = probs_stacked.gather(1, bins.reshape(-1, 1).to(torch.int64))
return periodicity.reshape(probabilities.size(0), probabilities.size(2))
def mean(signals, win_length=9):
assert signals.dim() == 2
signals = signals.unsqueeze(1)
mask = ~torch.isnan(signals)
padding = win_length // 2
ones_kernel = torch.ones(signals.size(1), 1, win_length, device=signals.device)
avg_pooled = torch.nn.functional.conv1d(torch.where(mask, signals, torch.zeros_like(signals)), ones_kernel, stride=1, padding=padding) / torch.nn.functional.conv1d(mask.float(), ones_kernel, stride=1, padding=padding).clamp(min=1)
avg_pooled[avg_pooled == 0] = float("nan")
return avg_pooled.squeeze(1)
def median(signals, win_length):
assert signals.dim() == 2
signals = signals.unsqueeze(1)
mask = ~torch.isnan(signals)
padding = win_length // 2
x = torch.nn.functional.pad(torch.where(mask, signals, torch.zeros_like(signals)), (padding, padding), mode="reflect")
mask = torch.nn.functional.pad(mask.float(), (padding, padding), mode="constant", value=0)
x = x.unfold(2, win_length, 1)
mask = mask.unfold(2, win_length, 1)
x = x.contiguous().view(x.size()[:3] + (-1,))
mask = mask.contiguous().view(mask.size()[:3] + (-1,))
x_sorted, _ = torch.sort(torch.where(mask.bool(), x.float(), float("inf")).to(x), dim=-1)
median_pooled = x_sorted.gather(-1, ((mask.sum(dim=-1) - 1) // 2).clamp(min=0).unsqueeze(-1).long()).squeeze(-1)
median_pooled[torch.isinf(median_pooled)] = float("nan")
return median_pooled.squeeze(1)