Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse files- inference.py +0 -15
- preprocess.py +4 -8
- requirements.txt +2 -1
inference.py
CHANGED
@@ -10,27 +10,12 @@ from model import MusicClassifier, MusicAudioClassifier
|
|
10 |
import argparse
|
11 |
import torch
|
12 |
import torchaudio
|
13 |
-
import deepspeed
|
14 |
import scipy.signal as signal
|
15 |
from typing import Dict, List
|
16 |
-
from google.cloud import storage
|
17 |
from dataset_f import FakeMusicCapsDataset
|
18 |
|
19 |
from preprocess import get_segments_from_wav, find_optimal_segment_length
|
20 |
|
21 |
-
#not for ismir
|
22 |
-
|
23 |
-
def download_from_gcs(bucket_name, source_blob_name, destination_file_name):
|
24 |
-
destination_dir = os.path.dirname(destination_file_name)
|
25 |
-
if not os.path.exists(destination_dir):
|
26 |
-
os.makedirs(destination_dir)
|
27 |
-
|
28 |
-
storage_client = storage.Client()
|
29 |
-
bucket = storage_client.bucket(bucket_name)
|
30 |
-
blob = bucket.blob(source_blob_name)
|
31 |
-
blob.download_to_filename(destination_file_name)
|
32 |
-
|
33 |
-
|
34 |
|
35 |
def highpass_filter(y, sr, cutoff=1000, order=5):
|
36 |
if isinstance(sr, np.ndarray):
|
|
|
10 |
import argparse
|
11 |
import torch
|
12 |
import torchaudio
|
|
|
13 |
import scipy.signal as signal
|
14 |
from typing import Dict, List
|
|
|
15 |
from dataset_f import FakeMusicCapsDataset
|
16 |
|
17 |
from preprocess import get_segments_from_wav, find_optimal_segment_length
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
def highpass_filter(y, sr, cutoff=1000, order=5):
|
21 |
if isinstance(sr, np.ndarray):
|
preprocess.py
CHANGED
@@ -9,8 +9,6 @@ import argparse
|
|
9 |
from tqdm import tqdm
|
10 |
import shutil
|
11 |
import concurrent.futures
|
12 |
-
from madmom.features.downbeats import DBNDownBeatTrackingProcessor
|
13 |
-
from madmom.features.downbeats import RNNDownBeatProcessor
|
14 |
|
15 |
|
16 |
def get_segments_from_wav(wav_path, device="cuda"):
|
@@ -19,12 +17,10 @@ def get_segments_from_wav(wav_path, device="cuda"):
|
|
19 |
file2beats = File2Beats(checkpoint_path="final0", device="cuda", dbn=False)
|
20 |
all_models = ["final0", "final1", "final2", "small0", "small1", "small2","single_final0", "single_final1", "single_final2"]
|
21 |
beats, downbeats = file2beats(wav_path)
|
22 |
-
if len(downbeats)
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
downbeats = array[:, 0]
|
27 |
-
return beats, downbeats#beats는 빈거 맞음
|
28 |
|
29 |
return beats, downbeats
|
30 |
#except Exception as e:
|
|
|
9 |
from tqdm import tqdm
|
10 |
import shutil
|
11 |
import concurrent.futures
|
|
|
|
|
12 |
|
13 |
|
14 |
def get_segments_from_wav(wav_path, device="cuda"):
|
|
|
17 |
file2beats = File2Beats(checkpoint_path="final0", device="cuda", dbn=False)
|
18 |
all_models = ["final0", "final1", "final2", "small0", "small1", "small2","single_final0", "single_final1", "single_final2"]
|
19 |
beats, downbeats = file2beats(wav_path)
|
20 |
+
if len(downbeats)==0: # downbeats를 그냥 0 2 4..로 넣어주자. 음악 길이에 맞게
|
21 |
+
waveform, sample_rate = torchaudio.load(wav_path)
|
22 |
+
duration = waveform.size(1) / sample_rate
|
23 |
+
downbeats = np.arange(0, duration, 2)
|
|
|
|
|
24 |
|
25 |
return beats, downbeats
|
26 |
#except Exception as e:
|
requirements.txt
CHANGED
@@ -3,4 +3,5 @@ torch
|
|
3 |
tqdm
|
4 |
imblearn
|
5 |
transformers
|
6 |
-
https://github.com/CPJKU/beat_this/archive/main.zip
|
|
|
|
3 |
tqdm
|
4 |
imblearn
|
5 |
transformers
|
6 |
+
https://github.com/CPJKU/beat_this/archive/main.zip
|
7 |
+
pytorch_lightning
|