admin commited on
Commit
7c99075
1 Parent(s): 8be39d8

upl base code

Browse files
Files changed (7) hide show
  1. .gitattributes +11 -11
  2. .gitignore +6 -0
  3. README.md +1 -1
  4. app.py +219 -0
  5. model.py +145 -0
  6. requirements.txt +6 -0
  7. utils.py +67 -0
.gitattributes CHANGED
@@ -1,35 +1,35 @@
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
 
4
  *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
  *.model filter=lfs diff=lfs merge=lfs -text
13
  *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
  *.onnx filter=lfs diff=lfs merge=lfs -text
17
  *.ot filter=lfs diff=lfs merge=lfs -text
18
  *.parquet filter=lfs diff=lfs merge=lfs -text
19
  *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
  *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
  *.tflite filter=lfs diff=lfs merge=lfs -text
30
  *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
  *.bz2 filter=lfs diff=lfs merge=lfs -text
 
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
 
11
  *.model filter=lfs diff=lfs merge=lfs -text
12
  *.msgpack filter=lfs diff=lfs merge=lfs -text
 
 
13
  *.onnx filter=lfs diff=lfs merge=lfs -text
14
  *.ot filter=lfs diff=lfs merge=lfs -text
15
  *.parquet filter=lfs diff=lfs merge=lfs -text
16
  *.pb filter=lfs diff=lfs merge=lfs -text
 
 
17
  *.pt filter=lfs diff=lfs merge=lfs -text
18
  *.pth filter=lfs diff=lfs merge=lfs -text
19
  *.rar filter=lfs diff=lfs merge=lfs -text
 
20
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
  *.tar.* filter=lfs diff=lfs merge=lfs -text
 
22
  *.tflite filter=lfs diff=lfs merge=lfs -text
23
  *.tgz filter=lfs diff=lfs merge=lfs -text
 
