Serhiy Stetskovych commited on
Commit
98a6a49
·
1 Parent(s): 93c6a78

Use device variable

Browse files
Files changed (2) hide show
  1. app.py +3 -4
  2. inference.py +0 -150
app.py CHANGED
@@ -1,5 +1,3 @@
1
- import os
2
- import torchaudio
3
  import torch
4
  import numpy as np
5
  import gradio as gr
@@ -9,6 +7,7 @@ import tqdm
9
 
10
  import look2hear.models
11
  from ml_collections import ConfigDict
 
12
 
13
  def load_audio(file_path):
14
  audio, samplerate = librosa.load(file_path, mono=False, sr=44100)
@@ -44,7 +43,7 @@ texts
44
 
45
 
46
  apollo_config = get_config('configs/apollo.yaml')
47
- apollo_model = look2hear.models.BaseModel.from_pretrain('weights/apollo.bin', **apollo_config['model']).cuda()
48
 
49
  models = [
50
  ('MP3 restore', apollo_model)
@@ -87,7 +86,7 @@ def enchance(model, audio):
87
  part = torch.nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0)
88
 
89
 
90
- chunk = part.unsqueeze(0).cuda()
91
  with torch.no_grad():
92
  out = model(chunk).squeeze(0).squeeze(0).cpu()
93
 
 
 
 
1
  import torch
2
  import numpy as np
3
  import gradio as gr
 
7
 
8
  import look2hear.models
9
  from ml_collections import ConfigDict
10
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
11
 
12
  def load_audio(file_path):
13
  audio, samplerate = librosa.load(file_path, mono=False, sr=44100)
 
43
 
44
 
45
  apollo_config = get_config('configs/apollo.yaml')
46
+ apollo_model = look2hear.models.BaseModel.from_pretrain('weights/apollo.bin', **apollo_config['model']).to(device)
47
 
48
  models = [
49
  ('MP3 restore', apollo_model)
 
86
  part = torch.nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0)
87
 
88
 
89
+ chunk = part.unsqueeze(0).to(device)
90
  with torch.no_grad():
91
  out = model(chunk).squeeze(0).squeeze(0).cpu()
92
 
