from beat_this.inference import File2Beats import torchaudio import torch from pathlib import Path import numpy as np from collections import Counter import os import argparse from tqdm import tqdm import concurrent.futures def get_segments_from_wav(wav_path, device="cuda"): """오디오 파일에서 비트와 다운비트를 추출합니다.""" #try: file2beats = File2Beats(checkpoint_path="final0", device="cuda", dbn=False) all_models = ["final0", "final1", "final2", "small0", "small1", "small2","single_final0", "single_final1", "single_final2"] beats, downbeats = file2beats(wav_path) if len(downbeats)==0: # downbeats를 그냥 0 2 4..로 넣어주자. 음악 길이에 맞게 waveform, sample_rate = torchaudio.load(wav_path) duration = waveform.size(1) / sample_rate downbeats = np.arange(0, duration, 2) return beats, downbeats #except Exception as e: # print(f"Error extracting beats from {wav_path}: {str(e)}") # return None, None def find_optimal_segment_length(downbeats, round_decimal=1, bar_length = 4): """다운비트 간격들의 분포를 분석하여 최적의 4마디 길이와 정제된 다운비트 위치들을 반환합니다.""" if len(downbeats) < 2: return 10.0, downbeats # 기본 10초 길이 반환 # 연속된 downbeat 간의 간격 계산 intervals = np.diff(downbeats) rounded_intervals = np.round(intervals, round_decimal) # 가장 흔한 간격 찾기 (1마디 길이) interval_counter = Counter(rounded_intervals) most_common_interval = interval_counter.most_common(1)[0][0] # 정제된 downbeat 위치 찾기 cleaned_downbeats = [downbeats[0]] # 첫 번째 위치는 항상 포함 for i in range(1, len(downbeats)): interval = rounded_intervals[i-1] # 현재 간격이 가장 흔한 간격과 비슷한지 확인 (10% 오차 허용) if abs(interval - most_common_interval) <= most_common_interval * 0.1: cleaned_downbeats.append(downbeats[i]) return float(most_common_interval * bar_length), np.array(cleaned_downbeats) def process_audio_file(audio_file, output_dir, temp_dir, device="cuda"): """단일 오디오 파일을 처리하고 세그먼트를 추출합니다.""" try: output_dir = Path(output_dir) # output_dir을 Path 객체로 변환 beats, downbeats = get_segments_from_wav(str(audio_file), device=device) for bar_length in [1,2,3]: # 문자열로 변환 후 "segments_wav"를 "segments_wav_숫자"로 대체 dir_str = str(output_dir) if "segments_wav" in dir_str: new_dir_str = dir_str.replace("segments_wav", f"segments_wav_{bar_length}") base_dir = Path(new_dir_str) else: # segments_wav가 없는 경우 처리 base_dir = output_dir.parent / f"{output_dir.name}_{bar_length}" file_seg_dir = base_dir / audio_file.stem file_seg_dir.mkdir(exist_ok=True, parents=True) # 비트 정보 추출 if beats is None or downbeats is None or len(downbeats) == 0: print(f"No beat information extracted for {audio_file.name}, skipping...") return 0 # 최적의 세그먼트 길이와 정제된 다운비트 찾기 optimal_length, cleaned_downbeats = find_optimal_segment_length(downbeats, bar_length=bar_length) # 오디오 로드 waveform, sample_rate = torchaudio.load(str(audio_file)) if waveform.size(0) > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) total_duration = waveform.size(1) / sample_rate segments_count = 0 # 각 다운비트에서 시작하는 세그먼트 생성 for i, start_time in enumerate(cleaned_downbeats): end_time = start_time + optimal_length # 마지막 세그먼트가 파일 길이를 초과하면 건너뛰기 if end_time > total_duration: continue start_sample = int(start_time * sample_rate) end_sample = int(end_time * sample_rate) # 세그먼트 추출 및 저장 segment = waveform[:, start_sample:end_sample] save_path = file_seg_dir / f"segment_{i}.wav" torchaudio.save(str(save_path), segment, sample_rate) segments_count += 1 # 임시 비트 정보 저장 (필요시) if temp_dir: segments_data = {'beat': beats, 'downbeat': downbeats} temp_path = temp_dir / f"{audio_file.stem}_segments.npy" np.save(str(temp_path), segments_data) return segments_count except Exception as e: print(f"Error processing {audio_file.name}: {str(e)}") return 0 def segment_dataset(base_dir, output_base_dir, temp_dir=None, num_workers=4, device="cuda"): """ISMIR2025 데이터셋의 full_length 폴더에서 세그먼트를 추출합니다.""" base_path = Path(base_dir) output_base_path = Path(output_base_dir) # 처리 통계 stats = { "processed_files": 0, "extracted_segments": 0, "failed_files": 0 } # 임시 디렉토리 생성 (비트 정보 저장용) if temp_dir: temp_dir = Path(temp_dir) temp_dir.mkdir(exist_ok=True) # Real과 Fake 오디오 모두 처리 for label in ["real", "fake"]: for split in ["train", "valid", "test"]: input_dir = base_path / label / split output_dir = output_base_path / label / split if not input_dir.exists(): print(f"Directory not found: {input_dir}") continue print(f"Processing {label}/{split} files...") audio_files = list(input_dir.glob("*.wav")) + list(input_dir.glob("*.mp3")) if not audio_files: print(f"No audio files found in {input_dir}") continue # 병렬 처리 설정 if num_workers > 1: with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor: future_to_file = { executor.submit(process_audio_file, file, output_dir, temp_dir, device): file for file in audio_files } for future in tqdm(concurrent.futures.as_completed(future_to_file), total=len(audio_files)): file = future_to_file[future] try: segments_count = future.result() if segments_count > 0: stats["processed_files"] += 1 stats["extracted_segments"] += segments_count else: stats["failed_files"] += 1 except Exception as e: print(f"Error processing {file.name}: {str(e)}") stats["failed_files"] += 1 else: # 직렬 처리 for file in tqdm(audio_files): segments_count = process_audio_file(file, output_dir, temp_dir, device) if segments_count > 0: stats["processed_files"] += 1 stats["extracted_segments"] += segments_count else: stats["failed_files"] += 1 # 최종 통계 보고 print("\n=== Segmentation Summary ===") print(f"Successfully processed files: {stats['processed_files']}") print(f"Failed files: {stats['failed_files']}") print(f"Total extracted segments: {stats['extracted_segments']}") print(f"Average segments per file: {stats['extracted_segments'] / max(1, stats['processed_files']):.2f}") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Extract segments from audio files in ISMIR2025 dataset") parser.add_argument("--input", type=str, default="/data/datasets/ISMIR2025/full_length_audio", help="Input directory with full_length audio files") parser.add_argument("--output", type=str, default="/data/datasets/ISMIR2025/segments_wav", help="Output directory for segments") parser.add_argument("--temp", type=str, default=None, help="Temporary directory for beat information (optional)") parser.add_argument("--workers", type=int, default=4, help="Number of parallel workers") parser.add_argument("--device", type=str, default="cuda", help="Device for beat extraction (cuda or cpu)") args = parser.parse_args() # 디렉토리 유효성 검사 input_path = Path(args.input) if not input_path.exists(): print(f"Input directory not found: {args.input}") # 다른 가능한 위치 확인 alternatives = [ "/data/datasets/ISMIR2025/full_length", "/data/ISMIR2025/full_length_audio", "/data/ISMIR2025/full_length" ] for alt_path in alternatives: if os.path.exists(alt_path): print(f"Found alternative input path: {alt_path}") args.input = alt_path break else: print("No valid input directory found.") exit(1) # 세그먼트 추출 실행 segment_dataset( base_dir=args.input, output_base_dir=args.output, temp_dir=args.temp, num_workers=args.workers, device=args.device )