File size: 3,240 Bytes
1e4a2ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import os
import sys
import time
import tqdm
import torch
import librosa
import traceback
import concurrent.futures
import numpy as np
import torch.nn as nn
sys.path.append(os.getcwd())
from main.library.utils import load_audio
from main.app.variables import logger, translations
from main.inference.extracting.setup_path import setup_paths
class RMSEnergyExtractor(nn.Module):
def __init__(self, frame_length=2048, hop_length=512, center=True, pad_mode = "reflect"):
super().__init__()
self.frame_length = frame_length
self.hop_length = hop_length
self.center = center
self.pad_mode = pad_mode
def forward(self, x):
assert x.ndim == 2
assert x.shape[0] == 1
if str(x.device).startswith("ocl"): x = x.contiguous()
rms = torch.from_numpy(
librosa.feature.rms(
y=x.squeeze(0).cpu().numpy(),
frame_length=self.frame_length,
hop_length=self.hop_length,
center=self.center,
pad_mode=self.pad_mode
)
)
return rms.squeeze(-2).to(x.device) if not str(x.device).startswith("ocl") else rms.contiguous().squeeze(-2).to(x.device)
def process_file_rms(files, device, threads):
threads = max(1, threads)
module = RMSEnergyExtractor(
frame_length=2048, hop_length=160, center=True, pad_mode = "reflect"
).to(device).eval().float()
def worker(file_info):
try:
file, out_path = file_info
out_file_path = os.path.join(out_path, os.path.basename(file))
if os.path.exists(out_file_path + ".npy"): return
with torch.no_grad():
feats = torch.from_numpy(load_audio(file, 16000)).unsqueeze(0)
feats = module(feats if device.startswith("ocl") else feats.to(device))
np.save(out_file_path, feats.float().cpu().numpy(), allow_pickle=False)
except:
logger.debug(traceback.format_exc())
with tqdm.tqdm(total=len(files), ncols=100, unit="p", leave=True) as pbar:
with concurrent.futures.ThreadPoolExecutor(max_workers=threads) as executor:
for _ in concurrent.futures.as_completed([executor.submit(worker, f) for f in files]):
pbar.update(1)
def run_rms_extraction(exp_dir, num_processes, devices, rms_extract):
if rms_extract:
wav_path, out_path = setup_paths(exp_dir, rms_extract=rms_extract)
start_time = time.time()
paths = sorted([(os.path.join(wav_path, file), out_path) for file in os.listdir(wav_path) if file.endswith(".wav")])
start_time = time.time()
logger.info(translations["rms_start_extract"].format(num_processes=num_processes))
with concurrent.futures.ProcessPoolExecutor(max_workers=len(devices)) as executor:
concurrent.futures.wait([executor.submit(process_file_rms, paths[i::len(devices)], devices[i], num_processes // len(devices)) for i in range(len(devices))])
logger.info(translations["rms_success_extract"].format(elapsed_time=f"{(time.time() - start_time):.2f}")) |