Santipab commited on
Commit
a19d827
·
verified ·
1 Parent(s): d588160

Upload 30 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ CX0075_png.rf.f86acbac9d6c41151e8caed4914a3e89.jpg filter=lfs diff=lfs merge=lfs -text
37
+ dev_3.jpg filter=lfs diff=lfs merge=lfs -text
38
+ DP1983_png.rf.3f2a58f7f0feb4f9ad7b34149149553b.jpg filter=lfs diff=lfs merge=lfs -text
39
+ HS1500_png.rf.8659b481c780f6b582532eb56d6f5349.jpg filter=lfs diff=lfs merge=lfs -text
40
+ image_1.jpg filter=lfs diff=lfs merge=lfs -text
41
+ image_3.jpg filter=lfs diff=lfs merge=lfs -text
CX0075_png.rf.f86acbac9d6c41151e8caed4914a3e89.jpg ADDED

Git LFS Details

  • SHA256: 62b1d1ec521592728cc449a1b81a59e44b889f6666ddb4133b6201c35c3e29d0
  • Pointer size: 131 Bytes
  • Size of remote file: 129 kB
DP1983_png.rf.3f2a58f7f0feb4f9ad7b34149149553b.jpg ADDED

Git LFS Details

  • SHA256: 5ae18118bab61893d1e9d12cded22474cd446e1cd24cf502f1ecc0bb08b04953
  • Pointer size: 131 Bytes
  • Size of remote file: 138 kB
HS1500_png.rf.8659b481c780f6b582532eb56d6f5349.jpg ADDED

