Upload 30 files
Browse files- .gitattributes +6 -0
- CX0075_png.rf.f86acbac9d6c41151e8caed4914a3e89.jpg +3 -0
- DP1983_png.rf.3f2a58f7f0feb4f9ad7b34149149553b.jpg +3 -0
- HS1500_png.rf.8659b481c780f6b582532eb56d6f5349.jpg +3 -0
- LICENSE +21 -0
- README.md +57 -12
- app.py +399 -0
- best.pt +3 -0
- cobb_evaluate.py +124 -0
- dataset.py +88 -0
- decoder.py +77 -0
- dev_1.jpg +0 -0
- dev_2.jpg +0 -0
- dev_3.jpg +3 -0
- draw_gaussian.py +49 -0
- draw_loss.py +64 -0
- draw_points.py +92 -0
- eval.py +222 -0
- image_1.jpg +3 -0
- image_1_la.jpg +0 -0
- image_2.jpg +0 -0
- image_2_la.jpg +0 -0
- image_3.jpg +3 -0
- image_3_la.jpg +0 -0
- loss.py +69 -0
- main.py +40 -0
- make_requirements.py +73 -0
- pre_proc.py +146 -0
- test.py +123 -0
- train.py +169 -0
- transform.py +181 -0
.gitattributes
CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
CX0075_png.rf.f86acbac9d6c41151e8caed4914a3e89.jpg filter=lfs diff=lfs merge=lfs -text
|
37 |
+
dev_3.jpg filter=lfs diff=lfs merge=lfs -text
|
38 |
+
DP1983_png.rf.3f2a58f7f0feb4f9ad7b34149149553b.jpg filter=lfs diff=lfs merge=lfs -text
|
39 |
+
HS1500_png.rf.8659b481c780f6b582532eb56d6f5349.jpg filter=lfs diff=lfs merge=lfs -text
|
40 |
+
image_1.jpg filter=lfs diff=lfs merge=lfs -text
|
41 |
+
image_3.jpg filter=lfs diff=lfs merge=lfs -text
|
CX0075_png.rf.f86acbac9d6c41151e8caed4914a3e89.jpg
ADDED
![]() |
Git LFS Details
|
DP1983_png.rf.3f2a58f7f0feb4f9ad7b34149149553b.jpg
ADDED
![]() |
Git LFS Details
|
HS1500_png.rf.8659b481c780f6b582532eb56d6f5349.jpg
ADDED
![]() |
Git LFS Details
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2020 yijingru
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,12 +1,57 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Vertebra-Focused-Landmark-Detection-Pytorch
|
2 |
+
Vertebra-Focused Landmark Detection for Scoliosis Assessment [[arXiv](https://arxiv.org/pdf/2001.03187.pdf)]
|
3 |
+
|
4 |
+
Accepted to ISBI2020.
|
5 |
+
|
6 |
+
|
7 |
+
Please cite the article in your publications if it helps your research:
|
8 |
+
|
9 |
+
@article{yi2020vertebra,
|
10 |
+
title={Vertebra-Focused Landmark Detection for Scoliosis Assessment},
|
11 |
+
author={Yi, Jingru and Wu, Pengxiang and Huang, Qiaoying and Qu, Hui and Metaxas, Dimitris N},
|
12 |
+
booktitle={ISBI},
|
13 |
+
year={2020}
|
14 |
+
}
|
15 |
+
|
16 |
+
|
17 |
+
<p align="center">
|
18 |
+
<img src="imgs/pic1.png", width="400">
|
19 |
+
</p>
|
20 |
+
|
21 |
+
<p align="center">
|
22 |
+
<img src="imgs/pic2.png", width="800">
|
23 |
+
</p>
|
24 |
+
|
25 |
+
# Dependencies
|
26 |
+
Ubuntu 14.04, Python 3.6.4, PyTorch 1.1.0, OpenCV-Python 4.1.0.25
|
27 |
+
|
28 |
+
# How to start
|
29 |
+
## Prepare Dataset
|
30 |
+
To directly use dataset.py, you can arrange the dataset as follows:
|
31 |
+
```
|
32 |
+
/dataPath/data
|
33 |
+
/train/*.jpg
|
34 |
+
/val/*.jpg
|
35 |
+
/test/*.jpg
|
36 |
+
/dataPath/labels/
|
37 |
+
/train/*.mat
|
38 |
+
/val/*.mat
|
39 |
+
/test/*.mat
|
40 |
+
```
|
41 |
+
The source dataset is from [[dataset16](http://spineweb.digitalimaginggroup.ca/spineweb/index.php?n=Main.Datasets#Dataset_16.3A_609_spinal_anterior-posterior_x-ray_images)].
|
42 |
+
To adapt the code to your own dataset, you can modify the dataset.py, for example, change the 'load_gt_pts' function to adapt it to your own annotations. The pretrained weights can be downloaded [here](https://drive.google.com/drive/folders/1LhKnGVE8dUw0nK9_x4vPNY_L7sPY2_aQ?usp=sharing).
|
43 |
+
|
44 |
+
## Train the model
|
45 |
+
```ruby
|
46 |
+
python main.py --data_dir dataPath --epochs 50 --batch_size 2 --dataset spinal --phase train
|
47 |
+
```
|
48 |
+
|
49 |
+
## Test the model
|
50 |
+
```ruby
|
51 |
+
python main.py --resume weightPath --data_dir dataPath --dataset spinal --phase test
|
52 |
+
```
|
53 |
+
|
54 |
+
|
55 |
+
## Evaluate the model
|
56 |
+
```ruby
|
57 |
+
python main.py --resume weightPath --data_dir dataPath --dataset spinal --phase eval
|
app.py
ADDED
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import shutil
|
4 |
+
import importlib.util
|
5 |
+
from io import BytesIO
|
6 |
+
from ultralytics import YOLO
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
import torch
|
10 |
+
# ─── FORCE CPU ONLY ─────────────────────────────────────────────────────────
|
11 |
+
torch.Tensor.cuda = lambda self, *args, **kwargs: self
|
12 |
+
torch.nn.Module.cuda = lambda self, *args, **kwargs: self
|
13 |
+
torch.cuda.synchronize = lambda *args, **kwargs: None
|
14 |
+
torch.cuda.is_available= lambda : False
|
15 |
+
torch.cuda.device_count= lambda : 0
|
16 |
+
_orig_to = torch.Tensor.to
|
17 |
+
def _to_cpu(self, *args, **kwargs):
|
18 |
+
new_args = []
|
19 |
+
for a in args:
|
20 |
+
if isinstance(a, str) and a.lower().startswith("cuda"):
|
21 |
+
new_args.append("cpu")
|
22 |
+
elif isinstance(a, torch.device) and a.type=="cuda":
|
23 |
+
new_args.append(torch.device("cpu"))
|
24 |
+
else:
|
25 |
+
new_args.append(a)
|
26 |
+
if "device" in kwargs:
|
27 |
+
dev = kwargs["device"]
|
28 |
+
if (isinstance(dev, str) and dev.lower().startswith("cuda")) or \
|
29 |
+
(isinstance(dev, torch.device) and dev.type=="cuda"):
|
30 |
+
kwargs["device"] = torch.device("cpu")
|
31 |
+
return _orig_to(self, *new_args, **kwargs)
|
32 |
+
torch.Tensor.to = _to_cpu
|
33 |
+
|
34 |
+
from torch.utils.data import DataLoader as _DL
|
35 |
+
def _dl0(ds, *a, **kw):
|
36 |
+
kw['num_workers'] = 0
|
37 |
+
return _DL(ds, *a, **kw)
|
38 |
+
import torch.utils.data as _du
|
39 |
+
_du.DataLoader = _dl0
|
40 |
+
|
41 |
+
import cv2
|
42 |
+
import numpy as np
|
43 |
+
import streamlit as st
|
44 |
+
from argparse import Namespace
|
45 |
+
|
46 |
+
# ─── DYNAMIC IMPORT ─────────────────────────────────────────────────────────
|
47 |
+
REPO = os.path.dirname(os.path.abspath(__file__))
|
48 |
+
sys.path.append(REPO)
|
49 |
+
models_dir = os.path.join(REPO, "models")
|
50 |
+
os.makedirs(models_dir, exist_ok=True)
|
51 |
+
open(os.path.join(models_dir, "__init__.py"), "a").close()
|
52 |
+
|
53 |
+
def load_mod(name, path):
|
54 |
+
spec = importlib.util.spec_from_file_location(name, path)
|
55 |
+
m = importlib.util.module_from_spec(spec)
|
56 |
+
spec.loader.exec_module(m)
|
57 |
+
sys.modules[name] = m
|
58 |
+
return m
|
59 |
+
|
60 |
+
dataset_mod = load_mod("dataset", os.path.join(REPO, "dataset.py"))
|
61 |
+
decoder_mod = load_mod("decoder", os.path.join(REPO, "decoder.py"))
|
62 |
+
draw_mod = load_mod("draw_points", os.path.join(REPO, "draw_points.py"))
|
63 |
+
test_mod = load_mod("test", os.path.join(REPO, "test.py"))
|
64 |
+
load_mod("models.dec_net", os.path.join(models_dir, "dec_net.py"))
|
65 |
+
load_mod("models.model_parts", os.path.join(models_dir, "model_parts.py"))
|
66 |
+
load_mod("models.resnet", os.path.join(models_dir, "resnet.py"))
|
67 |
+
load_mod("models.spinal_net", os.path.join(models_dir, "spinal_net.py"))
|
68 |
+
|
69 |
+
BaseDataset = dataset_mod.BaseDataset
|
70 |
+
Network = test_mod.Network
|
71 |
+
|
72 |
+
# ─── STREAMLIT UI ───────────────────────────────────────────────────────────
|
73 |
+
st.set_page_config(layout="wide", page_title="Vertebral Compression Fracture")
|
74 |
+
|
75 |
+
st.markdown(
|
76 |
+
"""
|
77 |
+
<div style='border: 2px solid #0080FF; border-radius: 5px; padding: 10px'>
|
78 |
+
<h1 style='text-align: center; color: #0080FF'>
|
79 |
+
🦴 Vertebral Compression Fracture Detection 🖼️
|
80 |
+
</h1>
|
81 |
+
</div>
|
82 |
+
""", unsafe_allow_html=True)
|
83 |
+
st.markdown("")
|
84 |
+
st.markdown("")
|
85 |
+
st.markdown("")
|
86 |
+
col1, col2, col3, col4 = st.columns(4)
|
87 |
+
|
88 |
+
with col4:
|
89 |
+
feature = st.selectbox(
|
90 |
+
"🔀 Select Feature",
|
91 |
+
["How to use", "AP - Detection", "AP - Cobb angle" , "LA - Image Segmetation", "Contract"],
|
92 |
+
index=0, # default to "AP"
|
93 |
+
help="Choose which view to display"
|
94 |
+
)
|
95 |
+
|
96 |
+
if feature == "How to use":
|
97 |
+
st.markdown("## 📖 How to use this app")
|
98 |
+
|
99 |
+
col1, col2, col3 = st.columns(3)
|
100 |
+
|
101 |
+
with col1:
|
102 |
+
st.markdown(
|
103 |
+
"""
|
104 |
+
<div style='border:2px solid #00BFFF; border-radius:10px; padding:15px; text-align:center; background-color:#F0F8FF'>
|
105 |
+
<h2>Step 1️⃣</h2>
|
106 |
+
<p>Go to <b>AP - Detection</b> or <b>LA - Image Segmentation</b></p>
|
107 |
+
<p>Select a sample image or upload your own image file.</p>
|
108 |
+
<p style='color:#008000;'><b>✅ Tip:</b> Best with X-ray images with clear vertebra visibility.</p>
|
109 |
+
</div>
|
110 |
+
""",
|
111 |
+
unsafe_allow_html=True
|
112 |
+
)
|
113 |
+
|
114 |
+
with col2:
|
115 |
+
st.markdown(
|
116 |
+
"""
|
117 |
+
<div style='border:2px solid #00BFFF; border-radius:10px; padding:15px; text-align:center; background-color:#F0F8FF'>
|
118 |
+
<h2>Step 2️⃣</h2>
|
119 |
+
<p>Press the <b>Enter</b> button.</p>
|
120 |
+
<p>The system will process your image automatically.</p>
|
121 |
+
<p style='color:#FFA500;'><b>⏳ Note:</b> Processing time depends on image size.</p>
|
122 |
+
</div>
|
123 |
+
""",
|
124 |
+
unsafe_allow_html=True
|
125 |
+
)
|
126 |
+
|
127 |
+
with col3:
|
128 |
+
st.markdown(
|
129 |
+
"""
|
130 |
+
<div style='border:2px solid #00BFFF; border-radius:10px; padding:15px; text-align:center; background-color:#F0F8FF'>
|
131 |
+
<h2>Step 3️⃣</h2>
|
132 |
+
<p>See the prediction results:</p>
|
133 |
+
<p style= text-align:left > 1. Bounding boxes & landmarks (AP)</p>
|
134 |
+
<p style= text-align:left > 2. Segmentation masks (LA)</p>
|
135 |
+
</div>
|
136 |
+
""",
|
137 |
+
unsafe_allow_html=True
|
138 |
+
)
|
139 |
+
|
140 |
+
st.markdown(" ")
|
141 |
+
st.info("สามารถเลือกฟีเจอร์ได้ผ่าน Select Feature โดยแต่ล่ะฟีเจอร์จะมีตัวอย่างกำกับให้ว่าเป็นยังไง")
|
142 |
+
|
143 |
+
# store original dimensions
|
144 |
+
elif feature == "AP - Detection":
|
145 |
+
uploaded = st.file_uploader("", type=["jpg", "jpeg", "png"])
|
146 |
+
orig_w = orig_h = None
|
147 |
+
img0 = None
|
148 |
+
run = st.button("Enter", use_container_width=True)
|
149 |
+
# ─── Maintain selected sample in session state ─────────
|
150 |
+
if "sample_img" not in st.session_state:
|
151 |
+
st.session_state.sample_img = None
|
152 |
+
|
153 |
+
# ─── SAMPLE BUTTONS ─────────────────────────────────────
|
154 |
+
with col1:
|
155 |
+
if st.button(" 1️⃣ Example",use_container_width=True):
|
156 |
+
st.session_state.sample_img = "image_1.jpg"
|
157 |
+
with col2:
|
158 |
+
if st.button(" 2️⃣ Example",use_container_width=True):
|
159 |
+
st.session_state.sample_img = "image_2.jpg"
|
160 |
+
with col3:
|
161 |
+
if st.button(" 3️⃣ Example",use_container_width=True):
|
162 |
+
st.session_state.sample_img = "image_3.jpg"
|
163 |
+
|
164 |
+
# ─── UI FOR UPLOAD + DISPLAY ───────────────────────────
|
165 |
+
col4, col5, col6 = st.columns(3)
|
166 |
+
with col4:
|
167 |
+
st.subheader("1️⃣ Upload & Run")
|
168 |
+
|
169 |
+
sample_img = st.session_state.sample_img # read persisted choice
|
170 |
+
|
171 |
+
# case 1: uploaded file
|
172 |
+
if uploaded:
|
173 |
+
buf = uploaded.getvalue()
|
174 |
+
arr = np.frombuffer(buf, np.uint8)
|
175 |
+
img0 = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
176 |
+
orig_h, orig_w = img0.shape[:2]
|
177 |
+
st.image(cv2.cvtColor(img0, cv2.COLOR_BGR2RGB), caption="Uploaded Image", use_container_width=True)
|
178 |
+
|
179 |
+
# case 2: selected sample image
|
180 |
+
elif sample_img is not None:
|
181 |
+
img_path = os.path.join(REPO, sample_img)
|
182 |
+
img0 = cv2.imread(img_path)
|
183 |
+
if img0 is not None:
|
184 |
+
orig_h, orig_w = img0.shape[:2]
|
185 |
+
st.image(cv2.cvtColor(img0, cv2.COLOR_BGR2RGB),
|
186 |
+
caption=f"Sample Image: {sample_img}",
|
187 |
+
use_container_width=True)
|
188 |
+
else:
|
189 |
+
st.error(f"Cannot find {sample_img} in directory!")
|
190 |
+
|
191 |
+
|
192 |
+
|
193 |
+
with col5:
|
194 |
+
st.subheader("2️⃣ Predictions")
|
195 |
+
with col6:
|
196 |
+
st.subheader("3️⃣ Heatmap")
|
197 |
+
|
198 |
+
# ─── ARGS & CHECKPOINT ─────────────────────────────────
|
199 |
+
args = Namespace(
|
200 |
+
resume="model_30.pth",
|
201 |
+
data_dir=os.path.join(REPO, "dataPath"),
|
202 |
+
dataset="spinal",
|
203 |
+
phase="test",
|
204 |
+
input_h=1024,
|
205 |
+
input_w=512,
|
206 |
+
down_ratio=4,
|
207 |
+
num_classes=1,
|
208 |
+
K=17,
|
209 |
+
conf_thresh=0.2,
|
210 |
+
)
|
211 |
+
weights_dir = os.path.join(REPO, "weights_spinal")
|
212 |
+
os.makedirs(weights_dir, exist_ok=True)
|
213 |
+
src_ckpt = os.path.join(REPO, "model_backup", args.resume)
|
214 |
+
dst_ckpt = os.path.join(weights_dir, args.resume)
|
215 |
+
if os.path.isfile(src_ckpt) and not os.path.isfile(dst_ckpt):
|
216 |
+
shutil.copy(src_ckpt, dst_ckpt)
|
217 |
+
|
218 |
+
# ─── MAIN LOGIC ────────────────────────────────────────
|
219 |
+
if img0 is not None and run and orig_w and orig_h:
|
220 |
+
# determine name for saving
|
221 |
+
if uploaded:
|
222 |
+
name = os.path.splitext(uploaded.name)[0] + ".jpg"
|
223 |
+
else:
|
224 |
+
name = os.path.splitext(sample_img)[0] + ".jpg"
|
225 |
+
|
226 |
+
testd = os.path.join(args.data_dir, "data", "test")
|
227 |
+
os.makedirs(testd, exist_ok=True)
|
228 |
+
cv2.imwrite(os.path.join(testd, name), img0)
|
229 |
+
|
230 |
+
orig_init = BaseDataset.__init__
|
231 |
+
def patched_init(self, data_dir, phase, input_h=None, input_w=None, down_ratio=4):
|
232 |
+
orig_init(self, data_dir, phase, input_h, input_w, down_ratio)
|
233 |
+
if phase == "test":
|
234 |
+
self.img_ids = [name]
|
235 |
+
BaseDataset.__init__ = patched_init
|
236 |
+
|
237 |
+
with st.spinner("Running model…"):
|
238 |
+
net = Network(args)
|
239 |
+
net.test(args, save=True)
|
240 |
+
|
241 |
+
out_dir = os.path.join(REPO, f"results_{args.dataset}")
|
242 |
+
pred_file = [f for f in os.listdir(out_dir)
|
243 |
+
if f.startswith(name) and f.endswith("_pred.jpg")][0]
|
244 |
+
txtf = os.path.join(out_dir, f"{name}.txt")
|
245 |
+
imgf = os.path.join(out_dir, pred_file)
|
246 |
+
|
247 |
+
# ─── Annotated Predictions ─────────────────────────
|
248 |
+
base = cv2.imread(imgf)
|
249 |
+
txt = np.loadtxt(txtf)
|
250 |
+
tlx, tly = txt[:, 2].astype(int), txt[:, 3].astype(int)
|
251 |
+
trx, try_ = txt[:, 4].astype(int), txt[:, 5].astype(int)
|
252 |
+
blx, bly = txt[:, 6].astype(int), txt[:, 7].astype(int)
|
253 |
+
brx, bry = txt[:, 8].astype(int), txt[:, 9].astype(int)
|
254 |
+
|
255 |
+
top_pts, bot_pts, mids, dists = [], [], [], []
|
256 |
+
for (x1, y1), (x2, y2), (x3, y3), (x4, y4) in zip(
|
257 |
+
zip(tlx, tly), zip(trx, try_),
|
258 |
+
zip(blx, bly), zip(brx, bry)):
|
259 |
+
tm = np.array([(x1 + x2) / 2, (y1 + y2) / 2])
|
260 |
+
bm = np.array([(x3 + x4) / 2, (y3 + y4) / 2])
|
261 |
+
top_pts.append(tm)
|
262 |
+
bot_pts.append(bm)
|
263 |
+
mids.append((tm + bm) / 2)
|
264 |
+
dists.append(np.linalg.norm(bm - tm))
|
265 |
+
|
266 |
+
ref = dists[-1]
|
267 |
+
ann = base.copy()
|
268 |
+
for tm, bm in zip(top_pts, bot_pts):
|
269 |
+
cv2.line(ann, tuple(tm.astype(int)), tuple(bm.astype(int)), (0, 255, 255), 2)
|
270 |
+
for m, d in zip(mids, dists):
|
271 |
+
pct = (d - ref) / ref * 100
|
272 |
+
clr = (0, 255, 255) if pct <= 20 else (0, 165, 255) if pct <= 40 else (0, 0, 255)
|
273 |
+
pos = (int(m[0]) + 40, int(m[1]) + 5)
|
274 |
+
cv2.putText(ann, f"{pct:.0f}%", pos,
|
275 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.5, clr, 2, cv2.LINE_AA)
|
276 |
+
|
277 |
+
ann_resized = cv2.resize(ann, (orig_w, orig_h), interpolation=cv2.INTER_LINEAR)
|
278 |
+
with col5:
|
279 |
+
st.image(cv2.cvtColor(ann_resized, cv2.COLOR_BGR2RGB), use_container_width=True)
|
280 |
+
|
281 |
+
H, W = base.shape[:2]
|
282 |
+
heat = np.zeros((H, W), np.float32)
|
283 |
+
for cx, cy in [(int(m[0]), int(m[1])) for m in mids]:
|
284 |
+
blob = np.zeros_like(heat)
|
285 |
+
blob[cy, cx] = 1.0
|
286 |
+
heat += cv2.GaussianBlur(blob, (0, 0), sigmaX=8, sigmaY=8)
|
287 |
+
heat /= (heat.max() + 1e-8)
|
288 |
+
hm8 = (heat * 255).astype(np.uint8)
|
289 |
+
hm_c = cv2.applyColorMap(hm8, cv2.COLORMAP_JET)
|
290 |
+
|
291 |
+
raw = cv2.imread(imgf, cv2.IMREAD_GRAYSCALE)
|
292 |
+
raw_b = cv2.cvtColor(raw, cv2.COLOR_GRAY2BGR)
|
293 |
+
overlay = cv2.addWeighted(raw_b, 0.6, hm_c, 0.4, 0)
|
294 |
+
overlay_resized = cv2.resize(overlay, (orig_w, orig_h), interpolation=cv2.INTER_LINEAR)
|
295 |
+
|
296 |
+
with col6:
|
297 |
+
st.image(cv2.cvtColor(overlay_resized, cv2.COLOR_BGR2RGB), use_container_width=True)
|
298 |
+
|
299 |
+
elif feature == "AP - Cobb angle":
|
300 |
+
st.write("กำลังพัฒนา")
|
301 |
+
|
302 |
+
elif feature == "LA - Image Segmetation":
|
303 |
+
uploaded = st.file_uploader("", type=["jpg", "jpeg", "png"])
|
304 |
+
img0 = None
|
305 |
+
|
306 |
+
# ─── Maintain selected sample in session state ─────────
|
307 |
+
if "sample_img_la" not in st.session_state:
|
308 |
+
st.session_state.sample_img_la = None
|
309 |
+
|
310 |
+
# ─── SAMPLE BUTTONS ─────────────────────────────────────
|
311 |
+
with col1:
|
312 |
+
if st.button(" 1️⃣ Example ", use_container_width=True):
|
313 |
+
st.session_state.sample_img_la = "image_1_la.jpg"
|
314 |
+
with col2:
|
315 |
+
if st.button(" 2️⃣ Example ", use_container_width=True):
|
316 |
+
st.session_state.sample_img_la = "image_2_la.jpg"
|
317 |
+
with col3:
|
318 |
+
if st.button(" 3️⃣ Example ", use_container_width=True):
|
319 |
+
st.session_state.sample_img_la = "image_3_la.jpg"
|
320 |
+
|
321 |
+
# ─── UI FOR UPLOAD + DISPLAY ───────────────────────────
|
322 |
+
run_la = st.button("Enter", use_container_width=True)
|
323 |
+
col7, col8 = st.columns(2)
|
324 |
+
|
325 |
+
with col7:
|
326 |
+
st.subheader("🖼️ Original Image")
|
327 |
+
|
328 |
+
sample_img_la = st.session_state.sample_img_la # read persisted choice
|
329 |
+
|
330 |
+
# case 1: uploaded file
|
331 |
+
if uploaded:
|
332 |
+
buf = uploaded.getvalue()
|
333 |
+
img0 = Image.open(BytesIO(buf)).convert("RGB")
|
334 |
+
st.image(img0, caption="Uploaded Image", use_container_width=True)
|
335 |
+
|
336 |
+
# case 2: selected sample image
|
337 |
+
elif sample_img_la is not None:
|
338 |
+
img_path = os.path.join(REPO, sample_img_la)
|
339 |
+
if os.path.isfile(img_path):
|
340 |
+
img0 = Image.open(img_path).convert("RGB")
|
341 |
+
st.image(img0, caption=f"Sample Image: {sample_img_la}", use_container_width=True)
|
342 |
+
else:
|
343 |
+
st.error(f"Cannot find {sample_img_la} in directory!")
|
344 |
+
|
345 |
+
with col8:
|
346 |
+
st.subheader("🔎 Predicted Image")
|
347 |
+
|
348 |
+
# ─── PREDICTION ────────────────────────────────────
|
349 |
+
if img0 is not None and run_la:
|
350 |
+
img_np = np.array(img0)
|
351 |
+
model = YOLO('./best.pt') # or your correct path to best.pt
|
352 |
+
with st.spinner("Running YOLO model…"):
|
353 |
+
results = model(img_np, imgsz=640)
|
354 |
+
pred_img = results[0].plot(boxes=False, probs=False) # returns numpy image with annotations
|
355 |
+
st.image(pred_img, caption="Prediction Result", use_container_width=True)
|
356 |
+
|
357 |
+
elif feature == "Contract":
|
358 |
+
with col1:
|
359 |
+
st.image("dev_1.jpg", caption=None, use_container_width=True)
|
360 |
+
st.markdown(
|
361 |
+
"""
|
362 |
+
<div style='border:2px solid #0080FF; border-radius:10px; padding:15px; text-align:center; background-color:#F0F8FF'>
|
363 |
+
<h3>Thitsanapat Uma</h3>
|
364 |
+
<a href='https://www.facebook.com/thitsanapat.uma' target='_blank'>
|
365 |
+
🔗 Facebook Profile
|
366 |
+
</a>
|
367 |
+
</div>
|
368 |
+
""",
|
369 |
+
unsafe_allow_html=True
|
370 |
+
)
|
371 |
+
with col2:
|
372 |
+
st.image("dev_2.jpg", caption=None, use_container_width=True)
|
373 |
+
st.markdown(
|
374 |
+
"""
|
375 |
+
<div style='border:2px solid #0080FF; border-radius:10px; padding:15px; text-align:center; background-color:#F0F8FF'>
|
376 |
+
<h3>Santipab Tongchan</h3>
|
377 |
+
<a href='https://www.facebook.com/santipab.tongchan.2025' target='_blank'>
|
378 |
+
🔗 Facebook Profile
|
379 |
+
</a>
|
380 |
+
</div>
|
381 |
+
""",
|
382 |
+
unsafe_allow_html=True
|
383 |
+
)
|
384 |
+
with col3:
|
385 |
+
st.image("dev_3.jpg", caption=None, use_container_width=True)
|
386 |
+
st.markdown(
|
387 |
+
"""
|
388 |
+
<div style='border:2px solid #0080FF; border-radius:10px; padding:15px; text-align:center; background-color:#F0F8FF'>
|
389 |
+
<h3>Suphanat Kamphapan</h3>
|
390 |
+
<a href='https://www.facebook.com/suphanat.kamphapan' target='_blank'>
|
391 |
+
🔗 Facebook Profile
|
392 |
+
</a>
|
393 |
+
</div>
|
394 |
+
""",
|
395 |
+
unsafe_allow_html=True
|
396 |
+
)
|
397 |
+
|
398 |
+
|
399 |
+
|
best.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b51882a04c2f47922248e7c57c712e659984f55a95ec55a7050109e7ae61a401
|
3 |
+
size 55847450
|
cobb_evaluate.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
###########################################################################################
|
2 |
+
## This code is transfered from matlab version of the MICCAI challenge
|
3 |
+
## Oct 1 2019
|
4 |
+
###########################################################################################
|
5 |
+
import numpy as np
|
6 |
+
import cv2
|
7 |
+
|
8 |
+
|
9 |
+
def is_S(mid_p_v):
|
10 |
+
# mid_p_v: 34 x 2
|
11 |
+
ll = []
|
12 |
+
num = mid_p_v.shape[0]
|
13 |
+
for i in range(num-2):
|
14 |
+
term1 = (mid_p_v[i, 1]-mid_p_v[num-1, 1])/(mid_p_v[0, 1]-mid_p_v[num-1, 1])
|
15 |
+
term2 = (mid_p_v[i, 0]-mid_p_v[num-1, 0])/(mid_p_v[0, 0]-mid_p_v[num-1, 0])
|
16 |
+
ll.append(term1-term2)
|
17 |
+
ll = np.asarray(ll, np.float32)[:, np.newaxis] # 32 x 1
|
18 |
+
ll_pair = np.matmul(ll, np.transpose(ll)) # 32 x 32
|
19 |
+
a = sum(sum(ll_pair))
|
20 |
+
b = sum(sum(abs(ll_pair)))
|
21 |
+
if abs(a-b)<1e-4:
|
22 |
+
return False
|
23 |
+
else:
|
24 |
+
return True
|
25 |
+
|
26 |
+
def cobb_angle_calc(pts, image):
|
27 |
+
pts = np.asarray(pts, np.float32) # 68 x 2
|
28 |
+
h,w,c = image.shape
|
29 |
+
num_pts = pts.shape[0] # number of points, 68
|
30 |
+
vnum = num_pts//4-1
|
31 |
+
|
32 |
+
mid_p_v = (pts[0::2,:]+pts[1::2,:])/2 # 34 x 2
|
33 |
+
mid_p = []
|
34 |
+
for i in range(0, num_pts, 4):
|
35 |
+
pt1 = (pts[i,:]+pts[i+2,:])/2
|
36 |
+
pt2 = (pts[i+1,:]+pts[i+3,:])/2
|
37 |
+
mid_p.append(pt1)
|
38 |
+
mid_p.append(pt2)
|
39 |
+
mid_p = np.asarray(mid_p, np.float32) # 34 x 2
|
40 |
+
|
41 |
+
for pt in mid_p:
|
42 |
+
cv2.circle(image,
|
43 |
+
(int(pt[0]), int(pt[1])),
|
44 |
+
12, (0,255,255), -1, 1)
|
45 |
+
|
46 |
+
for pt1, pt2 in zip(mid_p[0::2,:], mid_p[1::2,:]):
|
47 |
+
cv2.line(image,
|
48 |
+
(int(pt1[0]), int(pt1[1])),
|
49 |
+
(int(pt2[0]), int(pt2[1])),
|
50 |
+
color=(0,0,255),
|
51 |
+
thickness=5, lineType=1)
|
52 |
+
|
53 |
+
vec_m = mid_p[1::2,:]-mid_p[0::2,:] # 17 x 2
|
54 |
+
dot_v = np.matmul(vec_m, np.transpose(vec_m)) # 17 x 17
|
55 |
+
mod_v = np.sqrt(np.sum(vec_m**2, axis=1))[:, np.newaxis] # 17 x 1
|
56 |
+
mod_v = np.matmul(mod_v, np.transpose(mod_v)) # 17 x 17
|
57 |
+
cosine_angles = np.clip(dot_v/mod_v, a_min=0., a_max=1.)
|
58 |
+
angles = np.arccos(cosine_angles) # 17 x 17
|
59 |
+
pos1 = np.argmax(angles, axis=1)
|
60 |
+
maxt = np.amax(angles, axis=1)
|
61 |
+
pos2 = np.argmax(maxt)
|
62 |
+
cobb_angle1 = np.amax(maxt)
|
63 |
+
cobb_angle1 = cobb_angle1/np.pi*180
|
64 |
+
flag_s = is_S(mid_p_v)
|
65 |
+
if not flag_s: # not S
|
66 |
+
# print('Not S')
|
67 |
+
cobb_angle2 = angles[0, pos2]/np.pi*180
|
68 |
+
cobb_angle3 = angles[vnum, pos1[pos2]]/np.pi*180
|
69 |
+
cv2.line(image,
|
70 |
+
(int(mid_p[pos2 * 2, 0] ), int(mid_p[pos2 * 2, 1])),
|
71 |
+
(int(mid_p[pos2 * 2 + 1, 0]), int(mid_p[pos2 * 2 + 1, 1])),
|
72 |
+
color=(0, 255, 0), thickness=5, lineType=2)
|
73 |
+
cv2.line(image,
|
74 |
+
(int(mid_p[pos1[pos2] * 2, 0]), int(mid_p[pos1[pos2] * 2, 1])),
|
75 |
+
(int(mid_p[pos1[pos2] * 2 + 1, 0]), int(mid_p[pos1[pos2] * 2 + 1, 1])),
|
76 |
+
color=(0, 255, 0), thickness=5, lineType=2)
|
77 |
+
|
78 |
+
else:
|
79 |
+
if (mid_p_v[pos2*2, 1]+mid_p_v[pos1[pos2]*2,1])<h:
|
80 |
+
# print('Is S: condition1')
|
81 |
+
angle2 = angles[pos2,:(pos2+1)]
|
82 |
+
cobb_angle2 = np.max(angle2)
|
83 |
+
pos1_1 = np.argmax(angle2)
|
84 |
+
cobb_angle2 = cobb_angle2/np.pi*180
|
85 |
+
|
86 |
+
angle3 = angles[pos1[pos2], pos1[pos2]:(vnum+1)]
|
87 |
+
cobb_angle3 = np.max(angle3)
|
88 |
+
pos1_2 = np.argmax(angle3)
|
89 |
+
cobb_angle3 = cobb_angle3/np.pi*180
|
90 |
+
pos1_2 = pos1_2 + pos1[pos2]-1
|
91 |
+
|
92 |
+
cv2.line(image,
|
93 |
+
(int(mid_p[pos1_1 * 2, 0]), int(mid_p[pos1_1 * 2, 1])),
|
94 |
+
(int(mid_p[pos1_1 * 2+1, 0]), int(mid_p[pos1_1 * 2 + 1, 1])),
|
95 |
+
color=(0, 255, 0), thickness=5, lineType=2)
|
96 |
+
|
97 |
+
cv2.line(image,
|
98 |
+
(int(mid_p[pos1_2 * 2, 0]), int(mid_p[pos1_2 * 2, 1])),
|
99 |
+
(int(mid_p[pos1_2 * 2+1, 0]), int(mid_p[pos1_2 * 2 + 1, 1])),
|
100 |
+
color=(0, 255, 0), thickness=5, lineType=2)
|
101 |
+
|
102 |
+
else:
|
103 |
+
# print('Is S: condition2')
|
104 |
+
angle2 = angles[pos2,:(pos2+1)]
|
105 |
+
cobb_angle2 = np.max(angle2)
|
106 |
+
pos1_1 = np.argmax(angle2)
|
107 |
+
cobb_angle2 = cobb_angle2/np.pi*180
|
108 |
+
|
109 |
+
angle3 = angles[pos1_1, :(pos1_1+1)]
|
110 |
+
cobb_angle3 = np.max(angle3)
|
111 |
+
pos1_2 = np.argmax(angle3)
|
112 |
+
cobb_angle3 = cobb_angle3/np.pi*180
|
113 |
+
|
114 |
+
cv2.line(image,
|
115 |
+
(int(mid_p[pos1_1 * 2, 0]), int(mid_p[pos1_1 * 2, 1])),
|
116 |
+
(int(mid_p[pos1_1 * 2+1, 0]), int(mid_p[pos1_1 * 2 + 1, 1])),
|
117 |
+
color=(0, 255, 0), thickness=5, lineType=2)
|
118 |
+
|
119 |
+
cv2.line(image,
|
120 |
+
(int(mid_p[pos1_2 * 2, 0]), int(mid_p[pos1_2 * 2, 1])),
|
121 |
+
(int(mid_p[pos1_2 * 2+1, 0]), int(mid_p[pos1_2 * 2 + 1, 1])),
|
122 |
+
color=(0, 255, 0), thickness=5, lineType=2)
|
123 |
+
|
124 |
+
return [cobb_angle1, cobb_angle2, cobb_angle3]
|
dataset.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch.utils.data as data
|
3 |
+
import pre_proc
|
4 |
+
import cv2
|
5 |
+
from scipy.io import loadmat
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
|
9 |
+
def rearrange_pts(pts):
|
10 |
+
boxes = []
|
11 |
+
for k in range(0, len(pts), 4):
|
12 |
+
pts_4 = pts[k:k+4,:]
|
13 |
+
x_inds = np.argsort(pts_4[:, 0])
|
14 |
+
pt_l = np.asarray(pts_4[x_inds[:2], :])
|
15 |
+
pt_r = np.asarray(pts_4[x_inds[2:], :])
|
16 |
+
y_inds_l = np.argsort(pt_l[:,1])
|
17 |
+
y_inds_r = np.argsort(pt_r[:,1])
|
18 |
+
tl = pt_l[y_inds_l[0], :]
|
19 |
+
bl = pt_l[y_inds_l[1], :]
|
20 |
+
tr = pt_r[y_inds_r[0], :]
|
21 |
+
br = pt_r[y_inds_r[1], :]
|
22 |
+
# boxes.append([tl, tr, bl, br])
|
23 |
+
boxes.append(tl)
|
24 |
+
boxes.append(tr)
|
25 |
+
boxes.append(bl)
|
26 |
+
boxes.append(br)
|
27 |
+
return np.asarray(boxes, np.float32)
|
28 |
+
|
29 |
+
|
30 |
+
class BaseDataset(data.Dataset):
|
31 |
+
def __init__(self, data_dir, phase, input_h=None, input_w=None, down_ratio=4):
|
32 |
+
super(BaseDataset, self).__init__()
|
33 |
+
self.data_dir = data_dir
|
34 |
+
self.phase = phase
|
35 |
+
self.input_h = input_h
|
36 |
+
self.input_w = input_w
|
37 |
+
self.down_ratio = down_ratio
|
38 |
+
self.class_name = ['__background__', 'cell']
|
39 |
+
self.num_classes = 68
|
40 |
+
self.img_dir = os.path.join(data_dir, 'data', self.phase)
|
41 |
+
self.img_ids = sorted(os.listdir(self.img_dir))
|
42 |
+
|
43 |
+
def load_image(self, index):
|
44 |
+
image = cv2.imread(os.path.join(self.img_dir, self.img_ids[index]))
|
45 |
+
return image
|
46 |
+
|
47 |
+
def load_gt_pts(self, annopath):
|
48 |
+
pts = loadmat(annopath)['p2'] # num x 2 (x,y)
|
49 |
+
pts = rearrange_pts(pts)
|
50 |
+
return pts
|
51 |
+
|
52 |
+
def load_annoFolder(self, img_id):
|
53 |
+
return os.path.join(self.data_dir, 'labels', self.phase, img_id+'.mat')
|
54 |
+
|
55 |
+
def load_annotation(self, index):
|
56 |
+
img_id = self.img_ids[index]
|
57 |
+
annoFolder = self.load_annoFolder(img_id)
|
58 |
+
pts = self.load_gt_pts(annoFolder)
|
59 |
+
return pts
|
60 |
+
|
61 |
+
def __getitem__(self, index):
|
62 |
+
img_id = self.img_ids[index]
|
63 |
+
image = self.load_image(index)
|
64 |
+
if self.phase == 'test':
|
65 |
+
images = pre_proc.processing_test(image=image, input_h=self.input_h, input_w=self.input_w)
|
66 |
+
return {'images': images, 'img_id': img_id}
|
67 |
+
else:
|
68 |
+
aug_label = False
|
69 |
+
if self.phase == 'train':
|
70 |
+
aug_label = True
|
71 |
+
pts = self.load_annotation(index) # num_obj x h x w
|
72 |
+
out_image, pts_2 = pre_proc.processing_train(image=image,
|
73 |
+
pts=pts,
|
74 |
+
image_h=self.input_h,
|
75 |
+
image_w=self.input_w,
|
76 |
+
down_ratio=self.down_ratio,
|
77 |
+
aug_label=aug_label,
|
78 |
+
img_id=img_id)
|
79 |
+
|
80 |
+
data_dict = pre_proc.generate_ground_truth(image=out_image,
|
81 |
+
pts_2=pts_2,
|
82 |
+
image_h=self.input_h//self.down_ratio,
|
83 |
+
image_w=self.input_w//self.down_ratio,
|
84 |
+
img_id=img_id)
|
85 |
+
return data_dict
|
86 |
+
|
87 |
+
def __len__(self):
|
88 |
+
return len(self.img_ids)
|
decoder.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn.functional as F
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
|
5 |
+
class DecDecoder(object):
|
6 |
+
def __init__(self, K, conf_thresh):
|
7 |
+
self.K = 17
|
8 |
+
self.conf_thresh = conf_thresh
|
9 |
+
|
10 |
+
def _topk(self, scores):
|
11 |
+
batch, cat, height, width = scores.size()
|
12 |
+
|
13 |
+
topk_scores, topk_inds = torch.topk(scores.view(batch, cat, -1), self.K)
|
14 |
+
|
15 |
+
topk_inds = topk_inds % (height * width)
|
16 |
+
topk_ys = (topk_inds / width).int().float()
|
17 |
+
topk_xs = (topk_inds % width).int().float()
|
18 |
+
|
19 |
+
topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), self.K)
|
20 |
+
topk_inds = self._gather_feat( topk_inds.view(batch, -1, 1), topk_ind).view(batch, self.K)
|
21 |
+
topk_ys = self._gather_feat(topk_ys.view(batch, -1, 1), topk_ind).view(batch, self.K)
|
22 |
+
topk_xs = self._gather_feat(topk_xs.view(batch, -1, 1), topk_ind).view(batch, self.K)
|
23 |
+
|
24 |
+
return topk_score, topk_inds, topk_ys, topk_xs
|
25 |
+
|
26 |
+
|
27 |
+
def _nms(self, heat, kernel=3):
|
28 |
+
hmax = F.max_pool2d(heat, (kernel, kernel), stride=1, padding=(kernel - 1) // 2)
|
29 |
+
keep = (hmax == heat).float()
|
30 |
+
return heat * keep
|
31 |
+
|
32 |
+
def _gather_feat(self, feat, ind, mask=None):
|
33 |
+
dim = feat.size(2)
|
34 |
+
ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
|
35 |
+
feat = feat.gather(1, ind)
|
36 |
+
if mask is not None:
|
37 |
+
mask = mask.unsqueeze(2).expand_as(feat)
|
38 |
+
feat = feat[mask]
|
39 |
+
feat = feat.view(-1, dim)
|
40 |
+
return feat
|
41 |
+
|
42 |
+
def _tranpose_and_gather_feat(self, feat, ind):
|
43 |
+
feat = feat.permute(0, 2, 3, 1).contiguous()
|
44 |
+
feat = feat.view(feat.size(0), -1, feat.size(3))
|
45 |
+
feat = self._gather_feat(feat, ind)
|
46 |
+
return feat
|
47 |
+
|
48 |
+
def ctdet_decode(self, heat, wh, reg):
|
49 |
+
# output: num_obj x 7
|
50 |
+
# 7: cenx, ceny, w, h, angle, score, cls
|
51 |
+
batch, c, height, width = heat.size()
|
52 |
+
heat = self._nms(heat) # [1, 1, 256, 128]
|
53 |
+
scores, inds, ys, xs = self._topk(heat)
|
54 |
+
scores = scores.view(batch, self.K, 1)
|
55 |
+
reg = self._tranpose_and_gather_feat(reg, inds)
|
56 |
+
reg = reg.view(batch, self.K, 2)
|
57 |
+
xs = xs.view(batch, self.K, 1) + reg[:, :, 0:1]
|
58 |
+
ys = ys.view(batch, self.K, 1) + reg[:, :, 1:2]
|
59 |
+
wh = self._tranpose_and_gather_feat(wh, inds)
|
60 |
+
wh = wh.view(batch, self.K, 2*4)
|
61 |
+
|
62 |
+
tl_x = xs - wh[:,:,0:1]
|
63 |
+
tl_y = ys - wh[:,:,1:2]
|
64 |
+
tr_x = xs - wh[:,:,2:3]
|
65 |
+
tr_y = ys - wh[:,:,3:4]
|
66 |
+
bl_x = xs - wh[:,:,4:5]
|
67 |
+
bl_y = ys - wh[:,:,5:6]
|
68 |
+
br_x = xs - wh[:,:,6:7]
|
69 |
+
br_y = ys - wh[:,:,7:8]
|
70 |
+
|
71 |
+
pts = torch.cat([xs, ys,
|
72 |
+
tl_x,tl_y,
|
73 |
+
tr_x,tr_y,
|
74 |
+
bl_x,bl_y,
|
75 |
+
br_x,br_y,
|
76 |
+
scores], dim=2).squeeze(0)
|
77 |
+
return pts.data.cpu().numpy()
|
dev_1.jpg
ADDED
![]() |
dev_2.jpg
ADDED
![]() |
dev_3.jpg
ADDED
![]() |
Git LFS Details
|
draw_gaussian.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
def gaussian_radius(det_size, min_overlap=0.7):
|
5 |
+
height, width = det_size
|
6 |
+
|
7 |
+
a1 = 1
|
8 |
+
b1 = (height + width)
|
9 |
+
c1 = width * height * (1 - min_overlap) / (1 + min_overlap)
|
10 |
+
sq1 = np.sqrt(b1 ** 2 - 4 * a1 * c1)
|
11 |
+
r1 = (b1 + sq1) / 2
|
12 |
+
|
13 |
+
a2 = 4
|
14 |
+
b2 = 2 * (height + width)
|
15 |
+
c2 = (1 - min_overlap) * width * height
|
16 |
+
sq2 = np.sqrt(b2 ** 2 - 4 * a2 * c2)
|
17 |
+
r2 = (b2 + sq2) / 2
|
18 |
+
|
19 |
+
a3 = 4 * min_overlap
|
20 |
+
b3 = -2 * min_overlap * (height + width)
|
21 |
+
c3 = (min_overlap - 1) * width * height
|
22 |
+
sq3 = np.sqrt(b3 ** 2 - 4 * a3 * c3)
|
23 |
+
r3 = (b3 + sq3) / 2
|
24 |
+
return min(r1, r2, r3)
|
25 |
+
|
26 |
+
def gaussian2D(shape, sigma=1):
|
27 |
+
m, n = [(ss - 1.) / 2. for ss in shape]
|
28 |
+
y, x = np.ogrid[-m:m+1,-n:n+1]
|
29 |
+
|
30 |
+
h = np.exp(-(x * x + y * y) / (2 * sigma * sigma))
|
31 |
+
h[h < np.finfo(h.dtype).eps * h.max()] = 0
|
32 |
+
return h
|
33 |
+
|
34 |
+
def draw_umich_gaussian(heatmap, center, radius, k=1):
|
35 |
+
diameter = 2 * radius + 1
|
36 |
+
gaussian = gaussian2D((diameter, diameter), sigma=diameter / 6)
|
37 |
+
|
38 |
+
x, y = int(center[0]), int(center[1])
|
39 |
+
|
40 |
+
height, width = heatmap.shape[0:2]
|
41 |
+
|
42 |
+
left, right = min(x, radius), min(width - x, radius + 1)
|
43 |
+
top, bottom = min(y, radius), min(height - y, radius + 1)
|
44 |
+
|
45 |
+
masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
|
46 |
+
masked_gaussian = gaussian[radius - top:radius + bottom, radius - left:radius + right]
|
47 |
+
if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0: # TODO debug
|
48 |
+
np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap)
|
49 |
+
return heatmap
|
draw_loss.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib.pyplot as plt
|
2 |
+
import numpy as np
|
3 |
+
import os
|
4 |
+
|
5 |
+
def load_data(filename):
|
6 |
+
pts = []
|
7 |
+
f = open(filename, "rb")
|
8 |
+
for line in f:
|
9 |
+
pts.append(float(line.strip()))
|
10 |
+
f.close()
|
11 |
+
return pts
|
12 |
+
|
13 |
+
dataset = 'spinal'
|
14 |
+
weights_path = 'weights_'+dataset
|
15 |
+
|
16 |
+
###############################################
|
17 |
+
# Load data
|
18 |
+
train_pts = load_data(os.path.join(weights_path, 'train_loss.txt'))
|
19 |
+
val_pts = load_data(os.path.join(weights_path, 'val_loss.txt'))
|
20 |
+
|
21 |
+
def draw_loss():
|
22 |
+
x = np.linspace(0, len(train_pts), len(train_pts))
|
23 |
+
plt.plot(x,train_pts,'ro-',label='train')
|
24 |
+
plt.plot(x,val_pts,'bo-',label='val')
|
25 |
+
# plt.axis([0, 50, 9.25, 11])
|
26 |
+
plt.legend(loc='upper right')
|
27 |
+
|
28 |
+
plt.xlabel('Epochs')
|
29 |
+
plt.ylabel('Loss')
|
30 |
+
|
31 |
+
plt.show()
|
32 |
+
|
33 |
+
|
34 |
+
def draw_loss_ap():
|
35 |
+
ap05_pts = load_data(os.path.join(weights_path, 'ap_05_list.txt'))
|
36 |
+
ap07_pts = load_data(os.path.join(weights_path, 'ap_07_list.txt'))
|
37 |
+
|
38 |
+
x = np.linspace(0,len(train_pts),len(train_pts))
|
39 |
+
x1 = np.linspace(0, len(train_pts), len(ap05_pts))
|
40 |
+
|
41 |
+
fig, ax1 = plt.subplots()
|
42 |
+
|
43 |
+
color = 'tab:red'
|
44 |
+
ax1.set_xlabel('Epochs')
|
45 |
+
ax1.set_ylabel('Loss', color=color)
|
46 |
+
ax1.plot(x, train_pts, 'ro-',label='train')
|
47 |
+
ax1.plot(x, val_pts, 'bo-',label='val')
|
48 |
+
ax1.tick_params(axis='y', labelcolor=color)
|
49 |
+
plt.legend(loc = 'lower right')
|
50 |
+
ax2 = ax1.twinx() # instantiate a second axes that shares the same x-axis
|
51 |
+
color = 'tab:blue'
|
52 |
+
ax2.set_ylabel('AP', color=color) # we already handled the x-label with ax1
|
53 |
+
ax2.plot(x1, ap05_pts, 'go-',label='AP@05')
|
54 |
+
ax2.plot(x1, ap07_pts, 'yo-', label='AP@07')
|
55 |
+
ax2.tick_params(axis='y', labelcolor=color)
|
56 |
+
|
57 |
+
fig.tight_layout() # otherwise the right y-label is slightly clipped
|
58 |
+
plt.legend(loc = 'upper right')
|
59 |
+
plt.show()
|
60 |
+
|
61 |
+
|
62 |
+
if __name__ == '__main__':
|
63 |
+
draw_loss()
|
64 |
+
# draw_loss_ap()
|
draw_points.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
colors = [[0.76590096, 0.0266074, 0.9806378],
|
5 |
+
[0.54197179, 0.81682527, 0.95081629],
|
6 |
+
[0.0799733, 0.79737015, 0.15173816],
|
7 |
+
[0.93240442, 0.8993321, 0.09901344],
|
8 |
+
[0.73130136, 0.05366301, 0.98405681],
|
9 |
+
[0.01664966, 0.16387004, 0.94158259],
|
10 |
+
[0.54197179, 0.81682527, 0.45081629],
|
11 |
+
# [0.92074915, 0.09919099 ,0.97590748],
|
12 |
+
[0.83445145, 0.97921679, 0.12250426],
|
13 |
+
[0.7300924, 0.23253621, 0.29764521],
|
14 |
+
[0.3856775, 0.94859286, 0.9910683], # 10
|
15 |
+
[0.45762137, 0.03766411, 0.98755338],
|
16 |
+
[0.99496697, 0.09113071, 0.83322314],
|
17 |
+
[0.96478873, 0.0233309, 0.13149931],
|
18 |
+
[0.33240442, 0.9993321 , 0.59901344],
|
19 |
+
# [0.77690519,0.81783954,0.56220024],
|
20 |
+
# [0.93240442, 0.8993321, 0.09901344],
|
21 |
+
[0.95815068, 0.88436046, 0.55782268],
|
22 |
+
[0.03728425, 0.0618827, 0.88641827],
|
23 |
+
[0.05281129, 0.89572238, 0.08913828],
|
24 |
+
|
25 |
+
]
|
26 |
+
|
27 |
+
|
28 |
+
|
29 |
+
def draw_landmarks_regress_test(pts0, ori_image_regress, ori_image_points):
|
30 |
+
for i, pt in enumerate(pts0):
|
31 |
+
# color = np.random.rand(3)
|
32 |
+
color = colors[i]
|
33 |
+
# print(i+1, color)
|
34 |
+
color_255 = (255 * color[0], 255 * color[1], 255 * color[2])
|
35 |
+
cv2.circle(ori_image_regress, (int(pt[0]), int(pt[1])), 6, color_255, -1, 1)
|
36 |
+
# cv2.circle(ori_image, (int(pt[2]), int(pt[3])), 5, color_255, -1,1)
|
37 |
+
# cv2.circle(ori_image, (int(pt[4]), int(pt[5])), 5, color_255, -1,1)
|
38 |
+
# cv2.circle(ori_image, (int(pt[6]), int(pt[7])), 5, color_255, -1,1)
|
39 |
+
# cv2.circle(ori_image, (int(pt[8]), int(pt[9])), 5, color_255, -1,1)
|
40 |
+
cv2.arrowedLine(ori_image_regress, (int(pt[0]), int(pt[1])), (int(pt[2]), int(pt[3])), color_255, 2, 1,
|
41 |
+
tipLength=0.2)
|
42 |
+
cv2.arrowedLine(ori_image_regress, (int(pt[0]), int(pt[1])), (int(pt[4]), int(pt[5])), color_255, 2, 1,
|
43 |
+
tipLength=0.2)
|
44 |
+
cv2.arrowedLine(ori_image_regress, (int(pt[0]), int(pt[1])), (int(pt[6]), int(pt[7])), color_255, 2, 1,
|
45 |
+
tipLength=0.2)
|
46 |
+
cv2.arrowedLine(ori_image_regress, (int(pt[0]), int(pt[1])), (int(pt[8]), int(pt[9])), color_255, 2, 1,
|
47 |
+
tipLength=0.2)
|
48 |
+
cv2.putText(ori_image_regress, '{}'.format(i + 1),
|
49 |
+
(int(pt[4] + 10), int(pt[5] + 10)),
|
50 |
+
cv2.FONT_HERSHEY_DUPLEX,
|
51 |
+
1.2,
|
52 |
+
color_255, # (255,255,255),
|
53 |
+
1,
|
54 |
+
1)
|
55 |
+
# cv2.circle(ori_image, (int(pt[0]), int(pt[1])), 6, (255,255,255), -1,1)
|
56 |
+
cv2.circle(ori_image_points, (int(pt[2]), int(pt[3])), 5, color_255, -1, 1)
|
57 |
+
cv2.circle(ori_image_points, (int(pt[4]), int(pt[5])), 5, color_255, -1, 1)
|
58 |
+
cv2.circle(ori_image_points, (int(pt[6]), int(pt[7])), 5, color_255, -1, 1)
|
59 |
+
cv2.circle(ori_image_points, (int(pt[8]), int(pt[9])), 5, color_255, -1, 1)
|
60 |
+
return ori_image_regress, ori_image_points
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
def draw_landmarks_pre_proc(out_image, pts):
|
65 |
+
for i in range(17):
|
66 |
+
pts_4 = pts[4 * i:4 * i + 4, :]
|
67 |
+
color = colors[i]
|
68 |
+
color_255 = (255 * color[0], 255 * color[1], 255 * color[2])
|
69 |
+
cv2.circle(out_image, (int(pts_4[0, 0]), int(pts_4[0, 1])), 5, color_255, -1, 1)
|
70 |
+
cv2.circle(out_image, (int(pts_4[1, 0]), int(pts_4[1, 1])), 5, color_255, -1, 1)
|
71 |
+
cv2.circle(out_image, (int(pts_4[2, 0]), int(pts_4[2, 1])), 5, color_255, -1, 1)
|
72 |
+
cv2.circle(out_image, (int(pts_4[3, 0]), int(pts_4[3, 1])), 5, color_255, -1, 1)
|
73 |
+
return np.uint8(out_image)
|
74 |
+
|
75 |
+
|
76 |
+
def draw_regress_pre_proc(out_image, pts):
|
77 |
+
for i in range(17):
|
78 |
+
pts_4 = pts[4 * i:4 * i + 4, :]
|
79 |
+
pt = np.mean(pts_4, axis=0)
|
80 |
+
color = colors[i]
|
81 |
+
color_255 = (255 * color[0], 255 * color[1], 255 * color[2])
|
82 |
+
cv2.arrowedLine(out_image, (int(pt[0]), int(pt[1])), (int(pts_4[0, 0]), int(pts_4[0, 1])), color_255, 2, 1,
|
83 |
+
tipLength=0.2)
|
84 |
+
cv2.arrowedLine(out_image, (int(pt[0]), int(pt[1])), (int(pts_4[1, 0]), int(pts_4[1, 1])), color_255, 2, 1,
|
85 |
+
tipLength=0.2)
|
86 |
+
cv2.arrowedLine(out_image, (int(pt[0]), int(pt[1])), (int(pts_4[2, 0]), int(pts_4[2, 1])), color_255, 2, 1,
|
87 |
+
tipLength=0.2)
|
88 |
+
cv2.arrowedLine(out_image, (int(pt[0]), int(pt[1])), (int(pts_4[3, 0]), int(pts_4[3, 1])), color_255, 2, 1,
|
89 |
+
tipLength=0.2)
|
90 |
+
cv2.putText(out_image, '{}'.format(i + 1), (int(pts_4[1, 0] + 10), int(pts_4[1, 1] + 10)),
|
91 |
+
cv2.FONT_HERSHEY_DUPLEX, 1.2, color_255, 1, 1)
|
92 |
+
return np.uint8(out_image)
|
eval.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from models import spinal_net
|
4 |
+
import decoder
|
5 |
+
import os
|
6 |
+
from dataset import BaseDataset
|
7 |
+
import time
|
8 |
+
import cobb_evaluate
|
9 |
+
|
10 |
+
def apply_mask(image, mask, alpha=0.5):
|
11 |
+
"""Apply the given mask to the image.
|
12 |
+
"""
|
13 |
+
color = np.random.rand(3)
|
14 |
+
for c in range(3):
|
15 |
+
image[:, :, c] = np.where(mask == 1,
|
16 |
+
image[:, :, c] *
|
17 |
+
(1 - alpha) + alpha * color[c] * 255,
|
18 |
+
image[:, :, c])
|
19 |
+
return image
|
20 |
+
|
21 |
+
class Network(object):
|
22 |
+
def __init__(self, args):
|
23 |
+
torch.manual_seed(317)
|
24 |
+
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
25 |
+
heads = {'hm': args.num_classes, # cen, tl, tr, bl, br
|
26 |
+
'reg': 2*args.num_classes,
|
27 |
+
'wh': 2*4,}
|
28 |
+
|
29 |
+
self.model = spinal_net.SpineNet(heads=heads,
|
30 |
+
pretrained=True,
|
31 |
+
down_ratio=args.down_ratio,
|
32 |
+
final_kernel=1,
|
33 |
+
head_conv=256)
|
34 |
+
self.num_classes = args.num_classes
|
35 |
+
self.decoder = decoder.DecDecoder(K=args.K, conf_thresh=args.conf_thresh)
|
36 |
+
self.dataset = {'spinal': BaseDataset}
|
37 |
+
|
38 |
+
def load_model(self, model, resume):
|
39 |
+
checkpoint = torch.load(resume, map_location=lambda storage, loc: storage)
|
40 |
+
print('loaded weights from {}, epoch {}'.format(resume, checkpoint['epoch']))
|
41 |
+
state_dict_ = checkpoint['state_dict']
|
42 |
+
model.load_state_dict(state_dict_, strict=False)
|
43 |
+
return model
|
44 |
+
|
45 |
+
|
46 |
+
def eval(self, args, save):
|
47 |
+
save_path = 'weights_'+args.dataset
|
48 |
+
self.model = self.load_model(self.model, os.path.join(save_path, args.resume))
|
49 |
+
self.model = self.model.to(self.device)
|
50 |
+
self.model.eval()
|
51 |
+
|
52 |
+
dataset_module = self.dataset[args.dataset]
|
53 |
+
dsets = dataset_module(data_dir=args.data_dir,
|
54 |
+
phase='test',
|
55 |
+
input_h=args.input_h,
|
56 |
+
input_w=args.input_w,
|
57 |
+
down_ratio=args.down_ratio)
|
58 |
+
|
59 |
+
data_loader = torch.utils.data.DataLoader(dsets,
|
60 |
+
batch_size=1,
|
61 |
+
shuffle=False,
|
62 |
+
num_workers=1,
|
63 |
+
pin_memory=True)
|
64 |
+
|
65 |
+
total_time = []
|
66 |
+
landmark_dist = []
|
67 |
+
pr_cobb_angles = []
|
68 |
+
gt_cobb_angles = []
|
69 |
+
for cnt, data_dict in enumerate(data_loader):
|
70 |
+
begin_time = time.time()
|
71 |
+
images = data_dict['images'][0]
|
72 |
+
img_id = data_dict['img_id'][0]
|
73 |
+
images = images.to('cuda')
|
74 |
+
print('processing {}/{} image ...'.format(cnt, len(data_loader)))
|
75 |
+
|
76 |
+
with torch.no_grad():
|
77 |
+
output = self.model(images)
|
78 |
+
hm = output['hm']
|
79 |
+
wh = output['wh']
|
80 |
+
reg = output['reg']
|
81 |
+
torch.cuda.synchronize(self.device)
|
82 |
+
pts2 = self.decoder.ctdet_decode(hm, wh, reg) # 17, 11
|
83 |
+
pts0 = pts2.copy()
|
84 |
+
pts0[:,:10] *= args.down_ratio
|
85 |
+
x_index = range(0,10,2)
|
86 |
+
y_index = range(1,10,2)
|
87 |
+
ori_image = dsets.load_image(dsets.img_ids.index(img_id)).copy()
|
88 |
+
h,w,c = ori_image.shape
|
89 |
+
pts0[:, x_index] = pts0[:, x_index]/args.input_w*w
|
90 |
+
pts0[:, y_index] = pts0[:, y_index]/args.input_h*h
|
91 |
+
# sort the y axis
|
92 |
+
sort_ind = np.argsort(pts0[:,1])
|
93 |
+
pts0 = pts0[sort_ind]
|
94 |
+
pr_landmarks = []
|
95 |
+
for i, pt in enumerate(pts0):
|
96 |
+
pr_landmarks.append(pt[2:4])
|
97 |
+
pr_landmarks.append(pt[4:6])
|
98 |
+
pr_landmarks.append(pt[6:8])
|
99 |
+
pr_landmarks.append(pt[8:10])
|
100 |
+
pr_landmarks = np.asarray(pr_landmarks, np.float32) #[68, 2]
|
101 |
+
|
102 |
+
end_time = time.time()
|
103 |
+
total_time.append(end_time-begin_time)
|
104 |
+
|
105 |
+
gt_landmarks = dsets.load_gt_pts(dsets.load_annoFolder(img_id))
|
106 |
+
for pr_pt, gt_pt in zip(pr_landmarks, gt_landmarks):
|
107 |
+
landmark_dist.append(np.sqrt((pr_pt[0]-gt_pt[0])**2+(pr_pt[1]-gt_pt[1])**2))
|
108 |
+
|
109 |
+
pr_cobb_angles.append(cobb_evaluate.cobb_angle_calc(pr_landmarks, ori_image))
|
110 |
+
gt_cobb_angles.append(cobb_evaluate.cobb_angle_calc(gt_landmarks, ori_image))
|
111 |
+
|
112 |
+
pr_cobb_angles = np.asarray(pr_cobb_angles, np.float32)
|
113 |
+
gt_cobb_angles = np.asarray(gt_cobb_angles, np.float32)
|
114 |
+
|
115 |
+
out_abs = abs(gt_cobb_angles - pr_cobb_angles)
|
116 |
+
out_add = gt_cobb_angles + pr_cobb_angles
|
117 |
+
|
118 |
+
term1 = np.sum(out_abs, axis=1)
|
119 |
+
term2 = np.sum(out_add, axis=1)
|
120 |
+
|
121 |
+
SMAPE = np.mean(term1 / term2 * 100)
|
122 |
+
|
123 |
+
print('mse of landmarkds is {}'.format(np.mean(landmark_dist)))
|
124 |
+
print('SMAPE is {}'.format(SMAPE))
|
125 |
+
|
126 |
+
total_time = total_time[1:]
|
127 |
+
print('avg time is {}'.format(np.mean(total_time)))
|
128 |
+
print('FPS is {}'.format(1./np.mean(total_time)))
|
129 |
+
|
130 |
+
|
131 |
+
def SMAPE_single_angle(self, gt_cobb_angles, pr_cobb_angles):
|
132 |
+
out_abs = abs(gt_cobb_angles - pr_cobb_angles)
|
133 |
+
out_add = gt_cobb_angles + pr_cobb_angles
|
134 |
+
|
135 |
+
term1 = out_abs
|
136 |
+
term2 = out_add
|
137 |
+
|
138 |
+
term2[term2==0] += 1e-5
|
139 |
+
|
140 |
+
SMAPE = np.mean(term1 / term2 * 100)
|
141 |
+
return SMAPE
|
142 |
+
|
143 |
+
def eval_three_angles(self, args, save):
|
144 |
+
save_path = 'weights_'+args.dataset
|
145 |
+
self.model = self.load_model(self.model, os.path.join(save_path, args.resume))
|
146 |
+
self.model = self.model.to(self.device)
|
147 |
+
self.model.eval()
|
148 |
+
|
149 |
+
dataset_module = self.dataset[args.dataset]
|
150 |
+
dsets = dataset_module(data_dir=args.data_dir,
|
151 |
+
phase='test',
|
152 |
+
input_h=args.input_h,
|
153 |
+
input_w=args.input_w,
|
154 |
+
down_ratio=args.down_ratio)
|
155 |
+
|
156 |
+
data_loader = torch.utils.data.DataLoader(dsets,
|
157 |
+
batch_size=1,
|
158 |
+
shuffle=False,
|
159 |
+
num_workers=1,
|
160 |
+
pin_memory=True)
|
161 |
+
|
162 |
+
total_time = []
|
163 |
+
landmark_dist = []
|
164 |
+
pr_cobb_angles = []
|
165 |
+
gt_cobb_angles = []
|
166 |
+
for cnt, data_dict in enumerate(data_loader):
|
167 |
+
begin_time = time.time()
|
168 |
+
images = data_dict['images'][0]
|
169 |
+
img_id = data_dict['img_id'][0]
|
170 |
+
images = images.to('cuda')
|
171 |
+
print('processing {}/{} image ...'.format(cnt, len(data_loader)))
|
172 |
+
|
173 |
+
with torch.no_grad():
|
174 |
+
output = self.model(images)
|
175 |
+
hm = output['hm']
|
176 |
+
wh = output['wh']
|
177 |
+
reg = output['reg']
|
178 |
+
torch.cuda.synchronize(self.device)
|
179 |
+
pts2 = self.decoder.ctdet_decode(hm, wh, reg) # 17, 11
|
180 |
+
pts0 = pts2.copy()
|
181 |
+
pts0[:,:10] *= args.down_ratio
|
182 |
+
x_index = range(0,10,2)
|
183 |
+
y_index = range(1,10,2)
|
184 |
+
ori_image = dsets.load_image(dsets.img_ids.index(img_id)).copy()
|
185 |
+
h,w,c = ori_image.shape
|
186 |
+
pts0[:, x_index] = pts0[:, x_index]/args.input_w*w
|
187 |
+
pts0[:, y_index] = pts0[:, y_index]/args.input_h*h
|
188 |
+
# sort the y axis
|
189 |
+
sort_ind = np.argsort(pts0[:,1])
|
190 |
+
pts0 = pts0[sort_ind]
|
191 |
+
pr_landmarks = []
|
192 |
+
for i, pt in enumerate(pts0):
|
193 |
+
pr_landmarks.append(pt[2:4])
|
194 |
+
pr_landmarks.append(pt[4:6])
|
195 |
+
pr_landmarks.append(pt[6:8])
|
196 |
+
pr_landmarks.append(pt[8:10])
|
197 |
+
pr_landmarks = np.asarray(pr_landmarks, np.float32) #[68, 2]
|
198 |
+
|
199 |
+
end_time = time.time()
|
200 |
+
total_time.append(end_time-begin_time)
|
201 |
+
|
202 |
+
gt_landmarks = dsets.load_gt_pts(dsets.load_annoFolder(img_id))
|
203 |
+
for pr_pt, gt_pt in zip(pr_landmarks, gt_landmarks):
|
204 |
+
landmark_dist.append(np.sqrt((pr_pt[0]-gt_pt[0])**2+(pr_pt[1]-gt_pt[1])**2))
|
205 |
+
|
206 |
+
pr_cobb_angles.append(cobb_evaluate.cobb_angle_calc(pr_landmarks, ori_image))
|
207 |
+
gt_cobb_angles.append(cobb_evaluate.cobb_angle_calc(gt_landmarks, ori_image))
|
208 |
+
|
209 |
+
pr_cobb_angles = np.asarray(pr_cobb_angles, np.float32)
|
210 |
+
gt_cobb_angles = np.asarray(gt_cobb_angles, np.float32)
|
211 |
+
|
212 |
+
|
213 |
+
print('SMAPE1 is {}'.format(self.SMAPE_single_angle(gt_cobb_angles[:,0], pr_cobb_angles[:,0])))
|
214 |
+
print('SMAPE2 is {}'.format(self.SMAPE_single_angle(gt_cobb_angles[:,1], pr_cobb_angles[:,1])))
|
215 |
+
print('SMAPE3 is {}'.format(self.SMAPE_single_angle(gt_cobb_angles[:,2], pr_cobb_angles[:,2])))
|
216 |
+
|
217 |
+
print('mse of landmarkds is {}'.format(np.mean(landmark_dist)))
|
218 |
+
|
219 |
+
total_time = total_time[1:]
|
220 |
+
print('avg time is {}'.format(np.mean(total_time)))
|
221 |
+
print('FPS is {}'.format(1./np.mean(total_time)))
|
222 |
+
|
image_1.jpg
ADDED
![]() |
Git LFS Details
|
image_1_la.jpg
ADDED
![]() |
image_2.jpg
ADDED
![]() |
image_2_la.jpg
ADDED
![]() |
image_3.jpg
ADDED
![]() |
Git LFS Details
|
image_3_la.jpg
ADDED
![]() |
loss.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class RegL1Loss(nn.Module):
|
7 |
+
def __init__(self):
|
8 |
+
super(RegL1Loss, self).__init__()
|
9 |
+
|
10 |
+
def _gather_feat(self, feat, ind, mask=None):
|
11 |
+
dim = feat.size(2)
|
12 |
+
ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
|
13 |
+
feat = feat.gather(1, ind)
|
14 |
+
if mask is not None:
|
15 |
+
mask = mask.unsqueeze(2).expand_as(feat)
|
16 |
+
feat = feat[mask]
|
17 |
+
feat = feat.view(-1, dim)
|
18 |
+
return feat
|
19 |
+
|
20 |
+
def _tranpose_and_gather_feat(self, feat, ind):
|
21 |
+
feat = feat.permute(0, 2, 3, 1).contiguous()
|
22 |
+
feat = feat.view(feat.size(0), -1, feat.size(3))
|
23 |
+
feat = self._gather_feat(feat, ind)
|
24 |
+
return feat
|
25 |
+
|
26 |
+
def forward(self, output, mask, ind, target):
|
27 |
+
pred = self._tranpose_and_gather_feat(output, ind)
|
28 |
+
mask = mask.unsqueeze(2).expand_as(pred).float()
|
29 |
+
loss = F.l1_loss(pred * mask, target * mask, reduction='sum')
|
30 |
+
loss = loss / (mask.sum() + 1e-4)
|
31 |
+
return loss
|
32 |
+
|
33 |
+
class FocalLoss(nn.Module):
|
34 |
+
def __init__(self):
|
35 |
+
super(FocalLoss, self).__init__()
|
36 |
+
|
37 |
+
def forward(self, pred, gt):
|
38 |
+
pos_inds = gt.eq(1).float()
|
39 |
+
neg_inds = gt.lt(1).float()
|
40 |
+
neg_weights = torch.pow(1 - gt, 4)
|
41 |
+
|
42 |
+
loss = 0
|
43 |
+
|
44 |
+
pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
|
45 |
+
neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds
|
46 |
+
|
47 |
+
num_pos = pos_inds.float().sum()
|
48 |
+
pos_loss = pos_loss.sum()
|
49 |
+
neg_loss = neg_loss.sum()
|
50 |
+
|
51 |
+
if num_pos == 0:
|
52 |
+
loss = loss - neg_loss
|
53 |
+
else:
|
54 |
+
loss = loss - (pos_loss + neg_loss) / num_pos
|
55 |
+
return loss
|
56 |
+
|
57 |
+
class LossAll(torch.nn.Module):
|
58 |
+
def __init__(self):
|
59 |
+
super(LossAll, self).__init__()
|
60 |
+
self.L_hm = FocalLoss()
|
61 |
+
self.L_off = RegL1Loss()
|
62 |
+
self.L_wh = RegL1Loss()
|
63 |
+
|
64 |
+
def forward(self, pr_decs, gt_batch):
|
65 |
+
hm_loss = self.L_hm(pr_decs['hm'], gt_batch['hm'])
|
66 |
+
wh_loss = self.L_wh(pr_decs['wh'], gt_batch['reg_mask'], gt_batch['ind'], gt_batch['wh'])
|
67 |
+
off_loss = self.L_off(pr_decs['reg'], gt_batch['reg_mask'], gt_batch['ind'], gt_batch['reg'])
|
68 |
+
loss_dec = hm_loss + off_loss + wh_loss
|
69 |
+
return loss_dec
|
main.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import train
|
3 |
+
import test
|
4 |
+
import eval
|
5 |
+
|
6 |
+
def parse_args():
|
7 |
+
parser = argparse.ArgumentParser(description='CenterNet Modification Implementation')
|
8 |
+
parser.add_argument('--num_epoch', type=int, default=50, help='Number of epochs')
|
9 |
+
parser.add_argument('--batch_size', type=int, default=2, help='Number of epochs')
|
10 |
+
parser.add_argument('--num_workers', type=int, default=4, help='Number of workers')
|
11 |
+
parser.add_argument('--init_lr', type=float, default=1.25e-4, help='Init learning rate')
|
12 |
+
parser.add_argument('--down_ratio', type=int, default=4, help='down ratio')
|
13 |
+
parser.add_argument('--input_h', type=int, default=1024, help='input height')
|
14 |
+
parser.add_argument('--input_w', type=int, default=512, help='input width')
|
15 |
+
parser.add_argument('--K', type=int, default=100, help='maximum of objects')
|
16 |
+
parser.add_argument('--conf_thresh', type=float, default=0.2, help='confidence threshold')
|
17 |
+
parser.add_argument('--seg_thresh', type=float, default=0.5, help='confidence threshold')
|
18 |
+
parser.add_argument('--num_classes', type=int, default=1, help='number of classes')
|
19 |
+
parser.add_argument('--ngpus', type=int, default=0, help='number of gpus')
|
20 |
+
parser.add_argument('--resume', type=str, default='model_last.pth', help='weights to be resumed')
|
21 |
+
parser.add_argument('--data_dir', type=str, default='../../Datasets/spinal/', help='data directory')
|
22 |
+
parser.add_argument('--phase', type=str, default='test', help='data directory')
|
23 |
+
parser.add_argument('--dataset', type=str, default='spinal', help='data directory')
|
24 |
+
args = parser.parse_args()
|
25 |
+
return args
|
26 |
+
|
27 |
+
|
28 |
+
|
29 |
+
if __name__ == '__main__':
|
30 |
+
args = parse_args()
|
31 |
+
if args.phase == 'train':
|
32 |
+
is_object = train.Network(args)
|
33 |
+
is_object.train_network(args)
|
34 |
+
elif args.phase == 'test':
|
35 |
+
is_object = test.Network(args)
|
36 |
+
is_object.test(args, save=True)
|
37 |
+
elif args.phase == 'eval':
|
38 |
+
is_object = eval.Network(args)
|
39 |
+
is_object.eval(args, save=True)
|
40 |
+
# is_object.eval_three_angles(args, save=False)
|
make_requirements.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
import os
|
3 |
+
import ast
|
4 |
+
import sys
|
5 |
+
from importlib import metadata
|
6 |
+
|
7 |
+
# --- CONFIGURE THIS ---
|
8 |
+
PROJECT_PATH = r"C:\Users\santi\Desktop\Oto\Vertebra-Landmark-Detection"
|
9 |
+
OUTPUT_FILE = os.path.join(PROJECT_PATH, "requirements.txt")
|
10 |
+
# ----------------------
|
11 |
+
|
12 |
+
def find_py_files(root):
|
13 |
+
for dirpath, dirnames, filenames in os.walk(root):
|
14 |
+
# skip __pycache__
|
15 |
+
dirnames[:] = [d for d in dirnames if d != "__pycache__"]
|
16 |
+
for fname in filenames:
|
17 |
+
if fname.endswith(".py"):
|
18 |
+
yield os.path.join(dirpath, fname)
|
19 |
+
|
20 |
+
def collect_imports(py_path):
|
21 |
+
with open(py_path, "r", encoding="utf8") as f:
|
22 |
+
node = ast.parse(f.read(), filename=py_path)
|
23 |
+
imports = set()
|
24 |
+
for stmt in ast.walk(node):
|
25 |
+
if isinstance(stmt, ast.Import):
|
26 |
+
for n in stmt.names:
|
27 |
+
imports.add(n.name.split(".")[0])
|
28 |
+
elif isinstance(stmt, ast.ImportFrom):
|
29 |
+
if stmt.module and stmt.level == 0:
|
30 |
+
imports.add(stmt.module.split(".")[0])
|
31 |
+
return imports
|
32 |
+
|
33 |
+
def is_local_module(mod_name, project_root):
|
34 |
+
# if there's a folder or file matching mod_name in project, treat as local
|
35 |
+
path1 = os.path.join(project_root, mod_name + ".py")
|
36 |
+
path2 = os.path.join(project_root, mod_name)
|
37 |
+
return os.path.exists(path1) or os.path.exists(path2)
|
38 |
+
|
39 |
+
def main():
|
40 |
+
all_imports = set()
|
41 |
+
for py in find_py_files(PROJECT_PATH):
|
42 |
+
all_imports |= collect_imports(py)
|
43 |
+
|
44 |
+
# filter out builtins, stdlib, and local modules
|
45 |
+
externals = set()
|
46 |
+
for mod in sorted(all_imports):
|
47 |
+
if is_local_module(mod, PROJECT_PATH):
|
48 |
+
continue
|
49 |
+
try:
|
50 |
+
# try to see if it's installed as a distribution
|
51 |
+
dist = metadata.distribution(mod)
|
52 |
+
externals.add(f"{dist.metadata['Name']}=={dist.version}")
|
53 |
+
except metadata.PackageNotFoundError:
|
54 |
+
# not a top-level distribution, maybe stdlib or nested import
|
55 |
+
# skip modules that come with the stdlib
|
56 |
+
# crude check: if we can import and it's in stdlib path, skip
|
57 |
+
try:
|
58 |
+
m = __import__(mod)
|
59 |
+
if hasattr(m, "__file__") and "site-packages" in (m.__file__ or ""):
|
60 |
+
# lives in site-packages but dist metadata missing: include without version
|
61 |
+
externals.add(mod)
|
62 |
+
except ImportError:
|
63 |
+
pass
|
64 |
+
|
65 |
+
# write requirements.txt
|
66 |
+
with open(OUTPUT_FILE, "w", encoding="utf8") as out:
|
67 |
+
for line in sorted(externals):
|
68 |
+
out.write(line + "\n")
|
69 |
+
|
70 |
+
print(f"Written {len(externals)} packages to {OUTPUT_FILE}")
|
71 |
+
|
72 |
+
if __name__ == "__main__":
|
73 |
+
main()
|
pre_proc.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import torch
|
3 |
+
from draw_gaussian import *
|
4 |
+
import transform
|
5 |
+
import math
|
6 |
+
|
7 |
+
|
8 |
+
def processing_test(image, input_h, input_w):
|
9 |
+
image = cv2.resize(image, (input_w, input_h))
|
10 |
+
out_image = image.astype(np.float32) / 255.
|
11 |
+
out_image = out_image - 0.5
|
12 |
+
out_image = out_image.transpose(2, 0, 1).reshape(1, 3, input_h, input_w)
|
13 |
+
out_image = torch.from_numpy(out_image)
|
14 |
+
return out_image
|
15 |
+
|
16 |
+
|
17 |
+
def draw_spinal(pts, out_image):
|
18 |
+
colors = [(0, 0, 255), (0, 255, 255), (255, 0, 255), (0, 255, 0)]
|
19 |
+
for i in range(4):
|
20 |
+
cv2.circle(out_image, (int(pts[i, 0]), int(pts[i, 1])), 3, colors[i], 1, 1)
|
21 |
+
cv2.putText(out_image, '{}'.format(i+1), (int(pts[i, 0]), int(pts[i, 1])),
|
22 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.3, (0,0,0),1,1)
|
23 |
+
for i,j in zip([0,1,2,3], [1,2,3,0]):
|
24 |
+
cv2.line(out_image,
|
25 |
+
(int(pts[i, 0]), int(pts[i, 1])),
|
26 |
+
(int(pts[j, 0]), int(pts[j, 1])),
|
27 |
+
color=colors[i], thickness=1, lineType=1)
|
28 |
+
return out_image
|
29 |
+
|
30 |
+
|
31 |
+
def rearrange_pts(pts):
|
32 |
+
# rearrange left right sequence
|
33 |
+
boxes = []
|
34 |
+
centers = []
|
35 |
+
for k in range(0, len(pts), 4):
|
36 |
+
pts_4 = pts[k:k+4,:]
|
37 |
+
x_inds = np.argsort(pts_4[:, 0])
|
38 |
+
pt_l = np.asarray(pts_4[x_inds[:2], :])
|
39 |
+
pt_r = np.asarray(pts_4[x_inds[2:], :])
|
40 |
+
y_inds_l = np.argsort(pt_l[:,1])
|
41 |
+
y_inds_r = np.argsort(pt_r[:,1])
|
42 |
+
tl = pt_l[y_inds_l[0], :]
|
43 |
+
bl = pt_l[y_inds_l[1], :]
|
44 |
+
tr = pt_r[y_inds_r[0], :]
|
45 |
+
br = pt_r[y_inds_r[1], :]
|
46 |
+
# boxes.append([tl, tr, bl, br])
|
47 |
+
boxes.append(tl)
|
48 |
+
boxes.append(tr)
|
49 |
+
boxes.append(bl)
|
50 |
+
boxes.append(br)
|
51 |
+
centers.append(np.mean(pts_4, axis=0))
|
52 |
+
bboxes = np.asarray(boxes, np.float32)
|
53 |
+
# rearrange top to bottom sequence
|
54 |
+
centers = np.asarray(centers, np.float32)
|
55 |
+
sort_tb = np.argsort(centers[:,1])
|
56 |
+
new_bboxes = []
|
57 |
+
for sort_i in sort_tb:
|
58 |
+
new_bboxes.append(bboxes[4*sort_i, :])
|
59 |
+
new_bboxes.append(bboxes[4*sort_i+1, :])
|
60 |
+
new_bboxes.append(bboxes[4*sort_i+2, :])
|
61 |
+
new_bboxes.append(bboxes[4*sort_i+3, :])
|
62 |
+
new_bboxes = np.asarray(new_bboxes, np.float32)
|
63 |
+
return new_bboxes
|
64 |
+
|
65 |
+
|
66 |
+
def generate_ground_truth(image,
|
67 |
+
pts_2,
|
68 |
+
image_h,
|
69 |
+
image_w,
|
70 |
+
img_id):
|
71 |
+
hm = np.zeros((1, image_h, image_w), dtype=np.float32)
|
72 |
+
wh = np.zeros((17, 2*4), dtype=np.float32)
|
73 |
+
reg = np.zeros((17, 2), dtype=np.float32)
|
74 |
+
ind = np.zeros((17), dtype=np.int64)
|
75 |
+
reg_mask = np.zeros((17), dtype=np.uint8)
|
76 |
+
|
77 |
+
if pts_2[:,0].max()>image_w:
|
78 |
+
print('w is big', pts_2[:,0].max())
|
79 |
+
if pts_2[:,1].max()>image_h:
|
80 |
+
print('h is big', pts_2[:,1].max())
|
81 |
+
|
82 |
+
if pts_2.shape[0]!=68:
|
83 |
+
print('ATTENTION!! image {} pts does not equal to 68!!! '.format(img_id))
|
84 |
+
|
85 |
+
for k in range(17):
|
86 |
+
pts = pts_2[4*k:4*k+4,:]
|
87 |
+
bbox_h = np.mean([np.sqrt(np.sum((pts[0,:]-pts[2,:])**2)),
|
88 |
+
np.sqrt(np.sum((pts[1,:]-pts[3,:])**2))])
|
89 |
+
bbox_w = np.mean([np.sqrt(np.sum((pts[0,:]-pts[1,:])**2)),
|
90 |
+
np.sqrt(np.sum((pts[2,:]-pts[3,:])**2))])
|
91 |
+
cen_x, cen_y = np.mean(pts, axis=0)
|
92 |
+
ct = np.asarray([cen_x, cen_y], dtype=np.float32)
|
93 |
+
ct_int = ct.astype(np.int32)
|
94 |
+
radius = gaussian_radius((math.ceil(bbox_h), math.ceil(bbox_w)))
|
95 |
+
radius = max(0, int(radius))
|
96 |
+
draw_umich_gaussian(hm[0,:,:], ct_int, radius=radius)
|
97 |
+
ind[k] = ct_int[1] * image_w + ct_int[0]
|
98 |
+
reg[k] = ct - ct_int
|
99 |
+
reg_mask[k] = 1
|
100 |
+
for i in range(4):
|
101 |
+
wh[k,2*i:2*i+2] = ct-pts[i,:]
|
102 |
+
|
103 |
+
ret = {'input': image,
|
104 |
+
'hm': hm,
|
105 |
+
'ind': ind,
|
106 |
+
'reg': reg,
|
107 |
+
'wh': wh,
|
108 |
+
'reg_mask': reg_mask,
|
109 |
+
}
|
110 |
+
|
111 |
+
return ret
|
112 |
+
|
113 |
+
# def filter_pts(pts, w, h):
|
114 |
+
# pts_new = []
|
115 |
+
# for pt in pts:
|
116 |
+
# if any(pt) < 0 or pt[0] > w - 1 or pt[1] > h - 1:
|
117 |
+
# continue
|
118 |
+
# else:
|
119 |
+
# pts_new.append(pt)
|
120 |
+
# return np.asarray(pts_new, np.float32)
|
121 |
+
|
122 |
+
|
123 |
+
def processing_train(image, pts, image_h, image_w, down_ratio, aug_label, img_id):
|
124 |
+
# filter pts ----------------------------------------------------
|
125 |
+
h,w,c = image.shape
|
126 |
+
# pts = filter_pts(pts, w, h)
|
127 |
+
# ---------------------------------------------------------------
|
128 |
+
data_aug = {'train': transform.Compose([transform.ConvertImgFloat(),
|
129 |
+
transform.PhotometricDistort(),
|
130 |
+
transform.Expand(max_scale=1.5, mean=(0, 0, 0)),
|
131 |
+
transform.RandomMirror_w(),
|
132 |
+
transform.Resize(h=image_h, w=image_w)]),
|
133 |
+
'val': transform.Compose([transform.ConvertImgFloat(),
|
134 |
+
transform.Resize(h=image_h, w=image_w)])}
|
135 |
+
if aug_label:
|
136 |
+
out_image, pts = data_aug['train'](image.copy(), pts)
|
137 |
+
else:
|
138 |
+
out_image, pts = data_aug['val'](image.copy(), pts)
|
139 |
+
|
140 |
+
out_image = np.clip(out_image, a_min=0., a_max=255.)
|
141 |
+
out_image = np.transpose(out_image / 255. - 0.5, (2,0,1))
|
142 |
+
pts = rearrange_pts(pts)
|
143 |
+
pts2 = transform.rescale_pts(pts, down_ratio=down_ratio)
|
144 |
+
|
145 |
+
return np.asarray(out_image, np.float32), pts2
|
146 |
+
|
test.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from models import spinal_net
|
4 |
+
import cv2
|
5 |
+
import decoder
|
6 |
+
import os
|
7 |
+
from dataset import BaseDataset
|
8 |
+
import draw_points
|
9 |
+
|
10 |
+
def apply_mask(image, mask, alpha=0.5):
|
11 |
+
"""Apply the given mask to the image.
|
12 |
+
"""
|
13 |
+
color = np.random.rand(3)
|
14 |
+
for c in range(3):
|
15 |
+
image[:, :, c] = np.where(mask == 1,
|
16 |
+
image[:, :, c] *
|
17 |
+
(1 - alpha) + alpha * color[c] * 255,
|
18 |
+
image[:, :, c])
|
19 |
+
return image
|
20 |
+
|
21 |
+
class Network(object):
|
22 |
+
def __init__(self, args):
|
23 |
+
torch.manual_seed(317)
|
24 |
+
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
25 |
+
heads = {'hm': args.num_classes,
|
26 |
+
'reg': 2*args.num_classes,
|
27 |
+
'wh': 2*4,}
|
28 |
+
|
29 |
+
self.model = spinal_net.SpineNet(heads=heads,
|
30 |
+
pretrained=True,
|
31 |
+
down_ratio=args.down_ratio,
|
32 |
+
final_kernel=1,
|
33 |
+
head_conv=256)
|
34 |
+
self.num_classes = args.num_classes
|
35 |
+
self.decoder = decoder.DecDecoder(K=args.K, conf_thresh=args.conf_thresh)
|
36 |
+
self.dataset = {'spinal': BaseDataset}
|
37 |
+
|
38 |
+
def load_model(self, model, resume):
|
39 |
+
checkpoint = torch.load(resume, map_location=lambda storage, loc: storage)
|
40 |
+
print('loaded weights from {}, epoch {}'.format(resume, checkpoint['epoch']))
|
41 |
+
state_dict_ = checkpoint['state_dict']
|
42 |
+
model.load_state_dict(state_dict_, strict=False)
|
43 |
+
return model
|
44 |
+
|
45 |
+
def map_mask_to_image(self, mask, img, color=None):
|
46 |
+
if color is None:
|
47 |
+
color = np.random.rand(3)
|
48 |
+
mask = np.repeat(mask[:, :, np.newaxis], 3, axis=2)
|
49 |
+
mskd = img * mask
|
50 |
+
clmsk = np.ones(mask.shape) * mask
|
51 |
+
clmsk[:, :, 0] = clmsk[:, :, 0] * color[0] * 256
|
52 |
+
clmsk[:, :, 1] = clmsk[:, :, 1] * color[1] * 256
|
53 |
+
clmsk[:, :, 2] = clmsk[:, :, 2] * color[2] * 256
|
54 |
+
img = img + 1. * clmsk - 1. * mskd
|
55 |
+
return np.uint8(img)
|
56 |
+
|
57 |
+
|
58 |
+
def test(self, args, save):
|
59 |
+
save_path = 'weights_'+args.dataset
|
60 |
+
self.model = self.load_model(self.model, os.path.join(save_path, args.resume))
|
61 |
+
self.model = self.model.to(self.device)
|
62 |
+
self.model.eval()
|
63 |
+
|
64 |
+
dataset_module = self.dataset[args.dataset]
|
65 |
+
dsets = dataset_module(data_dir=args.data_dir,
|
66 |
+
phase='test',
|
67 |
+
input_h=args.input_h,
|
68 |
+
input_w=args.input_w,
|
69 |
+
down_ratio=args.down_ratio)
|
70 |
+
|
71 |
+
data_loader = torch.utils.data.DataLoader(dsets,
|
72 |
+
batch_size=1,
|
73 |
+
shuffle=False,
|
74 |
+
num_workers=1,
|
75 |
+
pin_memory=True)
|
76 |
+
|
77 |
+
|
78 |
+
for cnt, data_dict in enumerate(data_loader):
|
79 |
+
images = data_dict['images'][0]
|
80 |
+
img_id = data_dict['img_id'][0]
|
81 |
+
images = images.to('cuda')
|
82 |
+
print('processing {}/{} image ... {}'.format(cnt, len(data_loader), img_id))
|
83 |
+
with torch.no_grad():
|
84 |
+
output = self.model(images)
|
85 |
+
hm = output['hm']
|
86 |
+
wh = output['wh']
|
87 |
+
reg = output['reg']
|
88 |
+
|
89 |
+
torch.cuda.synchronize(self.device)
|
90 |
+
pts2 = self.decoder.ctdet_decode(hm, wh, reg) # 17, 11
|
91 |
+
pts0 = pts2.copy()
|
92 |
+
pts0[:,:10] *= args.down_ratio
|
93 |
+
|
94 |
+
print('totol pts num is {}'.format(len(pts2)))
|
95 |
+
|
96 |
+
ori_image = dsets.load_image(dsets.img_ids.index(img_id))
|
97 |
+
ori_image_regress = cv2.resize(ori_image, (args.input_w, args.input_h))
|
98 |
+
ori_image_points = ori_image_regress.copy()
|
99 |
+
|
100 |
+
h,w,c = ori_image.shape
|
101 |
+
pts0 = np.asarray(pts0, np.float32)
|
102 |
+
# pts0[:,0::2] = pts0[:,0::2]/args.input_w*w
|
103 |
+
# pts0[:,1::2] = pts0[:,1::2]/args.input_h*h
|
104 |
+
sort_ind = np.argsort(pts0[:,1])
|
105 |
+
pts0 = pts0[sort_ind]
|
106 |
+
|
107 |
+
ori_image_regress, ori_image_points = draw_points.draw_landmarks_regress_test(pts0,
|
108 |
+
ori_image_regress,
|
109 |
+
ori_image_points)
|
110 |
+
|
111 |
+
if save:
|
112 |
+
# 1) กำหนดโฟลเดอร์ผลลัพธ์
|
113 |
+
save_dir = os.path.join('results_'+args.dataset)
|
114 |
+
os.makedirs(save_dir, exist_ok=True)
|
115 |
+
|
116 |
+
# 2) บันทึกพิกัดลง .txt
|
117 |
+
txt_path = os.path.join(save_dir, f'{img_id}.txt')
|
118 |
+
# สมมติ pts0 เป็น array shape (N,2) หรือ (N,4) ตามที่คุณอยากบันทึก
|
119 |
+
np.savetxt(txt_path, pts0, fmt='%.4f')
|
120 |
+
|
121 |
+
# 3) บันทึกภาพ overlay
|
122 |
+
img_path = os.path.join(save_dir, f'{img_id}_pred.jpg')
|
123 |
+
cv2.imwrite(img_path, ori_image_points)
|
train.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import os
|
4 |
+
import numpy as np
|
5 |
+
from models import spinal_net
|
6 |
+
import decoder
|
7 |
+
import loss
|
8 |
+
from dataset import BaseDataset
|
9 |
+
|
10 |
+
def collater(data):
|
11 |
+
out_data_dict = {}
|
12 |
+
for name in data[0]:
|
13 |
+
out_data_dict[name] = []
|
14 |
+
for sample in data:
|
15 |
+
for name in sample:
|
16 |
+
out_data_dict[name].append(torch.from_numpy(sample[name]))
|
17 |
+
for name in out_data_dict:
|
18 |
+
out_data_dict[name] = torch.stack(out_data_dict[name], dim=0)
|
19 |
+
return out_data_dict
|
20 |
+
|
21 |
+
class Network(object):
|
22 |
+
def __init__(self, args):
|
23 |
+
torch.manual_seed(317)
|
24 |
+
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
25 |
+
heads = {'hm': args.num_classes,
|
26 |
+
'reg': 2*args.num_classes,
|
27 |
+
'wh': 2*4,}
|
28 |
+
|
29 |
+
self.model = spinal_net.SpineNet(heads=heads,
|
30 |
+
pretrained=True,
|
31 |
+
down_ratio=args.down_ratio,
|
32 |
+
final_kernel=1,
|
33 |
+
head_conv=256)
|
34 |
+
self.num_classes = args.num_classes
|
35 |
+
self.decoder = decoder.DecDecoder(K=args.K, conf_thresh=args.conf_thresh)
|
36 |
+
self.dataset = {'spinal': BaseDataset}
|
37 |
+
|
38 |
+
|
39 |
+
def save_model(self, path, epoch, model):
|
40 |
+
if isinstance(model, torch.nn.DataParallel):
|
41 |
+
state_dict = model.module.state_dict()
|
42 |
+
else:
|
43 |
+
state_dict = model.state_dict()
|
44 |
+
data = {'epoch': epoch, 'state_dict': state_dict}
|
45 |
+
torch.save(data, path)
|
46 |
+
|
47 |
+
def load_model(self, model, resume, strict=True):
|
48 |
+
checkpoint = torch.load(resume, map_location=lambda storage, loc: storage)
|
49 |
+
print('loaded weights from {}, epoch {}'.format(resume, checkpoint['epoch']))
|
50 |
+
state_dict_ = checkpoint['state_dict']
|
51 |
+
state_dict = {}
|
52 |
+
|
53 |
+
for k in state_dict_:
|
54 |
+
if k.startswith('module') and not k.startswith('module_list'):
|
55 |
+
state_dict[k[7:]] = state_dict_[k]
|
56 |
+
else:
|
57 |
+
state_dict[k] = state_dict_[k]
|
58 |
+
model_state_dict = model.state_dict()
|
59 |
+
|
60 |
+
if not strict:
|
61 |
+
for k in state_dict:
|
62 |
+
if k in model_state_dict:
|
63 |
+
if state_dict[k].shape != model_state_dict[k].shape:
|
64 |
+
print('Skip loading parameter {}, required shape{}, ' \
|
65 |
+
'loaded shape{}.'.format(k, model_state_dict[k].shape, state_dict[k].shape))
|
66 |
+
state_dict[k] = model_state_dict[k]
|
67 |
+
else:
|
68 |
+
print('Drop parameter {}.'.format(k))
|
69 |
+
for k in model_state_dict:
|
70 |
+
if not (k in state_dict):
|
71 |
+
print('No param {}.'.format(k))
|
72 |
+
state_dict[k] = model_state_dict[k]
|
73 |
+
model.load_state_dict(state_dict, strict=False)
|
74 |
+
return model
|
75 |
+
|
76 |
+
def train_network(self, args):
|
77 |
+
save_path = 'weights_'+args.dataset
|
78 |
+
if not os.path.exists(save_path):
|
79 |
+
os.mkdir(save_path)
|
80 |
+
self.optimizer = torch.optim.Adam(self.model.parameters(), args.init_lr)
|
81 |
+
scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=0.96, last_epoch=-1)
|
82 |
+
if args.ngpus>0:
|
83 |
+
if torch.cuda.device_count() > 1:
|
84 |
+
print("Let's use", torch.cuda.device_count(), "GPUs!")
|
85 |
+
self.model = nn.DataParallel(self.model)
|
86 |
+
|
87 |
+
self.model.to(self.device)
|
88 |
+
|
89 |
+
criterion = loss.LossAll()
|
90 |
+
print('Setting up data...')
|
91 |
+
|
92 |
+
dataset_module = self.dataset[args.dataset]
|
93 |
+
|
94 |
+
dsets = {x: dataset_module(data_dir=args.data_dir,
|
95 |
+
phase=x,
|
96 |
+
input_h=args.input_h,
|
97 |
+
input_w=args.input_w,
|
98 |
+
down_ratio=args.down_ratio)
|
99 |
+
for x in ['train', 'val']}
|
100 |
+
|
101 |
+
dsets_loader = {'train': torch.utils.data.DataLoader(dsets['train'],
|
102 |
+
batch_size=args.batch_size,
|
103 |
+
shuffle=True,
|
104 |
+
num_workers=args.num_workers,
|
105 |
+
pin_memory=True,
|
106 |
+
drop_last=True,
|
107 |
+
collate_fn=collater),
|
108 |
+
|
109 |
+
'val':torch.utils.data.DataLoader(dsets['val'],
|
110 |
+
batch_size=1,
|
111 |
+
shuffle=False,
|
112 |
+
num_workers=1,
|
113 |
+
pin_memory=True,
|
114 |
+
collate_fn=collater)}
|
115 |
+
|
116 |
+
|
117 |
+
print('Starting training...')
|
118 |
+
train_loss = []
|
119 |
+
val_loss = []
|
120 |
+
for epoch in range(1, args.num_epoch+1):
|
121 |
+
print('-'*10)
|
122 |
+
print('Epoch: {}/{} '.format(epoch, args.num_epoch))
|
123 |
+
epoch_loss = self.run_epoch(phase='train',
|
124 |
+
data_loader=dsets_loader['train'],
|
125 |
+
criterion=criterion)
|
126 |
+
train_loss.append(epoch_loss)
|
127 |
+
scheduler.step(epoch)
|
128 |
+
|
129 |
+
epoch_loss = self.run_epoch(phase='val',
|
130 |
+
data_loader=dsets_loader['val'],
|
131 |
+
criterion=criterion)
|
132 |
+
val_loss.append(epoch_loss)
|
133 |
+
|
134 |
+
np.savetxt(os.path.join(save_path, 'train_loss.txt'), train_loss, fmt='%.6f')
|
135 |
+
np.savetxt(os.path.join(save_path, 'val_loss.txt'), val_loss, fmt='%.6f')
|
136 |
+
|
137 |
+
if epoch % 10 == 0 or epoch ==1:
|
138 |
+
self.save_model(os.path.join(save_path, 'model_{}.pth'.format(epoch)), epoch, self.model)
|
139 |
+
|
140 |
+
if len(val_loss)>1:
|
141 |
+
if val_loss[-1]<np.min(val_loss[:-1]):
|
142 |
+
self.save_model(os.path.join(save_path, 'model_last.pth'), epoch, self.model)
|
143 |
+
|
144 |
+
def run_epoch(self, phase, data_loader, criterion):
|
145 |
+
if phase == 'train':
|
146 |
+
self.model.train()
|
147 |
+
else:
|
148 |
+
self.model.eval()
|
149 |
+
running_loss = 0.
|
150 |
+
for data_dict in data_loader:
|
151 |
+
for name in data_dict:
|
152 |
+
data_dict[name] = data_dict[name].to(device=self.device)
|
153 |
+
if phase == 'train':
|
154 |
+
self.optimizer.zero_grad()
|
155 |
+
with torch.enable_grad():
|
156 |
+
pr_decs = self.model(data_dict['input'])
|
157 |
+
loss = criterion(pr_decs, data_dict)
|
158 |
+
loss.backward()
|
159 |
+
self.optimizer.step()
|
160 |
+
else:
|
161 |
+
with torch.no_grad():
|
162 |
+
pr_decs = self.model(data_dict['input'])
|
163 |
+
loss = criterion(pr_decs, data_dict)
|
164 |
+
|
165 |
+
running_loss += loss.item()
|
166 |
+
epoch_loss = running_loss / len(data_loader)
|
167 |
+
print('{} loss: {}'.format(phase, epoch_loss))
|
168 |
+
return epoch_loss
|
169 |
+
|
transform.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import numpy as np
|
3 |
+
from numpy import random
|
4 |
+
import cv2
|
5 |
+
|
6 |
+
|
7 |
+
def rescale_pts(pts, down_ratio):
|
8 |
+
return np.asarray(pts, np.float32)/float(down_ratio)
|
9 |
+
|
10 |
+
|
11 |
+
class Compose(object):
|
12 |
+
def __init__(self, transforms):
|
13 |
+
self.transforms = transforms
|
14 |
+
|
15 |
+
def __call__(self, img, pts):
|
16 |
+
for t in self.transforms:
|
17 |
+
img, pts = t(img, pts)
|
18 |
+
return img, pts
|
19 |
+
|
20 |
+
class ConvertImgFloat(object):
|
21 |
+
def __call__(self, img, pts):
|
22 |
+
return img.astype(np.float32), pts.astype(np.float32)
|
23 |
+
|
24 |
+
class RandomContrast(object):
|
25 |
+
def __init__(self, lower=0.5, upper=1.5):
|
26 |
+
self.lower = lower
|
27 |
+
self.upper = upper
|
28 |
+
assert self.upper >= self.lower, "contrast upper must be >= lower."
|
29 |
+
assert self.lower >= 0, "contrast lower must be non-negative."
|
30 |
+
|
31 |
+
def __call__(self, img, pts):
|
32 |
+
if random.randint(2):
|
33 |
+
alpha = random.uniform(self.lower, self.upper)
|
34 |
+
img *= alpha
|
35 |
+
return img, pts
|
36 |
+
|
37 |
+
|
38 |
+
class RandomBrightness(object):
|
39 |
+
def __init__(self, delta=32):
|
40 |
+
assert delta >= 0.0
|
41 |
+
assert delta <= 255.0
|
42 |
+
self.delta = delta
|
43 |
+
|
44 |
+
def __call__(self, img, pts):
|
45 |
+
if random.randint(2):
|
46 |
+
delta = random.uniform(-self.delta, self.delta)
|
47 |
+
img += delta
|
48 |
+
return img, pts
|
49 |
+
|
50 |
+
class SwapChannels(object):
|
51 |
+
def __init__(self, swaps):
|
52 |
+
self.swaps = swaps
|
53 |
+
def __call__(self, img):
|
54 |
+
img = img[:, :, self.swaps]
|
55 |
+
return img
|
56 |
+
|
57 |
+
|
58 |
+
class RandomLightingNoise(object):
|
59 |
+
def __init__(self):
|
60 |
+
self.perms = ((0, 1, 2), (0, 2, 1),
|
61 |
+
(1, 0, 2), (1, 2, 0),
|
62 |
+
(2, 0, 1), (2, 1, 0))
|
63 |
+
def __call__(self, img, pts):
|
64 |
+
if random.randint(2):
|
65 |
+
swap = self.perms[random.randint(len(self.perms))]
|
66 |
+
shuffle = SwapChannels(swap)
|
67 |
+
img = shuffle(img)
|
68 |
+
return img, pts
|
69 |
+
|
70 |
+
|
71 |
+
class PhotometricDistort(object):
|
72 |
+
def __init__(self):
|
73 |
+
self.pd = RandomContrast()
|
74 |
+
self.rb = RandomBrightness()
|
75 |
+
self.rln = RandomLightingNoise()
|
76 |
+
|
77 |
+
def __call__(self, img, pts):
|
78 |
+
img, pts = self.rb(img, pts)
|
79 |
+
if random.randint(2):
|
80 |
+
distort = self.pd
|
81 |
+
else:
|
82 |
+
distort = self.pd
|
83 |
+
img, pts = distort(img, pts)
|
84 |
+
img, pts = self.rln(img, pts)
|
85 |
+
return img, pts
|
86 |
+
|
87 |
+
|
88 |
+
class Expand(object):
|
89 |
+
def __init__(self, max_scale = 1.5, mean = (0.5, 0.5, 0.5)):
|
90 |
+
self.mean = mean
|
91 |
+
self.max_scale = max_scale
|
92 |
+
|
93 |
+
def __call__(self, img, pts):
|
94 |
+
if random.randint(2):
|
95 |
+
return img, pts
|
96 |
+
h,w,c = img.shape
|
97 |
+
ratio = random.uniform(1,self.max_scale)
|
98 |
+
y1 = random.uniform(0, h*ratio-h)
|
99 |
+
x1 = random.uniform(0, w*ratio-w)
|
100 |
+
if np.max(pts[:,0])+int(x1)>w-1 or np.max(pts[:,1])+int(y1)>h-1: # keep all the pts
|
101 |
+
return img, pts
|
102 |
+
else:
|
103 |
+
expand_img = np.zeros(shape=(int(h*ratio), int(w*ratio),c),dtype=img.dtype)
|
104 |
+
expand_img[:,:,:] = self.mean
|
105 |
+
expand_img[int(y1):int(y1+h), int(x1):int(x1+w)] = img
|
106 |
+
pts[:, 0] += int(x1)
|
107 |
+
pts[:, 1] += int(y1)
|
108 |
+
return expand_img, pts
|
109 |
+
|
110 |
+
|
111 |
+
class RandomSampleCrop(object):
|
112 |
+
def __init__(self, ratio=(0.5, 1.5), min_win = 0.9):
|
113 |
+
self.sample_options = (
|
114 |
+
# using entire original input image
|
115 |
+
None,
|
116 |
+
# sample a patch s.t. MIN jaccard w/ obj in .1,.3,.4,.7,.9
|
117 |
+
# (0.1, None),
|
118 |
+
# (0.3, None),
|
119 |
+
(0.7, None),
|
120 |
+
(0.9, None),
|
121 |
+
# randomly sample a patch
|
122 |
+
(None, None),
|
123 |
+
)
|
124 |
+
self.ratio = ratio
|
125 |
+
self.min_win = min_win
|
126 |
+
|
127 |
+
def __call__(self, img, pts):
|
128 |
+
height, width ,_ = img.shape
|
129 |
+
while True:
|
130 |
+
mode = random.choice(self.sample_options)
|
131 |
+
if mode is None:
|
132 |
+
return img, pts
|
133 |
+
for _ in range(50):
|
134 |
+
current_img = img
|
135 |
+
current_pts = pts
|
136 |
+
w = random.uniform(self.min_win*width, width)
|
137 |
+
h = random.uniform(self.min_win*height, height)
|
138 |
+
if h/w<self.ratio[0] or h/w>self.ratio[1]:
|
139 |
+
continue
|
140 |
+
y1 = random.uniform(height-h)
|
141 |
+
x1 = random.uniform(width-w)
|
142 |
+
rect = np.array([int(y1), int(x1), int(y1+h), int(x1+w)])
|
143 |
+
current_img = current_img[rect[0]:rect[2], rect[1]:rect[3], :]
|
144 |
+
current_pts[:, 0] -= rect[1]
|
145 |
+
current_pts[:, 1] -= rect[0]
|
146 |
+
pts_new = []
|
147 |
+
for pt in current_pts:
|
148 |
+
if any(pt)<0 or pt[0]>current_img.shape[1]-1 or pt[1]>current_img.shape[0]-1:
|
149 |
+
continue
|
150 |
+
else:
|
151 |
+
pts_new.append(pt)
|
152 |
+
|
153 |
+
return current_img, np.asarray(pts_new, np.float32)
|
154 |
+
|
155 |
+
class RandomMirror_w(object):
|
156 |
+
def __call__(self, img, pts):
|
157 |
+
_,w,_ = img.shape
|
158 |
+
if random.randint(2):
|
159 |
+
img = img[:,::-1,:]
|
160 |
+
pts[:,0] = w-pts[:,0]
|
161 |
+
return img, pts
|
162 |
+
|
163 |
+
class RandomMirror_h(object):
|
164 |
+
def __call__(self, img, pts):
|
165 |
+
h,_,_ = img.shape
|
166 |
+
if random.randint(2):
|
167 |
+
img = img[::-1,:,:]
|
168 |
+
pts[:,1] = h-pts[:,1]
|
169 |
+
return img, pts
|
170 |
+
|
171 |
+
|
172 |
+
class Resize(object):
|
173 |
+
def __init__(self, h, w):
|
174 |
+
self.dsize = (w,h)
|
175 |
+
|
176 |
+
def __call__(self, img, pts):
|
177 |
+
h,w,c = img.shape
|
178 |
+
pts[:, 0] = pts[:, 0]/w*self.dsize[0]
|
179 |
+
pts[:, 1] = pts[:, 1]/h*self.dsize[1]
|
180 |
+
img = cv2.resize(img, dsize=self.dsize)
|
181 |
+
return img, np.asarray(pts)
|