rezasalatin
commited on
Upload 8 files
Browse files- .gitignore +5 -0
- README.md +32 -41
- environment.yml +22 -0
- test_image_seg.py +173 -0
- test_video_seg.py +132 -0
- test_video_seg.sh +4 -0
- train_video_seg.py +143 -0
- train_video_seg.sh +15 -0
.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.idea/
|
2 |
+
.vscode/
|
3 |
+
__pycache__/
|
4 |
+
|
5 |
+
logs/
|
README.md
CHANGED
@@ -1,12 +1,3 @@
|
|
1 |
-
---
|
2 |
-
license: mit
|
3 |
-
language:
|
4 |
-
- en
|
5 |
-
pipeline_tag: image-segmentation
|
6 |
-
tags:
|
7 |
-
- climate
|
8 |
-
---
|
9 |
-
|
10 |
# V-BeachNet
|
11 |
|
12 |
This repository contains the official PyTorch implementation for the paper "A New Framework for Quantifying Alongshore Variability of Swash Motion Using Fully Convolutional Networks." V-BeachNet is built upon V-FloodNet.
|
@@ -19,46 +10,46 @@ Liang, Y., Li, X., Tsai, B., Chen, Q., & Jafari, N. (2023). V-FloodNet: A video
|
|
19 |
|
20 |
## Prerequisites
|
21 |
|
22 |
-
|
23 |
|
24 |
1. Install Anaconda prerequisite (Can also be accessed from [here](https://docs.anaconda.com/anaconda/install/linux/)):
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
|
30 |
2. Download Anaconda3:
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
|
35 |
-
3.
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
|
40 |
## Steps
|
41 |
|
42 |
1. Clone this repository and change directory:
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
|
48 |
2. Create the virtual environment with the requirements:
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
3. Visit the "Training_Station" folder and copy your manually segmented
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
4. Visit the "Testing_Station" folder and copy your data to this directory. Open the following file to change any of the variables (especially
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# V-BeachNet
|
2 |
|
3 |
This repository contains the official PyTorch implementation for the paper "A New Framework for Quantifying Alongshore Variability of Swash Motion Using Fully Convolutional Networks." V-BeachNet is built upon V-FloodNet.
|
|
|
10 |
|
11 |
## Prerequisites
|
12 |
|
13 |
+
Install Conda on your Ubuntu 24.04 with default version of Python and Nvidia GPU.
|
14 |
|
15 |
1. Install Anaconda prerequisite (Can also be accessed from [here](https://docs.anaconda.com/anaconda/install/linux/)):
|
16 |
+
```sh
|
17 |
+
sudo apt update && \
|
18 |
+
sudo apt install libgl1-mesa-dri libegl1 libglu1-mesa libxrandr2 libxss1 libxcursor1 libxcomposite1 libasound2-data libasound2-plugins libxi6 libxtst6
|
19 |
+
```
|
20 |
|
21 |
2. Download Anaconda3:
|
22 |
+
```sh
|
23 |
+
curl -O https://repo.anaconda.com/archive/Anaconda3-2024.06-1-Linux-x86_64.sh
|
24 |
+
```
|
25 |
|
26 |
+
3. Located the downloaded file and install it:
|
27 |
+
```sh
|
28 |
+
bash Anaconda3-2024.06-1-Linux-x86_64.sh
|
29 |
+
```
|
30 |
|
31 |
## Steps
|
32 |
|
33 |
1. Clone this repository and change directory:
|
34 |
+
```sh
|
35 |
+
git clone https://github.com/rezasalatin/V-BeachNet.git
|
36 |
+
cd V-BeachNet
|
37 |
+
```
|
38 |
|
39 |
2. Create the virtual environment with the requirements:
|
40 |
+
```sh
|
41 |
+
conda env create -f environment.yml
|
42 |
+
conda activate vbeach
|
43 |
+
```
|
44 |
+
|
45 |
+
3. Visit the "Training_Station" folder and copy your manually segmented dataset to this directory. Open the following file to change any of the variables and save it. Then execute it to train the model:
|
46 |
+
```sh
|
47 |
+
./train_video_seg.sh
|
48 |
+
```
|
49 |
+
Access your trained model from log/ directory.
|
50 |
+
|
51 |
+
4. Visit the "Testing_Station" folder and copy your data to this directory. Open the following file to change any of the variables (especially model path from the log/ folder) and save it. Then execute it to test the model:
|
52 |
+
```sh
|
53 |
+
./test_video_seg.sh
|
54 |
+
```
|
55 |
+
Access your segmented data from output directory.
|
environment.yml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: vbeach
|
2 |
+
channels:
|
3 |
+
- conda-forge
|
4 |
+
- pytorch
|
5 |
+
- nvidia
|
6 |
+
dependencies:
|
7 |
+
- tqdm
|
8 |
+
- pytorch
|
9 |
+
- torchvision
|
10 |
+
- torchaudio
|
11 |
+
- pytorch-cuda=12.1
|
12 |
+
- scipy
|
13 |
+
- numpy
|
14 |
+
- matplotlib
|
15 |
+
- pandas
|
16 |
+
- scikit-learn
|
17 |
+
- opencv
|
18 |
+
- gxx_linux-64
|
19 |
+
- pip
|
20 |
+
- pip:
|
21 |
+
- torch-scatter -f https://data.pyg.org/whl/torch-2.1.2+cu121.html
|
22 |
+
- segmentation-models-pytorch==0.2.0
|
test_image_seg.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from PIL import Image
|
7 |
+
from pathlib import Path
|
8 |
+
from glob import glob
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torchvision.transforms as tf
|
13 |
+
|
14 |
+
from segmentation_models_pytorch import Linknet
|
15 |
+
import myutils
|
16 |
+
|
17 |
+
# ROOT_DIR = str(Path(__file__).resolve().parents[0])
|
18 |
+
ROOT_DIR = './'
|
19 |
+
# time_str = timestr = time.strftime("%Y-%m-%d %H-%M-%S")
|
20 |
+
# DEFAULT_OUT = os.path.join(ROOT_DIR, 'output', 'test_waterseg', time_str)
|
21 |
+
DEFAULT_OUT = os.path.join(ROOT_DIR, 'output', 'segs')
|
22 |
+
# DEFAULT_PALETTE = os.path.join(ROOT_DIR, "assets", "mask_palette.png")
|
23 |
+
# sys.path.append(ROOT_DIR)
|
24 |
+
# print("Added", ROOT_DIR, "to PATH.")
|
25 |
+
|
26 |
+
|
27 |
+
def norm_imagenet(img_pil, dims):
|
28 |
+
"""
|
29 |
+
Normalizes and resizes input image
|
30 |
+
:param img_pil: PIL Image
|
31 |
+
:param dims: Model's expected input dimensions
|
32 |
+
:return: Normalized Image as a Tensor
|
33 |
+
"""
|
34 |
+
|
35 |
+
# Mean and stddev of ImageNet dataset
|
36 |
+
mean = torch.tensor([0.485, 0.456, 0.406])
|
37 |
+
std = torch.tensor([0.229, 0.224, 0.225])
|
38 |
+
|
39 |
+
# Resize, convert to tensor, normalize
|
40 |
+
transform_norm = tf.Compose([
|
41 |
+
tf.Resize([dims[0], dims[1]]),
|
42 |
+
tf.ToTensor(),
|
43 |
+
tf.Normalize(mean, std)
|
44 |
+
])
|
45 |
+
|
46 |
+
img_norm = transform_norm(img_pil)
|
47 |
+
return img_norm
|
48 |
+
|
49 |
+
|
50 |
+
def predict_one(path, model, mask_outdir, overlay_outdir, device):
|
51 |
+
"""
|
52 |
+
Predicts a single image from path
|
53 |
+
:param path: Path to image
|
54 |
+
:param model: Loaded Torch Model
|
55 |
+
:param mask_outdir: Filepath to mask out directory
|
56 |
+
:param overlay_outdir: Filepath to overlay out directory
|
57 |
+
:return: None
|
58 |
+
"""
|
59 |
+
img_pil = myutils.load_image_in_PIL(path)
|
60 |
+
|
61 |
+
# Prediction is an PIL Image of 0s and 1s
|
62 |
+
prediction = predict_pil(model, img_pil, model_dims=(416, 416), device=device)
|
63 |
+
|
64 |
+
basename = str(Path(os.path.basename(path)).stem)
|
65 |
+
mask_savepth = os.path.join(mask_outdir, basename + '.png')
|
66 |
+
# mask_save = prediction.convert('RGB')
|
67 |
+
prediction.save(mask_savepth)
|
68 |
+
|
69 |
+
over_savepth = os.path.join(overlay_outdir, basename + '.png')
|
70 |
+
img_np = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
|
71 |
+
overlay_np = myutils.add_overlay(img_np, np.array(prediction))
|
72 |
+
cv2.imwrite(over_savepth, overlay_np)
|
73 |
+
# overlay_np = np.array(img_pil) * 1 + np.array(prediction.convert('RGB')) * 0.8
|
74 |
+
# overlay_np = overlay_np.clip(0, 255)
|
75 |
+
# Image.fromarray(overlay_np.astype(np.uint8)).save(over_savepth)
|
76 |
+
|
77 |
+
|
78 |
+
def predict_pil(model, img_pil, model_dims, device):
|
79 |
+
"""
|
80 |
+
Predicts a single PIL Image
|
81 |
+
:param model: Loaded PyTorch model
|
82 |
+
:param img_pil: PIL image
|
83 |
+
:param model_dims: Model input dimensions
|
84 |
+
:return: Segmentation prediction as PIL Image
|
85 |
+
"""
|
86 |
+
|
87 |
+
img_np = np.array(img_pil)
|
88 |
+
img_tensor_norm = norm_imagenet(img_pil, model_dims)
|
89 |
+
|
90 |
+
# Pipeline to resize the prediction to the original image dimensions
|
91 |
+
pred_resize = tf.Compose([tf.Resize([img_np.shape[0], img_np.shape[1]])])
|
92 |
+
|
93 |
+
# Add extra dimension at front as model expects input 1*3*dimX*dimY (batch size of 1)
|
94 |
+
input_data = img_tensor_norm.unsqueeze(0)
|
95 |
+
|
96 |
+
try:
|
97 |
+
# print("Converted input image to cuda.")
|
98 |
+
prediction = model.predict(input_data.to(device))
|
99 |
+
except:
|
100 |
+
print("Did not convert input image to cuda.")
|
101 |
+
prediction = model.predict(input_data)
|
102 |
+
|
103 |
+
prediction = pred_resize(prediction)
|
104 |
+
prediction = myutils.postprocessing_pred(prediction.squeeze().cpu().round().numpy().astype(np.uint8))
|
105 |
+
prediction = Image.fromarray(prediction).convert('P')
|
106 |
+
prediction.putpalette(myutils.color_palette)
|
107 |
+
return prediction
|
108 |
+
|
109 |
+
|
110 |
+
def test_waterseg(model_path, test_path, test_name, out_path, device):
|
111 |
+
"""
|
112 |
+
Tests either a single or an entire folder of images
|
113 |
+
:param args: Command line args
|
114 |
+
:return: None
|
115 |
+
"""
|
116 |
+
model = torch.load(model_path)
|
117 |
+
print('############################################')
|
118 |
+
print('############################################')
|
119 |
+
print('############################################')
|
120 |
+
test_path = test_path
|
121 |
+
out_path = os.path.join(out_path, test_name)
|
122 |
+
|
123 |
+
mask_out = os.path.join(out_path, 'mask')
|
124 |
+
overlay_out = os.path.join(out_path, 'overlay')
|
125 |
+
if not os.path.exists(mask_out):
|
126 |
+
os.makedirs(mask_out)
|
127 |
+
if not os.path.exists(overlay_out):
|
128 |
+
os.makedirs(overlay_out)
|
129 |
+
|
130 |
+
if os.path.isfile(test_path):
|
131 |
+
predict_one(test_path, model, mask_out, overlay_out, device)
|
132 |
+
elif os.path.isdir(test_path):
|
133 |
+
paths = glob(os.path.join(test_path, '*.jpg')) + glob(os.path.join(test_path, '*.png'))
|
134 |
+
for path in tqdm(paths):
|
135 |
+
predict_one(path, model, mask_out, overlay_out, device)
|
136 |
+
else:
|
137 |
+
print("Error: Unknown path type:", test_path)
|
138 |
+
|
139 |
+
|
140 |
+
if __name__ == '__main__':
|
141 |
+
# Hyper parameters
|
142 |
+
parser = argparse.ArgumentParser(description='V-FloodNet: Water Image Segmentation')
|
143 |
+
# Required: Path to the .pth file.
|
144 |
+
parser.add_argument('--model-path',
|
145 |
+
default='./records/link_efficientb4_model.pth',
|
146 |
+
type=str,
|
147 |
+
metavar='PATH',
|
148 |
+
help='Path to the model')
|
149 |
+
# Required: Path to either the single file or directory of files containing .jpg or .png images
|
150 |
+
parser.add_argument('--test-path',
|
151 |
+
type=str,
|
152 |
+
metavar='PATH',
|
153 |
+
required=True,
|
154 |
+
help='Can point to folder or an individual jpg/png image')
|
155 |
+
parser.add_argument('--test-name',
|
156 |
+
type=str,
|
157 |
+
required=True,
|
158 |
+
help='Test name')
|
159 |
+
parser.add_argument('--out-path',
|
160 |
+
default=DEFAULT_OUT,
|
161 |
+
type=str,
|
162 |
+
metavar='PATH',
|
163 |
+
help='(OPTIONAL) Path to output folder, defaults to project root/output')
|
164 |
+
args = parser.parse_args()
|
165 |
+
|
166 |
+
# Device
|
167 |
+
device = torch.device('cpu')
|
168 |
+
if torch.cuda.is_available():
|
169 |
+
device = torch.device('cuda')
|
170 |
+
|
171 |
+
test_waterseg(args.model_path, args.test_path, args.test_name, args.out_path, device)
|
172 |
+
|
173 |
+
print(myutils.gct(), 'Test image segmentation done.')
|
test_video_seg.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from tqdm import tqdm, trange
|
3 |
+
import os
|
4 |
+
import argparse
|
5 |
+
from glob import glob
|
6 |
+
import torch
|
7 |
+
from torch import utils
|
8 |
+
from torch.nn import functional as F
|
9 |
+
from torchvision.transforms import functional as TF
|
10 |
+
from torchvision.transforms import InterpolationMode
|
11 |
+
|
12 |
+
from video_module.dataset import Video_DS
|
13 |
+
from video_module.model import AFB_URR, FeatureBank
|
14 |
+
from test_image_seg import test_waterseg
|
15 |
+
import myutils
|
16 |
+
|
17 |
+
torch.set_grad_enabled(False)
|
18 |
+
|
19 |
+
|
20 |
+
def get_args():
|
21 |
+
parser = argparse.ArgumentParser(description='V-FloodNet: Water Video Segmentation')
|
22 |
+
parser.add_argument('--gpu', type=int, default=0, help='GPU card id.')
|
23 |
+
parser.add_argument('--budget', type=int, default=250000, help='Max number of features in the feature bank.')
|
24 |
+
parser.add_argument('--viz', action='store_true', default=True, help='Visualize data.')
|
25 |
+
parser.add_argument('--model-path', type=str, required=True, help='Path to the checkpoint.')
|
26 |
+
parser.add_argument('--update-rate', type=float, default=0.1, help='Update Rate for merging new features.')
|
27 |
+
parser.add_argument('--merge-thres', type=float, default=0.95, help='Merging Rate threshold.')
|
28 |
+
parser.add_argument('--test-path', type=str, required=True, help='Path to the test video frames.')
|
29 |
+
parser.add_argument('--test-name', type=str, required=True, help='Name for the test video.')
|
30 |
+
return parser.parse_args()
|
31 |
+
|
32 |
+
|
33 |
+
def main(args, device):
|
34 |
+
model = AFB_URR(device, update_bank=True, load_imagenet_params=False)
|
35 |
+
model = model.to(device)
|
36 |
+
model.eval()
|
37 |
+
|
38 |
+
downsample_size = 480
|
39 |
+
|
40 |
+
if os.path.isfile(args.model_path):
|
41 |
+
checkpoint = torch.load(args.model_path)
|
42 |
+
end_epoch = checkpoint['epoch']
|
43 |
+
model.load_state_dict(checkpoint['model'], strict=False)
|
44 |
+
train_loss = checkpoint['loss']
|
45 |
+
seed = checkpoint['seed']
|
46 |
+
print(myutils.gct(),
|
47 |
+
f'Loaded checkpoint {args.model_path}. (end_epoch: {end_epoch}, train_loss: {train_loss}, seed: {seed})')
|
48 |
+
else:
|
49 |
+
print(myutils.gct(), f'No checkpoint found at {args.model_path}')
|
50 |
+
raise IOError
|
51 |
+
|
52 |
+
img_list = sorted(glob(os.path.join(args.test_path, '*.jpg')) + glob(os.path.join(args.test_path, '*.png')))
|
53 |
+
first_frame = myutils.load_image_in_PIL(img_list[0])
|
54 |
+
first_name = os.path.basename(img_list[0])[:-4]
|
55 |
+
|
56 |
+
out_dir = './output/segs'
|
57 |
+
mask_dir = os.path.join(out_dir, args.test_name, 'mask')
|
58 |
+
mask_path = os.path.join(mask_dir, first_name + '.png')
|
59 |
+
if not os.path.exists(mask_path):
|
60 |
+
image_model_path = './records/link_efficientb4_model.pth'
|
61 |
+
test_waterseg(image_model_path, img_list[0], args.test_name, out_dir, device)
|
62 |
+
|
63 |
+
first_mask = myutils.load_image_in_PIL(mask_path, 'P')
|
64 |
+
seq_dataset = Video_DS(img_list, first_frame, first_mask)
|
65 |
+
|
66 |
+
seq_loader = utils.data.DataLoader(seq_dataset, batch_size=1, shuffle=False, num_workers=1)
|
67 |
+
|
68 |
+
seg_dir = os.path.join(out_dir, args.test_name, 'mask')
|
69 |
+
os.makedirs(seg_dir, exist_ok=True)
|
70 |
+
if args.viz:
|
71 |
+
overlay_dir = os.path.join(out_dir, args.test_name, 'overlay')
|
72 |
+
os.makedirs(overlay_dir, exist_ok=True)
|
73 |
+
|
74 |
+
obj_n = seq_dataset.obj_n
|
75 |
+
fb = FeatureBank(obj_n, args.budget, device, update_rate=args.update_rate, thres_close=args.merge_thres)
|
76 |
+
|
77 |
+
ori_first_frame = seq_dataset.first_frame.unsqueeze(0).to(device)
|
78 |
+
ori_first_mask = seq_dataset.first_mask.unsqueeze(0).to(device)
|
79 |
+
|
80 |
+
first_frame = TF.resize(ori_first_frame, downsample_size, InterpolationMode.BICUBIC)
|
81 |
+
first_mask = TF.resize(ori_first_mask, downsample_size, InterpolationMode.NEAREST)
|
82 |
+
|
83 |
+
pred = torch.argmax(ori_first_mask[0], dim=0).cpu().numpy().astype(np.uint8)
|
84 |
+
seg_path = os.path.join(seg_dir, f'{first_name}.png')
|
85 |
+
myutils.save_seg_mask(pred, seg_path, myutils.color_palette)
|
86 |
+
|
87 |
+
if args.viz:
|
88 |
+
overlay_path = os.path.join(overlay_dir, f'{first_name}.png')
|
89 |
+
myutils.save_overlay(ori_first_frame[0], pred, overlay_path, myutils.color_palette)
|
90 |
+
|
91 |
+
with torch.no_grad():
|
92 |
+
k4_list, v4_list = model.memorize(first_frame, first_mask)
|
93 |
+
fb.init_bank(k4_list, v4_list)
|
94 |
+
|
95 |
+
for idx, (frame, frame_name) in enumerate(tqdm(seq_loader)):
|
96 |
+
|
97 |
+
ori_frame = frame.to(device)
|
98 |
+
ori_size = ori_frame.shape[-2:]
|
99 |
+
frame = TF.resize(ori_frame, downsample_size, InterpolationMode.BICUBIC)
|
100 |
+
score, _ = model.segment(frame, fb)
|
101 |
+
pred_mask = F.softmax(score, dim=1)
|
102 |
+
|
103 |
+
k4_list, v4_list = model.memorize(frame, pred_mask)
|
104 |
+
fb.update(k4_list, v4_list, idx + 1)
|
105 |
+
|
106 |
+
pred = TF.resize(pred_mask, ori_size, InterpolationMode.BICUBIC)
|
107 |
+
pred = torch.argmax(pred[0], dim=0).cpu().numpy().astype(np.uint8)
|
108 |
+
pred = myutils.postprocessing_pred(pred)
|
109 |
+
seg_path = os.path.join(seg_dir, f'{frame_name[0]}.png')
|
110 |
+
myutils.save_seg_mask(pred, seg_path, myutils.color_palette)
|
111 |
+
if args.viz:
|
112 |
+
overlay_path = os.path.join(overlay_dir, f'{frame_name[0]}.png')
|
113 |
+
myutils.save_overlay(ori_frame[0], pred, overlay_path, myutils.color_palette)
|
114 |
+
|
115 |
+
fb.print_peak_mem()
|
116 |
+
|
117 |
+
|
118 |
+
if __name__ == '__main__':
|
119 |
+
|
120 |
+
args = get_args()
|
121 |
+
print(myutils.gct(), 'Args =', args)
|
122 |
+
|
123 |
+
if args.gpu >= 0 and torch.cuda.is_available():
|
124 |
+
device = torch.device('cuda', args.gpu)
|
125 |
+
else:
|
126 |
+
raise ValueError('CUDA is required. --gpu must be >= 0.')
|
127 |
+
|
128 |
+
assert os.path.isdir(args.test_path)
|
129 |
+
|
130 |
+
main(args, device)
|
131 |
+
|
132 |
+
print(myutils.gct(), 'Test video segmentation done.')
|
test_video_seg.sh
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
python test_video_seg.py \
|
2 |
+
--test-path=Testing_Station/Duck_Rectified/20211017_1400_UTC/ \
|
3 |
+
--test-name=Duck_Rectified \
|
4 |
+
--model-path=logs/20240724-104343/model/best.pth
|
train_video_seg.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import argparse
|
4 |
+
import numpy as np
|
5 |
+
from tqdm import tqdm
|
6 |
+
from glob import glob
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from torch.utils.data import DataLoader
|
10 |
+
|
11 |
+
from video_module.dataset import Water_Image_Train_DS
|
12 |
+
from video_module.model import AFB_URR, FeatureBank
|
13 |
+
import myutils
|
14 |
+
|
15 |
+
# Enable CUDA launch blocking for debugging; set to '0' to disable.
|
16 |
+
os.environ['CUDA_LAUNCH_BLOCKING'] = '0'
|
17 |
+
|
18 |
+
def get_args():
|
19 |
+
parser = argparse.ArgumentParser(description='Train V-BeachNet')
|
20 |
+
parser.add_argument('--gpu', type=int, default=0, help='GPU card id.')
|
21 |
+
parser.add_argument('--dataset', type=str, required=True, help='Dataset folder.')
|
22 |
+
parser.add_argument('--seed', type=int, default=-1, help='Random seed.')
|
23 |
+
parser.add_argument('--log', action='store_true', help='Save the training results.')
|
24 |
+
parser.add_argument('--lr', type=float, default=1e-5, help='Learning rate (default: 1e-5).')
|
25 |
+
parser.add_argument('--lu', type=float, default=0.5, help='Regularization factor (default: 0.5).')
|
26 |
+
parser.add_argument('--resume', type=str, help='Path to checkpoint (default: none).')
|
27 |
+
parser.add_argument('--new', action='store_true', help='Train the model from scratch.')
|
28 |
+
parser.add_argument('--scheduler-step', type=int, default=25, help='Scheduler step size (default: 25).')
|
29 |
+
parser.add_argument('--total-epochs', type=int, default=100, help='Total number of epochs (default: 100).')
|
30 |
+
parser.add_argument('--budget', type=int, default=300000, help='Maximum number of features in the feature bank (default: 300000).')
|
31 |
+
parser.add_argument('--obj-n', type=int, default=2, help='Maximum number of objects trained simultaneously.')
|
32 |
+
parser.add_argument('--clip-n', type=int, default=6, help='Maximum number of frames in a batch.')
|
33 |
+
|
34 |
+
return parser.parse_args()
|
35 |
+
|
36 |
+
def train_model(model, dataloader, criterion, optimizer):
|
37 |
+
stats = myutils.AvgMeter()
|
38 |
+
uncertainty_stats = myutils.AvgMeter()
|
39 |
+
progress_bar = tqdm(dataloader)
|
40 |
+
|
41 |
+
for _, sample in enumerate(progress_bar):
|
42 |
+
frames, masks, obj_n, info = sample
|
43 |
+
|
44 |
+
if obj_n.item() == 1:
|
45 |
+
continue
|
46 |
+
|
47 |
+
frames, masks = frames[0].to(device), masks[0].to(device)
|
48 |
+
fb_global = FeatureBank(obj_n.item(), args.budget, device)
|
49 |
+
k4_list, v4_list = model.memorize(frames[:1], masks[:1])
|
50 |
+
fb_global.init_bank(k4_list, v4_list)
|
51 |
+
|
52 |
+
scores, uncertainty = model.segment(frames[1:], fb_global)
|
53 |
+
label = torch.argmax(masks[1:], dim=1).long()
|
54 |
+
|
55 |
+
optimizer.zero_grad()
|
56 |
+
loss = criterion(scores, label) + args.lu * uncertainty
|
57 |
+
loss.backward()
|
58 |
+
optimizer.step()
|
59 |
+
|
60 |
+
uncertainty_stats.update(uncertainty.item())
|
61 |
+
stats.update(loss.item())
|
62 |
+
progress_bar.set_postfix(loss=f'{loss.item():.5f} (Avg: {stats.avg:.5f}, Uncertainty Avg: {uncertainty_stats.avg:.5f})')
|
63 |
+
|
64 |
+
return stats.avg
|
65 |
+
|
66 |
+
def main():
|
67 |
+
dataset = Water_Image_Train_DS(root=args.dataset, output_size=400, clip_n=args.clip_n, max_obj_n=args.obj_n)
|
68 |
+
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2, pin_memory=True)
|
69 |
+
print(myutils.gct(), f'Dataset with {len(dataset)} training cases.')
|
70 |
+
|
71 |
+
model = AFB_URR(device, update_bank=False, load_imagenet_params=True).to(device)
|
72 |
+
model.train()
|
73 |
+
model.apply(myutils.set_bn_eval)
|
74 |
+
|
75 |
+
optimizer = torch.optim.AdamW(filter(lambda x: x.requires_grad, model.parameters()), args.lr)
|
76 |
+
|
77 |
+
start_epoch, best_loss = 0, float('inf')
|
78 |
+
if args.resume:
|
79 |
+
if os.path.isfile(args.resume):
|
80 |
+
checkpoint = torch.load(args.resume)
|
81 |
+
model.load_state_dict(checkpoint['model'], strict=False)
|
82 |
+
seed = checkpoint.get('seed', int(time.time()))
|
83 |
+
|
84 |
+
if not args.new:
|
85 |
+
start_epoch = checkpoint['epoch'] + 1
|
86 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
87 |
+
best_loss = checkpoint['loss']
|
88 |
+
print(myutils.gct(), f'Resumed from checkpoint {args.resume} (Epoch: {start_epoch-1}, Best Loss: {best_loss}).')
|
89 |
+
else:
|
90 |
+
print(myutils.gct(), f'Loaded checkpoint {args.resume}. Training from scratch.')
|
91 |
+
else:
|
92 |
+
raise FileNotFoundError(f'No checkpoint found at {args.resume}')
|
93 |
+
else:
|
94 |
+
seed = args.seed if args.seed >= 0 else int(time.time())
|
95 |
+
|
96 |
+
print(myutils.gct(), 'Random seed:', seed)
|
97 |
+
torch.manual_seed(seed)
|
98 |
+
np.random.seed(seed)
|
99 |
+
|
100 |
+
criterion = torch.nn.CrossEntropyLoss().to(device)
|
101 |
+
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.scheduler_step, gamma=0.5, last_epoch=start_epoch-1)
|
102 |
+
|
103 |
+
for epoch in range(start_epoch, args.total_epochs):
|
104 |
+
print(f'\n{myutils.gct()} Epoch: {epoch}, Learning Rate: {scheduler.get_last_lr()[0]:.6f}')
|
105 |
+
loss = train_model(model, dataloader, criterion, optimizer)
|
106 |
+
|
107 |
+
if args.log:
|
108 |
+
checkpoint = {
|
109 |
+
'epoch': epoch,
|
110 |
+
'model': model.state_dict(),
|
111 |
+
'optimizer': optimizer.state_dict(),
|
112 |
+
'loss': loss,
|
113 |
+
'seed': seed
|
114 |
+
}
|
115 |
+
torch.save(checkpoint, os.path.join(model_path, 'final.pth'))
|
116 |
+
if loss < best_loss:
|
117 |
+
best_loss = loss
|
118 |
+
torch.save(checkpoint, os.path.join(model_path, 'best.pth'))
|
119 |
+
print('Updated best model.')
|
120 |
+
|
121 |
+
scheduler.step()
|
122 |
+
|
123 |
+
if __name__ == '__main__':
|
124 |
+
args = get_args()
|
125 |
+
print(myutils.gct(), f'Arguments: {args}')
|
126 |
+
|
127 |
+
if args.gpu >= 0 and torch.cuda.is_available():
|
128 |
+
device = torch.device('cuda', args.gpu)
|
129 |
+
else:
|
130 |
+
raise ValueError('CUDA is required. Ensure --gpu is set to >= 0.')
|
131 |
+
|
132 |
+
if args.log:
|
133 |
+
log_dir = os.path.join('logs', time.strftime('%Y%m%d-%H%M%S'))
|
134 |
+
model_path = os.path.join(log_dir, 'model')
|
135 |
+
os.makedirs(model_path, exist_ok=True)
|
136 |
+
myutils.save_scripts(log_dir, scripts_to_save=glob('*.*'))
|
137 |
+
myutils.save_scripts(log_dir, scripts_to_save=glob('dataset/*.py', recursive=True))
|
138 |
+
myutils.save_scripts(log_dir, scripts_to_save=glob('model/*.py', recursive=True))
|
139 |
+
myutils.save_scripts(log_dir, scripts_to_save=glob('myutils/*.py', recursive=True))
|
140 |
+
print(myutils.gct(), f'Created log directory: {log_dir}')
|
141 |
+
|
142 |
+
main()
|
143 |
+
print(myutils.gct(), 'Training completed.')
|
train_video_seg.sh
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
python3 train_video_seg.py \
|
2 |
+
--gpu=0 \
|
3 |
+
--dataset=Training_Station/Duck_Rectified/ \
|
4 |
+
--seed=-1 \
|
5 |
+
--log \
|
6 |
+
--lr=1e-5 \
|
7 |
+
--lu=0.5 \
|
8 |
+
--scheduler-step=25 \
|
9 |
+
--total-epochs=100 \
|
10 |
+
--budget=300000 \
|
11 |
+
--obj-n=2 \
|
12 |
+
--clip-n=6 \
|
13 |
+
--new
|
14 |
+
# --resume=logs/bestmodel/model/best.pth
|
15 |
+
# activate new or resume
|