Git LFS Details

  • SHA256: 82dc4ca15fa710cbc2d452e3d4705e5b9c350182bbfeca8486ccfa25ff5b07d1
  • Pointer size: 131 Bytes
  • Size of remote file: 112 kB
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2020 yijingru
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,12 +1,57 @@
1
- ---
2
- title: Cra Innovation Home App Demo
3
- emoji: 💻
4
- colorFrom: pink
5
- colorTo: pink
6
- sdk: streamlit
7
- sdk_version: 1.31.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Vertebra-Focused-Landmark-Detection-Pytorch
2
+ Vertebra-Focused Landmark Detection for Scoliosis Assessment [[arXiv](https://arxiv.org/pdf/2001.03187.pdf)]
3
+
4
+ Accepted to ISBI2020.
5
+
6
+
7
+ Please cite the article in your publications if it helps your research:
8
+
9
+ @article{yi2020vertebra,
10
+ title={Vertebra-Focused Landmark Detection for Scoliosis Assessment},
11
+ author={Yi, Jingru and Wu, Pengxiang and Huang, Qiaoying and Qu, Hui and Metaxas, Dimitris N},
12
+ booktitle={ISBI},
13
+ year={2020}
14
+ }
15
+
16
+
17
+ <p align="center">
18
+ <img src="imgs/pic1.png", width="400">
19
+ </p>
20
+
21
+ <p align="center">
22
+ <img src="imgs/pic2.png", width="800">
23
+ </p>
24
+
25
+ # Dependencies
26
+ Ubuntu 14.04, Python 3.6.4, PyTorch 1.1.0, OpenCV-Python 4.1.0.25
27
+
28
+ # How to start
29
+ ## Prepare Dataset
30
+ To directly use dataset.py, you can arrange the dataset as follows:
31
+ ```
32
+ /dataPath/data
33
+ /train/*.jpg
34
+ /val/*.jpg
35
+ /test/*.jpg
36
+ /dataPath/labels/
37
+ /train/*.mat
38
+ /val/*.mat
39
+ /test/*.mat
40
+ ```
41
+ The source dataset is from [[dataset16](http://spineweb.digitalimaginggroup.ca/spineweb/index.php?n=Main.Datasets#Dataset_16.3A_609_spinal_anterior-posterior_x-ray_images)].
42
+ To adapt the code to your own dataset, you can modify the dataset.py, for example, change the 'load_gt_pts' function to adapt it to your own annotations. The pretrained weights can be downloaded [here](https://drive.google.com/drive/folders/1LhKnGVE8dUw0nK9_x4vPNY_L7sPY2_aQ?usp=sharing).
43
+
44
+ ## Train the model
45
+ ```ruby
46
+ python main.py --data_dir dataPath --epochs 50 --batch_size 2 --dataset spinal --phase train
47
+ ```
48
+
49
+ ## Test the model
50
+ ```ruby
51
+ python main.py --resume weightPath --data_dir dataPath --dataset spinal --phase test
52
+ ```
53
+
54
+
55
+ ## Evaluate the model
56
+ ```ruby
57
+ python main.py --resume weightPath --data_dir dataPath --dataset spinal --phase eval
app.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import shutil
4
+ import importlib.util
5
+ from io import BytesIO
6
+ from ultralytics import YOLO
7
+ from PIL import Image
8
+
9
+ import torch
10
+ # ─── FORCE CPU ONLY ─────────────────────────────────────────────────────────
11
+ torch.Tensor.cuda = lambda self, *args, **kwargs: self
12
+ torch.nn.Module.cuda = lambda self, *args, **kwargs: self
13
+ torch.cuda.synchronize = lambda *args, **kwargs: None
14
+ torch.cuda.is_available= lambda : False
15
+ torch.cuda.device_count= lambda : 0
16
+ _orig_to = torch.Tensor.to
17
+ def _to_cpu(self, *args, **kwargs):
18
+ new_args = []
19
+ for a in args:
20
+ if isinstance(a, str) and a.lower().startswith("cuda"):
21
+ new_args.append("cpu")
22
+ elif isinstance(a, torch.device) and a.type=="cuda":
23
+ new_args.append(torch.device("cpu"))
24
+ else:
25
+ new_args.append(a)
26
+ if "device" in kwargs:
27
+ dev = kwargs["device"]
28
+ if (isinstance(dev, str) and dev.lower().startswith("cuda")) or \
29
+ (isinstance(dev, torch.device) and dev.type=="cuda"):
30
+ kwargs["device"] = torch.device("cpu")
31
+ return _orig_to(self, *new_args, **kwargs)
32
+ torch.Tensor.to = _to_cpu
33
+
34
+ from torch.utils.data import DataLoader as _DL
35
+ def _dl0(ds, *a, **kw):
36
+ kw['num_workers'] = 0
37
+ return _DL(ds, *a, **kw)
38
+ import torch.utils.data as _du
39
+ _du.DataLoader = _dl0
40
+
41
+ import cv2
42
+ import numpy as np
43
+ import streamlit as st
44
+ from argparse import Namespace
45
+
46
+ # ─── DYNAMIC IMPORT ─────────────────────────────────────────────────────────
47
+ REPO = os.path.dirname(os.path.abspath(__file__))
48
+ sys.path.append(REPO)
49
+ models_dir = os.path.join(REPO, "models")
50
+ os.makedirs(models_dir, exist_ok=True)
51
+ open(os.path.join(models_dir, "__init__.py"), "a").close()
52
+
53
+ def load_mod(name, path):
54
+ spec = importlib.util.spec_from_file_location(name, path)
55
+ m = importlib.util.module_from_spec(spec)
56
+ spec.loader.exec_module(m)
57
+ sys.modules[name] = m
58
+ return m
59
+
60
+ dataset_mod = load_mod("dataset", os.path.join(REPO, "dataset.py"))
61
+ decoder_mod = load_mod("decoder", os.path.join(REPO, "decoder.py"))
62
+ draw_mod = load_mod("draw_points", os.path.join(REPO, "draw_points.py"))
63
+ test_mod = load_mod("test", os.path.join(REPO, "test.py"))
64
+ load_mod("models.dec_net", os.path.join(models_dir, "dec_net.py"))
65
+ load_mod("models.model_parts", os.path.join(models_dir, "model_parts.py"))
66
+ load_mod("models.resnet", os.path.join(models_dir, "resnet.py"))
67
+ load_mod("models.spinal_net", os.path.join(models_dir, "spinal_net.py"))
68
+
69
+ BaseDataset = dataset_mod.BaseDataset
70
+ Network = test_mod.Network
71
+
72
+ # ─── STREAMLIT UI ───────────────────────────────────────────────────────────
73
+ st.set_page_config(layout="wide", page_title="Vertebral Compression Fracture")
74
+
75
+ st.markdown(
76
+ """
77
+ <div style='border: 2px solid #0080FF; border-radius: 5px; padding: 10px'>
78
+ <h1 style='text-align: center; color: #0080FF'>
79
+ 🦴 Vertebral Compression Fracture Detection 🖼️
80
+ </h1>
81
+ </div>
82
+ """, unsafe_allow_html=True)
83
+ st.markdown("")
84
+ st.markdown("")
85
+ st.markdown("")
86
+ col1, col2, col3, col4 = st.columns(4)
87
+
88
+ with col4:
89
+ feature = st.selectbox(
90
+ "🔀 Select Feature",
91
+ ["How to use", "AP - Detection", "AP - Cobb angle" , "LA - Image Segmetation", "Contract"],
92
+ index=0, # default to "AP"
93
+ help="Choose which view to display"
94
+ )
95
+
96
+ if feature == "How to use":
97
+ st.markdown("## 📖 How to use this app")
98
+
99
+ col1, col2, col3 = st.columns(3)
100
+
101
+ with col1:
102
+ st.markdown(
103
+ """
104
+ <div style='border:2px solid #00BFFF; border-radius:10px; padding:15px; text-align:center; background-color:#F0F8FF'>
105
+ <h2>Step 1️⃣</h2>
106
+ <p>Go to <b>AP - Detection</b> or <b>LA - Image Segmentation</b></p>
107
+ <p>Select a sample image or upload your own image file.</p>
108
+ <p style='color:#008000;'><b>✅ Tip:</b> Best with X-ray images with clear vertebra visibility.</p>
109
+ </div>
110
+ """,
111
+ unsafe_allow_html=True
112
+ )
113
+
114
+ with col2:
115
+ st.markdown(
116
+ """
117
+ <div style='border:2px solid #00BFFF; border-radius:10px; padding:15px; text-align:center; background-color:#F0F8FF'>
118
+ <h2>Step 2️⃣</h2>
119
+ <p>Press the <b>Enter</b> button.</p>
120
+ <p>The system will process your image automatically.</p>
121
+ <p style='color:#FFA500;'><b>⏳ Note:</b> Processing time depends on image size.</p>
122
+ </div>
123
+ """,
124
+ unsafe_allow_html=True
125
+ )
126
+
127
+ with col3:
128
+ st.markdown(
129
+ """
130
+ <div style='border:2px solid #00BFFF; border-radius:10px; padding:15px; text-align:center; background-color:#F0F8FF'>
131
+ <h2>Step 3️⃣</h2>
132
+ <p>See the prediction results:</p>
133
+ <p style= text-align:left > 1. Bounding boxes & landmarks (AP)</p>
134
+ <p style= text-align:left > 2. Segmentation masks (LA)</p>
135
+ </div>
136
+ """,
137
+ unsafe_allow_html=True
138
+ )
139
+
140
+ st.markdown(" ")
141
+ st.info("สามารถเลือกฟีเจอร์ได้ผ่าน Select Feature โดยแต่ล่ะฟีเจอร์จะมีตัวอย่างกำกับให้ว่าเป็นยังไง")
142
+
143
+ # store original dimensions
144
+ elif feature == "AP - Detection":
145
+ uploaded = st.file_uploader("", type=["jpg", "jpeg", "png"])
146
+ orig_w = orig_h = None
147
+ img0 = None
148
+ run = st.button("Enter", use_container_width=True)
149
+ # ─── Maintain selected sample in session state ─────────
150
+ if "sample_img" not in st.session_state:
151
+ st.session_state.sample_img = None
152
+
153
+ # ─── SAMPLE BUTTONS ─────────────────────────────────────
154
+ with col1:
155
+ if st.button(" 1️⃣ Example",use_container_width=True):
156
+ st.session_state.sample_img = "image_1.jpg"
157
+ with col2:
158
+ if st.button(" 2️⃣ Example",use_container_width=True):
159
+ st.session_state.sample_img = "image_2.jpg"
160
+ with col3:
161
+ if st.button(" 3️⃣ Example",use_container_width=True):
162
+ st.session_state.sample_img = "image_3.jpg"
163
+
164
+ # ─── UI FOR UPLOAD + DISPLAY ───────────────────────────
165
+ col4, col5, col6 = st.columns(3)
166
+ with col4:
167
+ st.subheader("1️⃣ Upload & Run")
168
+
169
+ sample_img = st.session_state.sample_img # read persisted choice
170
+
171
+ # case 1: uploaded file
172
+ if uploaded:
173
+ buf = uploaded.getvalue()
174
+ arr = np.frombuffer(buf, np.uint8)
175
+ img0 = cv2.imdecode(arr, cv2.IMREAD_COLOR)
176
+ orig_h, orig_w = img0.shape[:2]
177
+ st.image(cv2.cvtColor(img0, cv2.COLOR_BGR2RGB), caption="Uploaded Image", use_container_width=True)
178
+
179
+ # case 2: selected sample image
180
+ elif sample_img is not None:
181
+ img_path = os.path.join(REPO, sample_img)
182
+ img0 = cv2.imread(img_path)
183
+ if img0 is not None:
184
+ orig_h, orig_w = img0.shape[:2]
185
+ st.image(cv2.cvtColor(img0, cv2.COLOR_BGR2RGB),
186
+ caption=f"Sample Image: {sample_img}",
187
+ use_container_width=True)
188
+ else:
189
+ st.error(f"Cannot find {sample_img} in directory!")
190
+
191
+
192
+
193
+ with col5:
194
+ st.subheader("2️⃣ Predictions")
195
+ with col6:
196
+ st.subheader("3️⃣ Heatmap")
197
+
198
+ # ─── ARGS & CHECKPOINT ─────────────────────────────────
199
+ args = Namespace(
200
+ resume="model_30.pth",
201
+ data_dir=os.path.join(REPO, "dataPath"),
202
+ dataset="spinal",
203
+ phase="test",
204
+ input_h=1024,
205
+ input_w=512,
206
+ down_ratio=4,
207
+ num_classes=1,
208
+ K=17,
209
+ conf_thresh=0.2,
210
+ )
211
+ weights_dir = os.path.join(REPO, "weights_spinal")
212
+ os.makedirs(weights_dir, exist_ok=True)
213
+ src_ckpt = os.path.join(REPO, "model_backup", args.resume)
214
+ dst_ckpt = os.path.join(weights_dir, args.resume)
215
+ if os.path.isfile(src_ckpt) and not os.path.isfile(dst_ckpt):
216
+ shutil.copy(src_ckpt, dst_ckpt)
217
+
218
+ # ─── MAIN LOGIC ────────────────────────────────────────
219
+ if img0 is not None and run and orig_w and orig_h:
220
+ # determine name for saving
221
+ if uploaded:
222
+ name = os.path.splitext(uploaded.name)[0] + ".jpg"
223
+ else:
224
+ name = os.path.splitext(sample_img)[0] + ".jpg"
225
+
226
+ testd = os.path.join(args.data_dir, "data", "test")
227
+ os.makedirs(testd, exist_ok=True)
228
+ cv2.imwrite(os.path.join(testd, name), img0)
229
+
230
+ orig_init = BaseDataset.__init__
231
+ def patched_init(self, data_dir, phase, input_h=None, input_w=None, down_ratio=4):
232
+ orig_init(self, data_dir, phase, input_h, input_w, down_ratio)
233
+ if phase == "test":
234
+ self.img_ids = [name]
235
+ BaseDataset.__init__ = patched_init
236
+
237
+ with st.spinner("Running model…"):
238
+ net = Network(args)
239
+ net.test(args, save=True)
240
+
241
+ out_dir = os.path.join(REPO, f"results_{args.dataset}")
242
+ pred_file = [f for f in os.listdir(out_dir)
243
+ if f.startswith(name) and f.endswith("_pred.jpg")][0]
244
+ txtf = os.path.join(out_dir, f"{name}.txt")
245
+ imgf = os.path.join(out_dir, pred_file)
246
+
247
+ # ─── Annotated Predictions ─────────────────────────
248
+ base = cv2.imread(imgf)
249
+ txt = np.loadtxt(txtf)
250
+ tlx, tly = txt[:, 2].astype(int), txt[:, 3].astype(int)
251
+ trx, try_ = txt[:, 4].astype(int), txt[:, 5].astype(int)
252
+ blx, bly = txt[:, 6].astype(int), txt[:, 7].astype(int)
253
+ brx, bry = txt[:, 8].astype(int), txt[:, 9].astype(int)
254
+
255
+ top_pts, bot_pts, mids, dists = [], [], [], []
256
+ for (x1, y1), (x2, y2), (x3, y3), (x4, y4) in zip(
257
+ zip(tlx, tly), zip(trx, try_),
258
+ zip(blx, bly), zip(brx, bry)):
259
+ tm = np.array([(x1 + x2) / 2, (y1 + y2) / 2])
260
+ bm = np.array([(x3 + x4) / 2, (y3 + y4) / 2])
261
+ top_pts.append(tm)
262
+ bot_pts.append(bm)
263
+ mids.append((tm + bm) / 2)
264
+ dists.append(np.linalg.norm(bm - tm))
265
+
266
+ ref = dists[-1]
267
+ ann = base.copy()
268
+ for tm, bm in zip(top_pts, bot_pts):
269
+ cv2.line(ann, tuple(tm.astype(int)), tuple(bm.astype(int)), (0, 255, 255), 2)
270
+ for m, d in zip(mids, dists):
271
+ pct = (d - ref) / ref * 100
272
+ clr = (0, 255, 255) if pct <= 20 else (0, 165, 255) if pct <= 40 else (0, 0, 255)
273
+ pos = (int(m[0]) + 40, int(m[1]) + 5)
274
+ cv2.putText(ann, f"{pct:.0f}%", pos,
275
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, clr, 2, cv2.LINE_AA)
276
+
277
+ ann_resized = cv2.resize(ann, (orig_w, orig_h), interpolation=cv2.INTER_LINEAR)
278
+ with col5:
279
+ st.image(cv2.cvtColor(ann_resized, cv2.COLOR_BGR2RGB), use_container_width=True)
280
+
281
+ H, W = base.shape[:2]
282
+ heat = np.zeros((H, W), np.float32)
283
+ for cx, cy in [(int(m[0]), int(m[1])) for m in mids]:
284
+ blob = np.zeros_like(heat)
285
+ blob[cy, cx] = 1.0
286
+ heat += cv2.GaussianBlur(blob, (0, 0), sigmaX=8, sigmaY=8)
287
+ heat /= (heat.max() + 1e-8)
288
+ hm8 = (heat * 255).astype(np.uint8)
289
+ hm_c = cv2.applyColorMap(hm8, cv2.COLORMAP_JET)
290
+
291
+ raw = cv2.imread(imgf, cv2.IMREAD_GRAYSCALE)
292
+ raw_b = cv2.cvtColor(raw, cv2.COLOR_GRAY2BGR)
293
+ overlay = cv2.addWeighted(raw_b, 0.6, hm_c, 0.4, 0)
294
+ overlay_resized = cv2.resize(overlay, (orig_w, orig_h), interpolation=cv2.INTER_LINEAR)
295
+
296
+ with col6:
297
+ st.image(cv2.cvtColor(overlay_resized, cv2.COLOR_BGR2RGB), use_container_width=True)
298
+
299
+ elif feature == "AP - Cobb angle":
300
+ st.write("กำลังพัฒนา")
301
+
302
+ elif feature == "LA - Image Segmetation":
303
+ uploaded = st.file_uploader("", type=["jpg", "jpeg", "png"])
304
+ img0 = None
305
+
306
+ # ─── Maintain selected sample in session state ─────────
307
+ if "sample_img_la" not in st.session_state:
308
+ st.session_state.sample_img_la = None
309
+
310
+ # ─── SAMPLE BUTTONS ─────────────────────────────────────
311
+ with col1:
312
+ if st.button(" 1️⃣ Example ", use_container_width=True):
313
+ st.session_state.sample_img_la = "image_1_la.jpg"
314
+ with col2:
315
+ if st.button(" 2️⃣ Example ", use_container_width=True):
316
+ st.session_state.sample_img_la = "image_2_la.jpg"
317
+ with col3:
318
+ if st.button(" 3️⃣ Example ", use_container_width=True):
319
+ st.session_state.sample_img_la = "image_3_la.jpg"
320
+
321
+ # ─── UI FOR UPLOAD + DISPLAY ───────────────────────────
322
+ run_la = st.button("Enter", use_container_width=True)
323
+ col7, col8 = st.columns(2)
324
+
325
+ with col7:
326
+ st.subheader("🖼️ Original Image")
327
+
328
+ sample_img_la = st.session_state.sample_img_la # read persisted choice
329
+
330
+ # case 1: uploaded file
331
+ if uploaded:
332
+ buf = uploaded.getvalue()
333
+ img0 = Image.open(BytesIO(buf)).convert("RGB")
334
+ st.image(img0, caption="Uploaded Image", use_container_width=True)
335
+
336
+ # case 2: selected sample image
337
+ elif sample_img_la is not None:
338
+ img_path = os.path.join(REPO, sample_img_la)
339
+ if os.path.isfile(img_path):
340
+ img0 = Image.open(img_path).convert("RGB")
341
+ st.image(img0, caption=f"Sample Image: {sample_img_la}", use_container_width=True)
342
+ else:
343
+ st.error(f"Cannot find {sample_img_la} in directory!")
344
+
345
+ with col8:
346
+ st.subheader("🔎 Predicted Image")
347
+
348
+ # ─── PREDICTION ────────────────────────────────────
349
+ if img0 is not None and run_la:
350
+ img_np = np.array(img0)
351
+ model = YOLO('./best.pt') # or your correct path to best.pt
352
+ with st.spinner("Running YOLO model…"):
353
+ results = model(img_np, imgsz=640)
354
+ pred_img = results[0].plot(boxes=False, probs=False) # returns numpy image with annotations
355
+ st.image(pred_img, caption="Prediction Result", use_container_width=True)
356
+
357
+ elif feature == "Contract":
358
+ with col1:
359
+ st.image("dev_1.jpg", caption=None, use_container_width=True)
360
+ st.markdown(
361
+ """
362
+ <div style='border:2px solid #0080FF; border-radius:10px; padding:15px; text-align:center; background-color:#F0F8FF'>
363
+ <h3>Thitsanapat Uma</h3>
364
+ <a href='https://www.facebook.com/thitsanapat.uma' target='_blank'>
365
+ 🔗 Facebook Profile
366
+ </a>
367
+ </div>
368
+ """,
369
+ unsafe_allow_html=True
370
+ )
371
+ with col2:
372
+ st.image("dev_2.jpg", caption=None, use_container_width=True)
373
+ st.markdown(
374
+ """
375
+ <div style='border:2px solid #0080FF; border-radius:10px; padding:15px; text-align:center; background-color:#F0F8FF'>
376
+ <h3>Santipab Tongchan</h3>
377
+ <a href='https://www.facebook.com/santipab.tongchan.2025' target='_blank'>
378
+ 🔗 Facebook Profile
379
+ </a>
380
+ </div>
381
+ """,
382
+ unsafe_allow_html=True
383
+ )
384
+ with col3:
385
+ st.image("dev_3.jpg", caption=None, use_container_width=True)
386
+ st.markdown(
387
+ """
388
+ <div style='border:2px solid #0080FF; border-radius:10px; padding:15px; text-align:center; background-color:#F0F8FF'>
389
+ <h3>Suphanat Kamphapan</h3>
390
+ <a href='https://www.facebook.com/suphanat.kamphapan' target='_blank'>
391
+ 🔗 Facebook Profile
392
+ </a>
393
+ </div>
394
+ """,
395
+ unsafe_allow_html=True
396
+ )
397
+
398
+
399
+
best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b51882a04c2f47922248e7c57c712e659984f55a95ec55a7050109e7ae61a401
3
+ size 55847450
cobb_evaluate.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###########################################################################################
2
+ ## This code is transfered from matlab version of the MICCAI challenge
3
+ ## Oct 1 2019
4
+ ###########################################################################################
5
+ import numpy as np
6
+ import cv2
7
+
8
+
9
+ def is_S(mid_p_v):
10
+ # mid_p_v: 34 x 2
11
+ ll = []
12
+ num = mid_p_v.shape[0]
13
+ for i in range(num-2):
14
+ term1 = (mid_p_v[i, 1]-mid_p_v[num-1, 1])/(mid_p_v[0, 1]-mid_p_v[num-1, 1])
15
+ term2 = (mid_p_v[i, 0]-mid_p_v[num-1, 0])/(mid_p_v[0, 0]-mid_p_v[num-1, 0])
16
+ ll.append(term1-term2)
17
+ ll = np.asarray(ll, np.float32)[:, np.newaxis] # 32 x 1
18
+ ll_pair = np.matmul(ll, np.transpose(ll)) # 32 x 32
19
+ a = sum(sum(ll_pair))
20
+ b = sum(sum(abs(ll_pair)))
21
+ if abs(a-b)<1e-4:
22
+ return False
23
+ else:
24
+ return True
25
+
26
+ def cobb_angle_calc(pts, image):
27
+ pts = np.asarray(pts, np.float32) # 68 x 2
28
+ h,w,c = image.shape
29
+ num_pts = pts.shape[0] # number of points, 68
30
+ vnum = num_pts//4-1
31
+
32
+ mid_p_v = (pts[0::2,:]+pts[1::2,:])/2 # 34 x 2
33
+ mid_p = []
34
+ for i in range(0, num_pts, 4):
35
+ pt1 = (pts[i,:]+pts[i+2,:])/2
36
+ pt2 = (pts[i+1,:]+pts[i+3,:])/2
37
+ mid_p.append(pt1)
38
+ mid_p.append(pt2)
39
+ mid_p = np.asarray(mid_p, np.float32) # 34 x 2
40
+
41
+ for pt in mid_p:
42
+ cv2.circle(image,
43
+ (int(pt[0]), int(pt[1])),
44
+ 12, (0,255,255), -1, 1)
45
+
46
+ for pt1, pt2 in zip(mid_p[0::2,:], mid_p[1::2,:]):
47
+ cv2.line(image,
48
+ (int(pt1[0]), int(pt1[1])),
49
+ (int(pt2[0]), int(pt2[1])),
50
+ color=(0,0,255),
51
+ thickness=5, lineType=1)
52
+
53
+ vec_m = mid_p[1::2,:]-mid_p[0::2,:] # 17 x 2
54
+ dot_v = np.matmul(vec_m, np.transpose(vec_m)) # 17 x 17
55
+ mod_v = np.sqrt(np.sum(vec_m**2, axis=1))[:, np.newaxis] # 17 x 1
56
+ mod_v = np.matmul(mod_v, np.transpose(mod_v)) # 17 x 17
57
+ cosine_angles = np.clip(dot_v/mod_v, a_min=0., a_max=1.)
58
+ angles = np.arccos(cosine_angles) # 17 x 17
59
+ pos1 = np.argmax(angles, axis=1)
60
+ maxt = np.amax(angles, axis=1)
61
+ pos2 = np.argmax(maxt)
62
+ cobb_angle1 = np.amax(maxt)
63
+ cobb_angle1 = cobb_angle1/np.pi*180
64
+ flag_s = is_S(mid_p_v)
65
+ if not flag_s: # not S
66
+ # print('Not S')
67
+ cobb_angle2 = angles[0, pos2]/np.pi*180
68
+ cobb_angle3 = angles[vnum, pos1[pos2]]/np.pi*180
69
+ cv2.line(image,
70
+ (int(mid_p[pos2 * 2, 0] ), int(mid_p[pos2 * 2, 1])),
71
+ (int(mid_p[pos2 * 2 + 1, 0]), int(mid_p[pos2 * 2 + 1, 1])),
72
+ color=(0, 255, 0), thickness=5, lineType=2)
73
+ cv2.line(image,
74
+ (int(mid_p[pos1[pos2] * 2, 0]), int(mid_p[pos1[pos2] * 2, 1])),
75
+ (int(mid_p[pos1[pos2] * 2 + 1, 0]), int(mid_p[pos1[pos2] * 2 + 1, 1])),
76
+ color=(0, 255, 0), thickness=5, lineType=2)
77
+
78
+ else:
79
+ if (mid_p_v[pos2*2, 1]+mid_p_v[pos1[pos2]*2,1])<h:
80
+ # print('Is S: condition1')
81
+ angle2 = angles[pos2,:(pos2+1)]
82
+ cobb_angle2 = np.max(angle2)
83
+ pos1_1 = np.argmax(angle2)
84
+ cobb_angle2 = cobb_angle2/np.pi*180
85
+
86
+ angle3 = angles[pos1[pos2], pos1[pos2]:(vnum+1)]
87
+ cobb_angle3 = np.max(angle3)
88
+ pos1_2 = np.argmax(angle3)
89
+ cobb_angle3 = cobb_angle3/np.pi*180
90
+ pos1_2 = pos1_2 + pos1[pos2]-1
91
+
92
+ cv2.line(image,
93
+ (int(mid_p[pos1_1 * 2, 0]), int(mid_p[pos1_1 * 2, 1])),
94
+ (int(mid_p[pos1_1 * 2+1, 0]), int(mid_p[pos1_1 * 2 + 1, 1])),
95
+ color=(0, 255, 0), thickness=5, lineType=2)
96
+
97
+ cv2.line(image,
98
+ (int(mid_p[pos1_2 * 2, 0]), int(mid_p[pos1_2 * 2, 1])),
99
+ (int(mid_p[pos1_2 * 2+1, 0]), int(mid_p[pos1_2 * 2 + 1, 1])),
100
+ color=(0, 255, 0), thickness=5, lineType=2)
101
+
102
+ else:
103
+ # print('Is S: condition2')
104
+ angle2 = angles[pos2,:(pos2+1)]
105
+ cobb_angle2 = np.max(angle2)
106
+ pos1_1 = np.argmax(angle2)
107
+ cobb_angle2 = cobb_angle2/np.pi*180
108
+
109
+ angle3 = angles[pos1_1, :(pos1_1+1)]
110
+ cobb_angle3 = np.max(angle3)
111
+ pos1_2 = np.argmax(angle3)
112
+ cobb_angle3 = cobb_angle3/np.pi*180
113
+
114
+ cv2.line(image,
115
+ (int(mid_p[pos1_1 * 2, 0]), int(mid_p[pos1_1 * 2, 1])),
116
+ (int(mid_p[pos1_1 * 2+1, 0]), int(mid_p[pos1_1 * 2 + 1, 1])),
117
+ color=(0, 255, 0), thickness=5, lineType=2)
118
+
119
+ cv2.line(image,
120
+ (int(mid_p[pos1_2 * 2, 0]), int(mid_p[pos1_2 * 2, 1])),
121
+ (int(mid_p[pos1_2 * 2+1, 0]), int(mid_p[pos1_2 * 2 + 1, 1])),
122
+ color=(0, 255, 0), thickness=5, lineType=2)
123
+
124
+ return [cobb_angle1, cobb_angle2, cobb_angle3]
dataset.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch.utils.data as data
3
+ import pre_proc
4
+ import cv2
5
+ from scipy.io import loadmat
6
+ import numpy as np
7
+
8
+
9
+ def rearrange_pts(pts):
10
+ boxes = []
11
+ for k in range(0, len(pts), 4):
12
+ pts_4 = pts[k:k+4,:]
13
+ x_inds = np.argsort(pts_4[:, 0])
14
+ pt_l = np.asarray(pts_4[x_inds[:2], :])
15
+ pt_r = np.asarray(pts_4[x_inds[2:], :])
16
+ y_inds_l = np.argsort(pt_l[:,1])
17
+ y_inds_r = np.argsort(pt_r[:,1])
18
+ tl = pt_l[y_inds_l[0], :]
19
+ bl = pt_l[y_inds_l[1], :]
20
+ tr = pt_r[y_inds_r[0], :]
21
+ br = pt_r[y_inds_r[1], :]
22
+ # boxes.append([tl, tr, bl, br])
23
+ boxes.append(tl)
24
+ boxes.append(tr)
25
+ boxes.append(bl)
26
+ boxes.append(br)
27
+ return np.asarray(boxes, np.float32)
28
+
29
+
30
+ class BaseDataset(data.Dataset):
31
+ def __init__(self, data_dir, phase, input_h=None, input_w=None, down_ratio=4):
32
+ super(BaseDataset, self).__init__()
33
+ self.data_dir = data_dir
34
+ self.phase = phase
35
+ self.input_h = input_h
36
+ self.input_w = input_w
37
+ self.down_ratio = down_ratio
38
+ self.class_name = ['__background__', 'cell']
39
+ self.num_classes = 68
40
+ self.img_dir = os.path.join(data_dir, 'data', self.phase)
41
+ self.img_ids = sorted(os.listdir(self.img_dir))
42
+
43
+ def load_image(self, index):
44
+ image = cv2.imread(os.path.join(self.img_dir, self.img_ids[index]))
45
+ return image
46
+
47
+ def load_gt_pts(self, annopath):
48
+ pts = loadmat(annopath)['p2'] # num x 2 (x,y)
49
+ pts = rearrange_pts(pts)
50
+ return pts
51
+
52
+ def load_annoFolder(self, img_id):
53
+ return os.path.join(self.data_dir, 'labels', self.phase, img_id+'.mat')
54
+
55
+ def load_annotation(self, index):
56
+ img_id = self.img_ids[index]
57
+ annoFolder = self.load_annoFolder(img_id)
58
+ pts = self.load_gt_pts(annoFolder)
59
+ return pts
60
+
61
+ def __getitem__(self, index):
62
+ img_id = self.img_ids[index]
63
+ image = self.load_image(index)
64
+ if self.phase == 'test':
65
+ images = pre_proc.processing_test(image=image, input_h=self.input_h, input_w=self.input_w)
66
+ return {'images': images, 'img_id': img_id}
67
+ else:
68
+ aug_label = False
69
+ if self.phase == 'train':
70
+ aug_label = True
71
+ pts = self.load_annotation(index) # num_obj x h x w
72
+ out_image, pts_2 = pre_proc.processing_train(image=image,
73
+ pts=pts,
74
+ image_h=self.input_h,
75
+ image_w=self.input_w,
76
+ down_ratio=self.down_ratio,
77
+ aug_label=aug_label,
78
+ img_id=img_id)
79
+
80
+ data_dict = pre_proc.generate_ground_truth(image=out_image,
81
+ pts_2=pts_2,
82
+ image_h=self.input_h//self.down_ratio,
83
+ image_w=self.input_w//self.down_ratio,
84
+ img_id=img_id)
85
+ return data_dict
86
+
87
+ def __len__(self):
88
+ return len(self.img_ids)
decoder.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn.functional as F
2
+ import numpy as np
3
+ import torch
4
+
5
+ class DecDecoder(object):
6
+ def __init__(self, K, conf_thresh):
7
+ self.K = 17
8
+ self.conf_thresh = conf_thresh
9
+
10
+ def _topk(self, scores):
11
+ batch, cat, height, width = scores.size()
12
+
13
+ topk_scores, topk_inds = torch.topk(scores.view(batch, cat, -1), self.K)
14
+
15
+ topk_inds = topk_inds % (height * width)
16
+ topk_ys = (topk_inds / width).int().float()
17
+ topk_xs = (topk_inds % width).int().float()
18
+
19
+ topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), self.K)
20
+ topk_inds = self._gather_feat( topk_inds.view(batch, -1, 1), topk_ind).view(batch, self.K)
21
+ topk_ys = self._gather_feat(topk_ys.view(batch, -1, 1), topk_ind).view(batch, self.K)
22
+ topk_xs = self._gather_feat(topk_xs.view(batch, -1, 1), topk_ind).view(batch, self.K)
23
+
24
+ return topk_score, topk_inds, topk_ys, topk_xs
25
+
26
+
27
+ def _nms(self, heat, kernel=3):
28
+ hmax = F.max_pool2d(heat, (kernel, kernel), stride=1, padding=(kernel - 1) // 2)
29
+ keep = (hmax == heat).float()
30
+ return heat * keep
31
+
32
+ def _gather_feat(self, feat, ind, mask=None):
33
+ dim = feat.size(2)
34
+ ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
35
+ feat = feat.gather(1, ind)
36
+ if mask is not None:
37
+ mask = mask.unsqueeze(2).expand_as(feat)
38
+ feat = feat[mask]
39
+ feat = feat.view(-1, dim)
40
+ return feat
41
+
42
+ def _tranpose_and_gather_feat(self, feat, ind):
43
+ feat = feat.permute(0, 2, 3, 1).contiguous()
44
+ feat = feat.view(feat.size(0), -1, feat.size(3))
45
+ feat = self._gather_feat(feat, ind)
46
+ return feat
47
+
48
+ def ctdet_decode(self, heat, wh, reg):
49
+ # output: num_obj x 7
50
+ # 7: cenx, ceny, w, h, angle, score, cls
51
+ batch, c, height, width = heat.size()
52
+ heat = self._nms(heat) # [1, 1, 256, 128]
53
+ scores, inds, ys, xs = self._topk(heat)
54
+ scores = scores.view(batch, self.K, 1)
55
+ reg = self._tranpose_and_gather_feat(reg, inds)
56
+ reg = reg.view(batch, self.K, 2)
57
+ xs = xs.view(batch, self.K, 1) + reg[:, :, 0:1]
58
+ ys = ys.view(batch, self.K, 1) + reg[:, :, 1:2]
59
+ wh = self._tranpose_and_gather_feat(wh, inds)
60
+ wh = wh.view(batch, self.K, 2*4)
61
+
62
+ tl_x = xs - wh[:,:,0:1]
63
+ tl_y = ys - wh[:,:,1:2]
64
+ tr_x = xs - wh[:,:,2:3]
65
+ tr_y = ys - wh[:,:,3:4]
66
+ bl_x = xs - wh[:,:,4:5]
67
+ bl_y = ys - wh[:,:,5:6]
68
+ br_x = xs - wh[:,:,6:7]
69
+ br_y = ys - wh[:,:,7:8]
70
+
71
+ pts = torch.cat([xs, ys,
72
+ tl_x,tl_y,
73
+ tr_x,tr_y,
74
+ bl_x,bl_y,
75
+ br_x,br_y,
76
+ scores], dim=2).squeeze(0)
77
+ return pts.data.cpu().numpy()
dev_1.jpg ADDED
dev_2.jpg ADDED
dev_3.jpg ADDED

Git LFS Details

  • SHA256: 9773ecf6b207c51c56fa7ab9b4d9761ca0154a257a83656318e4aaf731a99633
  • Pointer size: 131 Bytes
  • Size of remote file: 100 kB
draw_gaussian.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def gaussian_radius(det_size, min_overlap=0.7):
5
+ height, width = det_size
6
+
7
+ a1 = 1
8
+ b1 = (height + width)
9
+ c1 = width * height * (1 - min_overlap) / (1 + min_overlap)
10
+ sq1 = np.sqrt(b1 ** 2 - 4 * a1 * c1)
11
+ r1 = (b1 + sq1) / 2
12
+
13
+ a2 = 4
14
+ b2 = 2 * (height + width)
15
+ c2 = (1 - min_overlap) * width * height
16
+ sq2 = np.sqrt(b2 ** 2 - 4 * a2 * c2)
17
+ r2 = (b2 + sq2) / 2
18
+
19
+ a3 = 4 * min_overlap
20
+ b3 = -2 * min_overlap * (height + width)
21
+ c3 = (min_overlap - 1) * width * height
22
+ sq3 = np.sqrt(b3 ** 2 - 4 * a3 * c3)
23
+ r3 = (b3 + sq3) / 2
24
+ return min(r1, r2, r3)
25
+
26
+ def gaussian2D(shape, sigma=1):
27
+ m, n = [(ss - 1.) / 2. for ss in shape]
28
+ y, x = np.ogrid[-m:m+1,-n:n+1]
29
+
30
+ h = np.exp(-(x * x + y * y) / (2 * sigma * sigma))
31
+ h[h < np.finfo(h.dtype).eps * h.max()] = 0
32
+ return h
33
+
34
+ def draw_umich_gaussian(heatmap, center, radius, k=1):
35
+ diameter = 2 * radius + 1
36
+ gaussian = gaussian2D((diameter, diameter), sigma=diameter / 6)
37
+
38
+ x, y = int(center[0]), int(center[1])
39
+
40
+ height, width = heatmap.shape[0:2]
41
+
42
+ left, right = min(x, radius), min(width - x, radius + 1)
43
+ top, bottom = min(y, radius), min(height - y, radius + 1)
44
+
45
+ masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
46
+ masked_gaussian = gaussian[radius - top:radius + bottom, radius - left:radius + right]
47
+ if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0: # TODO debug
48
+ np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap)
49
+ return heatmap
draw_loss.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ import os
4
+
5
+ def load_data(filename):
6
+ pts = []
7
+ f = open(filename, "rb")
8
+ for line in f:
9
+ pts.append(float(line.strip()))
10
+ f.close()
11
+ return pts
12
+
13
+ dataset = 'spinal'
14
+ weights_path = 'weights_'+dataset
15
+
16
+ ###############################################
17
+ # Load data
18
+ train_pts = load_data(os.path.join(weights_path, 'train_loss.txt'))
19
+ val_pts = load_data(os.path.join(weights_path, 'val_loss.txt'))
20
+
21
+ def draw_loss():
22
+ x = np.linspace(0, len(train_pts), len(train_pts))
23
+ plt.plot(x,train_pts,'ro-',label='train')
24
+ plt.plot(x,val_pts,'bo-',label='val')
25
+ # plt.axis([0, 50, 9.25, 11])
26
+ plt.legend(loc='upper right')
27
+
28
+ plt.xlabel('Epochs')
29
+ plt.ylabel('Loss')
30
+
31
+ plt.show()
32
+
33
+
34
+ def draw_loss_ap():
35
+ ap05_pts = load_data(os.path.join(weights_path, 'ap_05_list.txt'))
36
+ ap07_pts = load_data(os.path.join(weights_path, 'ap_07_list.txt'))
37
+
38
+ x = np.linspace(0,len(train_pts),len(train_pts))
39
+ x1 = np.linspace(0, len(train_pts), len(ap05_pts))
40
+
41
+ fig, ax1 = plt.subplots()
42
+
43
+ color = 'tab:red'
44
+ ax1.set_xlabel('Epochs')
45
+ ax1.set_ylabel('Loss', color=color)
46
+ ax1.plot(x, train_pts, 'ro-',label='train')
47
+ ax1.plot(x, val_pts, 'bo-',label='val')
48
+ ax1.tick_params(axis='y', labelcolor=color)
49
+ plt.legend(loc = 'lower right')
50
+ ax2 = ax1.twinx() # instantiate a second axes that shares the same x-axis
51
+ color = 'tab:blue'
52
+ ax2.set_ylabel('AP', color=color) # we already handled the x-label with ax1
53
+ ax2.plot(x1, ap05_pts, 'go-',label='AP@05')
54
+ ax2.plot(x1, ap07_pts, 'yo-', label='AP@07')
55
+ ax2.tick_params(axis='y', labelcolor=color)
56
+
57
+ fig.tight_layout() # otherwise the right y-label is slightly clipped
58
+ plt.legend(loc = 'upper right')
59
+ plt.show()
60
+
61
+
62
+ if __name__ == '__main__':
63
+ draw_loss()
64
+ # draw_loss_ap()
draw_points.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+ colors = [[0.76590096, 0.0266074, 0.9806378],
5
+ [0.54197179, 0.81682527, 0.95081629],
6
+ [0.0799733, 0.79737015, 0.15173816],
7
+ [0.93240442, 0.8993321, 0.09901344],
8
+ [0.73130136, 0.05366301, 0.98405681],
9
+ [0.01664966, 0.16387004, 0.94158259],
10
+ [0.54197179, 0.81682527, 0.45081629],
11
+ # [0.92074915, 0.09919099 ,0.97590748],
12
+ [0.83445145, 0.97921679, 0.12250426],
13
+ [0.7300924, 0.23253621, 0.29764521],
14
+ [0.3856775, 0.94859286, 0.9910683], # 10
15
+ [0.45762137, 0.03766411, 0.98755338],
16
+ [0.99496697, 0.09113071, 0.83322314],
17
+ [0.96478873, 0.0233309, 0.13149931],
18
+ [0.33240442, 0.9993321 , 0.59901344],
19
+ # [0.77690519,0.81783954,0.56220024],
20
+ # [0.93240442, 0.8993321, 0.09901344],
21
+ [0.95815068, 0.88436046, 0.55782268],
22
+ [0.03728425, 0.0618827, 0.88641827],
23
+ [0.05281129, 0.89572238, 0.08913828],
24
+
25
+ ]
26
+
27
+
28
+
29
+ def draw_landmarks_regress_test(pts0, ori_image_regress, ori_image_points):
30
+ for i, pt in enumerate(pts0):
31
+ # color = np.random.rand(3)
32
+ color = colors[i]
33
+ # print(i+1, color)
34
+ color_255 = (255 * color[0], 255 * color[1], 255 * color[2])
35
+ cv2.circle(ori_image_regress, (int(pt[0]), int(pt[1])), 6, color_255, -1, 1)
36
+ # cv2.circle(ori_image, (int(pt[2]), int(pt[3])), 5, color_255, -1,1)
37
+ # cv2.circle(ori_image, (int(pt[4]), int(pt[5])), 5, color_255, -1,1)
38
+ # cv2.circle(ori_image, (int(pt[6]), int(pt[7])), 5, color_255, -1,1)
39
+ # cv2.circle(ori_image, (int(pt[8]), int(pt[9])), 5, color_255, -1,1)
40
+ cv2.arrowedLine(ori_image_regress, (int(pt[0]), int(pt[1])), (int(pt[2]), int(pt[3])), color_255, 2, 1,
41
+ tipLength=0.2)
42
+ cv2.arrowedLine(ori_image_regress, (int(pt[0]), int(pt[1])), (int(pt[4]), int(pt[5])), color_255, 2, 1,
43
+ tipLength=0.2)
44
+ cv2.arrowedLine(ori_image_regress, (int(pt[0]), int(pt[1])), (int(pt[6]), int(pt[7])), color_255, 2, 1,
45
+ tipLength=0.2)
46
+ cv2.arrowedLine(ori_image_regress, (int(pt[0]), int(pt[1])), (int(pt[8]), int(pt[9])), color_255, 2, 1,
47
+ tipLength=0.2)
48
+ cv2.putText(ori_image_regress, '{}'.format(i + 1),
49
+ (int(pt[4] + 10), int(pt[5] + 10)),
50
+ cv2.FONT_HERSHEY_DUPLEX,
51
+ 1.2,
52
+ color_255, # (255,255,255),
53
+ 1,
54
+ 1)
55
+ # cv2.circle(ori_image, (int(pt[0]), int(pt[1])), 6, (255,255,255), -1,1)
56
+ cv2.circle(ori_image_points, (int(pt[2]), int(pt[3])), 5, color_255, -1, 1)
57
+ cv2.circle(ori_image_points, (int(pt[4]), int(pt[5])), 5, color_255, -1, 1)
58
+ cv2.circle(ori_image_points, (int(pt[6]), int(pt[7])), 5, color_255, -1, 1)
59
+ cv2.circle(ori_image_points, (int(pt[8]), int(pt[9])), 5, color_255, -1, 1)
60
+ return ori_image_regress, ori_image_points
61
+
62
+
63
+
64
+ def draw_landmarks_pre_proc(out_image, pts):
65
+ for i in range(17):
66
+ pts_4 = pts[4 * i:4 * i + 4, :]
67
+ color = colors[i]
68
+ color_255 = (255 * color[0], 255 * color[1], 255 * color[2])
69
+ cv2.circle(out_image, (int(pts_4[0, 0]), int(pts_4[0, 1])), 5, color_255, -1, 1)
70
+ cv2.circle(out_image, (int(pts_4[1, 0]), int(pts_4[1, 1])), 5, color_255, -1, 1)
71
+ cv2.circle(out_image, (int(pts_4[2, 0]), int(pts_4[2, 1])), 5, color_255, -1, 1)
72
+ cv2.circle(out_image, (int(pts_4[3, 0]), int(pts_4[3, 1])), 5, color_255, -1, 1)
73
+ return np.uint8(out_image)
74
+
75
+
76
+ def draw_regress_pre_proc(out_image, pts):
77
+ for i in range(17):
78
+ pts_4 = pts[4 * i:4 * i + 4, :]
79
+ pt = np.mean(pts_4, axis=0)
80
+ color = colors[i]
81
+ color_255 = (255 * color[0], 255 * color[1], 255 * color[2])
82
+ cv2.arrowedLine(out_image, (int(pt[0]), int(pt[1])), (int(pts_4[0, 0]), int(pts_4[0, 1])), color_255, 2, 1,
83
+ tipLength=0.2)
84
+ cv2.arrowedLine(out_image, (int(pt[0]), int(pt[1])), (int(pts_4[1, 0]), int(pts_4[1, 1])), color_255, 2, 1,
85
+ tipLength=0.2)
86
+ cv2.arrowedLine(out_image, (int(pt[0]), int(pt[1])), (int(pts_4[2, 0]), int(pts_4[2, 1])), color_255, 2, 1,
87
+ tipLength=0.2)
88
+ cv2.arrowedLine(out_image, (int(pt[0]), int(pt[1])), (int(pts_4[3, 0]), int(pts_4[3, 1])), color_255, 2, 1,
89
+ tipLength=0.2)
90
+ cv2.putText(out_image, '{}'.format(i + 1), (int(pts_4[1, 0] + 10), int(pts_4[1, 1] + 10)),
91
+ cv2.FONT_HERSHEY_DUPLEX, 1.2, color_255, 1, 1)
92
+ return np.uint8(out_image)
eval.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from models import spinal_net
4
+ import decoder
5
+ import os
6
+ from dataset import BaseDataset
7
+ import time
8
+ import cobb_evaluate
9
+
10
+ def apply_mask(image, mask, alpha=0.5):
11
+ """Apply the given mask to the image.
12
+ """
13
+ color = np.random.rand(3)
14
+ for c in range(3):
15
+ image[:, :, c] = np.where(mask == 1,
16
+ image[:, :, c] *
17
+ (1 - alpha) + alpha * color[c] * 255,
18
+ image[:, :, c])
19
+ return image
20
+
21
+ class Network(object):
22
+ def __init__(self, args):
23
+ torch.manual_seed(317)
24
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
25
+ heads = {'hm': args.num_classes, # cen, tl, tr, bl, br
26
+ 'reg': 2*args.num_classes,
27
+ 'wh': 2*4,}
28
+
29
+ self.model = spinal_net.SpineNet(heads=heads,
30
+ pretrained=True,
31
+ down_ratio=args.down_ratio,
32
+ final_kernel=1,
33
+ head_conv=256)
34
+ self.num_classes = args.num_classes
35
+ self.decoder = decoder.DecDecoder(K=args.K, conf_thresh=args.conf_thresh)
36
+ self.dataset = {'spinal': BaseDataset}
37
+
38
+ def load_model(self, model, resume):
39
+ checkpoint = torch.load(resume, map_location=lambda storage, loc: storage)
40
+ print('loaded weights from {}, epoch {}'.format(resume, checkpoint['epoch']))
41
+ state_dict_ = checkpoint['state_dict']
42
+ model.load_state_dict(state_dict_, strict=False)
43
+ return model
44
+
45
+
46
+ def eval(self, args, save):
47
+ save_path = 'weights_'+args.dataset
48
+ self.model = self.load_model(self.model, os.path.join(save_path, args.resume))
49
+ self.model = self.model.to(self.device)
50
+ self.model.eval()
51
+
52
+ dataset_module = self.dataset[args.dataset]
53
+ dsets = dataset_module(data_dir=args.data_dir,
54
+ phase='test',
55
+ input_h=args.input_h,
56
+ input_w=args.input_w,
57
+ down_ratio=args.down_ratio)
58
+
59
+ data_loader = torch.utils.data.DataLoader(dsets,
60
+ batch_size=1,
61
+ shuffle=False,
62
+ num_workers=1,
63
+ pin_memory=True)
64
+
65
+ total_time = []
66
+ landmark_dist = []
67
+ pr_cobb_angles = []
68
+ gt_cobb_angles = []
69
+ for cnt, data_dict in enumerate(data_loader):
70
+ begin_time = time.time()
71
+ images = data_dict['images'][0]
72
+ img_id = data_dict['img_id'][0]
73
+ images = images.to('cuda')
74
+ print('processing {}/{} image ...'.format(cnt, len(data_loader)))
75
+
76
+ with torch.no_grad():
77
+ output = self.model(images)
78
+ hm = output['hm']
79
+ wh = output['wh']
80
+ reg = output['reg']
81
+ torch.cuda.synchronize(self.device)
82
+ pts2 = self.decoder.ctdet_decode(hm, wh, reg) # 17, 11
83
+ pts0 = pts2.copy()
84
+ pts0[:,:10] *= args.down_ratio
85
+ x_index = range(0,10,2)
86
+ y_index = range(1,10,2)
87
+ ori_image = dsets.load_image(dsets.img_ids.index(img_id)).copy()
88
+ h,w,c = ori_image.shape
89
+ pts0[:, x_index] = pts0[:, x_index]/args.input_w*w
90
+ pts0[:, y_index] = pts0[:, y_index]/args.input_h*h
91
+ # sort the y axis
92
+ sort_ind = np.argsort(pts0[:,1])
93
+ pts0 = pts0[sort_ind]
94
+ pr_landmarks = []
95
+ for i, pt in enumerate(pts0):
96
+ pr_landmarks.append(pt[2:4])
97
+ pr_landmarks.append(pt[4:6])
98
+ pr_landmarks.append(pt[6:8])
99
+ pr_landmarks.append(pt[8:10])
100
+ pr_landmarks = np.asarray(pr_landmarks, np.float32) #[68, 2]
101
+
102
+ end_time = time.time()
103
+ total_time.append(end_time-begin_time)
104
+
105
+ gt_landmarks = dsets.load_gt_pts(dsets.load_annoFolder(img_id))
106
+ for pr_pt, gt_pt in zip(pr_landmarks, gt_landmarks):
107
+ landmark_dist.append(np.sqrt((pr_pt[0]-gt_pt[0])**2+(pr_pt[1]-gt_pt[1])**2))
108
+
109
+ pr_cobb_angles.append(cobb_evaluate.cobb_angle_calc(pr_landmarks, ori_image))
110
+ gt_cobb_angles.append(cobb_evaluate.cobb_angle_calc(gt_landmarks, ori_image))
111
+
112
+ pr_cobb_angles = np.asarray(pr_cobb_angles, np.float32)
113
+ gt_cobb_angles = np.asarray(gt_cobb_angles, np.float32)
114
+
115
+ out_abs = abs(gt_cobb_angles - pr_cobb_angles)
116
+ out_add = gt_cobb_angles + pr_cobb_angles
117
+
118
+ term1 = np.sum(out_abs, axis=1)
119
+ term2 = np.sum(out_add, axis=1)
120
+
121
+ SMAPE = np.mean(term1 / term2 * 100)
122
+
123
+ print('mse of landmarkds is {}'.format(np.mean(landmark_dist)))
124
+ print('SMAPE is {}'.format(SMAPE))
125
+
126
+ total_time = total_time[1:]
127
+ print('avg time is {}'.format(np.mean(total_time)))
128
+ print('FPS is {}'.format(1./np.mean(total_time)))
129
+
130
+
131
+ def SMAPE_single_angle(self, gt_cobb_angles, pr_cobb_angles):
132
+ out_abs = abs(gt_cobb_angles - pr_cobb_angles)
133
+ out_add = gt_cobb_angles + pr_cobb_angles
134
+
135
+ term1 = out_abs
136
+ term2 = out_add
137
+
138
+ term2[term2==0] += 1e-5
139
+
140
+ SMAPE = np.mean(term1 / term2 * 100)
141
+ return SMAPE
142
+
143
+ def eval_three_angles(self, args, save):
144
+ save_path = 'weights_'+args.dataset
145
+ self.model = self.load_model(self.model, os.path.join(save_path, args.resume))
146
+ self.model = self.model.to(self.device)
147
+ self.model.eval()
148
+
149
+ dataset_module = self.dataset[args.dataset]
150
+ dsets = dataset_module(data_dir=args.data_dir,
151
+ phase='test',
152
+ input_h=args.input_h,
153
+ input_w=args.input_w,
154
+ down_ratio=args.down_ratio)
155
+
156
+ data_loader = torch.utils.data.DataLoader(dsets,
157
+ batch_size=1,
158
+ shuffle=False,
159
+ num_workers=1,
160
+ pin_memory=True)
161
+
162
+ total_time = []
163
+ landmark_dist = []
164
+ pr_cobb_angles = []
165
+ gt_cobb_angles = []
166
+ for cnt, data_dict in enumerate(data_loader):
167
+ begin_time = time.time()
168
+ images = data_dict['images'][0]
169
+ img_id = data_dict['img_id'][0]
170
+ images = images.to('cuda')
171
+ print('processing {}/{} image ...'.format(cnt, len(data_loader)))
172
+
173
+ with torch.no_grad():
174
+ output = self.model(images)
175
+ hm = output['hm']
176
+ wh = output['wh']
177
+ reg = output['reg']
178
+ torch.cuda.synchronize(self.device)
179
+ pts2 = self.decoder.ctdet_decode(hm, wh, reg) # 17, 11
180
+ pts0 = pts2.copy()
181
+ pts0[:,:10] *= args.down_ratio
182
+ x_index = range(0,10,2)
183
+ y_index = range(1,10,2)
184
+ ori_image = dsets.load_image(dsets.img_ids.index(img_id)).copy()
185
+ h,w,c = ori_image.shape
186
+ pts0[:, x_index] = pts0[:, x_index]/args.input_w*w
187
+ pts0[:, y_index] = pts0[:, y_index]/args.input_h*h
188
+ # sort the y axis
189
+ sort_ind = np.argsort(pts0[:,1])
190
+ pts0 = pts0[sort_ind]
191
+ pr_landmarks = []
192
+ for i, pt in enumerate(pts0):
193
+ pr_landmarks.append(pt[2:4])
194
+ pr_landmarks.append(pt[4:6])
195
+ pr_landmarks.append(pt[6:8])
196
+ pr_landmarks.append(pt[8:10])
197
+ pr_landmarks = np.asarray(pr_landmarks, np.float32) #[68, 2]
198
+
199
+ end_time = time.time()
200
+ total_time.append(end_time-begin_time)
201
+
202
+ gt_landmarks = dsets.load_gt_pts(dsets.load_annoFolder(img_id))
203
+ for pr_pt, gt_pt in zip(pr_landmarks, gt_landmarks):
204
+ landmark_dist.append(np.sqrt((pr_pt[0]-gt_pt[0])**2+(pr_pt[1]-gt_pt[1])**2))
205
+
206
+ pr_cobb_angles.append(cobb_evaluate.cobb_angle_calc(pr_landmarks, ori_image))
207
+ gt_cobb_angles.append(cobb_evaluate.cobb_angle_calc(gt_landmarks, ori_image))
208
+
209
+ pr_cobb_angles = np.asarray(pr_cobb_angles, np.float32)
210
+ gt_cobb_angles = np.asarray(gt_cobb_angles, np.float32)
211
+
212
+
213
+ print('SMAPE1 is {}'.format(self.SMAPE_single_angle(gt_cobb_angles[:,0], pr_cobb_angles[:,0])))
214
+ print('SMAPE2 is {}'.format(self.SMAPE_single_angle(gt_cobb_angles[:,1], pr_cobb_angles[:,1])))
215
+ print('SMAPE3 is {}'.format(self.SMAPE_single_angle(gt_cobb_angles[:,2], pr_cobb_angles[:,2])))
216
+
217
+ print('mse of landmarkds is {}'.format(np.mean(landmark_dist)))
218
+
219
+ total_time = total_time[1:]
220
+ print('avg time is {}'.format(np.mean(total_time)))
221
+ print('FPS is {}'.format(1./np.mean(total_time)))
222
+
image_1.jpg ADDED

Git LFS Details

  • SHA256: 4a43c63edf522f9757b2fe688a288432991d5abb2c5e19c4588dbbbd570ddca0
  • Pointer size: 131 Bytes
  • Size of remote file: 119 kB
image_1_la.jpg ADDED
image_2.jpg ADDED
image_2_la.jpg ADDED
image_3.jpg ADDED

Git LFS Details

  • SHA256: 496588ac26ad1b7f0d0d69e807a630576bbad595588a91a35648ae69baa47814
  • Pointer size: 131 Bytes
  • Size of remote file: 124 kB
image_3_la.jpg ADDED
loss.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class RegL1Loss(nn.Module):
7
+ def __init__(self):
8
+ super(RegL1Loss, self).__init__()
9
+
10
+ def _gather_feat(self, feat, ind, mask=None):
11
+ dim = feat.size(2)
12
+ ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
13
+ feat = feat.gather(1, ind)
14
+ if mask is not None:
15
+ mask = mask.unsqueeze(2).expand_as(feat)
16
+ feat = feat[mask]
17
+ feat = feat.view(-1, dim)
18
+ return feat
19
+
20
+ def _tranpose_and_gather_feat(self, feat, ind):
21
+ feat = feat.permute(0, 2, 3, 1).contiguous()
22
+ feat = feat.view(feat.size(0), -1, feat.size(3))
23
+ feat = self._gather_feat(feat, ind)
24
+ return feat
25
+
26
+ def forward(self, output, mask, ind, target):
27
+ pred = self._tranpose_and_gather_feat(output, ind)
28
+ mask = mask.unsqueeze(2).expand_as(pred).float()
29
+ loss = F.l1_loss(pred * mask, target * mask, reduction='sum')
30
+ loss = loss / (mask.sum() + 1e-4)
31
+ return loss
32
+
33
+ class FocalLoss(nn.Module):
34
+ def __init__(self):
35
+ super(FocalLoss, self).__init__()
36
+
37
+ def forward(self, pred, gt):
38
+ pos_inds = gt.eq(1).float()
39
+ neg_inds = gt.lt(1).float()
40
+ neg_weights = torch.pow(1 - gt, 4)
41
+
42
+ loss = 0
43
+
44
+ pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
45
+ neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds
46
+
47
+ num_pos = pos_inds.float().sum()
48
+ pos_loss = pos_loss.sum()
49
+ neg_loss = neg_loss.sum()
50
+
51
+ if num_pos == 0:
52
+ loss = loss - neg_loss
53
+ else:
54
+ loss = loss - (pos_loss + neg_loss) / num_pos
55
+ return loss
56
+
57
+ class LossAll(torch.nn.Module):
58
+ def __init__(self):
59
+ super(LossAll, self).__init__()
60
+ self.L_hm = FocalLoss()
61
+ self.L_off = RegL1Loss()
62
+ self.L_wh = RegL1Loss()
63
+
64
+ def forward(self, pr_decs, gt_batch):
65
+ hm_loss = self.L_hm(pr_decs['hm'], gt_batch['hm'])
66
+ wh_loss = self.L_wh(pr_decs['wh'], gt_batch['reg_mask'], gt_batch['ind'], gt_batch['wh'])
67
+ off_loss = self.L_off(pr_decs['reg'], gt_batch['reg_mask'], gt_batch['ind'], gt_batch['reg'])
68
+ loss_dec = hm_loss + off_loss + wh_loss
69
+ return loss_dec
main.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import train
3
+ import test
4
+ import eval
5
+
6
+ def parse_args():
7
+ parser = argparse.ArgumentParser(description='CenterNet Modification Implementation')
8
+ parser.add_argument('--num_epoch', type=int, default=50, help='Number of epochs')
9
+ parser.add_argument('--batch_size', type=int, default=2, help='Number of epochs')
10
+ parser.add_argument('--num_workers', type=int, default=4, help='Number of workers')
11
+ parser.add_argument('--init_lr', type=float, default=1.25e-4, help='Init learning rate')
12
+ parser.add_argument('--down_ratio', type=int, default=4, help='down ratio')
13
+ parser.add_argument('--input_h', type=int, default=1024, help='input height')
14
+ parser.add_argument('--input_w', type=int, default=512, help='input width')
15
+ parser.add_argument('--K', type=int, default=100, help='maximum of objects')
16
+ parser.add_argument('--conf_thresh', type=float, default=0.2, help='confidence threshold')
17
+ parser.add_argument('--seg_thresh', type=float, default=0.5, help='confidence threshold')
18
+ parser.add_argument('--num_classes', type=int, default=1, help='number of classes')
19
+ parser.add_argument('--ngpus', type=int, default=0, help='number of gpus')
20
+ parser.add_argument('--resume', type=str, default='model_last.pth', help='weights to be resumed')
21
+ parser.add_argument('--data_dir', type=str, default='../../Datasets/spinal/', help='data directory')
22
+ parser.add_argument('--phase', type=str, default='test', help='data directory')
23
+ parser.add_argument('--dataset', type=str, default='spinal', help='data directory')
24
+ args = parser.parse_args()
25
+ return args
26
+
27
+
28
+
29
+ if __name__ == '__main__':
30
+ args = parse_args()
31
+ if args.phase == 'train':
32
+ is_object = train.Network(args)
33
+ is_object.train_network(args)
34
+ elif args.phase == 'test':
35
+ is_object = test.Network(args)
36
+ is_object.test(args, save=True)
37
+ elif args.phase == 'eval':
38
+ is_object = eval.Network(args)
39
+ is_object.eval(args, save=True)
40
+ # is_object.eval_three_angles(args, save=False)
make_requirements.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import os
3
+ import ast
4
+ import sys
5
+ from importlib import metadata
6
+
7
+ # --- CONFIGURE THIS ---
8
+ PROJECT_PATH = r"C:\Users\santi\Desktop\Oto\Vertebra-Landmark-Detection"
9
+ OUTPUT_FILE = os.path.join(PROJECT_PATH, "requirements.txt")
10
+ # ----------------------
11
+
12
+ def find_py_files(root):
13
+ for dirpath, dirnames, filenames in os.walk(root):
14
+ # skip __pycache__
15
+ dirnames[:] = [d for d in dirnames if d != "__pycache__"]
16
+ for fname in filenames:
17
+ if fname.endswith(".py"):
18
+ yield os.path.join(dirpath, fname)
19
+
20
+ def collect_imports(py_path):
21
+ with open(py_path, "r", encoding="utf8") as f:
22
+ node = ast.parse(f.read(), filename=py_path)
23
+ imports = set()
24
+ for stmt in ast.walk(node):
25
+ if isinstance(stmt, ast.Import):
26
+ for n in stmt.names:
27
+ imports.add(n.name.split(".")[0])
28
+ elif isinstance(stmt, ast.ImportFrom):
29
+ if stmt.module and stmt.level == 0:
30
+ imports.add(stmt.module.split(".")[0])
31
+ return imports
32
+
33
+ def is_local_module(mod_name, project_root):
34
+ # if there's a folder or file matching mod_name in project, treat as local
35
+ path1 = os.path.join(project_root, mod_name + ".py")
36
+ path2 = os.path.join(project_root, mod_name)
37
+ return os.path.exists(path1) or os.path.exists(path2)
38
+
39
+ def main():
40
+ all_imports = set()
41
+ for py in find_py_files(PROJECT_PATH):
42
+ all_imports |= collect_imports(py)
43
+
44
+ # filter out builtins, stdlib, and local modules
45
+ externals = set()
46
+ for mod in sorted(all_imports):
47
+ if is_local_module(mod, PROJECT_PATH):
48
+ continue
49
+ try:
50
+ # try to see if it's installed as a distribution
51
+ dist = metadata.distribution(mod)
52
+ externals.add(f"{dist.metadata['Name']}=={dist.version}")
53
+ except metadata.PackageNotFoundError:
54
+ # not a top-level distribution, maybe stdlib or nested import
55
+ # skip modules that come with the stdlib
56
+ # crude check: if we can import and it's in stdlib path, skip
57
+ try:
58
+ m = __import__(mod)
59
+ if hasattr(m, "__file__") and "site-packages" in (m.__file__ or ""):
60
+ # lives in site-packages but dist metadata missing: include without version
61
+ externals.add(mod)
62
+ except ImportError:
63
+ pass
64
+
65
+ # write requirements.txt
66
+ with open(OUTPUT_FILE, "w", encoding="utf8") as out:
67
+ for line in sorted(externals):
68
+ out.write(line + "\n")
69
+
70
+ print(f"Written {len(externals)} packages to {OUTPUT_FILE}")
71
+
72
+ if __name__ == "__main__":
73
+ main()
pre_proc.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ from draw_gaussian import *
4
+ import transform
5
+ import math
6
+
7
+
8
+ def processing_test(image, input_h, input_w):
9
+ image = cv2.resize(image, (input_w, input_h))
10
+ out_image = image.astype(np.float32) / 255.
11
+ out_image = out_image - 0.5
12
+ out_image = out_image.transpose(2, 0, 1).reshape(1, 3, input_h, input_w)
13
+ out_image = torch.from_numpy(out_image)
14
+ return out_image
15
+
16
+
17
+ def draw_spinal(pts, out_image):
18
+ colors = [(0, 0, 255), (0, 255, 255), (255, 0, 255), (0, 255, 0)]
19
+ for i in range(4):
20
+ cv2.circle(out_image, (int(pts[i, 0]), int(pts[i, 1])), 3, colors[i], 1, 1)
21
+ cv2.putText(out_image, '{}'.format(i+1), (int(pts[i, 0]), int(pts[i, 1])),
22
+ cv2.FONT_HERSHEY_SIMPLEX, 0.3, (0,0,0),1,1)
23
+ for i,j in zip([0,1,2,3], [1,2,3,0]):
24
+ cv2.line(out_image,
25
+ (int(pts[i, 0]), int(pts[i, 1])),
26
+ (int(pts[j, 0]), int(pts[j, 1])),
27
+ color=colors[i], thickness=1, lineType=1)
28
+ return out_image
29
+
30
+
31
+ def rearrange_pts(pts):
32
+ # rearrange left right sequence
33
+ boxes = []
34
+ centers = []
35
+ for k in range(0, len(pts), 4):
36
+ pts_4 = pts[k:k+4,:]
37
+ x_inds = np.argsort(pts_4[:, 0])
38
+ pt_l = np.asarray(pts_4[x_inds[:2], :])
39
+ pt_r = np.asarray(pts_4[x_inds[2:], :])
40
+ y_inds_l = np.argsort(pt_l[:,1])
41
+ y_inds_r = np.argsort(pt_r[:,1])
42
+ tl = pt_l[y_inds_l[0], :]
43
+ bl = pt_l[y_inds_l[1], :]
44
+ tr = pt_r[y_inds_r[0], :]
45
+ br = pt_r[y_inds_r[1], :]
46
+ # boxes.append([tl, tr, bl, br])
47
+ boxes.append(tl)
48
+ boxes.append(tr)
49
+ boxes.append(bl)
50
+ boxes.append(br)
51
+ centers.append(np.mean(pts_4, axis=0))
52
+ bboxes = np.asarray(boxes, np.float32)
53
+ # rearrange top to bottom sequence
54
+ centers = np.asarray(centers, np.float32)
55
+ sort_tb = np.argsort(centers[:,1])
56
+ new_bboxes = []
57
+ for sort_i in sort_tb:
58
+ new_bboxes.append(bboxes[4*sort_i, :])
59
+ new_bboxes.append(bboxes[4*sort_i+1, :])
60
+ new_bboxes.append(bboxes[4*sort_i+2, :])
61
+ new_bboxes.append(bboxes[4*sort_i+3, :])
62
+ new_bboxes = np.asarray(new_bboxes, np.float32)
63
+ return new_bboxes
64
+
65
+
66
+ def generate_ground_truth(image,
67
+ pts_2,
68
+ image_h,
69
+ image_w,
70
+ img_id):
71
+ hm = np.zeros((1, image_h, image_w), dtype=np.float32)
72
+ wh = np.zeros((17, 2*4), dtype=np.float32)
73
+ reg = np.zeros((17, 2), dtype=np.float32)
74
+ ind = np.zeros((17), dtype=np.int64)
75
+ reg_mask = np.zeros((17), dtype=np.uint8)
76
+
77
+ if pts_2[:,0].max()>image_w:
78
+ print('w is big', pts_2[:,0].max())
79
+ if pts_2[:,1].max()>image_h:
80
+ print('h is big', pts_2[:,1].max())
81
+
82
+ if pts_2.shape[0]!=68:
83
+ print('ATTENTION!! image {} pts does not equal to 68!!! '.format(img_id))
84
+
85
+ for k in range(17):
86
+ pts = pts_2[4*k:4*k+4,:]
87
+ bbox_h = np.mean([np.sqrt(np.sum((pts[0,:]-pts[2,:])**2)),
88
+ np.sqrt(np.sum((pts[1,:]-pts[3,:])**2))])
89
+ bbox_w = np.mean([np.sqrt(np.sum((pts[0,:]-pts[1,:])**2)),
90
+ np.sqrt(np.sum((pts[2,:]-pts[3,:])**2))])
91
+ cen_x, cen_y = np.mean(pts, axis=0)
92
+ ct = np.asarray([cen_x, cen_y], dtype=np.float32)
93
+ ct_int = ct.astype(np.int32)
94
+ radius = gaussian_radius((math.ceil(bbox_h), math.ceil(bbox_w)))
95
+ radius = max(0, int(radius))
96
+ draw_umich_gaussian(hm[0,:,:], ct_int, radius=radius)
97
+ ind[k] = ct_int[1] * image_w + ct_int[0]
98
+ reg[k] = ct - ct_int
99
+ reg_mask[k] = 1
100
+ for i in range(4):
101
+ wh[k,2*i:2*i+2] = ct-pts[i,:]
102
+
103
+ ret = {'input': image,
104
+ 'hm': hm,
105
+ 'ind': ind,
106
+ 'reg': reg,
107
+ 'wh': wh,
108
+ 'reg_mask': reg_mask,
109
+ }
110
+
111
+ return ret
112
+
113
+ # def filter_pts(pts, w, h):
114
+ # pts_new = []
115
+ # for pt in pts:
116
+ # if any(pt) < 0 or pt[0] > w - 1 or pt[1] > h - 1:
117
+ # continue
118
+ # else:
119
+ # pts_new.append(pt)
120
+ # return np.asarray(pts_new, np.float32)
121
+
122
+
123
+ def processing_train(image, pts, image_h, image_w, down_ratio, aug_label, img_id):
124
+ # filter pts ----------------------------------------------------
125
+ h,w,c = image.shape
126
+ # pts = filter_pts(pts, w, h)
127
+ # ---------------------------------------------------------------
128
+ data_aug = {'train': transform.Compose([transform.ConvertImgFloat(),
129
+ transform.PhotometricDistort(),
130
+ transform.Expand(max_scale=1.5, mean=(0, 0, 0)),
131
+ transform.RandomMirror_w(),
132
+ transform.Resize(h=image_h, w=image_w)]),
133
+ 'val': transform.Compose([transform.ConvertImgFloat(),
134
+ transform.Resize(h=image_h, w=image_w)])}
135
+ if aug_label:
136
+ out_image, pts = data_aug['train'](image.copy(), pts)
137
+ else:
138
+ out_image, pts = data_aug['val'](image.copy(), pts)
139
+
140
+ out_image = np.clip(out_image, a_min=0., a_max=255.)
141
+ out_image = np.transpose(out_image / 255. - 0.5, (2,0,1))
142
+ pts = rearrange_pts(pts)
143
+ pts2 = transform.rescale_pts(pts, down_ratio=down_ratio)
144
+
145
+ return np.asarray(out_image, np.float32), pts2
146
+
test.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from models import spinal_net
4
+ import cv2
5
+ import decoder
6
+ import os
7
+ from dataset import BaseDataset
8
+ import draw_points
9
+
10
+ def apply_mask(image, mask, alpha=0.5):
11
+ """Apply the given mask to the image.
12
+ """
13
+ color = np.random.rand(3)
14
+ for c in range(3):
15
+ image[:, :, c] = np.where(mask == 1,
16
+ image[:, :, c] *
17
+ (1 - alpha) + alpha * color[c] * 255,
18
+ image[:, :, c])
19
+ return image
20
+
21
+ class Network(object):
22
+ def __init__(self, args):
23
+ torch.manual_seed(317)
24
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
25
+ heads = {'hm': args.num_classes,
26
+ 'reg': 2*args.num_classes,
27
+ 'wh': 2*4,}
28
+
29
+ self.model = spinal_net.SpineNet(heads=heads,
30
+ pretrained=True,
31
+ down_ratio=args.down_ratio,
32
+ final_kernel=1,
33
+ head_conv=256)
34
+ self.num_classes = args.num_classes
35
+ self.decoder = decoder.DecDecoder(K=args.K, conf_thresh=args.conf_thresh)
36
+ self.dataset = {'spinal': BaseDataset}
37
+
38
+ def load_model(self, model, resume):
39
+ checkpoint = torch.load(resume, map_location=lambda storage, loc: storage)
40
+ print('loaded weights from {}, epoch {}'.format(resume, checkpoint['epoch']))
41
+ state_dict_ = checkpoint['state_dict']
42
+ model.load_state_dict(state_dict_, strict=False)
43
+ return model
44
+
45
+ def map_mask_to_image(self, mask, img, color=None):
46
+ if color is None:
47
+ color = np.random.rand(3)
48
+ mask = np.repeat(mask[:, :, np.newaxis], 3, axis=2)
49
+ mskd = img * mask
50
+ clmsk = np.ones(mask.shape) * mask
51
+ clmsk[:, :, 0] = clmsk[:, :, 0] * color[0] * 256
52
+ clmsk[:, :, 1] = clmsk[:, :, 1] * color[1] * 256
53
+ clmsk[:, :, 2] = clmsk[:, :, 2] * color[2] * 256
54
+ img = img + 1. * clmsk - 1. * mskd
55
+ return np.uint8(img)
56
+
57
+
58
+ def test(self, args, save):
59
+ save_path = 'weights_'+args.dataset
60
+ self.model = self.load_model(self.model, os.path.join(save_path, args.resume))
61
+ self.model = self.model.to(self.device)
62
+ self.model.eval()
63
+
64
+ dataset_module = self.dataset[args.dataset]
65
+ dsets = dataset_module(data_dir=args.data_dir,
66
+ phase='test',
67
+ input_h=args.input_h,
68
+ input_w=args.input_w,
69
+ down_ratio=args.down_ratio)
70
+
71
+ data_loader = torch.utils.data.DataLoader(dsets,
72
+ batch_size=1,
73
+ shuffle=False,
74
+ num_workers=1,
75
+ pin_memory=True)
76
+
77
+
78
+ for cnt, data_dict in enumerate(data_loader):
79
+ images = data_dict['images'][0]
80
+ img_id = data_dict['img_id'][0]
81
+ images = images.to('cuda')
82
+ print('processing {}/{} image ... {}'.format(cnt, len(data_loader), img_id))
83
+ with torch.no_grad():
84
+ output = self.model(images)
85
+ hm = output['hm']
86
+ wh = output['wh']
87
+ reg = output['reg']
88
+
89
+ torch.cuda.synchronize(self.device)
90
+ pts2 = self.decoder.ctdet_decode(hm, wh, reg) # 17, 11
91
+ pts0 = pts2.copy()
92
+ pts0[:,:10] *= args.down_ratio
93
+
94
+ print('totol pts num is {}'.format(len(pts2)))
95
+
96
+ ori_image = dsets.load_image(dsets.img_ids.index(img_id))
97
+ ori_image_regress = cv2.resize(ori_image, (args.input_w, args.input_h))
98
+ ori_image_points = ori_image_regress.copy()
99
+
100
+ h,w,c = ori_image.shape
101
+ pts0 = np.asarray(pts0, np.float32)
102
+ # pts0[:,0::2] = pts0[:,0::2]/args.input_w*w
103
+ # pts0[:,1::2] = pts0[:,1::2]/args.input_h*h
104
+ sort_ind = np.argsort(pts0[:,1])
105
+ pts0 = pts0[sort_ind]
106
+
107
+ ori_image_regress, ori_image_points = draw_points.draw_landmarks_regress_test(pts0,
108
+ ori_image_regress,
109
+ ori_image_points)
110
+
111
+ if save:
112
+ # 1) กำหนดโฟลเดอร์ผลลัพธ์
113
+ save_dir = os.path.join('results_'+args.dataset)
114
+ os.makedirs(save_dir, exist_ok=True)
115
+
116
+ # 2) บันทึกพิกัดลง .txt
117
+ txt_path = os.path.join(save_dir, f'{img_id}.txt')
118
+ # สมมติ pts0 เป็น array shape (N,2) หรือ (N,4) ตามที่คุณอยากบันทึก
119
+ np.savetxt(txt_path, pts0, fmt='%.4f')
120
+
121
+ # 3) บันทึกภาพ overlay
122
+ img_path = os.path.join(save_dir, f'{img_id}_pred.jpg')
123
+ cv2.imwrite(img_path, ori_image_points)
train.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import os
4
+ import numpy as np
5
+ from models import spinal_net
6
+ import decoder
7
+ import loss
8
+ from dataset import BaseDataset
9
+
10
+ def collater(data):
11
+ out_data_dict = {}
12
+ for name in data[0]:
13
+ out_data_dict[name] = []
14
+ for sample in data:
15
+ for name in sample:
16
+ out_data_dict[name].append(torch.from_numpy(sample[name]))
17
+ for name in out_data_dict:
18
+ out_data_dict[name] = torch.stack(out_data_dict[name], dim=0)
19
+ return out_data_dict
20
+
21
+ class Network(object):
22
+ def __init__(self, args):
23
+ torch.manual_seed(317)
24
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
25
+ heads = {'hm': args.num_classes,
26
+ 'reg': 2*args.num_classes,
27
+ 'wh': 2*4,}
28
+
29
+ self.model = spinal_net.SpineNet(heads=heads,
30
+ pretrained=True,
31
+ down_ratio=args.down_ratio,
32
+ final_kernel=1,
33
+ head_conv=256)
34
+ self.num_classes = args.num_classes
35
+ self.decoder = decoder.DecDecoder(K=args.K, conf_thresh=args.conf_thresh)
36
+ self.dataset = {'spinal': BaseDataset}
37
+
38
+
39
+ def save_model(self, path, epoch, model):
40
+ if isinstance(model, torch.nn.DataParallel):
41
+ state_dict = model.module.state_dict()
42
+ else:
43
+ state_dict = model.state_dict()
44
+ data = {'epoch': epoch, 'state_dict': state_dict}
45
+ torch.save(data, path)
46
+
47
+ def load_model(self, model, resume, strict=True):
48
+ checkpoint = torch.load(resume, map_location=lambda storage, loc: storage)
49
+ print('loaded weights from {}, epoch {}'.format(resume, checkpoint['epoch']))
50
+ state_dict_ = checkpoint['state_dict']
51
+ state_dict = {}
52
+
53
+ for k in state_dict_:
54
+ if k.startswith('module') and not k.startswith('module_list'):
55
+ state_dict[k[7:]] = state_dict_[k]
56
+ else:
57
+ state_dict[k] = state_dict_[k]
58
+ model_state_dict = model.state_dict()
59
+
60
+ if not strict:
61
+ for k in state_dict:
62
+ if k in model_state_dict:
63
+ if state_dict[k].shape != model_state_dict[k].shape:
64
+ print('Skip loading parameter {}, required shape{}, ' \
65
+ 'loaded shape{}.'.format(k, model_state_dict[k].shape, state_dict[k].shape))
66
+ state_dict[k] = model_state_dict[k]
67
+ else:
68
+ print('Drop parameter {}.'.format(k))
69
+ for k in model_state_dict:
70
+ if not (k in state_dict):
71
+ print('No param {}.'.format(k))
72
+ state_dict[k] = model_state_dict[k]
73
+ model.load_state_dict(state_dict, strict=False)
74
+ return model
75
+
76
+ def train_network(self, args):
77
+ save_path = 'weights_'+args.dataset
78
+ if not os.path.exists(save_path):
79
+ os.mkdir(save_path)
80
+ self.optimizer = torch.optim.Adam(self.model.parameters(), args.init_lr)
81
+ scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=0.96, last_epoch=-1)
82
+ if args.ngpus>0:
83
+ if torch.cuda.device_count() > 1:
84
+ print("Let's use", torch.cuda.device_count(), "GPUs!")
85
+ self.model = nn.DataParallel(self.model)
86
+
87
+ self.model.to(self.device)
88
+
89
+ criterion = loss.LossAll()
90
+ print('Setting up data...')
91
+
92
+ dataset_module = self.dataset[args.dataset]
93
+
94
+ dsets = {x: dataset_module(data_dir=args.data_dir,
95
+ phase=x,
96
+ input_h=args.input_h,
97
+ input_w=args.input_w,
98
+ down_ratio=args.down_ratio)
99
+ for x in ['train', 'val']}
100
+
101
+ dsets_loader = {'train': torch.utils.data.DataLoader(dsets['train'],
102
+ batch_size=args.batch_size,
103
+ shuffle=True,
104
+ num_workers=args.num_workers,
105
+ pin_memory=True,
106
+ drop_last=True,
107
+ collate_fn=collater),
108
+
109
+ 'val':torch.utils.data.DataLoader(dsets['val'],
110
+ batch_size=1,
111
+ shuffle=False,
112
+ num_workers=1,
113
+ pin_memory=True,
114
+ collate_fn=collater)}
115
+
116
+
117
+ print('Starting training...')
118
+ train_loss = []
119
+ val_loss = []
120
+ for epoch in range(1, args.num_epoch+1):
121
+ print('-'*10)
122
+ print('Epoch: {}/{} '.format(epoch, args.num_epoch))
123
+ epoch_loss = self.run_epoch(phase='train',
124
+ data_loader=dsets_loader['train'],
125
+ criterion=criterion)
126
+ train_loss.append(epoch_loss)
127
+ scheduler.step(epoch)
128
+
129
+ epoch_loss = self.run_epoch(phase='val',
130
+ data_loader=dsets_loader['val'],
131
+ criterion=criterion)
132
+ val_loss.append(epoch_loss)
133
+
134
+ np.savetxt(os.path.join(save_path, 'train_loss.txt'), train_loss, fmt='%.6f')
135
+ np.savetxt(os.path.join(save_path, 'val_loss.txt'), val_loss, fmt='%.6f')
136
+
137
+ if epoch % 10 == 0 or epoch ==1:
138
+ self.save_model(os.path.join(save_path, 'model_{}.pth'.format(epoch)), epoch, self.model)
139
+
140
+ if len(val_loss)>1:
141
+ if val_loss[-1]<np.min(val_loss[:-1]):
142
+ self.save_model(os.path.join(save_path, 'model_last.pth'), epoch, self.model)
143
+
144
+ def run_epoch(self, phase, data_loader, criterion):
145
+ if phase == 'train':
146
+ self.model.train()
147
+ else:
148
+ self.model.eval()
149
+ running_loss = 0.
150
+ for data_dict in data_loader:
151
+ for name in data_dict:
152
+ data_dict[name] = data_dict[name].to(device=self.device)
153
+ if phase == 'train':
154
+ self.optimizer.zero_grad()
155
+ with torch.enable_grad():
156
+ pr_decs = self.model(data_dict['input'])
157
+ loss = criterion(pr_decs, data_dict)
158
+ loss.backward()
159
+ self.optimizer.step()
160
+ else:
161
+ with torch.no_grad():
162
+ pr_decs = self.model(data_dict['input'])
163
+ loss = criterion(pr_decs, data_dict)
164
+
165
+ running_loss += loss.item()
166
+ epoch_loss = running_loss / len(data_loader)
167
+ print('{} loss: {}'.format(phase, epoch_loss))
168
+ return epoch_loss
169
+
transform.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ from numpy import random
4
+ import cv2
5
+
6
+
7
+ def rescale_pts(pts, down_ratio):
8
+ return np.asarray(pts, np.float32)/float(down_ratio)
9
+
10
+
11
+ class Compose(object):
12
+ def __init__(self, transforms):
13
+ self.transforms = transforms
14
+
15
+ def __call__(self, img, pts):
16
+ for t in self.transforms:
17
+ img, pts = t(img, pts)
18
+ return img, pts
19
+
20
+ class ConvertImgFloat(object):
21
+ def __call__(self, img, pts):
22
+ return img.astype(np.float32), pts.astype(np.float32)
23
+
24
+ class RandomContrast(object):
25
+ def __init__(self, lower=0.5, upper=1.5):
26
+ self.lower = lower
27
+ self.upper = upper
28
+ assert self.upper >= self.lower, "contrast upper must be >= lower."
29
+ assert self.lower >= 0, "contrast lower must be non-negative."
30
+
31
+ def __call__(self, img, pts):
32
+ if random.randint(2):
33
+ alpha = random.uniform(self.lower, self.upper)
34
+ img *= alpha
35
+ return img, pts
36
+
37
+
38
+ class RandomBrightness(object):
39
+ def __init__(self, delta=32):
40
+ assert delta >= 0.0
41
+ assert delta <= 255.0
42
+ self.delta = delta
43
+
44
+ def __call__(self, img, pts):
45
+ if random.randint(2):
46
+ delta = random.uniform(-self.delta, self.delta)
47
+ img += delta
48
+ return img, pts
49
+
50
+ class SwapChannels(object):
51
+ def __init__(self, swaps):
52
+ self.swaps = swaps
53
+ def __call__(self, img):
54
+ img = img[:, :, self.swaps]
55
+ return img
56
+
57
+
58
+ class RandomLightingNoise(object):
59
+ def __init__(self):
60
+ self.perms = ((0, 1, 2), (0, 2, 1),
61
+ (1, 0, 2), (1, 2, 0),
62
+ (2, 0, 1), (2, 1, 0))
63
+ def __call__(self, img, pts):
64
+ if random.randint(2):
65
+ swap = self.perms[random.randint(len(self.perms))]
66
+ shuffle = SwapChannels(swap)
67
+ img = shuffle(img)
68
+ return img, pts
69
+
70
+
71
+ class PhotometricDistort(object):
72
+ def __init__(self):
73
+ self.pd = RandomContrast()
74
+ self.rb = RandomBrightness()
75
+ self.rln = RandomLightingNoise()
76
+
77
+ def __call__(self, img, pts):
78
+ img, pts = self.rb(img, pts)
79
+ if random.randint(2):
80
+ distort = self.pd
81
+ else:
82
+ distort = self.pd
83
+ img, pts = distort(img, pts)
84
+ img, pts = self.rln(img, pts)
85
+ return img, pts
86
+
87
+
88
+ class Expand(object):
89
+ def __init__(self, max_scale = 1.5, mean = (0.5, 0.5, 0.5)):
90
+ self.mean = mean
91
+ self.max_scale = max_scale
92
+
93
+ def __call__(self, img, pts):
94
+ if random.randint(2):
95
+ return img, pts
96
+ h,w,c = img.shape
97
+ ratio = random.uniform(1,self.max_scale)
98
+ y1 = random.uniform(0, h*ratio-h)
99
+ x1 = random.uniform(0, w*ratio-w)
100
+ if np.max(pts[:,0])+int(x1)>w-1 or np.max(pts[:,1])+int(y1)>h-1: # keep all the pts
101
+ return img, pts
102
+ else:
103
+ expand_img = np.zeros(shape=(int(h*ratio), int(w*ratio),c),dtype=img.dtype)
104
+ expand_img[:,:,:] = self.mean
105
+ expand_img[int(y1):int(y1+h), int(x1):int(x1+w)] = img
106
+ pts[:, 0] += int(x1)
107
+ pts[:, 1] += int(y1)
108
+ return expand_img, pts
109
+
110
+
111
+ class RandomSampleCrop(object):
112
+ def __init__(self, ratio=(0.5, 1.5), min_win = 0.9):
113
+ self.sample_options = (
114
+ # using entire original input image
115
+ None,
116
+ # sample a patch s.t. MIN jaccard w/ obj in .1,.3,.4,.7,.9
117
+ # (0.1, None),
118
+ # (0.3, None),
119
+ (0.7, None),
120
+ (0.9, None),
121
+ # randomly sample a patch
122
+ (None, None),
123
+ )
124
+ self.ratio = ratio
125
+ self.min_win = min_win
126
+
127
+ def __call__(self, img, pts):
128
+ height, width ,_ = img.shape
129
+ while True:
130
+ mode = random.choice(self.sample_options)
131
+ if mode is None:
132
+ return img, pts
133
+ for _ in range(50):
134
+ current_img = img
135
+ current_pts = pts
136
+ w = random.uniform(self.min_win*width, width)
137
+ h = random.uniform(self.min_win*height, height)
138
+ if h/w<self.ratio[0] or h/w>self.ratio[1]:
139
+ continue
140
+ y1 = random.uniform(height-h)
141
+ x1 = random.uniform(width-w)
142
+ rect = np.array([int(y1), int(x1), int(y1+h), int(x1+w)])
143
+ current_img = current_img[rect[0]:rect[2], rect[1]:rect[3], :]
144
+ current_pts[:, 0] -= rect[1]
145
+ current_pts[:, 1] -= rect[0]
146
+ pts_new = []
147
+ for pt in current_pts:
148
+ if any(pt)<0 or pt[0]>current_img.shape[1]-1 or pt[1]>current_img.shape[0]-1:
149
+ continue
150
+ else:
151
+ pts_new.append(pt)
152
+
153
+ return current_img, np.asarray(pts_new, np.float32)
154
+
155
+ class RandomMirror_w(object):
156
+ def __call__(self, img, pts):
157
+ _,w,_ = img.shape
158
+ if random.randint(2):
159
+ img = img[:,::-1,:]
160
+ pts[:,0] = w-pts[:,0]
161
+ return img, pts
162
+
163
+ class RandomMirror_h(object):
164
+ def __call__(self, img, pts):
165
+ h,_,_ = img.shape
166
+ if random.randint(2):
167
+ img = img[::-1,:,:]
168
+ pts[:,1] = h-pts[:,1]
169
+ return img, pts
170
+
171
+
172
+ class Resize(object):
173
+ def __init__(self, h, w):
174
+ self.dsize = (w,h)
175
+
176
+ def __call__(self, img, pts):
177
+ h,w,c = img.shape
178
+ pts[:, 0] = pts[:, 0]/w*self.dsize[0]
179
+ pts[:, 1] = pts[:, 1]/h*self.dsize[1]
180
+ img = cv2.resize(img, dsize=self.dsize)
181
+ return img, np.asarray(pts)