Spaces:
Runtime error
Runtime error
my
commited on
Commit
·
32ca76b
1
Parent(s):
84c0c04
Add application file
Browse files- requirements.txt +3 -0
- .gitignore +146 -0
- app.py +143 -0
- models/__init__.py +0 -0
- models/hinet.py +20 -0
- models/invblock.py +36 -0
- models/module_util.py +79 -0
- models/my_model_v7_recover.py +95 -0
- models/rrdb_denselayer.py +25 -0
- utils/__init__.py +0 -0
- utils/bin_util.py +104 -0
- utils/file_reader.py +77 -0
- utils/metric_util.py +88 -0
- utils/model_util.py +118 -0
- utils/pesq_util.py +25 -0
- utils/pickle_util.py +27 -0
- utils/silent_util.py +18 -0
- utils/wm_add_v2.py +87 -0
- utils/wm_decode_v2.py +113 -0
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
torch==1.13.1
|
2 |
+
torchvision==0.14.1
|
3 |
+
torchaudio==0.13.1
|
.gitignore
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.DS_Store
|
2 |
+
|
3 |
+
#idea
|
4 |
+
.idea
|
5 |
+
wandb/
|
6 |
+
temp/
|
7 |
+
data/
|
8 |
+
|
9 |
+
# Byte-compiled / optimized / DLL files
|
10 |
+
__pycache__/
|
11 |
+
*.py[cod]
|
12 |
+
*$py.class
|
13 |
+
|
14 |
+
# C extensions
|
15 |
+
*.so
|
16 |
+
|
17 |
+
# Distribution / packaging
|
18 |
+
.Python
|
19 |
+
build/
|
20 |
+
develop-eggs/
|
21 |
+
dist/
|
22 |
+
downloads/
|
23 |
+
eggs/
|
24 |
+
.eggs/
|
25 |
+
lib/
|
26 |
+
lib64/
|
27 |
+
parts/
|
28 |
+
sdist/
|
29 |
+
var/
|
30 |
+
wheels/
|
31 |
+
share/python-wheels/
|
32 |
+
*.egg-info/
|
33 |
+
.installed.cfg
|
34 |
+
*.egg
|
35 |
+
MANIFEST
|
36 |
+
|
37 |
+
# PyInstaller
|
38 |
+
# Usually these files are written by a python script from a template
|
39 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
40 |
+
*.manifest
|
41 |
+
*.spec
|
42 |
+
|
43 |
+
# Installer logs
|
44 |
+
pip-log.txt
|
45 |
+
pip-delete-this-directory.txt
|
46 |
+
|
47 |
+
# Unit test / coverage reports
|
48 |
+
htmlcov/
|
49 |
+
.tox/
|
50 |
+
.nox/
|
51 |
+
.coverage
|
52 |
+
.coverage.*
|
53 |
+
.cache
|
54 |
+
nosetests.xml
|
55 |
+
coverage.xml
|
56 |
+
*.cover
|
57 |
+
*.py,cover
|
58 |
+
.hypothesis/
|
59 |
+
.pytest_cache/
|
60 |
+
cover/
|
61 |
+
|
62 |
+
# Translations
|
63 |
+
*.mo
|
64 |
+
*.pot
|
65 |
+
|
66 |
+
# Django stuff:
|
67 |
+
*.log
|
68 |
+
local_settings.py
|
69 |
+
db.sqlite3
|
70 |
+
db.sqlite3-journal
|
71 |
+
|
72 |
+
# Flask stuff:
|
73 |
+
instance/
|
74 |
+
.webassets-cache
|
75 |
+
|
76 |
+
# Scrapy stuff:
|
77 |
+
.scrapy
|
78 |
+
|
79 |
+
# Sphinx documentation
|
80 |
+
docs/_build/
|
81 |
+
|
82 |
+
# PyBuilder
|
83 |
+
.pybuilder/
|
84 |
+
target/
|
85 |
+
|
86 |
+
# Jupyter Notebook
|
87 |
+
.ipynb_checkpoints
|
88 |
+
|
89 |
+
# IPython
|
90 |
+
profile_default/
|
91 |
+
ipython_config.py
|
92 |
+
|
93 |
+
# pyenv
|
94 |
+
# For a library or package, you might want to ignore these files since the code is
|
95 |
+
# intended to run in multiple environments; otherwise, check them in:
|
96 |
+
# .python-version
|
97 |
+
|
98 |
+
# pipenv
|
99 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
100 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
101 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
102 |
+
# install all needed dependencies.
|
103 |
+
#Pipfile.lock
|
104 |
+
|
105 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
106 |
+
__pypackages__/
|
107 |
+
|
108 |
+
# Celery stuff
|
109 |
+
celerybeat-schedule
|
110 |
+
celerybeat.pid
|
111 |
+
|
112 |
+
# SageMath parsed files
|
113 |
+
*.sage.py
|
114 |
+
|
115 |
+
# Environments
|
116 |
+
.env
|
117 |
+
.venv
|
118 |
+
env/
|
119 |
+
venv/
|
120 |
+
ENV/
|
121 |
+
env.bak/
|
122 |
+
venv.bak/
|
123 |
+
|
124 |
+
# Spyder project settings
|
125 |
+
.spyderproject
|
126 |
+
.spyproject
|
127 |
+
|
128 |
+
# Rope project settings
|
129 |
+
.ropeproject
|
130 |
+
|
131 |
+
# mkdocs documentation
|
132 |
+
/site
|
133 |
+
|
134 |
+
# mypy
|
135 |
+
.mypy_cache/
|
136 |
+
.dmypy.json
|
137 |
+
dmypy.json
|
138 |
+
|
139 |
+
# Pyre type checker
|
140 |
+
.pyre/
|
141 |
+
|
142 |
+
# pytype static type analyzer
|
143 |
+
.pytype/
|
144 |
+
|
145 |
+
# Cython debug symbols
|
146 |
+
cython_debug/
|
app.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pdb
|
2 |
+
import time
|
3 |
+
|
4 |
+
import soundfile
|
5 |
+
import streamlit as st
|
6 |
+
import os
|
7 |
+
from utils import wm_add_v2, file_reader, model_util, wm_decode_v2, bin_util
|
8 |
+
from models import my_model_v7_recover
|
9 |
+
import torch
|
10 |
+
import uuid
|
11 |
+
import datetime
|
12 |
+
import numpy as np
|
13 |
+
from huggingface_hub import hf_hub_download, HfApi
|
14 |
+
|
15 |
+
|
16 |
+
# Function to add watermark to audio
|
17 |
+
def add_watermark(audio_path, watermark_text):
|
18 |
+
assert len(watermark_text) == 5
|
19 |
+
|
20 |
+
start_bit, msg_bit, watermark = wm_add_v2.create_parcel_message(len_start_bit, 32, watermark_text)
|
21 |
+
|
22 |
+
data, sr, audio_length_second = file_reader.read_as_single_channel_16k(audio_path, 16000)
|
23 |
+
|
24 |
+
_, signal_wmd, time_cost = wm_add_v2.add_watermark(watermark, data, 16000, 0.1, device, model)
|
25 |
+
|
26 |
+
tmp_file_name = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + "_" + str(uuid.uuid4()) + ".wav"
|
27 |
+
tmp_file_path = 'temp/' + tmp_file_name
|
28 |
+
soundfile.write(tmp_file_path, signal_wmd, sr)
|
29 |
+
return tmp_file_path
|
30 |
+
|
31 |
+
|
32 |
+
# Function to decode watermark from audio
|
33 |
+
def decode_watermark(audio_path):
|
34 |
+
data, sr, audio_length_second = file_reader.read_as_single_channel_16k(audio_path, 16000)
|
35 |
+
data = data[0:5 * sr]
|
36 |
+
start_bit = wm_add_v2.fix_pattern[0:len_start_bit]
|
37 |
+
support_count, mean_result, results = wm_decode_v2.extract_watermark_v2(
|
38 |
+
data,
|
39 |
+
start_bit,
|
40 |
+
0.1,
|
41 |
+
16000,
|
42 |
+
0.3,
|
43 |
+
model,
|
44 |
+
device, "best")
|
45 |
+
|
46 |
+
if mean_result is None:
|
47 |
+
return "No Watermark"
|
48 |
+
|
49 |
+
payload = mean_result[len_start_bit:]
|
50 |
+
return bin_util.binArray2HexStr(payload)
|
51 |
+
|
52 |
+
|
53 |
+
# Main web app
|
54 |
+
def main():
|
55 |
+
if "def_value" not in st.session_state:
|
56 |
+
st.session_state.def_value = bin_util.binArray2HexStr(np.random.choice([0, 1], size=32 - len_start_bit))
|
57 |
+
|
58 |
+
st.title("Neural Audio Watermark")
|
59 |
+
st.write("Choose the action you want to perform:")
|
60 |
+
|
61 |
+
action = st.selectbox("Select Action", ["Add Watermark", "Decode Watermark"])
|
62 |
+
|
63 |
+
if action == "Add Watermark":
|
64 |
+
audio_file = st.file_uploader("Upload Audio File (WAV)", type=["wav"], accept_multiple_files=False)
|
65 |
+
if audio_file:
|
66 |
+
tmp_input_audio_file = os.path.join("temp", audio_file.name)
|
67 |
+
with open(tmp_input_audio_file, "wb") as f:
|
68 |
+
f.write(audio_file.getbuffer())
|
69 |
+
st.audio(tmp_input_audio_file, format="audio/wav")
|
70 |
+
|
71 |
+
watermark_text = st.text_input("Enter Watermark Text (5 English letters)", value=st.session_state.def_value)
|
72 |
+
|
73 |
+
add_watermark_button = st.button("Add Watermark", key="add_watermark_btn")
|
74 |
+
if add_watermark_button: # 点击按钮后执行的
|
75 |
+
if audio_file and watermark_text:
|
76 |
+
with st.spinner("Adding Watermark..."):
|
77 |
+
# add_watermark_button.empty()
|
78 |
+
# st.button("Add Watermark", disabled=True)
|
79 |
+
# st.button("Add Watermark", disabled=True, key="add_watermark_btn_disabled")
|
80 |
+
t1 = time.time()
|
81 |
+
|
82 |
+
watermarked_audio = add_watermark(tmp_input_audio_file, watermark_text)
|
83 |
+
encode_time_cost = time.time() - t1
|
84 |
+
|
85 |
+
st.write("Watermarked Audio:")
|
86 |
+
st.audio(watermarked_audio, format="audio/wav")
|
87 |
+
st.write("Time Cost:%d seconds" % encode_time_cost)
|
88 |
+
|
89 |
+
# st.button("Add Watermark", disabled=False)
|
90 |
+
|
91 |
+
elif action == "Decode Watermark":
|
92 |
+
audio_file = st.file_uploader("Upload Audio File (WAV/MP3)", type=["wav", "mp3"], accept_multiple_files=False)
|
93 |
+
if audio_file:
|
94 |
+
if st.button("Decode Watermark"):
|
95 |
+
# 1.保存
|
96 |
+
tmp_file_for_decode_path = os.path.join("temp", audio_file.name)
|
97 |
+
with open(tmp_file_for_decode_path, "wb") as f:
|
98 |
+
f.write(audio_file.getbuffer())
|
99 |
+
|
100 |
+
# 2.执行
|
101 |
+
with st.spinner("Decoding..."):
|
102 |
+
t1 = time.time()
|
103 |
+
decoded_watermark = decode_watermark(tmp_file_for_decode_path)
|
104 |
+
decode_cost = time.time() - t1
|
105 |
+
|
106 |
+
print("decoded_watermark", decoded_watermark)
|
107 |
+
# Display the decoded watermark
|
108 |
+
st.write("Decoded Watermark:", decoded_watermark)
|
109 |
+
st.write("Time Cost:%d seconds" % (decode_cost))
|
110 |
+
|
111 |
+
|
112 |
+
def load_model(resume_path):
|
113 |
+
n_fft = 1000
|
114 |
+
hop_length = 400
|
115 |
+
# https://huggingface.co/M4869/InvertibleWM/blob/main/step59000_snr39.99_pesq4.35_BERP_none0.30_mean1.81_std1.81.pkl
|
116 |
+
api_key = st.secrets["api_key"]
|
117 |
+
print(api_key, api_key)
|
118 |
+
model_ckpt_path = hf_hub_download(repo_id="M4869/InvertibleWM",
|
119 |
+
filename="step59000_snr39.99_pesq4.35_BERP_none0.30_mean1.81_std1.81.pkl",
|
120 |
+
token=api_key
|
121 |
+
)
|
122 |
+
# print("model_ckpt_path", model_ckpt_path)
|
123 |
+
resume_path = model_ckpt_path
|
124 |
+
# return
|
125 |
+
|
126 |
+
model = my_model_v7_recover.Model(16000, 32, n_fft, hop_length,
|
127 |
+
use_recover_layer=False, num_layers=8).to(device)
|
128 |
+
checkpoint = torch.load(resume_path, map_location=torch.device('cpu'))
|
129 |
+
state_dict = model_util.map_state_dict(checkpoint['model'])
|
130 |
+
model.load_state_dict(state_dict, strict=True)
|
131 |
+
model.eval()
|
132 |
+
return model
|
133 |
+
|
134 |
+
|
135 |
+
if __name__ == "__main__":
|
136 |
+
len_start_bit = 12
|
137 |
+
|
138 |
+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
139 |
+
|
140 |
+
model = load_model("./data/step59000_snr39.99_pesq4.35_BERP_none0.30_mean1.81_std1.81.pkl")
|
141 |
+
|
142 |
+
main()
|
143 |
+
# decode_watermark("/Users/my/Downloads/7a95b353a46893903e9f946c24170b210ce14e8c52c63bb2ab3d144e.wav")
|
models/__init__.py
ADDED
File without changes
|
models/hinet.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from models.invblock import INV_block
|
3 |
+
|
4 |
+
|
5 |
+
class Hinet(torch.nn.Module):
|
6 |
+
|
7 |
+
def __init__(self, in_channel=2, num_layers=16):
|
8 |
+
super(Hinet, self).__init__()
|
9 |
+
self.inv_blocks = torch.nn.ModuleList([INV_block(in_channel) for _ in range(num_layers)])
|
10 |
+
|
11 |
+
def forward(self, x1, x2, rev=False):
|
12 |
+
# x1:cover
|
13 |
+
# x2:secret
|
14 |
+
if not rev:
|
15 |
+
for inv_block in self.inv_blocks:
|
16 |
+
x1, x2 = inv_block(x1, x2)
|
17 |
+
else:
|
18 |
+
for inv_block in reversed(self.inv_blocks):
|
19 |
+
x1, x2 = inv_block(x1, x2, rev=True)
|
20 |
+
return x1, x2
|
models/invblock.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from models.rrdb_denselayer import ResidualDenseBlock_out
|
4 |
+
|
5 |
+
|
6 |
+
class INV_block(nn.Module):
|
7 |
+
def __init__(self, channel=2, subnet_constructor=ResidualDenseBlock_out, clamp=2.0):
|
8 |
+
super().__init__()
|
9 |
+
self.clamp = clamp
|
10 |
+
|
11 |
+
# ρ
|
12 |
+
self.r = subnet_constructor(channel, channel)
|
13 |
+
# η
|
14 |
+
self.y = subnet_constructor(channel, channel)
|
15 |
+
# φ
|
16 |
+
self.f = subnet_constructor(channel, channel)
|
17 |
+
|
18 |
+
def e(self, s):
|
19 |
+
return torch.exp(self.clamp * 2 * (torch.sigmoid(s) - 0.5))
|
20 |
+
|
21 |
+
def forward(self, x1, x2, rev=False):
|
22 |
+
if not rev:
|
23 |
+
|
24 |
+
t2 = self.f(x2)
|
25 |
+
y1 = x1 + t2
|
26 |
+
s1, t1 = self.r(y1), self.y(y1)
|
27 |
+
y2 = self.e(s1) * x2 + t1
|
28 |
+
|
29 |
+
else:
|
30 |
+
|
31 |
+
s1, t1 = self.r(x1), self.y(x1)
|
32 |
+
y2 = (x2 - t1) / self.e(s1)
|
33 |
+
t2 = self.f(y2)
|
34 |
+
y1 = (x1 - t2)
|
35 |
+
|
36 |
+
return y1, y2
|
models/module_util.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.init as init
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
|
7 |
+
def initialize_weights(net_l, scale=1):
|
8 |
+
if not isinstance(net_l, list):
|
9 |
+
net_l = [net_l]
|
10 |
+
for net in net_l:
|
11 |
+
for m in net.modules():
|
12 |
+
if isinstance(m, nn.Conv2d):
|
13 |
+
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
14 |
+
m.weight.data *= scale # for residual block
|
15 |
+
if m.bias is not None:
|
16 |
+
m.bias.data.zero_()
|
17 |
+
elif isinstance(m, nn.Linear):
|
18 |
+
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
19 |
+
m.weight.data *= scale
|
20 |
+
if m.bias is not None:
|
21 |
+
m.bias.data.zero_()
|
22 |
+
elif isinstance(m, nn.BatchNorm2d):
|
23 |
+
init.constant_(m.weight, 1)
|
24 |
+
init.constant_(m.bias.data, 0.0)
|
25 |
+
|
26 |
+
|
27 |
+
def make_layer(block, n_layers):
|
28 |
+
layers = []
|
29 |
+
for _ in range(n_layers):
|
30 |
+
layers.append(block())
|
31 |
+
return nn.Sequential(*layers)
|
32 |
+
|
33 |
+
|
34 |
+
class ResidualBlock_noBN(nn.Module):
|
35 |
+
'''Residual block w/o BN
|
36 |
+
---Conv-ReLU-Conv-+-
|
37 |
+
|________________|
|
38 |
+
'''
|
39 |
+
|
40 |
+
def __init__(self, nf=64):
|
41 |
+
super(ResidualBlock_noBN, self).__init__()
|
42 |
+
self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
43 |
+
self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
44 |
+
|
45 |
+
# initialization
|
46 |
+
initialize_weights([self.conv1, self.conv2], 0.1)
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
identity = x
|
50 |
+
out = F.relu(self.conv1(x), inplace=True)
|
51 |
+
out = self.conv2(out)
|
52 |
+
return identity + out
|
53 |
+
|
54 |
+
|
55 |
+
def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'):
|
56 |
+
"""Warp an image or feature map with optical flow
|
57 |
+
Args:
|
58 |
+
x (Tensor): size (N, C, H, W)
|
59 |
+
flow (Tensor): size (N, H, W, 2), normal value
|
60 |
+
interp_mode (str): 'nearest' or 'bilinear'
|
61 |
+
padding_mode (str): 'zeros' or 'border' or 'reflection'
|
62 |
+
Returns:
|
63 |
+
Tensor: warped image or feature map
|
64 |
+
"""
|
65 |
+
flow = flow.permute(0,2,3,1)
|
66 |
+
assert x.size()[-2:] == flow.size()[1:3]
|
67 |
+
B, C, H, W = x.size()
|
68 |
+
# mesh grid
|
69 |
+
grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))
|
70 |
+
grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
|
71 |
+
grid.requires_grad = False
|
72 |
+
grid = grid.type_as(x)
|
73 |
+
vgrid = grid + flow
|
74 |
+
# scale grid to [-1,1]
|
75 |
+
vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0
|
76 |
+
vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0
|
77 |
+
vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
|
78 |
+
output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode)
|
79 |
+
return output
|
models/my_model_v7_recover.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pdb
|
2 |
+
|
3 |
+
import torch.optim
|
4 |
+
import torch.nn as nn
|
5 |
+
from models.hinet import Hinet
|
6 |
+
# from utils.attacks import attack_layer, mp3_attack_v2, butterworth_attack
|
7 |
+
import numpy as np
|
8 |
+
import random
|
9 |
+
|
10 |
+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
11 |
+
|
12 |
+
|
13 |
+
class Model(nn.Module):
|
14 |
+
def __init__(self, num_point, num_bit, n_fft, hop_length, use_recover_layer, num_layers):
|
15 |
+
super(Model, self).__init__()
|
16 |
+
self.hinet = Hinet(num_layers=num_layers)
|
17 |
+
self.watermark_fc = torch.nn.Linear(num_bit, num_point)
|
18 |
+
self.watermark_fc_back = torch.nn.Linear(num_point, num_bit)
|
19 |
+
self.n_fft = n_fft
|
20 |
+
self.hop_length = hop_length
|
21 |
+
self.dropout1 = torch.nn.Dropout()
|
22 |
+
self.identity = torch.nn.Identity()
|
23 |
+
self.recover_layer = SameSizeConv2d(2, 2)
|
24 |
+
self.use_recover_layer = use_recover_layer
|
25 |
+
|
26 |
+
def stft(self, data):
|
27 |
+
window = torch.hann_window(self.n_fft).to(data.device)
|
28 |
+
tmp = torch.stft(data, n_fft=self.n_fft, hop_length=self.hop_length, window=window, return_complex=False)
|
29 |
+
# [1, 501, 41, 2]
|
30 |
+
return tmp
|
31 |
+
|
32 |
+
def istft(self, signal_wmd_fft):
|
33 |
+
window = torch.hann_window(self.n_fft).to(signal_wmd_fft.device)
|
34 |
+
|
35 |
+
# Changed in version 2.0: Real datatype inputs are no longer supported. Input must now have a complex datatype, as returned by stft(..., return_complex=True).
|
36 |
+
|
37 |
+
return torch.istft(signal_wmd_fft, n_fft=self.n_fft, hop_length=self.hop_length, window=window,
|
38 |
+
return_complex=False)
|
39 |
+
|
40 |
+
def encode(self, signal, message, need_fft=False):
|
41 |
+
# 1.信号执行fft
|
42 |
+
signal_fft = self.stft(signal)
|
43 |
+
# import pdb
|
44 |
+
# pdb.set_trace()
|
45 |
+
# (batch,freq_bins,time_frames,2)
|
46 |
+
|
47 |
+
# 2.Message执行fft
|
48 |
+
message_expand = self.watermark_fc(message)
|
49 |
+
message_fft = self.stft(message_expand)
|
50 |
+
|
51 |
+
# 3.encode
|
52 |
+
signal_wmd_fft, msg_remain = self.enc_dec(signal_fft, message_fft, rev=False)
|
53 |
+
# (batch,freq_bins,time_frames,2)
|
54 |
+
signal_wmd = self.istft(signal_wmd_fft)
|
55 |
+
if need_fft:
|
56 |
+
return signal_wmd, signal_fft, message_fft
|
57 |
+
|
58 |
+
return signal_wmd
|
59 |
+
|
60 |
+
def decode(self, signal):
|
61 |
+
signal_fft = self.stft(signal)
|
62 |
+
if self.use_recover_layer:
|
63 |
+
signal_fft = self.recover_layer(signal_fft)
|
64 |
+
watermark_fft = signal_fft
|
65 |
+
# watermark_fft = torch.randn(signal_fft.shape).cuda()
|
66 |
+
_, message_restored_fft = self.enc_dec(signal_fft, watermark_fft, rev=True)
|
67 |
+
message_restored_expanded = self.istft(message_restored_fft)
|
68 |
+
message_restored_float = self.watermark_fc_back(message_restored_expanded).clamp(-1, 1)
|
69 |
+
return message_restored_float
|
70 |
+
|
71 |
+
def enc_dec(self, signal, watermark, rev):
|
72 |
+
signal = signal.permute(0, 3, 2, 1)
|
73 |
+
# [4, 2, 41, 501]
|
74 |
+
|
75 |
+
watermark = watermark.permute(0, 3, 2, 1)
|
76 |
+
|
77 |
+
# pdb.set_trace()
|
78 |
+
signal2, watermark2 = self.hinet(signal, watermark, rev)
|
79 |
+
return signal2.permute(0, 3, 2, 1), watermark2.permute(0, 3, 2, 1)
|
80 |
+
|
81 |
+
|
82 |
+
class SameSizeConv2d(nn.Module):
|
83 |
+
def __init__(self, in_channels, out_channels):
|
84 |
+
super(SameSizeConv2d, self).__init__()
|
85 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
86 |
+
|
87 |
+
def forward(self, x):
|
88 |
+
# (batch,501,41,2]
|
89 |
+
x1 = x.permute(0, 3, 1, 2)
|
90 |
+
# (batch,2,501,41]
|
91 |
+
x2 = self.conv(x1)
|
92 |
+
# (batch,2,501,41]
|
93 |
+
x3 = x2.permute(0, 2, 3, 1)
|
94 |
+
# (batch,501,41,2]
|
95 |
+
return x3
|
models/rrdb_denselayer.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import models.module_util as mutil
|
4 |
+
|
5 |
+
|
6 |
+
# Dense connection
|
7 |
+
class ResidualDenseBlock_out(nn.Module):
|
8 |
+
def __init__(self, in_channel, out_channel, bias=True):
|
9 |
+
super(ResidualDenseBlock_out, self).__init__()
|
10 |
+
self.conv1 = nn.Conv2d(in_channel, 32, 3, 1, 1, bias=bias)
|
11 |
+
self.conv2 = nn.Conv2d(in_channel + 32, 32, 3, 1, 1, bias=bias)
|
12 |
+
self.conv3 = nn.Conv2d(in_channel + 2 * 32, 32, 3, 1, 1, bias=bias)
|
13 |
+
self.conv4 = nn.Conv2d(in_channel + 3 * 32, 32, 3, 1, 1, bias=bias)
|
14 |
+
self.conv5 = nn.Conv2d(in_channel + 4 * 32, out_channel, 3, 1, 1, bias=bias)
|
15 |
+
self.lrelu = nn.LeakyReLU(inplace=True)
|
16 |
+
# initialization
|
17 |
+
mutil.initialize_weights([self.conv5], 0.)
|
18 |
+
|
19 |
+
def forward(self, x):
|
20 |
+
x1 = self.lrelu(self.conv1(x))
|
21 |
+
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
22 |
+
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
23 |
+
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
24 |
+
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
25 |
+
return x5
|
utils/__init__.py
ADDED
File without changes
|
utils/bin_util.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
def hexChar2binStr(v):
|
5 |
+
assert len(v) == 1
|
6 |
+
# e => '1110'
|
7 |
+
return '{0:04b}'.format(int(v, 16))
|
8 |
+
|
9 |
+
|
10 |
+
def hexStr2BinStr(hex_str):
|
11 |
+
output = [hexChar2binStr(c) for c in hex_str]
|
12 |
+
# ['1110', '1100', ....]
|
13 |
+
return "".join(output)
|
14 |
+
|
15 |
+
|
16 |
+
def hexStr2BinArray(hex_str):
|
17 |
+
# 十六进制字符串==> 0,1g构成的数组
|
18 |
+
tmp = hexStr2BinStr(hex_str)
|
19 |
+
return np.array([int(i) for i in tmp])
|
20 |
+
|
21 |
+
|
22 |
+
def binStr2HexStr(binary_str):
|
23 |
+
return hex(int(binary_str, 2))[2:]
|
24 |
+
|
25 |
+
|
26 |
+
def binArray2HexStr(bin_array):
|
27 |
+
tmp = "".join(["%d" % i for i in bin_array])
|
28 |
+
return binStr2HexStr(tmp)
|
29 |
+
|
30 |
+
|
31 |
+
# 判断是否为合法的16进制字符串
|
32 |
+
def is_hex_str(s):
|
33 |
+
hex_chars = "0123456789abcdefABCDEF"
|
34 |
+
return all(c in hex_chars for c in s)
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
def flip_bytearray(input_bytearray, num_bits_to_flip):
|
40 |
+
tmp = bytearray_to_binary_list(input_bytearray)
|
41 |
+
tmp = flip_array(tmp,num_bits_to_flip)
|
42 |
+
return binary_list_to_bytearray(tmp)
|
43 |
+
|
44 |
+
def flip_array(input_bits, num_bits_to_flip):
|
45 |
+
|
46 |
+
# 随机选择要翻转的位的索引
|
47 |
+
flip_indices = np.random.choice(len(input_bits), num_bits_to_flip, replace=False)
|
48 |
+
|
49 |
+
# 创建一个全零的掩码数组
|
50 |
+
mask = np.zeros_like(input_bits)
|
51 |
+
|
52 |
+
# 将选定的索引设置为 1
|
53 |
+
mask[flip_indices] = 1
|
54 |
+
|
55 |
+
# 将输入位数组与掩码进行逐元素异或运算,实现翻转位
|
56 |
+
flipped_bits = input_bits ^ mask
|
57 |
+
return flipped_bits
|
58 |
+
|
59 |
+
|
60 |
+
|
61 |
+
def bytearray_to_binary_list(byte_array):
|
62 |
+
binary_list = []
|
63 |
+
for byte in byte_array:
|
64 |
+
binary_str = format(byte, '08b') # 将字节转换为 8 位二进制字符串
|
65 |
+
binary_digits = [int(bit) for bit in binary_str] # 将二进制字符串转换为整数列表
|
66 |
+
binary_list.extend(binary_digits) # 将整数列表添加到结果列表中
|
67 |
+
return binary_list
|
68 |
+
|
69 |
+
|
70 |
+
def binary_list_to_bytearray(binary_list):
|
71 |
+
# 这个函数假设输入列表的长度是 8 的倍数,否则将引发异常。
|
72 |
+
byte_list = []
|
73 |
+
for i in range(0, len(binary_list), 8):
|
74 |
+
binary_str = ''.join(str(bit) for bit in binary_list[i:i + 8]) # 将 8 个位连接为一个二进制字符串
|
75 |
+
byte_value = int(binary_str, 2) # 将二进制字符串转换为整数
|
76 |
+
byte_list.append(byte_value) # 将整数添加到字节列表中
|
77 |
+
return bytearray(byte_list)
|
78 |
+
|
79 |
+
|
80 |
+
|
81 |
+
|
82 |
+
if __name__ == "__main__":
|
83 |
+
# hex_str = "ecd057f0d1fbb25d6430b338b5d72eb2"
|
84 |
+
# arr = hexStr2BinArray(hex_str)
|
85 |
+
# out = binArray2HexStr(arr)
|
86 |
+
# print(out==hex_str)
|
87 |
+
# bin_str = "".join()
|
88 |
+
# assert bin2hex_str(bin_str) == hex_str
|
89 |
+
# print(bin_str, len(bin_str))
|
90 |
+
#
|
91 |
+
watermark = np.random.randint(2, size=44)
|
92 |
+
res = binArray2HexStr(watermark)
|
93 |
+
print(res)
|
94 |
+
|
95 |
+
test_str1 = "3ad30c748a2"
|
96 |
+
test_str2 = "3ad30Z748a2"
|
97 |
+
|
98 |
+
print(is_hex_str(test_str1)) # 输出 True
|
99 |
+
print(is_hex_str(test_str2)) # 输出 False
|
100 |
+
|
101 |
+
|
102 |
+
# encode_file("1.wav", watermark)
|
103 |
+
# out = decode_file("tmp_output.wav")
|
104 |
+
# assert np.all(watermark == out)
|
utils/file_reader.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import soundfile
|
3 |
+
import librosa
|
4 |
+
import resampy
|
5 |
+
|
6 |
+
|
7 |
+
def is_wav_file(filename):
|
8 |
+
# 获取文件扩展名
|
9 |
+
file_extension = os.path.splitext(filename)[1]
|
10 |
+
|
11 |
+
# 判断文件扩展名是否为'.wav'或'.WAV'
|
12 |
+
return file_extension.lower() == ".wav"
|
13 |
+
|
14 |
+
|
15 |
+
import numpy as np
|
16 |
+
|
17 |
+
|
18 |
+
def read_as_single_channel_16k(audio_file, def_sr, verbose=False, aim_second=None):
|
19 |
+
assert os.path.exists(audio_file), "音频文件不存在"
|
20 |
+
|
21 |
+
file_extension = os.path.splitext(audio_file)[1].lower()
|
22 |
+
|
23 |
+
if file_extension == ".mp3":
|
24 |
+
data, origin_sr = librosa.load(audio_file, sr=None)
|
25 |
+
elif file_extension in [".wav", ".flac"]:
|
26 |
+
data, origin_sr = soundfile.read(audio_file)
|
27 |
+
else:
|
28 |
+
raise Exception("不支持的文件类型:" + file_extension)
|
29 |
+
|
30 |
+
# 通道数
|
31 |
+
if len(data.shape) == 2:
|
32 |
+
left_channel = data[:, 0]
|
33 |
+
if verbose:
|
34 |
+
print("双通道文件,变为单通道")
|
35 |
+
data = left_channel
|
36 |
+
|
37 |
+
# 采样率
|
38 |
+
if origin_sr != def_sr:
|
39 |
+
data = resampy.resample(data, origin_sr, def_sr)
|
40 |
+
if verbose:
|
41 |
+
print("原始音频采样率不是16kHZ,可能会对水印性能造成影响")
|
42 |
+
|
43 |
+
sr = def_sr
|
44 |
+
audio_length_second = 1.0 * len(data) / sr
|
45 |
+
if verbose:
|
46 |
+
print("输入音频长度:%d秒" % audio_length_second)
|
47 |
+
|
48 |
+
# 判断通道数
|
49 |
+
if len(data.shape) == 2:
|
50 |
+
data = data[:, 0]
|
51 |
+
print("选取第一个通道")
|
52 |
+
|
53 |
+
if aim_second is not None:
|
54 |
+
signal = data
|
55 |
+
assert len(signal) > 0
|
56 |
+
current_second = len(signal) / sr
|
57 |
+
if current_second < aim_second:
|
58 |
+
repeat_count = int(aim_second / current_second) + 1
|
59 |
+
signal = np.repeat(signal, repeat_count)
|
60 |
+
data = signal[0:sr * aim_second]
|
61 |
+
|
62 |
+
return data, sr, audio_length_second
|
63 |
+
|
64 |
+
|
65 |
+
def read_as_single_channel(file, aim_sr):
|
66 |
+
if file.endswith(".mp3"):
|
67 |
+
data, sr = librosa.load(file, sr=aim_sr) # 这里默认就是会转换为输入的sr
|
68 |
+
else:
|
69 |
+
data, sr = soundfile.read(file)
|
70 |
+
|
71 |
+
if len(data.shape) == 2: # 双声道
|
72 |
+
data = data[:, 0] # 只要第一个声道
|
73 |
+
|
74 |
+
# 然后再切换sr,因为soundfile可能读取出一个双通道的东西
|
75 |
+
if sr != aim_sr:
|
76 |
+
data = resampy.resample(data, sr, aim_sr)
|
77 |
+
return data
|
utils/metric_util.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
def calc_ber(watermark_decoded_tensor, watermark_tensor, threshold=0.5):
|
7 |
+
watermark_decoded_binary = watermark_decoded_tensor >= threshold
|
8 |
+
watermark_binary = watermark_tensor >= threshold
|
9 |
+
ber_tensor = 1 - (watermark_decoded_binary == watermark_binary).to(torch.float32).mean()
|
10 |
+
return ber_tensor
|
11 |
+
|
12 |
+
|
13 |
+
def to_equal_length(original, signal_watermarked):
|
14 |
+
if original.shape != signal_watermarked.shape:
|
15 |
+
print("警告!输入内容长度不一致", len(original), len(signal_watermarked))
|
16 |
+
min_length = min(len(original), len(signal_watermarked))
|
17 |
+
original = original[0:min_length]
|
18 |
+
signal_watermarked = signal_watermarked[0:min_length]
|
19 |
+
assert original.shape == signal_watermarked.shape
|
20 |
+
return original, signal_watermarked
|
21 |
+
|
22 |
+
|
23 |
+
def signal_noise_ratio(original, signal_watermarked):
|
24 |
+
# 数值越高越好,最好的结果为无穷大
|
25 |
+
original, signal_watermarked = to_equal_length(original, signal_watermarked)
|
26 |
+
noise_strength = np.sum((original - signal_watermarked) ** 2)
|
27 |
+
if noise_strength == 0: # 说明原始信号并未改变
|
28 |
+
return np.inf
|
29 |
+
signal_strength = np.sum(original ** 2)
|
30 |
+
ratio = signal_strength / noise_strength
|
31 |
+
|
32 |
+
# np.log10(1) == 0
|
33 |
+
# 当噪声比信号强度还高时,信噪比就是负的
|
34 |
+
# 如果ratio是0,那么 np.log10(0) 就是负无穷 -inf
|
35 |
+
# 这里限定一个最小值,以免出现负无穷情况
|
36 |
+
ratio = max(1e-10, ratio)
|
37 |
+
return 10 * np.log10(ratio)
|
38 |
+
|
39 |
+
|
40 |
+
def batch_signal_noise_ratio(original, signal_watermarked):
|
41 |
+
signal = original.detach().cpu().numpy()
|
42 |
+
signal_watermarked = signal_watermarked.detach().cpu().numpy()
|
43 |
+
tmp_list = []
|
44 |
+
for s, swm in zip(signal, signal_watermarked):
|
45 |
+
out = signal_noise_ratio(s, swm)
|
46 |
+
tmp_list.append(out)
|
47 |
+
return np.mean(tmp_list)
|
48 |
+
|
49 |
+
|
50 |
+
def calc_bce_acc(predictions, ground_truth, threshold=0.5):
|
51 |
+
assert predictions.shape == ground_truth.shape
|
52 |
+
|
53 |
+
# 将预测值转换为类别标签
|
54 |
+
predicted_labels = (predictions >= threshold).float()
|
55 |
+
|
56 |
+
# 计算准确率
|
57 |
+
accuracy = ((predicted_labels == ground_truth).float().mean().item())
|
58 |
+
return accuracy
|
59 |
+
|
60 |
+
|
61 |
+
def resample_to16k(data, old_sr):
|
62 |
+
# 对数据进行重采样
|
63 |
+
new_fs = 16000
|
64 |
+
new_data = data[::int(old_sr / new_fs)]
|
65 |
+
return new_data
|
66 |
+
|
67 |
+
|
68 |
+
import pypesq
|
69 |
+
|
70 |
+
|
71 |
+
def pesq(signal1, signal2, sr):
|
72 |
+
signal1, signal2 = to_equal_length(signal1, signal2)
|
73 |
+
|
74 |
+
# Perceptual Evaluation of Speech Quality
|
75 |
+
# [−0.5 to 4.5], PESQ>3.5 时音频质量较好,>4.0基本上就听不到了
|
76 |
+
# 函数只支持16k或8k的输入,因此在输入前校验采样率。由于这个指标计算的是可感知性,因此这里改变采样率和水印鲁棒性是无关的
|
77 |
+
if sr != 16000:
|
78 |
+
signal1 = resample_to16k(signal1, sr)
|
79 |
+
signal2 = resample_to16k(signal2, sr)
|
80 |
+
|
81 |
+
try:
|
82 |
+
pesq = pypesq.pesq(signal1, signal2, 16000)
|
83 |
+
# 可能会有错误:ValueError: ref is all zeros, processing error!
|
84 |
+
except Exception as e:
|
85 |
+
pesq = 0
|
86 |
+
print("pesq计算错误:", e)
|
87 |
+
|
88 |
+
return pesq
|
utils/model_util.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import os
|
3 |
+
import json
|
4 |
+
import sys
|
5 |
+
from utils import pickle_util
|
6 |
+
|
7 |
+
history_array = []
|
8 |
+
|
9 |
+
|
10 |
+
def save_model(epoch, model, optimizer, file_save_path):
|
11 |
+
dirpath = os.path.abspath(os.path.join(file_save_path, os.pardir))
|
12 |
+
if not os.path.exists(dirpath):
|
13 |
+
print("mkdir:", dirpath)
|
14 |
+
os.makedirs(dirpath)
|
15 |
+
|
16 |
+
opti = None
|
17 |
+
if optimizer is not None:
|
18 |
+
opti = optimizer.state_dict()
|
19 |
+
|
20 |
+
torch.save(obj={
|
21 |
+
'epoch': epoch,
|
22 |
+
'model': model.state_dict(),
|
23 |
+
'optimizer': opti,
|
24 |
+
}, f=file_save_path)
|
25 |
+
|
26 |
+
history_array.append(file_save_path)
|
27 |
+
|
28 |
+
|
29 |
+
def save_model_v4(epoch, model, optimizer, file_save_path, discriminator):
|
30 |
+
dirpath = os.path.abspath(os.path.join(file_save_path, os.pardir))
|
31 |
+
if not os.path.exists(dirpath):
|
32 |
+
print("mkdir:", dirpath)
|
33 |
+
os.makedirs(dirpath)
|
34 |
+
|
35 |
+
opti = None
|
36 |
+
if optimizer is not None:
|
37 |
+
opti = optimizer.state_dict()
|
38 |
+
|
39 |
+
torch.save(obj={
|
40 |
+
'epoch': epoch,
|
41 |
+
'model': model.state_dict(),
|
42 |
+
'optimizer': opti,
|
43 |
+
"discriminator": discriminator,
|
44 |
+
}, f=file_save_path)
|
45 |
+
|
46 |
+
history_array.append(file_save_path)
|
47 |
+
|
48 |
+
|
49 |
+
def delete_last_saved_model():
|
50 |
+
if len(history_array) == 0:
|
51 |
+
return
|
52 |
+
last_path = history_array.pop()
|
53 |
+
if os.path.exists(last_path):
|
54 |
+
os.remove(last_path)
|
55 |
+
print("delete model:", last_path)
|
56 |
+
|
57 |
+
if os.path.exists(last_path + ".json"):
|
58 |
+
os.remove(last_path + ".json")
|
59 |
+
|
60 |
+
|
61 |
+
def load_model(resume_path, model, optimizer=None, strict=True):
|
62 |
+
checkpoint = torch.load(resume_path, map_location=torch.device('cpu'))
|
63 |
+
start_epoch = checkpoint['epoch'] + 1
|
64 |
+
model.load_state_dict(checkpoint['model'], strict=strict)
|
65 |
+
if optimizer is not None:
|
66 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
67 |
+
print("checkpoint loaded!")
|
68 |
+
return start_epoch
|
69 |
+
|
70 |
+
|
71 |
+
def save_model_v2(model, args, model_save_name):
|
72 |
+
model_save_path = os.path.join(args.model_save_folder, args.project, args.name, model_save_name)
|
73 |
+
save_model(0, model, None, model_save_path)
|
74 |
+
print("save:", model_save_path)
|
75 |
+
|
76 |
+
|
77 |
+
def save_project_info(args):
|
78 |
+
run_info = {
|
79 |
+
"cmd_str": ' '.join(sys.argv[1:]),
|
80 |
+
"args": vars(args),
|
81 |
+
}
|
82 |
+
|
83 |
+
name = "run_info.json"
|
84 |
+
folder = os.path.join(args.model_save_folder, args.project, args.name)
|
85 |
+
if not os.path.exists(folder):
|
86 |
+
os.makedirs(folder)
|
87 |
+
|
88 |
+
json_file_path = os.path.join(folder, name)
|
89 |
+
with open(json_file_path, "w") as f:
|
90 |
+
json.dump(run_info, f)
|
91 |
+
|
92 |
+
print("save_project_info:", json_file_path)
|
93 |
+
|
94 |
+
|
95 |
+
def get_pkl_json(folder):
|
96 |
+
names = [i for i in os.listdir(folder) if ".pkl.json" in i]
|
97 |
+
assert len(names) == 1
|
98 |
+
json_path = os.path.join(folder, names[0])
|
99 |
+
obj = pickle_util.read_json(json_path)
|
100 |
+
return obj
|
101 |
+
|
102 |
+
|
103 |
+
# 并行
|
104 |
+
|
105 |
+
def is_data_parallel_checkpoint(state_dict):
|
106 |
+
return any(key.startswith('module.') for key in state_dict.keys())
|
107 |
+
|
108 |
+
|
109 |
+
def map_state_dict(state_dict):
|
110 |
+
if is_data_parallel_checkpoint(state_dict):
|
111 |
+
# 处理 DataParallel 添加的前缀 'module.'
|
112 |
+
from collections import OrderedDict
|
113 |
+
new_state_dict = OrderedDict()
|
114 |
+
for k, v in state_dict.items():
|
115 |
+
name = k[7:] if k.startswith('module.') else k # 移除前缀 'module.'
|
116 |
+
new_state_dict[name] = v
|
117 |
+
return new_state_dict
|
118 |
+
return state_dict
|
utils/pesq_util.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pypesq
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
def batch_pesq(batch_signal, batch_signal_wmd):
|
6 |
+
batch_signal1 = batch_signal.detach().cpu().numpy()
|
7 |
+
batch_signal2 = batch_signal_wmd.detach().cpu().numpy()
|
8 |
+
pesq_array = []
|
9 |
+
for signal1, signal2 in zip(batch_signal1, batch_signal2):
|
10 |
+
try:
|
11 |
+
pesq = pypesq.pesq(signal1, signal2, 16000)
|
12 |
+
#可能会有错误:ValueError: ref is all zeros, processing error!
|
13 |
+
|
14 |
+
except Exception as e:
|
15 |
+
print(e)
|
16 |
+
|
17 |
+
continue
|
18 |
+
if np.isnan(pesq):
|
19 |
+
print("pesq is nan!")
|
20 |
+
continue
|
21 |
+
pesq_array.append(pesq)
|
22 |
+
|
23 |
+
if len(pesq_array) > 0:
|
24 |
+
return np.mean(pesq_array)
|
25 |
+
return -1
|
utils/pickle_util.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import _pickle as pickle # python3
|
2 |
+
import time
|
3 |
+
import json
|
4 |
+
|
5 |
+
|
6 |
+
def read_pickle(filepath):
|
7 |
+
f = open(filepath, 'rb')
|
8 |
+
word2mfccs = pickle.load(f)
|
9 |
+
f.close()
|
10 |
+
return word2mfccs
|
11 |
+
|
12 |
+
|
13 |
+
def save_pickle(save_path, save_data):
|
14 |
+
f = open(save_path, 'wb')
|
15 |
+
pickle.dump(save_data, f)
|
16 |
+
f.close()
|
17 |
+
|
18 |
+
|
19 |
+
def read_json(filepath):
|
20 |
+
with open(filepath) as f:
|
21 |
+
obj = json.load(f)
|
22 |
+
return obj
|
23 |
+
|
24 |
+
|
25 |
+
def save_json(save_path, obj):
|
26 |
+
with open(save_path, 'w') as f:
|
27 |
+
json.dump(obj, f)
|
utils/silent_util.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
def is_silent(data, silence_threshold=0.01):
|
5 |
+
rms = np.sqrt(np.mean(data ** 2))
|
6 |
+
return rms < silence_threshold
|
7 |
+
|
8 |
+
|
9 |
+
def has_silent_part(trunck):
|
10 |
+
num_part = 3
|
11 |
+
part_length = int(len(trunck) / num_part)
|
12 |
+
for i in range(num_part):
|
13 |
+
start = part_length * i
|
14 |
+
end = start + part_length
|
15 |
+
mini_trunck = trunck[start:end]
|
16 |
+
if is_silent(mini_trunck):
|
17 |
+
return True
|
18 |
+
return False
|
utils/wm_add_v2.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils import silent_util
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from utils import bin_util
|
5 |
+
|
6 |
+
fix_pattern = [1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0,
|
7 |
+
0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1,
|
8 |
+
1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1,
|
9 |
+
1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0,
|
10 |
+
0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0]
|
11 |
+
|
12 |
+
|
13 |
+
def create_parcel_message(len_start_bit, num_bit, wm_text, verbose=False):
|
14 |
+
# 2.起始bit
|
15 |
+
# start_bit = np.array([0] * len_start_bit)
|
16 |
+
start_bit = fix_pattern[0:len_start_bit]
|
17 |
+
error_prob = 2 ** len_start_bit / 10000
|
18 |
+
# todo:考虑threshold的时候的错误率呢?
|
19 |
+
if verbose:
|
20 |
+
print("起始bit长度:%d,错误率:%.1f万" % (len(start_bit), error_prob))
|
21 |
+
|
22 |
+
# 3.信息内容
|
23 |
+
length_msg = num_bit - len(start_bit)
|
24 |
+
if wm_text:
|
25 |
+
msg_arr = bin_util.hexStr2BinArray(wm_text)
|
26 |
+
else:
|
27 |
+
msg_arr = np.random.choice([0, 1], size=length_msg)
|
28 |
+
|
29 |
+
# 4.封装信息
|
30 |
+
watermark = np.concatenate([start_bit, msg_arr])
|
31 |
+
assert len(watermark) == num_bit
|
32 |
+
return start_bit, msg_arr, watermark
|
33 |
+
|
34 |
+
|
35 |
+
import time
|
36 |
+
|
37 |
+
|
38 |
+
def add_watermark(bir_array, data, num_point, shift_range, device, model, silence_check=False):
|
39 |
+
t1 = time.time()
|
40 |
+
# 1.获得区块大小
|
41 |
+
chunk_size = num_point + int(num_point * shift_range)
|
42 |
+
|
43 |
+
output_chunks = []
|
44 |
+
idx_trunck = -1
|
45 |
+
for i in range(0, len(data), chunk_size):
|
46 |
+
idx_trunck += 1
|
47 |
+
current_chunk = data[i:i + chunk_size].copy()
|
48 |
+
# 最后一块,长度不足
|
49 |
+
if len(current_chunk) < chunk_size:
|
50 |
+
output_chunks.append(current_chunk)
|
51 |
+
break
|
52 |
+
|
53 |
+
# 处理区块: [水印区|间隔区]
|
54 |
+
current_chunk_cover_area = current_chunk[0:num_point]
|
55 |
+
current_chunk_shift_area = current_chunk[num_point:]
|
56 |
+
current_chunk_cover_area_wmd = encode_trunck_with_silence_check(silence_check,
|
57 |
+
idx_trunck,
|
58 |
+
current_chunk_cover_area, bir_array,
|
59 |
+
device, model)
|
60 |
+
output = np.concatenate([current_chunk_cover_area_wmd, current_chunk_shift_area])
|
61 |
+
assert output.shape == current_chunk.shape
|
62 |
+
output_chunks.append(output)
|
63 |
+
|
64 |
+
assert len(output_chunks) > 0
|
65 |
+
reconstructed_array = np.concatenate(output_chunks)
|
66 |
+
time_cost = time.time() - t1
|
67 |
+
return data, reconstructed_array, time_cost
|
68 |
+
|
69 |
+
|
70 |
+
def encode_trunck_with_silence_check(silence_check, trunck_idx, trunck, wm, device, model):
|
71 |
+
# 1.判断是否是静音,通过判断子段是否静音来处理
|
72 |
+
if silence_check and silent_util.is_silent(trunck):
|
73 |
+
print("跳过静音区块:", trunck_idx)
|
74 |
+
return trunck
|
75 |
+
|
76 |
+
# 2.加入水印
|
77 |
+
trnck_wmd = encode_trunck(trunck, wm, device, model)
|
78 |
+
return trnck_wmd
|
79 |
+
|
80 |
+
|
81 |
+
def encode_trunck(trunck, wm, device, model):
|
82 |
+
with torch.no_grad():
|
83 |
+
signal = torch.FloatTensor(trunck).to(device)[None]
|
84 |
+
message = torch.FloatTensor(np.array(wm)).to(device)[None]
|
85 |
+
signal_wmd_tensor = model.encode(signal, message)
|
86 |
+
signal_wmd = signal_wmd_tensor.detach().cpu().numpy().squeeze()
|
87 |
+
return signal_wmd
|
utils/wm_decode_v2.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pdb
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from utils import bin_util
|
6 |
+
|
7 |
+
|
8 |
+
def decode_trunck(trunck, model, device):
|
9 |
+
with torch.no_grad():
|
10 |
+
signal = torch.FloatTensor(trunck).to(device).unsqueeze(0)
|
11 |
+
message = (model.decode(signal) >= 0.5).int()
|
12 |
+
message = message.detach().cpu().numpy().squeeze()
|
13 |
+
return message
|
14 |
+
|
15 |
+
|
16 |
+
def is_start_bit_match(start_bit, decoded_start_bit, start_bit_ber_threshold):
|
17 |
+
assert decoded_start_bit.shape == start_bit.shape
|
18 |
+
ber = 1 - np.mean(start_bit == decoded_start_bit)
|
19 |
+
return ber < start_bit_ber_threshold
|
20 |
+
|
21 |
+
|
22 |
+
def extract_watermark(data, start_bit, shift_range, num_point, start_bit_ber_threshold, model, device,
|
23 |
+
verbose=False):
|
24 |
+
# pdb.set_trace()
|
25 |
+
shift_range_points = int(shift_range * num_point)
|
26 |
+
i = 0 # 当前的指针位置
|
27 |
+
results = []
|
28 |
+
while True:
|
29 |
+
start = i
|
30 |
+
end = start + num_point
|
31 |
+
trunck = data[start:end]
|
32 |
+
if len(trunck) < num_point:
|
33 |
+
break
|
34 |
+
|
35 |
+
bit_array = decode_trunck(trunck, model, device)
|
36 |
+
decoded_start_bit = bit_array[0:len(start_bit)]
|
37 |
+
if not is_start_bit_match(start_bit, decoded_start_bit, start_bit_ber_threshold):
|
38 |
+
i = i + shift_range_points
|
39 |
+
continue
|
40 |
+
# 寻找到了起始位置
|
41 |
+
if verbose:
|
42 |
+
msg_bit = bit_array[len(start_bit):]
|
43 |
+
msg_str = bin_util.binArray2HexStr(msg_bit)
|
44 |
+
print(i, "解码信息:", msg_str)
|
45 |
+
results.append(bit_array)
|
46 |
+
i = i + num_point + shift_range_points
|
47 |
+
|
48 |
+
support_count = len(results)
|
49 |
+
if support_count == 0:
|
50 |
+
mean_result = None
|
51 |
+
first_result = None
|
52 |
+
exist_prob = None
|
53 |
+
else:
|
54 |
+
mean_result = (np.array(results).mean(axis=0) >= 0.5).astype(int)
|
55 |
+
exist_prob = (mean_result[0:len(start_bit)] == start_bit).mean()
|
56 |
+
first_result = results[0]
|
57 |
+
|
58 |
+
return support_count, exist_prob, mean_result, first_result
|
59 |
+
|
60 |
+
|
61 |
+
def extract_watermark_v2(data, start_bit, shift_range, num_point,
|
62 |
+
start_bit_ber_threshold, model, device,
|
63 |
+
merge_type,
|
64 |
+
shift_range_p=0.5, ):
|
65 |
+
shift_range_points = int(shift_range * num_point * shift_range_p)
|
66 |
+
i = 0 # 当前的指针位置
|
67 |
+
results = []
|
68 |
+
while True:
|
69 |
+
start = i
|
70 |
+
end = start + num_point
|
71 |
+
trunck = data[start:end]
|
72 |
+
if len(trunck) < num_point:
|
73 |
+
break
|
74 |
+
|
75 |
+
bit_array = decode_trunck(trunck, model, device)
|
76 |
+
decoded_start_bit = bit_array[0:len(start_bit)]
|
77 |
+
|
78 |
+
ber_start_bit = 1 - np.mean(start_bit == decoded_start_bit)
|
79 |
+
if ber_start_bit > start_bit_ber_threshold:
|
80 |
+
i = i + shift_range_points
|
81 |
+
continue
|
82 |
+
# 寻找到了起始位置
|
83 |
+
results.append({
|
84 |
+
"sim": 1 - ber_start_bit,
|
85 |
+
"msg": bit_array,
|
86 |
+
})
|
87 |
+
# 这里很重要,如果threshold设置的太大,那么就会跳过一些可能的点
|
88 |
+
# i = i + num_point + shift_range_points
|
89 |
+
i = i + shift_range_points
|
90 |
+
|
91 |
+
support_count = len(results)
|
92 |
+
if support_count == 0:
|
93 |
+
mean_result = None
|
94 |
+
else:
|
95 |
+
# 1.加权得到最终结果
|
96 |
+
if merge_type == "weighted":
|
97 |
+
raise Exception("")
|
98 |
+
elif merge_type == "best":
|
99 |
+
# 相似度从大到小排序
|
100 |
+
best_val = sorted(results, key=lambda x: x["sim"], reverse=True)[0]
|
101 |
+
if np.isclose(1.0, best_val["sim"]):
|
102 |
+
# 那么对所有为1.0的进行求平均
|
103 |
+
results_1 = [i["msg"] for i in results if np.isclose(i["sim"], 1.0)]
|
104 |
+
mean_result = (np.array(results_1).mean(axis=0) >= 0.5).astype(int)
|
105 |
+
else:
|
106 |
+
mean_result = best_val["msg"]
|
107 |
+
|
108 |
+
else:
|
109 |
+
raise Exception("")
|
110 |
+
# assert merge_type == "mean"
|
111 |
+
# mean_result = (np.array([i[-1] for i in results]).mean(axis=0) >= 0.5).astype(int)
|
112 |
+
|
113 |
+
return support_count, mean_result, results
|