24
  *.xz filter=lfs diff=lfs merge=lfs -text
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *.tfevents* filter=lfs diff=lfs merge=lfs -text
28
+ *.db* filter=lfs diff=lfs merge=lfs -text
29
+ *.ark* filter=lfs diff=lfs merge=lfs -text
30
+ **/*ckpt*data* filter=lfs diff=lfs merge=lfs -text
31
+ **/*ckpt*.meta filter=lfs diff=lfs merge=lfs -text
32
+ **/*ckpt*.index filter=lfs diff=lfs merge=lfs -text
33
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
34
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
35
+ *.wav filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ *.pt
2
+ __pycache__/*
3
+ tmp/*
4
+ flagged/*
5
+ test.py
6
+ rename.sh
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: GZ IsoTech
3
- emoji: 🏆
4
  colorFrom: red
5
  colorTo: indigo
6
  sdk: gradio
 
1
  ---
2
  title: GZ IsoTech
3
+ emoji: 🪕🎵
4
  colorFrom: red
5
  colorTo: indigo
6
  sdk: gradio
app.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import random
4
+ import shutil
5
+ import librosa
6
+ import warnings
7
+ import numpy as np
8
+ import gradio as gr
9
+ import librosa.display
10
+ import matplotlib.pyplot as plt
11
+ from utils import get_modelist, find_files, embed_img, TEMP_DIR
12
+ from collections import Counter
13
+ from model import EvalNet
14
+
15
+
16
+ TRANSLATE = {
17
+ "vibrato": "颤音",
18
+ "upward_portamento": "上滑音",
19
+ "downward_portamento": "下滑音",
20
+ "returning_portamento": "回滑音",
21
+ "glissando": "刮奏, 花指",
22
+ "tremolo": "摇指",
23
+ "harmonics": "泛音",
24
+ "plucks": "勾, 打, 抹, 托, ...",
25
+ }
26
+ CLASSES = list(TRANSLATE.keys())
27
+ SAMPLE_RATE = 44100
28
+
29
+
30
+ def circular_padding(spec: np.ndarray, end: int):
31
+ size = len(spec)
32
+ if end <= size:
33
+ return spec
34
+
35
+ num_padding = end - size
36
+ num_repeat = num_padding // size + int(num_padding % size != 0)
37
+ padding = np.tile(spec, num_repeat)
38
+ return np.concatenate((spec, padding))[:end]
39
+
40
+
41
+ def wav2mel(audio_path: str, width=3):
42
+ os.makedirs(TEMP_DIR, exist_ok=True)
43
+ try:
44
+ y, sr = librosa.load(audio_path, sr=SAMPLE_RATE)
45
+ total_frames = len(y)
46
+ if total_frames % (width * sr) != 0:
47
+ count = total_frames // (width * sr) + 1
48
+ y = circular_padding(y, count * width * sr)
49
+
50
+ mel_spec = librosa.feature.melspectrogram(y=y, sr=sr)
51
+ log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max)
52
+ dur = librosa.get_duration(y=y, sr=sr)
53
+ total_frames = log_mel_spec.shape[1]
54
+ step = int(width * total_frames / dur)
55
+ count = int(total_frames / step)
56
+ begin = int(0.5 * (total_frames - count * step))
57
+ end = begin + step * count
58
+ for i in range(begin, end, step):
59
+ librosa.display.specshow(log_mel_spec[:, i : i + step])
60
+ plt.axis("off")
61
+ plt.savefig(
62
+ f"{TEMP_DIR}/{i}.jpg",
63
+ bbox_inches="tight",
64
+ pad_inches=0.0,
65
+ )
66
+ plt.close()
67
+
68
+ except Exception as e:
69
+ print(f"Error converting {audio_path} : {e}")
70
+
71
+
72
+ def wav2cqt(audio_path: str, width=3):
73
+ os.makedirs(TEMP_DIR, exist_ok=True)
74
+ try:
75
+ y, sr = librosa.load(audio_path, sr=SAMPLE_RATE)
76
+ total_frames = len(y)
77
+ if total_frames % (width * sr) != 0:
78
+ count = total_frames // (width * sr) + 1
79
+ y = circular_padding(y, count * width * sr)
80
+
81
+ cqt_spec = librosa.cqt(y=y, sr=sr)
82
+ log_cqt_spec = librosa.power_to_db(np.abs(cqt_spec) ** 2, ref=np.max)
83
+ dur = librosa.get_duration(y=y, sr=sr)
84
+ total_frames = log_cqt_spec.shape[1]
85
+ step = int(width * total_frames / dur)
86
+ count = int(total_frames / step)
87
+ begin = int(0.5 * (total_frames - count * step))
88
+ end = begin + step * count
89
+ for i in range(begin, end, step):
90
+ librosa.display.specshow(log_cqt_spec[:, i : i + step])
91
+ plt.axis("off")
92
+ plt.savefig(
93
+ f"{TEMP_DIR}/{i}.jpg",
94
+ bbox_inches="tight",
95
+ pad_inches=0.0,
96
+ )
97
+ plt.close()
98
+
99
+ except Exception as e:
100
+ print(f"Error converting {audio_path} : {e}")
101
+
102
+
103
+ def wav2chroma(audio_path: str, width=3):
104
+ os.makedirs(TEMP_DIR, exist_ok=True)
105
+ try:
106
+ y, sr = librosa.load(audio_path, sr=SAMPLE_RATE)
107
+ total_frames = len(y)
108
+ if total_frames % (width * sr) != 0:
109
+ count = total_frames // (width * sr) + 1
110
+ y = circular_padding(y, count * width * sr)
111
+
112
+ chroma_spec = librosa.feature.chroma_stft(y=y, sr=sr)
113
+ log_chroma_spec = librosa.power_to_db(np.abs(chroma_spec) ** 2, ref=np.max)
114
+ dur = librosa.get_duration(y=y, sr=sr)
115
+ total_frames = log_chroma_spec.shape[1]
116
+ step = int(width * total_frames / dur)
117
+ count = int(total_frames / step)
118
+ begin = int(0.5 * (total_frames - count * step))
119
+ end = begin + step * count
120
+ for i in range(begin, end, step):
121
+ librosa.display.specshow(log_chroma_spec[:, i : i + step])
122
+ plt.axis("off")
123
+ plt.savefig(
124
+ f"{TEMP_DIR}/{i}.jpg",
125
+ bbox_inches="tight",
126
+ pad_inches=0.0,
127
+ )
128
+ plt.close()
129
+
130
+ except Exception as e:
131
+ print(f"Error converting {audio_path} : {e}")
132
+
133
+
134
+ def most_frequent_value(lst: list):
135
+ counter = Counter(lst)
136
+ max_count = max(counter.values())
137
+ for element, count in counter.items():
138
+ if count == max_count:
139
+ return element
140
+
141
+ return None
142
+
143
+
144
+ def infer(wav_path: str, log_name: str, folder_path=TEMP_DIR):
145
+ if os.path.exists(folder_path):
146
+ shutil.rmtree(folder_path)
147
+
148
+ if not wav_path:
149
+ return None, "请输入音频 Please input an audio!"
150
+
151
+ try:
152
+ model = EvalNet(log_name, len(TRANSLATE)).model
153
+ except Exception as e:
154
+ return None, f"{e}"
155
+
156
+ spec = log_name.split("_")[-3]
157
+ eval("wav2%s" % spec)(wav_path)
158
+ jpgs = find_files(folder_path, ".jpg")
159
+ preds = []
160
+ for jpg in jpgs:
161
+ input = embed_img(jpg)
162
+ output: torch.Tensor = model(input)
163
+ preds.append(torch.max(output.data, 1)[1])
164
+
165
+ pred_id = most_frequent_value(preds)
166
+ return (
167
+ os.path.basename(wav_path),
168
+ f"{TRANSLATE[CLASSES[pred_id]]} ({CLASSES[pred_id].capitalize()})",
169
+ )
170
+
171
+
172
+ if __name__ == "__main__":
173
+ warnings.filterwarnings("ignore")
174
+ models = get_modelist()
175
+ examples = []
176
+ example_wavs = find_files()
177
+ model_num = len(models)
178
+ for wav in example_wavs:
179
+ examples.append([wav, models[random.randint(0, model_num - 1)]])
180
+
181
+ with gr.Blocks() as demo:
182
+ gr.Interface(
183
+ fn=infer,
184
+ inputs=[
185
+ gr.Audio(label="上传录音 Upload a recording", type="filepath"),
186
+ gr.Dropdown(
187
+ choices=models, label="选择模型 Select a model", value=models[0]
188
+ ),
189
+ ],
190
+ outputs=[
191
+ gr.Textbox(label="音频文件名 Audio filename", show_copy_button=True),
192
+ gr.Textbox(
193
+ label="古筝演奏技法识别 Guzheng playing tech recognition",
194
+ show_copy_button=True,
195
+ ),
196
+ ],
197
+ examples=examples,
198
+ cache_examples=False,
199
+ flagging_mode="never",
200
+ title="建议录音时长保持在 3s 左右<br>It is recommended to keep the recording length around 3s.",
201
+ )
202
+
203
+ gr.Markdown(
204
+ """
205
+ # 引用 Cite
206
+ ```bibtex
207
+ @dataset{zhaorui_liu_2021_5676893,
208
+ author = {Monan Zhou, Shenyang Xu, Zhaorui Liu, Zhaowen Wang, Feng Yu, Wei Li and Baoqiang Han},
209
+ title = {CCMusic: an Open and Diverse Database for Chinese Music Information Retrieval Research},
210
+ month = {mar},
211
+ year = {2024},
212
+ publisher = {HuggingFace},
213
+ version = {1.2},
214
+ url = {https://huggingface.co/ccmusic-database}
215
+ }
216
+ ```"""
217
+ )
218
+
219
+ demo.launch()
model.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.models as models
4
+ from modelscope.msdatasets import MsDataset
5
+ from utils import MODEL_DIR
6
+
7
+
8
+ class EvalNet:
9
+ model: nn.Module = None
10
+ m_type = "squeezenet"
11
+ input_size = 224
12
+ output_size = 512
13
+
14
+ def __init__(self, log_name: str, cls_num: int):
15
+ saved_model_path = f"{MODEL_DIR}/{log_name}/save.pt"
16
+ m_ver = "_".join(log_name.split("_")[:-3])
17
+ self.m_type, self.input_size = self._model_info(m_ver)
18
+
19
+ if not hasattr(models, m_ver):
20
+ raise Exception("Unsupported model.")
21
+
22
+ self.model = eval("models.%s()" % m_ver)
23
+ linear_output = self._set_outsize()
24
+ self._set_classifier(cls_num, linear_output)
25
+ checkpoint = torch.load(saved_model_path, map_location="cpu")
26
+ if torch.cuda.is_available():
27
+ checkpoint = torch.load(saved_model_path)
28
+
29
+ self.model.load_state_dict(checkpoint, False)
30
+ self.model.eval()
31
+
32
+ def _get_backbone(self, ver: str, backbone_list: list):
33
+ for bb in backbone_list:
34
+ if ver == bb["ver"]:
35
+ return bb
36
+
37
+ print("Backbone name not found, using default option - alexnet.")
38
+ return backbone_list[0]
39
+
40
+ def _model_info(self, m_ver: str):
41
+ backbone_list = MsDataset.load(
42
+ "monetjoe/cv_backbones",
43
+ split="v1",
44
+ )
45
+ backbone = self._get_backbone(m_ver, backbone_list)
46
+ m_type = str(backbone["type"])
47
+ input_size = int(backbone["input_size"])
48
+ return m_type, input_size
49
+
50
+ def _classifier(self, cls_num: int, output_size: int, linear_output: bool):
51
+ q = (1.0 * output_size / cls_num) ** 0.25
52
+ l1 = int(q * cls_num)
53
+ l2 = int(q * l1)
54
+ l3 = int(q * l2)
55
+ if linear_output:
56
+ return torch.nn.Sequential(
57
+ nn.Dropout(),
58
+ nn.Linear(output_size, l3),
59
+ nn.ReLU(inplace=True),
60
+ nn.Dropout(),
61
+ nn.Linear(l3, l2),
62
+ nn.ReLU(inplace=True),
63
+ nn.Dropout(),
64
+ nn.Linear(l2, l1),
65
+ nn.ReLU(inplace=True),
66
+ nn.Linear(l1, cls_num),
67
+ )
68
+
69
+ else:
70
+ return torch.nn.Sequential(
71
+ nn.Dropout(),
72
+ nn.Conv2d(output_size, l3, kernel_size=(1, 1), stride=(1, 1)),
73
+ nn.ReLU(inplace=True),
74
+ nn.AdaptiveAvgPool2d(output_size=(1, 1)),
75
+ nn.Flatten(),
76
+ nn.Linear(l3, l2),
77
+ nn.ReLU(inplace=True),
78
+ nn.Dropout(),
79
+ nn.Linear(l2, l1),
80
+ nn.ReLU(inplace=True),
81
+ nn.Linear(l1, cls_num),
82
+ )
83
+
84
+ def _set_outsize(self):
85
+ for name, module in self.model.named_modules():
86
+ if (
87
+ str(name).__contains__("classifier")
88
+ or str(name).__eq__("fc")
89
+ or str(name).__contains__("head")
90
+ or hasattr(module, "classifier")
91
+ ):
92
+ if isinstance(module, torch.nn.Linear):
93
+ self.output_size = module.in_features
94
+ return True
95
+
96
+ if isinstance(module, torch.nn.Conv2d):
97
+ self.output_size = module.in_channels
98
+ return False
99
+
100
+ return False
101
+
102
+ def _set_classifier(self, cls_num: int, linear_output: bool):
103
+ if self.m_type == "convnext":
104
+ del self.model.classifier[2]
105
+ self.model.classifier = nn.Sequential(
106
+ *list(self.model.classifier)
107
+ + list(self._classifier(cls_num, self.output_size, linear_output))
108
+ )
109
+ return
110
+
111
+ elif self.m_type == "maxvit":
112
+ del self.model.classifier[5]
113
+ self.model.classifier = nn.Sequential(
114
+ *list(self.model.classifier)
115
+ + list(self._classifier(cls_num, self.output_size, linear_output))
116
+ )
117
+ return
118
+
119
+ if hasattr(self.model, "classifier"):
120
+ self.model.classifier = self._classifier(
121
+ cls_num, self.output_size, linear_output
122
+ )
123
+ return
124
+
125
+ elif hasattr(self.model, "fc"):
126
+ self.model.fc = self._classifier(cls_num, self.output_size, linear_output)
127
+ return
128
+
129
+ elif hasattr(self.model, "head"):
130
+ self.model.head = self._classifier(cls_num, self.output_size, linear_output)
131
+ return
132
+
133
+ self.model.heads.head = self._classifier(
134
+ cls_num, self.output_size, linear_output
135
+ )
136
+
137
+ def forward(self, x: torch.Tensor):
138
+ if torch.cuda.is_available():
139
+ x = x.cuda()
140
+ self.model = self.model.cuda()
141
+
142
+ if self.m_type == "googlenet":
143
+ return self.model(x)[0]
144
+ else:
145
+ return self.model(x)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ modelscope
2
+ librosa
3
+ torch
4
+ matplotlib
5
+ torchvision
6
+ pillow
utils.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchvision.transforms as transforms
4
+ from modelscope import snapshot_download
5
+ from PIL import Image
6
+
7
+ MODEL_DIR = snapshot_download(
8
+ f"ccmusic-database/GZ_IsoTech",
9
+ cache_dir=f"{os.getcwd()}/__pycache__",
10
+ )
11
+ TEMP_DIR = f"{os.getcwd()}/flagged"
12
+
13
+
14
+ def toCUDA(x):
15
+ if hasattr(x, "cuda"):
16
+ if torch.cuda.is_available():
17
+ return x.cuda()
18
+
19
+ return x
20
+
21
+
22
+ def find_files(folder_path=f"{MODEL_DIR}/examples", ext=".wav"):
23
+ wav_files = []
24
+ for root, _, files in os.walk(folder_path):
25
+ for file in files:
26
+ if file.endswith(ext):
27
+ file_path = os.path.join(root, file)
28
+ wav_files.append(file_path)
29
+
30
+ return wav_files
31
+
32
+
33
+ def get_modelist(model_dir=MODEL_DIR):
34
+ try:
35
+ entries = os.listdir(model_dir)
36
+ except OSError as e:
37
+ print(f"无法访问 {model_dir}: {e}")
38
+ return
39
+
40
+ # 遍历所有条目
41
+ output = []
42
+ for entry in entries:
43
+ # 获取完整路径
44
+ full_path = os.path.join(model_dir, entry)
45
+ # 跳过'.git'文件夹
46
+ if entry == ".git" or entry == "examples":
47
+ print(f"跳过 .git 或 examples 文件夹: {full_path}")
48
+ continue
49
+
50
+ # 检查条目是文件还是目录
51
+ if os.path.isdir(full_path):
52
+ # 打印目录路径
53
+ output.append(os.path.basename(full_path))
54
+
55
+ return output
56
+
57
+
58
+ def embed_img(img_path: str, input_size=224):
59
+ transform = transforms.Compose(
60
+ [
61
+ transforms.Resize([input_size, input_size]),
62
+ transforms.ToTensor(),
63
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
64
+ ]
65
+ )
66
+ img = Image.open(img_path).convert("RGB")
67
+ return transform(img).unsqueeze(0)