inference.py DELETED
@@ -1,150 +0,0 @@
1
- import os
2
- import torch
3
- import librosa
4
- import look2hear.models
5
- import soundfile as sf
6
- from tqdm.auto import tqdm
7
- import argparse
8
- import numpy as np
9
- import yaml
10
- from ml_collections import ConfigDict
11
- #from omegaconf import OmegaConf
12
-
13
- import warnings
14
- warnings.filterwarnings("ignore")
15
-
16
- def get_config(config_path):
17
- with open(config_path) as f:
18
- #config = OmegaConf.load(config_path)
19
- config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
20
- return config
21
-
22
- def load_audio(file_path):
23
- audio, samplerate = librosa.load(file_path, mono=False, sr=44100)
24
- print(f'INPUT audio.shape = {audio.shape} | samplerate = {samplerate}')
25
- #audio = dBgain(audio, -6)
26
- return torch.from_numpy(audio), samplerate
27
-
28
- def save_audio(file_path, audio, samplerate=44100):
29
- #audio = dBgain(audio, +6)
30
- sf.write(file_path, audio.T, samplerate, subtype="PCM_16")
31
-
32
- def process_chunk(chunk):
33
- chunk = chunk.unsqueeze(0).cpu()
34
- with torch.no_grad():
35
- return model(chunk).squeeze(0).squeeze(0).cpu()
36
-
37
- def _getWindowingArray(window_size, fade_size):
38
- # IMPORTANT NOTE :
39
- # no fades here in the end, only removing the failed ending of the chunk
40
- fadein = torch.linspace(1, 1, fade_size)
41
- fadeout = torch.linspace(0, 0, fade_size)
42
- window = torch.ones(window_size)
43
- window[-fade_size:] *= fadeout
44
- window[:fade_size] *= fadein
45
- return window
46
-
47
- def dBgain(audio, volume_gain_dB):
48
- gain = 10 ** (volume_gain_dB / 20)
49
- gained_audio = audio * gain
50
- return gained_audio
51
-
52
-
53
- def main(input_wav, output_wav, ckpt_path):
54
- os.environ['CUDA_VISIBLE_DEVICES'] = "0"
55
-
56
- global model
57
- feature_dim = config['model']['feature_dim']
58
- sr = config['model']['sr']
59
- win = config['model']['win']
60
- layer = config['model']['layer']
61
- model = look2hear.models.BaseModel.from_pretrain(ckpt_path, sr=sr, win=win, feature_dim=feature_dim, layer=layer).cpu()
62
-
63
- test_data, samplerate = load_audio(input_wav)
64
-
65
- C = chunk_size * samplerate # chunk_size seconds to samples
66
- N = overlap
67
- step = C // N
68
- fade_size = 3 * 44100 # 3 seconds
69
- print(f"N = {N} | C = {C} | step = {step} | fade_size = {fade_size}")
70
-
71
- border = C - step
72
-
73
- # handle mono inputs correctly
74
- if len(test_data.shape) == 1:
75
- test_data = test_data.unsqueeze(0)
76
-
77
- # Pad the input if necessary
78
- if test_data.shape[1] > 2 * border and (border > 0):
79
- test_data = torch.nn.functional.pad(test_data, (border, border), mode='reflect')
80
-
81
- windowingArray = _getWindowingArray(C, fade_size)
82
-
83
- result = torch.zeros((1,) + tuple(test_data.shape), dtype=torch.float32)
84
- counter = torch.zeros((1,) + tuple(test_data.shape), dtype=torch.float32)
85
-
86
- i = 0
87
- progress_bar = tqdm(total=test_data.shape[1], desc="Processing audio chunks", leave=False)
88
-
89
- while i < test_data.shape[1]:
90
- part = test_data[:, i:i + C]
91
- length = part.shape[-1]
92
- if length < C:
93
- if length > C // 2 + 1:
94
- part = torch.nn.functional.pad(input=part, pad=(0, C - length), mode='reflect')
95
- else:
96
- part = torch.nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0)
97
-
98
- out = process_chunk(part)
99
-
100
- window = windowingArray
101
- if i == 0: # First audio chunk, no fadein
102
- window[:fade_size] = 1
103
- elif i + C >= test_data.shape[1]: # Last audio chunk, no fadeout
104
- window[-fade_size:] = 1
105
-
106
- result[..., i:i+length] += out[..., :length] * window[..., :length]
107
- counter[..., i:i+length] += window[..., :length]
108
-
109
- i += step
110
- progress_bar.update(step)
111
-
112
- progress_bar.close()
113
-
114
- final_output = result / counter
115
- final_output = final_output.squeeze(0).numpy()
116
- np.nan_to_num(final_output, copy=False, nan=0.0)
117
-
118
- # Remove padding if added earlier
119
- if test_data.shape[1] > 2 * border and (border > 0):
120
- final_output = final_output[..., border:-border]
121
-
122
- save_audio(output_wav, final_output, samplerate)
123
- print(f'Success! Output file saved as {output_wav}')
124
-
125
- # Memory clearing
126
- model.cpu()
127
- del model
128
- torch.cuda.empty_cache()
129
-
130
- if __name__ == "__main__":
131
- parser = argparse.ArgumentParser(description="Audio Inference Script")
132
- parser.add_argument("--in_wav", type=str, required=True, help="Path to input wav file")
133
- parser.add_argument("--out_wav", type=str, required=True, help="Path to output wav file")
134
- parser.add_argument("--ckpt", type=str, required=True, help="Path to model checkpoint file", default="model/pytorch_model.bin")
135
- parser.add_argument("--config", type=str, help="Path to model config file", default="config/apollo.yaml")
136
- parser.add_argument("--chunk_size", type=int, help="chunk size value in seconds", default=10)
137
- parser.add_argument("--overlap", type=int, help="Overlap", default=2)
138
- args = parser.parse_args()
139
-
140
- ckpt_path = args.ckpt
141
- chunk_size = args.chunk_size
142
- overlap = args.overlap
143
- config = get_config(args.config)
144
- print(config['model'])
145
- print(f'ckpt_path = {ckpt_path}')
146
- #print(f'config = {config}')
147
- print(f'chunk_size = {chunk_size}, overlap = {overlap}')
148
-
149
-
150
- main(args.in_wav, args.out_wav, ckpt_path)