import os from pathlib import Path import json from beat_this.inference import File2Beats import numpy as np import torch from typing import List, Tuple, Optional import pytorch_lightning as pl from model import MusicClassifier, MusicAudioClassifier import argparse import torch import torchaudio import deepspeed import scipy.signal as signal from typing import Dict, List from google.cloud import storage from dataset_f import FakeMusicCapsDataset from preprocess import get_segments_from_wav, find_optimal_segment_length #not for ismir def download_from_gcs(bucket_name, source_blob_name, destination_file_name): destination_dir = os.path.dirname(destination_file_name) if not os.path.exists(destination_dir): os.makedirs(destination_dir) storage_client = storage.Client() bucket = storage_client.bucket(bucket_name) blob = bucket.blob(source_blob_name) blob.download_to_filename(destination_file_name) def highpass_filter(y, sr, cutoff=1000, order=5): if isinstance(sr, np.ndarray): sr = np.mean(sr) if not isinstance(sr, (int, float)): raise ValueError(f"sr must be a number, but got {type(sr)}: {sr}") nyquist = 0.5 * sr if cutoff <= 0 or cutoff >= nyquist: cutoff = max(10, min(cutoff, nyquist - 1)) normal_cutoff = cutoff / nyquist b, a = signal.butter(order, normal_cutoff, btype='high', analog=False) y_filtered = signal.lfilter(b, a, y) return y_filtered def load_audio(audio_path: str, sr: int = 24000) -> Tuple[torch.Tensor, torch.Tensor]: """ 오디오 파일을 불러와 세그먼트로 분할합니다. 고정된 길이의 세그먼트를 최대 48개 추출하고, 부족한 경우 패딩을 추가합니다. Args: audio_path: 오디오 파일 경로 sr: 목표 샘플링 레이트 (기본값 24000) Returns: Tuple containing: - 오디오 파형이 담긴 텐서 (48, 1, 240000) - 패딩 마스크 텐서 (48), True = 패딩, False = 실제 오디오 """ beats, downbeats = get_segments_from_wav(audio_path) optimal_length, cleaned_downbeats = find_optimal_segment_length(downbeats) waveform, sample_rate = torchaudio.load(audio_path) # 데이터 타입을 float32로 변환 waveform = waveform.to(torch.float32) if sample_rate != sr: resampler = torchaudio.transforms.Resample(sample_rate, sr) waveform = resampler(waveform) # 모노로 변환 (필요한 경우) if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) # 240000 샘플 = 10초 @ 24kHz fixed_samples = 240000 # 각 downbeat에서 시작하는 segment 생성 segments = [] # 다운비트가 없거나 매우 적을 경우 전체 오디오를 일정 간격으로 분할 if len(cleaned_downbeats) < 2: # 오디오 총 길이 (초) total_duration = waveform.size(1) / sr # 5초 간격으로 세그먼트 시작점 생성 (또는 더 짧은 간격으로 설정 가능) # 240000 샘플은 10초이므로 5초 간격은 세그먼트 간 50% 오버랩 segment_interval = 5.0 # 초 단위 # 시작 시간 목록 생성 (0초부터 시작) start_times = [t for t in np.arange(0, total_duration - (fixed_samples/sr) + 0.01, segment_interval)] # 최소한 하나의 세그먼트는 보장 if not start_times and total_duration > 0: start_times = [0.0] else: # 기존 방식대로 다운비트 사용 start_times = cleaned_downbeats # 세그먼트 추출 for i, start_time in enumerate(start_times): # 시작 샘플 인덱스 계산 start_sample = int(start_time * sr) # 끝 샘플 인덱스 계산 (시작 지점 + 고정 길이) end_sample = start_sample + fixed_samples # 파일 끝을 넘어가는지 확인 if end_sample > waveform.size(1): # 짧은 곡의 경우: 끝에서부터 거꾸로 세그먼트 추출 시도 if start_sample < waveform.size(1) and waveform.size(1) >= fixed_samples: start_sample = waveform.size(1) - fixed_samples end_sample = waveform.size(1) else: continue # 정확히 fixed_samples 길이의 세그먼트 추출 segment = waveform[:, start_sample:end_sample] # 하이패스 필터 적용 - 채널 차원 유지 filtered = torch.tensor(highpass_filter(segment.squeeze().numpy(), sr)).unsqueeze(0) segments.append(filtered) # 최대 48개 세그먼트만 사용 if len(segments) >= 48: break # 세그먼트가 없는 경우, 곡이 너무 짧아서 고정 길이 세그먼트를 만들 수 없는 경우 if not segments: if waveform.size(1) > 0: # 오디오가 있지만 매우 짧은 경우 # 전체 오디오를 하나의 세그먼트로 사용하고 나머지는 제로 패딩 segment = waveform # 필요한 길이에 맞게 패딩 추가 padding_length = fixed_samples - segment.size(1) if padding_length > 0: segment = torch.nn.functional.pad(segment, (0, padding_length)) # 하이패스 필터 적용 filtered = torch.tensor(highpass_filter(segment.squeeze().numpy(), sr)).unsqueeze(0) segments.append(filtered) else: # 완전히 빈 오디오일 경우 return torch.zeros((48, 1, fixed_samples), dtype=torch.float32), torch.ones(48, dtype=torch.bool) # 스택하여 텐서로 변환 - (n_segments, 1, time_samples) 형태 유지 stacked_segments = torch.stack(segments) # 실제 세그먼트 수 (패딩 아님) num_segments = stacked_segments.shape[0] # 패딩 마스크 생성 (False = 실제 오디오, True = 패딩) padding_mask = torch.zeros(48, dtype=torch.bool) # 48개 미만인 경우 패딩 추가 if num_segments < 48: # 빈 세그먼트로 패딩 (zeros) padding = torch.zeros((48 - num_segments, 1, fixed_samples), dtype=torch.float32) stacked_segments = torch.cat([stacked_segments, padding], dim=0) # 패딩 마스크 설정 (True = 패딩) padding_mask[num_segments:] = True return stacked_segments, padding_mask def run_inference(model, audio_segments: torch.Tensor, padding_mask: torch.Tensor, device: str = 'cuda' if torch.cuda.is_available() else 'cpu') -> Dict: """ Run inference on audio segments. Args: model: The loaded model audio_segments: Preprocessed audio segments tensor (48, 1, 240000) device: Device to run inference on Returns: Dictionary with prediction results """ model.eval() model.to(device) model = model.half() print(padding_mask.shape) with torch.no_grad(): # 데이터 형태 확인 및 조정 # wav_collate_with_mask 함수와 일치하도록 처리 if audio_segments.shape[1] == 1: # (48, 1, 240000) 형태 # 채널 차원 제거하고 배치 차원 추가 audio_segments = audio_segments[:, 0, :].unsqueeze(0) # (1, 48, 240000) else: audio_segments = audio_segments.unsqueeze(0) # (1, 48, 768) # 사실 audio가 아니라 embedding segments일수도 # 데이터를 half 타입으로 변환 if padding_mask.dim() == 1: padding_mask = padding_mask.unsqueeze(0) # [48] -> [1, 48] audio_segments = audio_segments.to(device).half() mask = padding_mask.to(device) print(f"Input shape: {audio_segments.shape}") print(f"Mask shape: {mask.shape}") print(f"Mask: {mask}") # 추론 실행 (마스크 포함) outputs = model(audio_segments, mask) print(f"Output type: {type(outputs)}") print(f"Output: {outputs}") # 모델 출력 구조에 따라 처리 if isinstance(outputs, dict): result = outputs else: # 단일 텐서인 경우 (로짓) logits = outputs.squeeze() prob = torch.sigmoid(logits).item() result = { "prediction": "Fake" if prob > 0.5 else "Real", "confidence": f"{max(prob, 1-prob)*100:.2f}%", "fake_probability": f"{prob:.4f}", "real_probability": f"{1-prob:.4f}", "raw_output": logits.cpu().numpy().tolist() } return result def get_model(model_type, device): """Load the specified model.""" if model_type == "MERT": from ISMIR_2025.MERT.networks import CCV model = CCV(embed_dim=768, num_heads=8, num_layers=6, num_classes=2, freeze_feature_extractor=True).to(device) ckpt_file = 'mert_finetune_10.pth' model.load_state_dict(torch.load(ckpt_file, map_location=device)) embed_dim = 768 else: raise ValueError(f"Unknown model type: {model_type}") model.eval() return model, embed_dim """ elif model_type == "music2vec": from ISMIR_2025.music2vec.networks import Music2VecClassifier model = Music2VecClassifier(freeze_feature_extractor=True).to(device) ckpt_file = '/data/kym/AI_Music_Detection/Code/model/music2vec/ckpt/fakemusicretrain/musiv2vec_processor/finetune_10.pth' embed_dim = 768 elif model_type == "wav2vec": from ISMIR_2025.wav2vec.networks import Wav2Vec2ForFakeMusic model = Wav2Vec2ForFakeMusic(num_classes=2, freeze_feature_extractor=True).to(device) ckpt_file = '/data/kym/AI_Music_Detection/Code/model/wav2vec/ckpt/split/wav2vec_processor/wav2vec2_finetune_10.pth' embed_dim = 768 elif model_type == "ccv": from ISMIR_2025.Model.networks import CCV model = CCV(embed_dim=512, num_heads=8, num_layers=6, num_classes=2, freeze_feature_extractor=True).to(device) ckpt_file = '/data/kym/AI_Music_Detection/Code/model/ckpt/datasplit/hp1000/best_model_CCV.pth' embed_dim = 512 """ def inference_with_audio(audio_path): #audio_path = "The Chainsmokers & Coldplay - Something Just Like This (Lyric).mp3" model_type = "MERT" checkpoint_path = "with_embedding_MERT_768_embedding/EmbeddingModel_MERT_768-epoch=0353-val_loss=0.3866-val_acc=0.9809-val_f1=0.9803-val_precision=0.9764-val_recall=0.9842.ckpt" device = 'cuda' # Note: Model loading would be handled by your code print(f"Loading model of type {model_type} from {checkpoint_path}") backbone_model, input_dim = get_model(model_type, device) segments, padding_mask = load_audio(audio_path, sr=24000) segments = segments.to(device).to(torch.float32) padding_mask = padding_mask.to(device).unsqueeze(0) logits,embedding = backbone_model(segments.squeeze(1)) test_dataset = FakeMusicCapsDataset([audio_path], [0], target_duration=10.0) test_data, test_target = test_dataset[0] test_data = test_data.to(device).to(torch.float32) test_target = test_target.to(device) output, _ = backbone_model(test_data.unsqueeze(0)) # 모델 로드 부분 추가 model = MusicAudioClassifier.load_from_checkpoint( checkpoint_path, input_dim=input_dim, #emb_model=backbone_model is_emb = True, mode = 'both' ) # Run inference print(f"Segments shape: {segments.shape}") print("Running inference...") results = run_inference(model, embedding, padding_mask, device=device) # 결과 출력 print(f"Results: {results}") return str(results) if __name__ == "__main__": inference_with_audio()