Spaces:
Runtime error
Runtime error
from beat_this.inference import File2Beats | |
import torchaudio | |
import torch | |
from pathlib import Path | |
import numpy as np | |
from collections import Counter | |
import os | |
import argparse | |
from tqdm import tqdm | |
import concurrent.futures | |
def get_segments_from_wav(wav_path, device="cuda"): | |
"""์ค๋์ค ํ์ผ์์ ๋นํธ์ ๋ค์ด๋นํธ๋ฅผ ์ถ์ถํฉ๋๋ค.""" | |
#try: | |
file2beats = File2Beats(checkpoint_path="final0", device="cuda", dbn=False) | |
all_models = ["final0", "final1", "final2", "small0", "small1", "small2","single_final0", "single_final1", "single_final2"] | |
beats, downbeats = file2beats(wav_path) | |
if len(downbeats)==0: # downbeats๋ฅผ ๊ทธ๋ฅ 0 2 4..๋ก ๋ฃ์ด์ฃผ์. ์์ ๊ธธ์ด์ ๋ง๊ฒ | |
waveform, sample_rate = torchaudio.load(wav_path) | |
duration = waveform.size(1) / sample_rate | |
downbeats = np.arange(0, duration, 2) | |
return beats, downbeats | |
#except Exception as e: | |
# print(f"Error extracting beats from {wav_path}: {str(e)}") | |
# return None, None | |
def find_optimal_segment_length(downbeats, round_decimal=1, bar_length = 4): | |
"""๋ค์ด๋นํธ ๊ฐ๊ฒฉ๋ค์ ๋ถํฌ๋ฅผ ๋ถ์ํ์ฌ ์ต์ ์ 4๋ง๋ ๊ธธ์ด์ ์ ์ ๋ ๋ค์ด๋นํธ ์์น๋ค์ ๋ฐํํฉ๋๋ค.""" | |
if len(downbeats) < 2: | |
return 10.0, downbeats # ๊ธฐ๋ณธ 10์ด ๊ธธ์ด ๋ฐํ | |
# ์ฐ์๋ downbeat ๊ฐ์ ๊ฐ๊ฒฉ ๊ณ์ฐ | |
intervals = np.diff(downbeats) | |
rounded_intervals = np.round(intervals, round_decimal) | |
# ๊ฐ์ฅ ํํ ๊ฐ๊ฒฉ ์ฐพ๊ธฐ (1๋ง๋ ๊ธธ์ด) | |
interval_counter = Counter(rounded_intervals) | |
most_common_interval = interval_counter.most_common(1)[0][0] | |
# ์ ์ ๋ downbeat ์์น ์ฐพ๊ธฐ | |
cleaned_downbeats = [downbeats[0]] # ์ฒซ ๋ฒ์งธ ์์น๋ ํญ์ ํฌํจ | |
for i in range(1, len(downbeats)): | |
interval = rounded_intervals[i-1] | |
# ํ์ฌ ๊ฐ๊ฒฉ์ด ๊ฐ์ฅ ํํ ๊ฐ๊ฒฉ๊ณผ ๋น์ทํ์ง ํ์ธ (10% ์ค์ฐจ ํ์ฉ) | |
if abs(interval - most_common_interval) <= most_common_interval * 0.1: | |
cleaned_downbeats.append(downbeats[i]) | |
return float(most_common_interval * bar_length), np.array(cleaned_downbeats) | |
def process_audio_file(audio_file, output_dir, temp_dir, device="cuda"): | |
"""๋จ์ผ ์ค๋์ค ํ์ผ์ ์ฒ๋ฆฌํ๊ณ ์ธ๊ทธ๋จผํธ๋ฅผ ์ถ์ถํฉ๋๋ค.""" | |
try: | |
output_dir = Path(output_dir) # output_dir์ Path ๊ฐ์ฒด๋ก ๋ณํ | |
beats, downbeats = get_segments_from_wav(str(audio_file), device=device) | |
for bar_length in [1,2,3]: | |
# ๋ฌธ์์ด๋ก ๋ณํ ํ "segments_wav"๋ฅผ "segments_wav_์ซ์"๋ก ๋์ฒด | |
dir_str = str(output_dir) | |
if "segments_wav" in dir_str: | |
new_dir_str = dir_str.replace("segments_wav", f"segments_wav_{bar_length}") | |
base_dir = Path(new_dir_str) | |
else: | |
# segments_wav๊ฐ ์๋ ๊ฒฝ์ฐ ์ฒ๋ฆฌ | |
base_dir = output_dir.parent / f"{output_dir.name}_{bar_length}" | |
file_seg_dir = base_dir / audio_file.stem | |
file_seg_dir.mkdir(exist_ok=True, parents=True) | |
# ๋นํธ ์ ๋ณด ์ถ์ถ | |
if beats is None or downbeats is None or len(downbeats) == 0: | |
print(f"No beat information extracted for {audio_file.name}, skipping...") | |
return 0 | |
# ์ต์ ์ ์ธ๊ทธ๋จผํธ ๊ธธ์ด์ ์ ์ ๋ ๋ค์ด๋นํธ ์ฐพ๊ธฐ | |
optimal_length, cleaned_downbeats = find_optimal_segment_length(downbeats, bar_length=bar_length) | |
# ์ค๋์ค ๋ก๋ | |
waveform, sample_rate = torchaudio.load(str(audio_file)) | |
if waveform.size(0) > 1: | |
waveform = torch.mean(waveform, dim=0, keepdim=True) | |
total_duration = waveform.size(1) / sample_rate | |
segments_count = 0 | |
# ๊ฐ ๋ค์ด๋นํธ์์ ์์ํ๋ ์ธ๊ทธ๋จผํธ ์์ฑ | |
for i, start_time in enumerate(cleaned_downbeats): | |
end_time = start_time + optimal_length | |
# ๋ง์ง๋ง ์ธ๊ทธ๋จผํธ๊ฐ ํ์ผ ๊ธธ์ด๋ฅผ ์ด๊ณผํ๋ฉด ๊ฑด๋๋ฐ๊ธฐ | |
if end_time > total_duration: | |
continue | |
start_sample = int(start_time * sample_rate) | |
end_sample = int(end_time * sample_rate) | |
# ์ธ๊ทธ๋จผํธ ์ถ์ถ ๋ฐ ์ ์ฅ | |
segment = waveform[:, start_sample:end_sample] | |
save_path = file_seg_dir / f"segment_{i}.wav" | |
torchaudio.save(str(save_path), segment, sample_rate) | |
segments_count += 1 | |
# ์์ ๋นํธ ์ ๋ณด ์ ์ฅ (ํ์์) | |
if temp_dir: | |
segments_data = {'beat': beats, 'downbeat': downbeats} | |
temp_path = temp_dir / f"{audio_file.stem}_segments.npy" | |
np.save(str(temp_path), segments_data) | |
return segments_count | |
except Exception as e: | |
print(f"Error processing {audio_file.name}: {str(e)}") | |
return 0 | |
def segment_dataset(base_dir, output_base_dir, temp_dir=None, num_workers=4, device="cuda"): | |
"""ISMIR2025 ๋ฐ์ดํฐ์ ์ full_length ํด๋์์ ์ธ๊ทธ๋จผํธ๋ฅผ ์ถ์ถํฉ๋๋ค.""" | |
base_path = Path(base_dir) | |
output_base_path = Path(output_base_dir) | |
# ์ฒ๋ฆฌ ํต๊ณ | |
stats = { | |
"processed_files": 0, | |
"extracted_segments": 0, | |
"failed_files": 0 | |
} | |
# ์์ ๋๋ ํ ๋ฆฌ ์์ฑ (๋นํธ ์ ๋ณด ์ ์ฅ์ฉ) | |
if temp_dir: | |
temp_dir = Path(temp_dir) | |
temp_dir.mkdir(exist_ok=True) | |
# Real๊ณผ Fake ์ค๋์ค ๋ชจ๋ ์ฒ๋ฆฌ | |
for label in ["real", "fake"]: | |
for split in ["train", "valid", "test"]: | |
input_dir = base_path / label / split | |
output_dir = output_base_path / label / split | |
if not input_dir.exists(): | |
print(f"Directory not found: {input_dir}") | |
continue | |
print(f"Processing {label}/{split} files...") | |
audio_files = list(input_dir.glob("*.wav")) + list(input_dir.glob("*.mp3")) | |
if not audio_files: | |
print(f"No audio files found in {input_dir}") | |
continue | |
# ๋ณ๋ ฌ ์ฒ๋ฆฌ ์ค์ | |
if num_workers > 1: | |
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor: | |
future_to_file = { | |
executor.submit(process_audio_file, file, output_dir, temp_dir, device): file | |
for file in audio_files | |
} | |
for future in tqdm(concurrent.futures.as_completed(future_to_file), total=len(audio_files)): | |
file = future_to_file[future] | |
try: | |
segments_count = future.result() | |
if segments_count > 0: | |
stats["processed_files"] += 1 | |
stats["extracted_segments"] += segments_count | |
else: | |
stats["failed_files"] += 1 | |
except Exception as e: | |
print(f"Error processing {file.name}: {str(e)}") | |
stats["failed_files"] += 1 | |
else: | |
# ์ง๋ ฌ ์ฒ๋ฆฌ | |
for file in tqdm(audio_files): | |
segments_count = process_audio_file(file, output_dir, temp_dir, device) | |
if segments_count > 0: | |
stats["processed_files"] += 1 | |
stats["extracted_segments"] += segments_count | |
else: | |
stats["failed_files"] += 1 | |
# ์ต์ข ํต๊ณ ๋ณด๊ณ | |
print("\n=== Segmentation Summary ===") | |
print(f"Successfully processed files: {stats['processed_files']}") | |
print(f"Failed files: {stats['failed_files']}") | |
print(f"Total extracted segments: {stats['extracted_segments']}") | |
print(f"Average segments per file: {stats['extracted_segments'] / max(1, stats['processed_files']):.2f}") | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Extract segments from audio files in ISMIR2025 dataset") | |
parser.add_argument("--input", type=str, default="/data/datasets/ISMIR2025/full_length_audio", | |
help="Input directory with full_length audio files") | |
parser.add_argument("--output", type=str, default="/data/datasets/ISMIR2025/segments_wav", | |
help="Output directory for segments") | |
parser.add_argument("--temp", type=str, default=None, | |
help="Temporary directory for beat information (optional)") | |
parser.add_argument("--workers", type=int, default=4, | |
help="Number of parallel workers") | |
parser.add_argument("--device", type=str, default="cuda", | |
help="Device for beat extraction (cuda or cpu)") | |
args = parser.parse_args() | |
# ๋๋ ํ ๋ฆฌ ์ ํจ์ฑ ๊ฒ์ฌ | |
input_path = Path(args.input) | |
if not input_path.exists(): | |
print(f"Input directory not found: {args.input}") | |
# ๋ค๋ฅธ ๊ฐ๋ฅํ ์์น ํ์ธ | |
alternatives = [ | |
"/data/datasets/ISMIR2025/full_length", | |
"/data/ISMIR2025/full_length_audio", | |
"/data/ISMIR2025/full_length" | |
] | |
for alt_path in alternatives: | |
if os.path.exists(alt_path): | |
print(f"Found alternative input path: {alt_path}") | |
args.input = alt_path | |
break | |
else: | |
print("No valid input directory found.") | |
exit(1) | |
# ์ธ๊ทธ๋จผํธ ์ถ์ถ ์คํ | |
segment_dataset( | |
base_dir=args.input, | |
output_base_dir=args.output, | |
temp_dir=args.temp, | |
num_workers=args.workers, | |
device=args.device | |
) |