Spaces:
Runtime error
Runtime error
pengdaqian
commited on
Commit
·
d853526
1
Parent(s):
6b780da
fix
Browse files- app.py +6 -6
- torchspleeter/__init__.py +4 -0
- torchspleeter/checkpoints/2stems/audio_example.mp3 +0 -0
- torchspleeter/checkpoints/2stems/testcheckpoint0.ckpt +3 -0
- torchspleeter/checkpoints/2stems/testcheckpoint1.ckpt +3 -0
- torchspleeter/command_interface.py +98 -0
- torchspleeter/estimator.py +160 -0
- torchspleeter/test/test_estimator.py +45 -0
- torchspleeter/unet.py +97 -0
app.py
CHANGED
@@ -14,10 +14,9 @@ import gradio as gr
|
|
14 |
import librosa
|
15 |
import numpy as np
|
16 |
import soundfile
|
17 |
-
from spleeter.separator import Separator
|
18 |
from pydub import AudioSegment
|
19 |
import uuid
|
20 |
-
|
21 |
import logging
|
22 |
|
23 |
logging.getLogger('numba').setLevel(logging.WARNING)
|
@@ -84,11 +83,11 @@ model = SynthesizerInfer(
|
|
84 |
load_svc_model("vits_pretrain/sovits5.0-48k-debug.pth", model)
|
85 |
model.eval()
|
86 |
model.to(device)
|
87 |
-
separator = Separator('spleeter:2stems')
|
88 |
whisper_model = whisper.inference.load_model(os.path.join("whisper_pretrain", "medium.pt"))
|
89 |
|
|
|
90 |
# warm up
|
91 |
-
separator.separate_to_file('warm.wav', '/tmp/warm')
|
92 |
|
93 |
|
94 |
def svc_change(argswave, argsspk):
|
@@ -196,7 +195,8 @@ def svc_main(sid, input_audio):
|
|
196 |
soundfile.write(input_audio_tmp_file, audio, sampling_rate, format="wav")
|
197 |
if not os.path.exists(tmpfile_path):
|
198 |
os.makedirs(tmpfile_path)
|
199 |
-
|
|
|
200 |
|
201 |
curr_tmp_path = os.path.join(tmpfile_path, os.path.splitext(input_audio_tmp_file)[0])
|
202 |
vocals_filepath = os.path.join(curr_tmp_path, 'vocals.wav')
|
@@ -346,8 +346,8 @@ def main():
|
|
346 |
|
347 |
app.queue(max_size=3, api_open=False).launch()
|
348 |
except KeyboardInterrupt:
|
349 |
-
separator._get_session().close()
|
350 |
app.close()
|
|
|
351 |
sys.exit(0)
|
352 |
|
353 |
|
|
|
14 |
import librosa
|
15 |
import numpy as np
|
16 |
import soundfile
|
|
|
17 |
from pydub import AudioSegment
|
18 |
import uuid
|
19 |
+
from torchspleeter import split_to_parts
|
20 |
import logging
|
21 |
|
22 |
logging.getLogger('numba').setLevel(logging.WARNING)
|
|
|
83 |
load_svc_model("vits_pretrain/sovits5.0-48k-debug.pth", model)
|
84 |
model.eval()
|
85 |
model.to(device)
|
|
|
86 |
whisper_model = whisper.inference.load_model(os.path.join("whisper_pretrain", "medium.pt"))
|
87 |
|
88 |
+
|
89 |
# warm up
|
90 |
+
# separator.separate_to_file('warm.wav', '/tmp/warm')
|
91 |
|
92 |
|
93 |
def svc_change(argswave, argsspk):
|
|
|
195 |
soundfile.write(input_audio_tmp_file, audio, sampling_rate, format="wav")
|
196 |
if not os.path.exists(tmpfile_path):
|
197 |
os.makedirs(tmpfile_path)
|
198 |
+
|
199 |
+
split_to_parts(input_audio_tmp_file, tmpfile_path, models='torchspleeter/checkpoints/2stems/testcheckpoint1.ckpt')
|
200 |
|
201 |
curr_tmp_path = os.path.join(tmpfile_path, os.path.splitext(input_audio_tmp_file)[0])
|
202 |
vocals_filepath = os.path.join(curr_tmp_path, 'vocals.wav')
|
|
|
346 |
|
347 |
app.queue(max_size=3, api_open=False).launch()
|
348 |
except KeyboardInterrupt:
|
|
|
349 |
app.close()
|
350 |
+
separator._get_session().close()
|
351 |
sys.exit(0)
|
352 |
|
353 |
|
torchspleeter/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torchspleeter.command_interface import *
|
2 |
+
|
3 |
+
|
4 |
+
version="0.1.5"
|
torchspleeter/checkpoints/2stems/audio_example.mp3
ADDED
Binary file (263 kB). View file
|
|
torchspleeter/checkpoints/2stems/testcheckpoint0.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:385ea3372c6a3ceee01f6ded5504bb7ee1e9f0101950ae58869dc18382deb75c
|
3 |
+
size 59050239
|
torchspleeter/checkpoints/2stems/testcheckpoint1.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3e4d6ede1ecad091468550773e77934aac3f1e039c0697fc9039aba9b935e344
|
3 |
+
size 59033471
|
torchspleeter/command_interface.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
|
3 |
+
This provides an interface to interact with the spleeter system on
|
4 |
+
|
5 |
+
|
6 |
+
"""
|
7 |
+
|
8 |
+
import os
|
9 |
+
from torchspleeter.estimator import Estimator
|
10 |
+
import argparse
|
11 |
+
import uuid
|
12 |
+
import numpy as np
|
13 |
+
import librosa
|
14 |
+
import soundfile
|
15 |
+
import torch
|
16 |
+
import pydub
|
17 |
+
import os
|
18 |
+
import shutil
|
19 |
+
|
20 |
+
|
21 |
+
def split_to_parts(inputaudiofile, outputdir, instruments=2, models=[]):
|
22 |
+
filedata = pydub.AudioSegment.from_file(inputaudiofile)
|
23 |
+
sr = filedata.frame_rate
|
24 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
25 |
+
# es = Estimator(2, './checkpoints/2stems/model').to(device)
|
26 |
+
# es = Estimator(2, ['./checkpoints/2stems/testcheckpoint0.ckpt','./checkpoints/2stems/testcheckpoint1.ckpt']).to(device)
|
27 |
+
es = Estimator()
|
28 |
+
es.eval()
|
29 |
+
|
30 |
+
# load wav audio
|
31 |
+
testaudiofile = inputaudiofile
|
32 |
+
channels = filedata.channels
|
33 |
+
mono_selection = False
|
34 |
+
if not os.path.exists(outputdir):
|
35 |
+
os.makedirs(outputdir, exist_ok=True)
|
36 |
+
if channels == 1:
|
37 |
+
mono_selection = True
|
38 |
+
multichannel = pydub.AudioSegment.from_mono_audiosegments(filedata, filedata)
|
39 |
+
os.makedirs(os.path.join(outputdir, 'tmp'), exist_ok=True)
|
40 |
+
testaudiofile = os.path.join(outputdir, "tmp" + str(uuid.uuid4()) + "." + testaudiofile.split('.')[-1])
|
41 |
+
# testaudiofile=testaudiofile.split('.')[0]+"-stereo."+testaudiofile.split('.')[-1]
|
42 |
+
multichannel.export(out_f=testaudiofile, format=testaudiofile.split('.')[-1])
|
43 |
+
print(mono_selection)
|
44 |
+
print(channels)
|
45 |
+
wav, _ = librosa.load(testaudiofile, mono=False, res_type='kaiser_fast', sr=sr)
|
46 |
+
wav = torch.Tensor(wav).to(device)
|
47 |
+
if mono_selection:
|
48 |
+
shutil.rmtree(os.path.join(outputdir, "tmp"))
|
49 |
+
# os.remove(testaudiofile)
|
50 |
+
wavs = es.separate(wav)
|
51 |
+
outputname = str(uuid.uuid4())
|
52 |
+
returnarray = []
|
53 |
+
for i in range(len(wavs)):
|
54 |
+
finaloutput = os.path.join(outputdir, outputname)
|
55 |
+
fname = '-out_{}.wav'.format(i)
|
56 |
+
fname = finaloutput + fname
|
57 |
+
print('Writing ', fname)
|
58 |
+
soundfile.write(fname, wavs[i].cpu().detach().numpy().T, sr, "PCM_16")
|
59 |
+
returnarray.append(fname)
|
60 |
+
# write_wav(fname, np.asfortranarray(wavs[i].squeeze().numpy()), sr)
|
61 |
+
return returnarray
|
62 |
+
|
63 |
+
|
64 |
+
def get_file_list(dirname):
|
65 |
+
outputfilelist = []
|
66 |
+
for subdir, dirs, files in os.walk(dirname):
|
67 |
+
for file in files:
|
68 |
+
outputfilelist.append(os.path.join(subdir, file))
|
69 |
+
|
70 |
+
return outputfilelist
|
71 |
+
|
72 |
+
|
73 |
+
def main():
|
74 |
+
parser = argparse.ArgumentParser(
|
75 |
+
description='torchspleeter allows you to separate instrumentals from audio (vocals, instruments, background noise, etc) in a simple, cross platform manner')
|
76 |
+
parser.add_argument('-i', '--inputfile', help='Input Audio File to split into instrumentals', required=True)
|
77 |
+
parser.add_argument('-o', '--output', help='Output directory to deposit split audio', required=True)
|
78 |
+
parser.add_argument('-n', '--number', help="Number of instruments in the model (default 2)", required=False,
|
79 |
+
default=2, type=int)
|
80 |
+
parser.add_argument('-m', '--modeldir',
|
81 |
+
help="directory containing number of pre-converted torch compatible model components",
|
82 |
+
required=False)
|
83 |
+
args = vars(parser.parse_args())
|
84 |
+
print(args)
|
85 |
+
if args['modeldir'] is not None:
|
86 |
+
modelfiles = get_file_list(args['modeldir'])
|
87 |
+
if len(modelfiles) != args['number']:
|
88 |
+
raise ValueError("You must have the same number of models as you do number of instruments!")
|
89 |
+
else:
|
90 |
+
args['modeldir'] = []
|
91 |
+
outputfiles = split_to_parts(args['inputfile'], args['output'], args['number'], args['modeldir'])
|
92 |
+
print("Your output files are:")
|
93 |
+
for item in outputfiles:
|
94 |
+
print(item)
|
95 |
+
|
96 |
+
|
97 |
+
if __name__ == "__main__":
|
98 |
+
main()
|
torchspleeter/estimator.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch import nn
|
6 |
+
import tqdm
|
7 |
+
# from torchaudio.functional import istft
|
8 |
+
|
9 |
+
from torchspleeter.unet import UNet
|
10 |
+
#from .util import tf2pytorch
|
11 |
+
|
12 |
+
import os
|
13 |
+
dirname = os.path.dirname(__file__)
|
14 |
+
defaultmodel0 = os.path.join(dirname, 'checkpoints/2stems/testcheckpoint0.ckpt')
|
15 |
+
defaultmodel1 = os.path.join(dirname, 'checkpoints/2stems/testcheckpoint1.ckpt')
|
16 |
+
|
17 |
+
def load_ckpt(model, ckpt):
|
18 |
+
state_dict = model.state_dict()
|
19 |
+
for k, v in ckpt.items():
|
20 |
+
if k in state_dict:
|
21 |
+
target_shape = state_dict[k].shape
|
22 |
+
assert target_shape == v.shape
|
23 |
+
state_dict.update({k: torch.from_numpy(v)})
|
24 |
+
else:
|
25 |
+
print('Ignore ', k)
|
26 |
+
|
27 |
+
model.load_state_dict(state_dict)
|
28 |
+
return model
|
29 |
+
|
30 |
+
|
31 |
+
def pad_and_partition(tensor, T):
|
32 |
+
"""
|
33 |
+
pads zero and partition tensor into segments of length T
|
34 |
+
|
35 |
+
Args:
|
36 |
+
tensor(Tensor): BxCxFxL
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
tensor of size (B*[L/T] x C x F x T)
|
40 |
+
"""
|
41 |
+
old_size = tensor.size(3)
|
42 |
+
new_size = math.ceil(old_size/T) * T
|
43 |
+
tensor = F.pad(tensor, [0, new_size - old_size])
|
44 |
+
[b, c, t, f] = tensor.shape
|
45 |
+
split = new_size // T
|
46 |
+
return torch.cat(torch.split(tensor, T, dim=3), dim=0)
|
47 |
+
|
48 |
+
|
49 |
+
class Estimator(nn.Module):
|
50 |
+
def __init__(self, num_instrumments=2, checkpoint_path=None):
|
51 |
+
super(Estimator, self).__init__()
|
52 |
+
if checkpoint_path is None:
|
53 |
+
checkpoint_path=[defaultmodel0,defaultmodel1]
|
54 |
+
else:
|
55 |
+
if len(checkpoint_path)<1:
|
56 |
+
checkpoint_path=[defaultmodel0,defaultmodel1]
|
57 |
+
# stft config
|
58 |
+
self.F = 1024
|
59 |
+
self.T = 512
|
60 |
+
self.win_length = 4096
|
61 |
+
self.hop_length = 1024
|
62 |
+
self.win = nn.Parameter(
|
63 |
+
torch.hann_window(self.win_length),
|
64 |
+
requires_grad=False
|
65 |
+
)
|
66 |
+
|
67 |
+
ckpts=[]
|
68 |
+
if len(checkpoint_path) != num_instrumments:
|
69 |
+
raise ValueError("You must submit as many models as there are instruments!")
|
70 |
+
for ckpt_path in checkpoint_path:
|
71 |
+
ckpts.append(torch.load(ckpt_path))
|
72 |
+
|
73 |
+
#self.ckpts = ckpt #torch.load(checkpoint_path)#, num_instrumments)
|
74 |
+
|
75 |
+
#ckpts = #tf2pytorch(checkpoint_path, num_instrumments)
|
76 |
+
|
77 |
+
# filter
|
78 |
+
self.instruments = nn.ModuleList()
|
79 |
+
for i in range(num_instrumments):
|
80 |
+
print('Loading model for instrumment {}'.format(i))
|
81 |
+
net = UNet(2)
|
82 |
+
ckpt = ckpts[i]
|
83 |
+
net = load_ckpt(net, ckpt)
|
84 |
+
net.eval() # change mode to eval
|
85 |
+
self.instruments.append(net)
|
86 |
+
|
87 |
+
def compute_stft(self, wav):
|
88 |
+
"""
|
89 |
+
Computes stft feature from wav
|
90 |
+
|
91 |
+
Args:
|
92 |
+
wav (Tensor): B x L
|
93 |
+
"""
|
94 |
+
|
95 |
+
stft = torch.stft(wav, n_fft=self.win_length, hop_length=self.hop_length, window=self.win,
|
96 |
+
center=True, return_complex=False, pad_mode='constant')
|
97 |
+
|
98 |
+
# only keep freqs smaller than self.F
|
99 |
+
stft = stft[:, :self.F, :, :]
|
100 |
+
real = stft[:, :, :, 0]
|
101 |
+
im = stft[:, :, :, 1]
|
102 |
+
mag = torch.sqrt(real ** 2 + im ** 2)
|
103 |
+
|
104 |
+
return stft, mag
|
105 |
+
|
106 |
+
def inverse_stft(self, stft):
|
107 |
+
"""Inverses stft to wave form"""
|
108 |
+
|
109 |
+
pad = self.win_length // 2 + 1 - stft.size(1)
|
110 |
+
stft = F.pad(stft, (0, 0, 0, 0, 0, pad))
|
111 |
+
wav = torch.istft(stft, self.win_length, hop_length=self.hop_length, center=True,
|
112 |
+
window=self.win)
|
113 |
+
return wav.detach()
|
114 |
+
|
115 |
+
def separate(self, wav):
|
116 |
+
"""
|
117 |
+
Separates stereo wav into different tracks corresponding to different instruments
|
118 |
+
|
119 |
+
Args:
|
120 |
+
wav (tensor): 2 x L
|
121 |
+
"""
|
122 |
+
|
123 |
+
# stft - 2 X F x L x 2
|
124 |
+
# stft_mag - 2 X F x L
|
125 |
+
stft, stft_mag = self.compute_stft(wav)
|
126 |
+
|
127 |
+
L = stft.size(2)
|
128 |
+
|
129 |
+
# 1 x 2 x F x T
|
130 |
+
stft_mag = stft_mag.unsqueeze(-1).permute([3, 0, 1, 2])
|
131 |
+
stft_mag = pad_and_partition(stft_mag, self.T) # B x 2 x F x T
|
132 |
+
stft_mag = stft_mag.transpose(2, 3) # B x 2 x T x F
|
133 |
+
|
134 |
+
B = stft_mag.shape[0]
|
135 |
+
|
136 |
+
# compute instruments' mask
|
137 |
+
masks = []
|
138 |
+
for net in self.instruments:
|
139 |
+
mask = net(stft_mag)
|
140 |
+
masks.append(mask)
|
141 |
+
|
142 |
+
# compute denominator
|
143 |
+
mask_sum = sum([m ** 2 for m in masks])
|
144 |
+
mask_sum += 1e-10
|
145 |
+
|
146 |
+
wavs = []
|
147 |
+
for mask in tqdm.tqdm(masks):
|
148 |
+
mask = (mask ** 2 + 1e-10/2)/(mask_sum)
|
149 |
+
mask = mask.transpose(2, 3) # B x 2 X F x T
|
150 |
+
|
151 |
+
mask = torch.cat(
|
152 |
+
torch.split(mask, 1, dim=0), dim=3)
|
153 |
+
|
154 |
+
mask = mask.squeeze(0)[:,:,:L].unsqueeze(-1) # 2 x F x L x 1
|
155 |
+
stft_masked = stft * mask
|
156 |
+
wav_masked = self.inverse_stft(stft_masked)
|
157 |
+
|
158 |
+
wavs.append(wav_masked)
|
159 |
+
|
160 |
+
return wavs
|
torchspleeter/test/test_estimator.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import librosa
|
3 |
+
import soundfile
|
4 |
+
import torch
|
5 |
+
import pydub
|
6 |
+
import os
|
7 |
+
from torchspleeter.estimator import Estimator
|
8 |
+
dirname = os.path.dirname(__file__)
|
9 |
+
testfilename = os.path.join(dirname, 'checkpoints/2stems/audio_example.mp3')
|
10 |
+
|
11 |
+
if __name__ == '__main__':
|
12 |
+
sr = 44100
|
13 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
14 |
+
#es = Estimator(2, './checkpoints/2stems/model').to(device)
|
15 |
+
#es = Estimator(2, ['./checkpoints/2stems/testcheckpoint0.ckpt','./checkpoints/2stems/testcheckpoint1.ckpt']).to(device)
|
16 |
+
es=Estimator()
|
17 |
+
es.eval()
|
18 |
+
|
19 |
+
# load wav audio
|
20 |
+
testaudiofile=testfilename
|
21 |
+
filestats=pydub.AudioSegment.from_file(testaudiofile)
|
22 |
+
channels=filestats.channels
|
23 |
+
mono_selection=False
|
24 |
+
if channels==1:
|
25 |
+
mono_selection=True
|
26 |
+
multichannel=pydub.AudioSegment.from_mono_audiosegments(filestats,filestats)
|
27 |
+
testaudiofile=testaudiofile.split('.')[0]+"-stereo."+testaudiofile.split('.')[-1]
|
28 |
+
multichannel.export(out_f=testaudiofile,format=testaudiofile.split('.')[-1])
|
29 |
+
print(mono_selection)
|
30 |
+
print(channels)
|
31 |
+
wav, _ = librosa.load(testaudiofile, mono=False, res_type='kaiser_fast',sr=sr)
|
32 |
+
wav = torch.Tensor(wav).to(device)
|
33 |
+
if mono_selection:
|
34 |
+
os.remove(testaudiofile)
|
35 |
+
|
36 |
+
|
37 |
+
# normalize audio
|
38 |
+
# wav_torch = wav / (wav.max() + 1e-8)
|
39 |
+
|
40 |
+
wavs = es.separate(wav)
|
41 |
+
for i in range(len(wavs)):
|
42 |
+
fname = 'output/out_{}.wav'.format(i)
|
43 |
+
print('Writing ',fname)
|
44 |
+
soundfile.write(fname, wavs[i].cpu().detach().numpy().T, sr, "PCM_16")
|
45 |
+
# write_wav(fname, np.asfortranarray(wavs[i].squeeze().numpy()), sr)
|
torchspleeter/unet.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
class CustomPad(nn.Module):
|
6 |
+
def __init__(self, padding_setting=(1, 2, 1, 2)):
|
7 |
+
super(CustomPad, self).__init__()
|
8 |
+
self.padding_setting = padding_setting
|
9 |
+
|
10 |
+
def forward(self, x):
|
11 |
+
return F.pad(x, self.padding_setting, "constant", 0)
|
12 |
+
|
13 |
+
|
14 |
+
class CustomTransposedPad(nn.Module):
|
15 |
+
def __init__(self, padding_setting=(1, 2, 1, 2)):
|
16 |
+
super(CustomTransposedPad, self).__init__()
|
17 |
+
self.padding_setting = padding_setting
|
18 |
+
|
19 |
+
def forward(self, x):
|
20 |
+
l,r,t,b = self.padding_setting
|
21 |
+
return x[:,:,l:-r,t:-b]
|
22 |
+
|
23 |
+
|
24 |
+
def down_block(in_filters, out_filters):
|
25 |
+
return nn.Sequential(CustomPad(),
|
26 |
+
nn.Conv2d(in_filters, out_filters, kernel_size=5, stride=2,padding=0)), \
|
27 |
+
nn.Sequential(
|
28 |
+
nn.BatchNorm2d(out_filters, track_running_stats=True, eps=1e-3, momentum=0.01),
|
29 |
+
nn.LeakyReLU(0.2))
|
30 |
+
|
31 |
+
|
32 |
+
def up_block(in_filters, out_filters, dropout=False):
|
33 |
+
layers = [
|
34 |
+
nn.ConvTranspose2d(in_filters, out_filters, kernel_size=5,stride=2),
|
35 |
+
CustomTransposedPad(),
|
36 |
+
nn.ReLU(),
|
37 |
+
nn.BatchNorm2d(out_filters, track_running_stats=True, eps=1e-3, momentum=0.01)
|
38 |
+
]
|
39 |
+
if dropout:
|
40 |
+
layers.append(nn.Dropout(0.5))
|
41 |
+
|
42 |
+
return nn.Sequential(*layers)
|
43 |
+
|
44 |
+
|
45 |
+
class UNet(nn.Module):
|
46 |
+
def __init__(self, in_channels=2):
|
47 |
+
super(UNet, self).__init__()
|
48 |
+
self.down1_conv, self.down1_act = down_block(in_channels, 16)
|
49 |
+
self.down2_conv, self.down2_act = down_block(16, 32)
|
50 |
+
self.down3_conv, self.down3_act = down_block(32, 64)
|
51 |
+
self.down4_conv, self.down4_act = down_block(64, 128)
|
52 |
+
self.down5_conv, self.down5_act = down_block(128, 256)
|
53 |
+
self.down6_conv, self.down6_act = down_block(256, 512)
|
54 |
+
|
55 |
+
self.up1 = up_block(512, 256, dropout=True)
|
56 |
+
self.up2 = up_block(512, 128, dropout=True)
|
57 |
+
self.up3 = up_block(256, 64, dropout=True)
|
58 |
+
self.up4 = up_block(128, 32)
|
59 |
+
self.up5 = up_block(64, 16)
|
60 |
+
self.up6 = up_block(32, 1)
|
61 |
+
self.up7 = nn.Sequential(
|
62 |
+
nn.Conv2d(1, in_channels, kernel_size=4, dilation=2, padding=3),
|
63 |
+
nn.Sigmoid()
|
64 |
+
)
|
65 |
+
|
66 |
+
def forward(self, x):
|
67 |
+
d1_conv = self.down1_conv(x)
|
68 |
+
d1 = self.down1_act(d1_conv)
|
69 |
+
|
70 |
+
d2_conv = self.down2_conv(d1)
|
71 |
+
d2 = self.down2_act(d2_conv)
|
72 |
+
|
73 |
+
d3_conv = self.down3_conv(d2)
|
74 |
+
d3 = self.down3_act(d3_conv)
|
75 |
+
|
76 |
+
d4_conv = self.down4_conv(d3)
|
77 |
+
d4 = self.down4_act(d4_conv)
|
78 |
+
|
79 |
+
d5_conv = self.down5_conv(d4)
|
80 |
+
d5 = self.down5_act(d5_conv)
|
81 |
+
|
82 |
+
d6_conv = self.down6_conv(d5)
|
83 |
+
d6 = self.down6_act(d6_conv)
|
84 |
+
|
85 |
+
u1 = self.up1(d6_conv)
|
86 |
+
u2 = self.up2(torch.cat([d5_conv, u1], axis=1))
|
87 |
+
u3 = self.up3(torch.cat([d4_conv, u2], axis=1))
|
88 |
+
u4 = self.up4(torch.cat([d3_conv, u3], axis=1))
|
89 |
+
u5 = self.up5(torch.cat([d2_conv, u4], axis=1))
|
90 |
+
u6 = self.up6(torch.cat([d1_conv, u5], axis=1))
|
91 |
+
u7 = self.up7(u6)
|
92 |
+
return u7 * x
|
93 |
+
|
94 |
+
|
95 |
+
if __name__ == '__main__':
|
96 |
+
net = UNet(14)
|
97 |
+
print(net(torch.rand(1, 14, 20, 48)).shape)
|