|
import os
|
|
import torch
|
|
import librosa
|
|
import functools
|
|
import scipy.stats
|
|
|
|
import numpy as np
|
|
|
|
CENTS_PER_BIN, MAX_FMAX, PITCH_BINS, SAMPLE_RATE, WINDOW_SIZE = 20, 2006, 360, 16000, 1024
|
|
|
|
class Crepe(torch.nn.Module):
|
|
def __init__(self, model='full'):
|
|
super().__init__()
|
|
if model == 'full':
|
|
in_channels = [1, 1024, 128, 128, 128, 256]
|
|
out_channels = [1024, 128, 128, 128, 256, 512]
|
|
self.in_features = 2048
|
|
elif model == 'large':
|
|
in_channels = [1, 768, 96, 96, 96, 192]
|
|
out_channels = [768, 96, 96, 96, 192, 384]
|
|
self.in_features = 1536
|
|
elif model == 'medium':
|
|
in_channels = [1, 512, 64, 64, 64, 128]
|
|
out_channels = [512, 64, 64, 64, 128, 256]
|
|
self.in_features = 1024
|
|
elif model == 'small':
|
|
in_channels = [1, 256, 32, 32, 32, 64]
|
|
out_channels = [256, 32, 32, 32, 64, 128]
|
|
self.in_features = 512
|
|
elif model == 'tiny':
|
|
in_channels = [1, 128, 16, 16, 16, 32]
|
|
out_channels = [128, 16, 16, 16, 32, 64]
|
|
self.in_features = 256
|
|
|
|
kernel_sizes = [(512, 1)] + 5 * [(64, 1)]
|
|
strides = [(4, 1)] + 5 * [(1, 1)]
|
|
|
|
batch_norm_fn = functools.partial(torch.nn.BatchNorm2d, eps=0.0010000000474974513, momentum=0.0)
|
|
|
|
self.conv1 = torch.nn.Conv2d(in_channels=in_channels[0], out_channels=out_channels[0], kernel_size=kernel_sizes[0], stride=strides[0])
|
|
self.conv1_BN = batch_norm_fn(num_features=out_channels[0])
|
|
self.conv2 = torch.nn.Conv2d(in_channels=in_channels[1], out_channels=out_channels[1], kernel_size=kernel_sizes[1], stride=strides[1])
|
|
self.conv2_BN = batch_norm_fn(num_features=out_channels[1])
|
|
|
|
self.conv3 = torch.nn.Conv2d(in_channels=in_channels[2], out_channels=out_channels[2], kernel_size=kernel_sizes[2], stride=strides[2])
|
|
self.conv3_BN = batch_norm_fn(num_features=out_channels[2])
|
|
self.conv4 = torch.nn.Conv2d(in_channels=in_channels[3], out_channels=out_channels[3], kernel_size=kernel_sizes[3], stride=strides[3])
|
|
self.conv4_BN = batch_norm_fn(num_features=out_channels[3])
|
|
|
|
self.conv5 = torch.nn.Conv2d(in_channels=in_channels[4], out_channels=out_channels[4], kernel_size=kernel_sizes[4], stride=strides[4])
|
|
self.conv5_BN = batch_norm_fn(num_features=out_channels[4])
|
|
self.conv6 = torch.nn.Conv2d(in_channels=in_channels[5], out_channels=out_channels[5], kernel_size=kernel_sizes[5], stride=strides[5])
|
|
self.conv6_BN = batch_norm_fn(num_features=out_channels[5])
|
|
|
|
self.classifier = torch.nn.Linear(in_features=self.in_features, out_features=PITCH_BINS)
|
|
|
|
def forward(self, x, embed=False):
|
|
x = self.embed(x)
|
|
if embed: return x
|
|
|
|
return torch.sigmoid(self.classifier(self.layer(x, self.conv6, self.conv6_BN).permute(0, 2, 1, 3).reshape(-1, self.in_features)))
|
|
|
|
def embed(self, x):
|
|
x = x[:, None, :, None]
|
|
|
|
return self.layer(self.layer(self.layer(self.layer(self.layer(x, self.conv1, self.conv1_BN, (0, 0, 254, 254)), self.conv2, self.conv2_BN), self.conv3, self.conv3_BN), self.conv4, self.conv4_BN), self.conv5, self.conv5_BN)
|
|
|
|
def layer(self, x, conv, batch_norm, padding=(0, 0, 31, 32)):
|
|
return torch.nn.functional.max_pool2d(batch_norm(torch.nn.functional.relu(conv(torch.nn.functional.pad(x, padding)))), (2, 1), (2, 1))
|
|
|
|
def viterbi(logits):
|
|
if not hasattr(viterbi, 'transition'):
|
|
xx, yy = np.meshgrid(range(360), range(360))
|
|
transition = np.maximum(12 - abs(xx - yy), 0)
|
|
viterbi.transition = transition / transition.sum(axis=1, keepdims=True)
|
|
|
|
with torch.no_grad():
|
|
probs = torch.nn.functional.softmax(logits, dim=1)
|
|
|
|
bins = torch.tensor(np.array([librosa.sequence.viterbi(sequence, viterbi.transition).astype(np.int64) for sequence in probs.cpu().numpy()]), device=probs.device)
|
|
return bins, bins_to_frequency(bins)
|
|
|
|
def predict(audio, sample_rate, hop_length=None, fmin=50, fmax=MAX_FMAX, model='full', return_periodicity=False, batch_size=None, device='cpu', pad=True, providers=None, onnx=False):
|
|
results = []
|
|
|
|
if onnx:
|
|
import onnxruntime as ort
|
|
|
|
sess_options = ort.SessionOptions()
|
|
sess_options.log_severity_level = 3
|
|
|
|
session = ort.InferenceSession(os.path.join("assets", "models", "predictors", f"crepe_{model}.onnx"), sess_options=sess_options, providers=providers)
|
|
|
|
for frames in preprocess(audio, sample_rate, hop_length, batch_size, device, pad):
|
|
result = postprocess(torch.tensor(session.run([session.get_outputs()[0].name], {session.get_inputs()[0].name: frames.cpu().numpy()})[0].transpose(1, 0)[None]), fmin, fmax, return_periodicity)
|
|
results.append((result[0], result[1]) if isinstance(result, tuple) else result)
|
|
|
|
del session
|
|
|
|
if return_periodicity:
|
|
pitch, periodicity = zip(*results)
|
|
return torch.cat(pitch, 1), torch.cat(periodicity, 1)
|
|
|
|
return torch.cat(results, 1)
|
|
else:
|
|
with torch.no_grad():
|
|
for frames in preprocess(audio, sample_rate, hop_length, batch_size, device, pad):
|
|
result = postprocess(infer(frames, model, device, embed=False).reshape(audio.size(0), -1, PITCH_BINS).transpose(1, 2), fmin, fmax, return_periodicity)
|
|
results.append((result[0].to(audio.device), result[1].to(audio.device)) if isinstance(result, tuple) else result.to(audio.device))
|
|
|
|
if return_periodicity:
|
|
pitch, periodicity = zip(*results)
|
|
return torch.cat(pitch, 1), torch.cat(periodicity, 1)
|
|
|
|
return torch.cat(results, 1)
|
|
|
|
def bins_to_frequency(bins):
|
|
cents = CENTS_PER_BIN * bins + 1997.3794084376191
|
|
return 10 * 2 ** ((cents + cents.new_tensor(scipy.stats.triang.rvs(c=0.5, loc=-CENTS_PER_BIN, scale=2 * CENTS_PER_BIN, size=cents.size()))) / 1200)
|
|
|
|
def frequency_to_bins(frequency, quantize_fn=torch.floor):
|
|
return quantize_fn(((1200 * torch.log2(frequency / 10)) - 1997.3794084376191) / CENTS_PER_BIN).int()
|
|
|
|
def infer(frames, model='full', device='cpu', embed=False):
|
|
if not hasattr(infer, 'model') or not hasattr(infer, 'capacity') or (hasattr(infer, 'capacity') and infer.capacity != model): load_model(device, model)
|
|
infer.model = infer.model.to(device)
|
|
|
|
return infer.model(frames, embed=embed)
|
|
|
|
def load_model(device, capacity='full'):
|
|
infer.capacity = capacity
|
|
infer.model = Crepe(capacity)
|
|
infer.model.load_state_dict(torch.load(os.path.join("assets", "models", "predictors", f"crepe_{capacity}.pth"), map_location=device))
|
|
infer.model = infer.model.to(torch.device(device))
|
|
infer.model.eval()
|
|
|
|
def postprocess(probabilities, fmin=0, fmax=MAX_FMAX, return_periodicity=False):
|
|
probabilities = probabilities.detach()
|
|
|
|
probabilities[:, :frequency_to_bins(torch.tensor(fmin))] = -float('inf')
|
|
probabilities[:, frequency_to_bins(torch.tensor(fmax), torch.ceil):] = -float('inf')
|
|
|
|
bins, pitch = viterbi(probabilities)
|
|
|
|
if not return_periodicity: return pitch
|
|
return pitch, periodicity(probabilities, bins)
|
|
|
|
def preprocess(audio, sample_rate, hop_length=None, batch_size=None, device='cpu', pad=True):
|
|
hop_length = sample_rate // 100 if hop_length is None else hop_length
|
|
|
|
if sample_rate != SAMPLE_RATE:
|
|
audio = torch.tensor(librosa.resample(audio.detach().cpu().numpy().squeeze(0), orig_sr=sample_rate, target_sr=SAMPLE_RATE, res_type="soxr_vhq"), device=audio.device).unsqueeze(0)
|
|
hop_length = int(hop_length * SAMPLE_RATE / sample_rate)
|
|
|
|
if pad:
|
|
total_frames = 1 + int(audio.size(1) // hop_length)
|
|
audio = torch.nn.functional.pad(audio, (WINDOW_SIZE // 2, WINDOW_SIZE // 2))
|
|
else: total_frames = 1 + int((audio.size(1) - WINDOW_SIZE) // hop_length)
|
|
|
|
batch_size = total_frames if batch_size is None else batch_size
|
|
|
|
for i in range(0, total_frames, batch_size):
|
|
frames = torch.nn.functional.unfold(audio[:, None, None, max(0, i * hop_length):min(audio.size(1), (i + batch_size - 1) * hop_length + WINDOW_SIZE)], kernel_size=(1, WINDOW_SIZE), stride=(1, hop_length))
|
|
frames = frames.transpose(1, 2).reshape(-1, WINDOW_SIZE).to(device)
|
|
frames -= frames.mean(dim=1, keepdim=True)
|
|
frames /= torch.max(torch.tensor(1e-10, device=frames.device), frames.std(dim=1, keepdim=True))
|
|
|
|
yield frames
|
|
|
|
def periodicity(probabilities, bins):
|
|
probs_stacked = probabilities.transpose(1, 2).reshape(-1, PITCH_BINS)
|
|
periodicity = probs_stacked.gather(1, bins.reshape(-1, 1).to(torch.int64))
|
|
|
|
return periodicity.reshape(probabilities.size(0), probabilities.size(2))
|
|
|
|
def mean(signals, win_length=9):
|
|
assert signals.dim() == 2
|
|
|
|
signals = signals.unsqueeze(1)
|
|
mask = ~torch.isnan(signals)
|
|
padding = win_length // 2
|
|
|
|
ones_kernel = torch.ones(signals.size(1), 1, win_length, device=signals.device)
|
|
avg_pooled = torch.nn.functional.conv1d(torch.where(mask, signals, torch.zeros_like(signals)), ones_kernel, stride=1, padding=padding) / torch.nn.functional.conv1d(mask.float(), ones_kernel, stride=1, padding=padding).clamp(min=1)
|
|
avg_pooled[avg_pooled == 0] = float("nan")
|
|
|
|
return avg_pooled.squeeze(1)
|
|
|
|
def median(signals, win_length):
|
|
assert signals.dim() == 2
|
|
|
|
signals = signals.unsqueeze(1)
|
|
mask = ~torch.isnan(signals)
|
|
padding = win_length // 2
|
|
|
|
x = torch.nn.functional.pad(torch.where(mask, signals, torch.zeros_like(signals)), (padding, padding), mode="reflect")
|
|
mask = torch.nn.functional.pad(mask.float(), (padding, padding), mode="constant", value=0)
|
|
|
|
x = x.unfold(2, win_length, 1)
|
|
mask = mask.unfold(2, win_length, 1)
|
|
|
|
x = x.contiguous().view(x.size()[:3] + (-1,))
|
|
mask = mask.contiguous().view(mask.size()[:3] + (-1,))
|
|
|
|
x_sorted, _ = torch.sort(torch.where(mask.bool(), x.float(), float("inf")).to(x), dim=-1)
|
|
|
|
median_pooled = x_sorted.gather(-1, ((mask.sum(dim=-1) - 1) // 2).clamp(min=0).unsqueeze(-1).long()).squeeze(-1)
|
|
median_pooled[torch.isinf(median_pooled)] = float("nan")
|
|
|
|
return median_pooled.squeeze(1) |