nininigold commited on
Commit
5b13050
·
verified ·
1 Parent(s): 8204aee

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. inference.py +0 -15
  2. preprocess.py +4 -8
  3. requirements.txt +2 -1
inference.py CHANGED
@@ -10,27 +10,12 @@ from model import MusicClassifier, MusicAudioClassifier
10
  import argparse
11
  import torch
12
  import torchaudio
13
- import deepspeed
14
  import scipy.signal as signal
15
  from typing import Dict, List
16
- from google.cloud import storage
17
  from dataset_f import FakeMusicCapsDataset
18
 
19
  from preprocess import get_segments_from_wav, find_optimal_segment_length
20
 
21
- #not for ismir
22
-
23
- def download_from_gcs(bucket_name, source_blob_name, destination_file_name):
24
- destination_dir = os.path.dirname(destination_file_name)
25
- if not os.path.exists(destination_dir):
26
- os.makedirs(destination_dir)
27
-
28
- storage_client = storage.Client()
29
- bucket = storage_client.bucket(bucket_name)
30
- blob = bucket.blob(source_blob_name)
31
- blob.download_to_filename(destination_file_name)
32
-
33
-
34
 
35
  def highpass_filter(y, sr, cutoff=1000, order=5):
36
  if isinstance(sr, np.ndarray):
 
10
  import argparse
11
  import torch
12
  import torchaudio
 
13
  import scipy.signal as signal
14
  from typing import Dict, List
 
15
  from dataset_f import FakeMusicCapsDataset
16
 
17
  from preprocess import get_segments_from_wav, find_optimal_segment_length
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  def highpass_filter(y, sr, cutoff=1000, order=5):
21
  if isinstance(sr, np.ndarray):
preprocess.py CHANGED
@@ -9,8 +9,6 @@ import argparse
9
  from tqdm import tqdm
10
  import shutil
11
  import concurrent.futures
12
- from madmom.features.downbeats import DBNDownBeatTrackingProcessor
13
- from madmom.features.downbeats import RNNDownBeatProcessor
14
 
15
 
16
  def get_segments_from_wav(wav_path, device="cuda"):
@@ -19,12 +17,10 @@ def get_segments_from_wav(wav_path, device="cuda"):
19
  file2beats = File2Beats(checkpoint_path="final0", device="cuda", dbn=False)
20
  all_models = ["final0", "final1", "final2", "small0", "small1", "small2","single_final0", "single_final1", "single_final2"]
21
  beats, downbeats = file2beats(wav_path)
22
- if len(downbeats) <1:
23
- proc = DBNDownBeatTrackingProcessor(beats_per_bar=[3, 4], fps=100)
24
- act = RNNDownBeatProcessor()(wav_path)
25
- array = proc(act)
26
- downbeats = array[:, 0]
27
- return beats, downbeats#beats는 빈거 맞음
28
 
29
  return beats, downbeats
30
  #except Exception as e:
 
9
  from tqdm import tqdm
10
  import shutil
11
  import concurrent.futures
 
 
12
 
13
 
14
  def get_segments_from_wav(wav_path, device="cuda"):
 
17
  file2beats = File2Beats(checkpoint_path="final0", device="cuda", dbn=False)
18
  all_models = ["final0", "final1", "final2", "small0", "small1", "small2","single_final0", "single_final1", "single_final2"]
19
  beats, downbeats = file2beats(wav_path)
20
+ if len(downbeats)==0: # downbeats를 그냥 0 2 4..로 넣어주자. 음악 길이에 맞게
21
+ waveform, sample_rate = torchaudio.load(wav_path)
22
+ duration = waveform.size(1) / sample_rate
23
+ downbeats = np.arange(0, duration, 2)
 
 
24
 
25
  return beats, downbeats
26
  #except Exception as e:
requirements.txt CHANGED
@@ -3,4 +3,5 @@ torch
3
  tqdm
4
  imblearn
5
  transformers
6
- https://github.com/CPJKU/beat_this/archive/main.zip
 
 
3
  tqdm
4
  imblearn
5
  transformers
6
+ https://github.com/CPJKU/beat_this/archive/main.zip
7
+ pytorch_lightning