rezasalatin commited on
Commit
0b70b07
·
verified ·
1 Parent(s): 24ad6cc

Upload 8 files

Browse files
Files changed (8) hide show
  1. .gitignore +5 -0
  2. README.md +32 -41
  3. environment.yml +22 -0
  4. test_image_seg.py +173 -0
  5. test_video_seg.py +132 -0
  6. test_video_seg.sh +4 -0
  7. train_video_seg.py +143 -0
  8. train_video_seg.sh +15 -0
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ .idea/
2
+ .vscode/
3
+ __pycache__/
4
+
5
+ logs/
README.md CHANGED
@@ -1,12 +1,3 @@
1
- ---
2
- license: mit
3
- language:
4
- - en
5
- pipeline_tag: image-segmentation
6
- tags:
7
- - climate
8
- ---
9
-
10
  # V-BeachNet
11
 
12
  This repository contains the official PyTorch implementation for the paper "A New Framework for Quantifying Alongshore Variability of Swash Motion Using Fully Convolutional Networks." V-BeachNet is built upon V-FloodNet.
@@ -19,46 +10,46 @@ Liang, Y., Li, X., Tsai, B., Chen, Q., & Jafari, N. (2023). V-FloodNet: A video
19
 
20
  ## Prerequisites
21
 
22
- This code is tested on a newly installed Ubuntu 24.04 with default version of Python and Nvidia GPU.
23
 
24
  1. Install Anaconda prerequisite (Can also be accessed from [here](https://docs.anaconda.com/anaconda/install/linux/)):
25
- ```sh
26
- sudo apt update && \
27
- sudo apt install libgl1-mesa-dri libegl1 libglu1-mesa libxrandr2 libxss1 libxcursor1 libxcomposite1 libasound2-data libasound2-plugins libxi6 libxtst6
28
- ```
29
 
30
  2. Download Anaconda3:
31
- ```sh
32
- curl -O https://repo.anaconda.com/archive/Anaconda3-2024.06-1-Linux-x86_64.sh
33
- ```
34
 
35
- 3. Locate the downloaded file and install it:
36
- ```sh
37
- bash Anaconda3-2024.06-1-Linux-x86_64.sh
38
- ```
39
 
40
  ## Steps
41
 
42
  1. Clone this repository and change directory:
43
- ```sh
44
- git clone https://huggingface.co/rezasalatin/V-BeachNet.git
45
- cd V-BeachNet
46
- ```
47
 
48
  2. Create the virtual environment with the requirements:
49
- ```sh
50
- conda env create -f environment.yml
51
- conda activate vbeach
52
- ```
53
-
54
- 3. Visit the "Training_Station" folder and copy your manually segmented (using [labelme](https://github.com/labelmeai/labelme)) dataset to this directory. Open the following file to change any of the variables and save it. Then execute it to train the model:
55
- ```sh
56
- ./train_video_seg.sh
57
- ```
58
- Access your trained model from the `log/` directory.
59
-
60
- 4. Visit the "Testing_Station" folder and copy your data to this directory. Open the following file to change any of the variables (especially the model path from the `log/` folder) and save it. Then execute it to test the model:
61
- ```sh
62
- ./test_video_seg.sh
63
- ```
64
- Access your segmented data from the `output` directory.
 
 
 
 
 
 
 
 
 
 
1
  # V-BeachNet
2
 
3
  This repository contains the official PyTorch implementation for the paper "A New Framework for Quantifying Alongshore Variability of Swash Motion Using Fully Convolutional Networks." V-BeachNet is built upon V-FloodNet.
 
10
 
11
  ## Prerequisites
12
 
13
+ Install Conda on your Ubuntu 24.04 with default version of Python and Nvidia GPU.
14
 
15
  1. Install Anaconda prerequisite (Can also be accessed from [here](https://docs.anaconda.com/anaconda/install/linux/)):
16
+ ```sh
17
+ sudo apt update && \
18
+ sudo apt install libgl1-mesa-dri libegl1 libglu1-mesa libxrandr2 libxss1 libxcursor1 libxcomposite1 libasound2-data libasound2-plugins libxi6 libxtst6
19
+ ```
20
 
21
  2. Download Anaconda3:
22
+ ```sh
23
+ curl -O https://repo.anaconda.com/archive/Anaconda3-2024.06-1-Linux-x86_64.sh
24
+ ```
25
 
26
+ 3. Located the downloaded file and install it:
27
+ ```sh
28
+ bash Anaconda3-2024.06-1-Linux-x86_64.sh
29
+ ```
30
 
31
  ## Steps
32
 
33
  1. Clone this repository and change directory:
34
+ ```sh
35
+ git clone https://github.com/rezasalatin/V-BeachNet.git
36
+ cd V-BeachNet
37
+ ```
38
 
39
  2. Create the virtual environment with the requirements:
40
+ ```sh
41
+ conda env create -f environment.yml
42
+ conda activate vbeach
43
+ ```
44
+
45
+ 3. Visit the "Training_Station" folder and copy your manually segmented dataset to this directory. Open the following file to change any of the variables and save it. Then execute it to train the model:
46
+ ```sh
47
+ ./train_video_seg.sh
48
+ ```
49
+ Access your trained model from log/ directory.
50
+
51
+ 4. Visit the "Testing_Station" folder and copy your data to this directory. Open the following file to change any of the variables (especially model path from the log/ folder) and save it. Then execute it to test the model:
52
+ ```sh
53
+ ./test_video_seg.sh
54
+ ```
55
+ Access your segmented data from output directory.
environment.yml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: vbeach
2
+ channels:
3
+ - conda-forge
4
+ - pytorch
5
+ - nvidia
6
+ dependencies:
7
+ - tqdm
8
+ - pytorch
9
+ - torchvision
10
+ - torchaudio
11
+ - pytorch-cuda=12.1
12
+ - scipy
13
+ - numpy
14
+ - matplotlib
15
+ - pandas
16
+ - scikit-learn
17
+ - opencv
18
+ - gxx_linux-64
19
+ - pip
20
+ - pip:
21
+ - torch-scatter -f https://data.pyg.org/whl/torch-2.1.2+cu121.html
22
+ - segmentation-models-pytorch==0.2.0
test_image_seg.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import cv2
4
+ import numpy as np
5
+
6
+ from PIL import Image
7
+ from pathlib import Path
8
+ from glob import glob
9
+ from tqdm import tqdm
10
+
11
+ import torch
12
+ import torchvision.transforms as tf
13
+
14
+ from segmentation_models_pytorch import Linknet
15
+ import myutils
16
+
17
+ # ROOT_DIR = str(Path(__file__).resolve().parents[0])
18
+ ROOT_DIR = './'
19
+ # time_str = timestr = time.strftime("%Y-%m-%d %H-%M-%S")
20
+ # DEFAULT_OUT = os.path.join(ROOT_DIR, 'output', 'test_waterseg', time_str)
21
+ DEFAULT_OUT = os.path.join(ROOT_DIR, 'output', 'segs')
22
+ # DEFAULT_PALETTE = os.path.join(ROOT_DIR, "assets", "mask_palette.png")
23
+ # sys.path.append(ROOT_DIR)
24
+ # print("Added", ROOT_DIR, "to PATH.")
25
+
26
+
27
+ def norm_imagenet(img_pil, dims):
28
+ """
29
+ Normalizes and resizes input image
30
+ :param img_pil: PIL Image
31
+ :param dims: Model's expected input dimensions
32
+ :return: Normalized Image as a Tensor
33
+ """
34
+
35
+ # Mean and stddev of ImageNet dataset
36
+ mean = torch.tensor([0.485, 0.456, 0.406])
37
+ std = torch.tensor([0.229, 0.224, 0.225])
38
+
39
+ # Resize, convert to tensor, normalize
40
+ transform_norm = tf.Compose([
41
+ tf.Resize([dims[0], dims[1]]),
42
+ tf.ToTensor(),
43
+ tf.Normalize(mean, std)
44
+ ])
45
+
46
+ img_norm = transform_norm(img_pil)
47
+ return img_norm
48
+
49
+
50
+ def predict_one(path, model, mask_outdir, overlay_outdir, device):
51
+ """
52
+ Predicts a single image from path
53
+ :param path: Path to image
54
+ :param model: Loaded Torch Model
55
+ :param mask_outdir: Filepath to mask out directory
56
+ :param overlay_outdir: Filepath to overlay out directory
57
+ :return: None
58
+ """
59
+ img_pil = myutils.load_image_in_PIL(path)
60
+
61
+ # Prediction is an PIL Image of 0s and 1s
62
+ prediction = predict_pil(model, img_pil, model_dims=(416, 416), device=device)
63
+
64
+ basename = str(Path(os.path.basename(path)).stem)
65
+ mask_savepth = os.path.join(mask_outdir, basename + '.png')
66
+ # mask_save = prediction.convert('RGB')
67
+ prediction.save(mask_savepth)
68
+
69
+ over_savepth = os.path.join(overlay_outdir, basename + '.png')
70
+ img_np = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
71
+ overlay_np = myutils.add_overlay(img_np, np.array(prediction))
72
+ cv2.imwrite(over_savepth, overlay_np)
73
+ # overlay_np = np.array(img_pil) * 1 + np.array(prediction.convert('RGB')) * 0.8
74
+ # overlay_np = overlay_np.clip(0, 255)
75
+ # Image.fromarray(overlay_np.astype(np.uint8)).save(over_savepth)
76
+
77
+
78
+ def predict_pil(model, img_pil, model_dims, device):
79
+ """
80
+ Predicts a single PIL Image
81
+ :param model: Loaded PyTorch model
82
+ :param img_pil: PIL image
83
+ :param model_dims: Model input dimensions
84
+ :return: Segmentation prediction as PIL Image
85
+ """
86
+
87
+ img_np = np.array(img_pil)
88
+ img_tensor_norm = norm_imagenet(img_pil, model_dims)
89
+
90
+ # Pipeline to resize the prediction to the original image dimensions
91
+ pred_resize = tf.Compose([tf.Resize([img_np.shape[0], img_np.shape[1]])])
92
+
93
+ # Add extra dimension at front as model expects input 1*3*dimX*dimY (batch size of 1)
94
+ input_data = img_tensor_norm.unsqueeze(0)
95
+
96
+ try:
97
+ # print("Converted input image to cuda.")
98
+ prediction = model.predict(input_data.to(device))
99
+ except:
100
+ print("Did not convert input image to cuda.")
101
+ prediction = model.predict(input_data)
102
+
103
+ prediction = pred_resize(prediction)
104
+ prediction = myutils.postprocessing_pred(prediction.squeeze().cpu().round().numpy().astype(np.uint8))
105
+ prediction = Image.fromarray(prediction).convert('P')
106
+ prediction.putpalette(myutils.color_palette)
107
+ return prediction
108
+
109
+
110
+ def test_waterseg(model_path, test_path, test_name, out_path, device):
111
+ """
112
+ Tests either a single or an entire folder of images
113
+ :param args: Command line args
114
+ :return: None
115
+ """
116
+ model = torch.load(model_path)
117
+ print('############################################')
118
+ print('############################################')
119
+ print('############################################')
120
+ test_path = test_path
121
+ out_path = os.path.join(out_path, test_name)
122
+
123
+ mask_out = os.path.join(out_path, 'mask')
124
+ overlay_out = os.path.join(out_path, 'overlay')
125
+ if not os.path.exists(mask_out):
126
+ os.makedirs(mask_out)
127
+ if not os.path.exists(overlay_out):
128
+ os.makedirs(overlay_out)
129
+
130
+ if os.path.isfile(test_path):
131
+ predict_one(test_path, model, mask_out, overlay_out, device)
132
+ elif os.path.isdir(test_path):
133
+ paths = glob(os.path.join(test_path, '*.jpg')) + glob(os.path.join(test_path, '*.png'))
134
+ for path in tqdm(paths):
135
+ predict_one(path, model, mask_out, overlay_out, device)
136
+ else:
137
+ print("Error: Unknown path type:", test_path)
138
+
139
+
140
+ if __name__ == '__main__':
141
+ # Hyper parameters
142
+ parser = argparse.ArgumentParser(description='V-FloodNet: Water Image Segmentation')
143
+ # Required: Path to the .pth file.
144
+ parser.add_argument('--model-path',
145
+ default='./records/link_efficientb4_model.pth',
146
+ type=str,
147
+ metavar='PATH',
148
+ help='Path to the model')
149
+ # Required: Path to either the single file or directory of files containing .jpg or .png images
150
+ parser.add_argument('--test-path',
151
+ type=str,
152
+ metavar='PATH',
153
+ required=True,
154
+ help='Can point to folder or an individual jpg/png image')
155
+ parser.add_argument('--test-name',
156
+ type=str,
157
+ required=True,
158
+ help='Test name')
159
+ parser.add_argument('--out-path',
160
+ default=DEFAULT_OUT,
161
+ type=str,
162
+ metavar='PATH',
163
+ help='(OPTIONAL) Path to output folder, defaults to project root/output')
164
+ args = parser.parse_args()
165
+
166
+ # Device
167
+ device = torch.device('cpu')
168
+ if torch.cuda.is_available():
169
+ device = torch.device('cuda')
170
+
171
+ test_waterseg(args.model_path, args.test_path, args.test_name, args.out_path, device)
172
+
173
+ print(myutils.gct(), 'Test image segmentation done.')
test_video_seg.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from tqdm import tqdm, trange
3
+ import os
4
+ import argparse
5
+ from glob import glob
6
+ import torch
7
+ from torch import utils
8
+ from torch.nn import functional as F
9
+ from torchvision.transforms import functional as TF
10
+ from torchvision.transforms import InterpolationMode
11
+
12
+ from video_module.dataset import Video_DS
13
+ from video_module.model import AFB_URR, FeatureBank
14
+ from test_image_seg import test_waterseg
15
+ import myutils
16
+
17
+ torch.set_grad_enabled(False)
18
+
19
+
20
+ def get_args():
21
+ parser = argparse.ArgumentParser(description='V-FloodNet: Water Video Segmentation')
22
+ parser.add_argument('--gpu', type=int, default=0, help='GPU card id.')
23
+ parser.add_argument('--budget', type=int, default=250000, help='Max number of features in the feature bank.')
24
+ parser.add_argument('--viz', action='store_true', default=True, help='Visualize data.')
25
+ parser.add_argument('--model-path', type=str, required=True, help='Path to the checkpoint.')
26
+ parser.add_argument('--update-rate', type=float, default=0.1, help='Update Rate for merging new features.')
27
+ parser.add_argument('--merge-thres', type=float, default=0.95, help='Merging Rate threshold.')
28
+ parser.add_argument('--test-path', type=str, required=True, help='Path to the test video frames.')
29
+ parser.add_argument('--test-name', type=str, required=True, help='Name for the test video.')
30
+ return parser.parse_args()
31
+
32
+
33
+ def main(args, device):
34
+ model = AFB_URR(device, update_bank=True, load_imagenet_params=False)
35
+ model = model.to(device)
36
+ model.eval()
37
+
38
+ downsample_size = 480
39
+
40
+ if os.path.isfile(args.model_path):
41
+ checkpoint = torch.load(args.model_path)
42
+ end_epoch = checkpoint['epoch']
43
+ model.load_state_dict(checkpoint['model'], strict=False)
44
+ train_loss = checkpoint['loss']
45
+ seed = checkpoint['seed']
46
+ print(myutils.gct(),
47
+ f'Loaded checkpoint {args.model_path}. (end_epoch: {end_epoch}, train_loss: {train_loss}, seed: {seed})')
48
+ else:
49
+ print(myutils.gct(), f'No checkpoint found at {args.model_path}')
50
+ raise IOError
51
+
52
+ img_list = sorted(glob(os.path.join(args.test_path, '*.jpg')) + glob(os.path.join(args.test_path, '*.png')))
53
+ first_frame = myutils.load_image_in_PIL(img_list[0])
54
+ first_name = os.path.basename(img_list[0])[:-4]
55
+
56
+ out_dir = './output/segs'
57
+ mask_dir = os.path.join(out_dir, args.test_name, 'mask')
58
+ mask_path = os.path.join(mask_dir, first_name + '.png')
59
+ if not os.path.exists(mask_path):
60
+ image_model_path = './records/link_efficientb4_model.pth'
61
+ test_waterseg(image_model_path, img_list[0], args.test_name, out_dir, device)
62
+
63
+ first_mask = myutils.load_image_in_PIL(mask_path, 'P')
64
+ seq_dataset = Video_DS(img_list, first_frame, first_mask)
65
+
66
+ seq_loader = utils.data.DataLoader(seq_dataset, batch_size=1, shuffle=False, num_workers=1)
67
+
68
+ seg_dir = os.path.join(out_dir, args.test_name, 'mask')
69
+ os.makedirs(seg_dir, exist_ok=True)
70
+ if args.viz:
71
+ overlay_dir = os.path.join(out_dir, args.test_name, 'overlay')
72
+ os.makedirs(overlay_dir, exist_ok=True)
73
+
74
+ obj_n = seq_dataset.obj_n
75
+ fb = FeatureBank(obj_n, args.budget, device, update_rate=args.update_rate, thres_close=args.merge_thres)
76
+
77
+ ori_first_frame = seq_dataset.first_frame.unsqueeze(0).to(device)
78
+ ori_first_mask = seq_dataset.first_mask.unsqueeze(0).to(device)
79
+
80
+ first_frame = TF.resize(ori_first_frame, downsample_size, InterpolationMode.BICUBIC)
81
+ first_mask = TF.resize(ori_first_mask, downsample_size, InterpolationMode.NEAREST)
82
+
83
+ pred = torch.argmax(ori_first_mask[0], dim=0).cpu().numpy().astype(np.uint8)
84
+ seg_path = os.path.join(seg_dir, f'{first_name}.png')
85
+ myutils.save_seg_mask(pred, seg_path, myutils.color_palette)
86
+
87
+ if args.viz:
88
+ overlay_path = os.path.join(overlay_dir, f'{first_name}.png')
89
+ myutils.save_overlay(ori_first_frame[0], pred, overlay_path, myutils.color_palette)
90
+
91
+ with torch.no_grad():
92
+ k4_list, v4_list = model.memorize(first_frame, first_mask)
93
+ fb.init_bank(k4_list, v4_list)
94
+
95
+ for idx, (frame, frame_name) in enumerate(tqdm(seq_loader)):
96
+
97
+ ori_frame = frame.to(device)
98
+ ori_size = ori_frame.shape[-2:]
99
+ frame = TF.resize(ori_frame, downsample_size, InterpolationMode.BICUBIC)
100
+ score, _ = model.segment(frame, fb)
101
+ pred_mask = F.softmax(score, dim=1)
102
+
103
+ k4_list, v4_list = model.memorize(frame, pred_mask)
104
+ fb.update(k4_list, v4_list, idx + 1)
105
+
106
+ pred = TF.resize(pred_mask, ori_size, InterpolationMode.BICUBIC)
107
+ pred = torch.argmax(pred[0], dim=0).cpu().numpy().astype(np.uint8)
108
+ pred = myutils.postprocessing_pred(pred)
109
+ seg_path = os.path.join(seg_dir, f'{frame_name[0]}.png')
110
+ myutils.save_seg_mask(pred, seg_path, myutils.color_palette)
111
+ if args.viz:
112
+ overlay_path = os.path.join(overlay_dir, f'{frame_name[0]}.png')
113
+ myutils.save_overlay(ori_frame[0], pred, overlay_path, myutils.color_palette)
114
+
115
+ fb.print_peak_mem()
116
+
117
+
118
+ if __name__ == '__main__':
119
+
120
+ args = get_args()
121
+ print(myutils.gct(), 'Args =', args)
122
+
123
+ if args.gpu >= 0 and torch.cuda.is_available():
124
+ device = torch.device('cuda', args.gpu)
125
+ else:
126
+ raise ValueError('CUDA is required. --gpu must be >= 0.')
127
+
128
+ assert os.path.isdir(args.test_path)
129
+
130
+ main(args, device)
131
+
132
+ print(myutils.gct(), 'Test video segmentation done.')
test_video_seg.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ python test_video_seg.py \
2
+ --test-path=Testing_Station/Duck_Rectified/20211017_1400_UTC/ \
3
+ --test-name=Duck_Rectified \
4
+ --model-path=logs/20240724-104343/model/best.pth
train_video_seg.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import argparse
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from glob import glob
7
+
8
+ import torch
9
+ from torch.utils.data import DataLoader
10
+
11
+ from video_module.dataset import Water_Image_Train_DS
12
+ from video_module.model import AFB_URR, FeatureBank
13
+ import myutils
14
+
15
+ # Enable CUDA launch blocking for debugging; set to '0' to disable.
16
+ os.environ['CUDA_LAUNCH_BLOCKING'] = '0'
17
+
18
+ def get_args():
19
+ parser = argparse.ArgumentParser(description='Train V-BeachNet')
20
+ parser.add_argument('--gpu', type=int, default=0, help='GPU card id.')
21
+ parser.add_argument('--dataset', type=str, required=True, help='Dataset folder.')
22
+ parser.add_argument('--seed', type=int, default=-1, help='Random seed.')
23
+ parser.add_argument('--log', action='store_true', help='Save the training results.')
24
+ parser.add_argument('--lr', type=float, default=1e-5, help='Learning rate (default: 1e-5).')
25
+ parser.add_argument('--lu', type=float, default=0.5, help='Regularization factor (default: 0.5).')
26
+ parser.add_argument('--resume', type=str, help='Path to checkpoint (default: none).')
27
+ parser.add_argument('--new', action='store_true', help='Train the model from scratch.')
28
+ parser.add_argument('--scheduler-step', type=int, default=25, help='Scheduler step size (default: 25).')
29
+ parser.add_argument('--total-epochs', type=int, default=100, help='Total number of epochs (default: 100).')
30
+ parser.add_argument('--budget', type=int, default=300000, help='Maximum number of features in the feature bank (default: 300000).')
31
+ parser.add_argument('--obj-n', type=int, default=2, help='Maximum number of objects trained simultaneously.')
32
+ parser.add_argument('--clip-n', type=int, default=6, help='Maximum number of frames in a batch.')
33
+
34
+ return parser.parse_args()
35
+
36
+ def train_model(model, dataloader, criterion, optimizer):
37
+ stats = myutils.AvgMeter()
38
+ uncertainty_stats = myutils.AvgMeter()
39
+ progress_bar = tqdm(dataloader)
40
+
41
+ for _, sample in enumerate(progress_bar):
42
+ frames, masks, obj_n, info = sample
43
+
44
+ if obj_n.item() == 1:
45
+ continue
46
+
47
+ frames, masks = frames[0].to(device), masks[0].to(device)
48
+ fb_global = FeatureBank(obj_n.item(), args.budget, device)
49
+ k4_list, v4_list = model.memorize(frames[:1], masks[:1])
50
+ fb_global.init_bank(k4_list, v4_list)
51
+
52
+ scores, uncertainty = model.segment(frames[1:], fb_global)
53
+ label = torch.argmax(masks[1:], dim=1).long()
54
+
55
+ optimizer.zero_grad()
56
+ loss = criterion(scores, label) + args.lu * uncertainty
57
+ loss.backward()
58
+ optimizer.step()
59
+
60
+ uncertainty_stats.update(uncertainty.item())
61
+ stats.update(loss.item())
62
+ progress_bar.set_postfix(loss=f'{loss.item():.5f} (Avg: {stats.avg:.5f}, Uncertainty Avg: {uncertainty_stats.avg:.5f})')
63
+
64
+ return stats.avg
65
+
66
+ def main():
67
+ dataset = Water_Image_Train_DS(root=args.dataset, output_size=400, clip_n=args.clip_n, max_obj_n=args.obj_n)
68
+ dataloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2, pin_memory=True)
69
+ print(myutils.gct(), f'Dataset with {len(dataset)} training cases.')
70
+
71
+ model = AFB_URR(device, update_bank=False, load_imagenet_params=True).to(device)
72
+ model.train()
73
+ model.apply(myutils.set_bn_eval)
74
+
75
+ optimizer = torch.optim.AdamW(filter(lambda x: x.requires_grad, model.parameters()), args.lr)
76
+
77
+ start_epoch, best_loss = 0, float('inf')
78
+ if args.resume:
79
+ if os.path.isfile(args.resume):
80
+ checkpoint = torch.load(args.resume)
81
+ model.load_state_dict(checkpoint['model'], strict=False)
82
+ seed = checkpoint.get('seed', int(time.time()))
83
+
84
+ if not args.new:
85
+ start_epoch = checkpoint['epoch'] + 1
86
+ optimizer.load_state_dict(checkpoint['optimizer'])
87
+ best_loss = checkpoint['loss']
88
+ print(myutils.gct(), f'Resumed from checkpoint {args.resume} (Epoch: {start_epoch-1}, Best Loss: {best_loss}).')
89
+ else:
90
+ print(myutils.gct(), f'Loaded checkpoint {args.resume}. Training from scratch.')
91
+ else:
92
+ raise FileNotFoundError(f'No checkpoint found at {args.resume}')
93
+ else:
94
+ seed = args.seed if args.seed >= 0 else int(time.time())
95
+
96
+ print(myutils.gct(), 'Random seed:', seed)
97
+ torch.manual_seed(seed)
98
+ np.random.seed(seed)
99
+
100
+ criterion = torch.nn.CrossEntropyLoss().to(device)
101
+ scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.scheduler_step, gamma=0.5, last_epoch=start_epoch-1)
102
+
103
+ for epoch in range(start_epoch, args.total_epochs):
104
+ print(f'\n{myutils.gct()} Epoch: {epoch}, Learning Rate: {scheduler.get_last_lr()[0]:.6f}')
105
+ loss = train_model(model, dataloader, criterion, optimizer)
106
+
107
+ if args.log:
108
+ checkpoint = {
109
+ 'epoch': epoch,
110
+ 'model': model.state_dict(),
111
+ 'optimizer': optimizer.state_dict(),
112
+ 'loss': loss,
113
+ 'seed': seed
114
+ }
115
+ torch.save(checkpoint, os.path.join(model_path, 'final.pth'))
116
+ if loss < best_loss:
117
+ best_loss = loss
118
+ torch.save(checkpoint, os.path.join(model_path, 'best.pth'))
119
+ print('Updated best model.')
120
+
121
+ scheduler.step()
122
+
123
+ if __name__ == '__main__':
124
+ args = get_args()
125
+ print(myutils.gct(), f'Arguments: {args}')
126
+
127
+ if args.gpu >= 0 and torch.cuda.is_available():
128
+ device = torch.device('cuda', args.gpu)
129
+ else:
130
+ raise ValueError('CUDA is required. Ensure --gpu is set to >= 0.')
131
+
132
+ if args.log:
133
+ log_dir = os.path.join('logs', time.strftime('%Y%m%d-%H%M%S'))
134
+ model_path = os.path.join(log_dir, 'model')
135
+ os.makedirs(model_path, exist_ok=True)
136
+ myutils.save_scripts(log_dir, scripts_to_save=glob('*.*'))
137
+ myutils.save_scripts(log_dir, scripts_to_save=glob('dataset/*.py', recursive=True))
138
+ myutils.save_scripts(log_dir, scripts_to_save=glob('model/*.py', recursive=True))
139
+ myutils.save_scripts(log_dir, scripts_to_save=glob('myutils/*.py', recursive=True))
140
+ print(myutils.gct(), f'Created log directory: {log_dir}')
141
+
142
+ main()
143
+ print(myutils.gct(), 'Training completed.')
train_video_seg.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python3 train_video_seg.py \
2
+ --gpu=0 \
3
+ --dataset=Training_Station/Duck_Rectified/ \
4
+ --seed=-1 \
5
+ --log \
6
+ --lr=1e-5 \
7
+ --lu=0.5 \
8
+ --scheduler-step=25 \
9
+ --total-epochs=100 \
10
+ --budget=300000 \
11
+ --obj-n=2 \
12
+ --clip-n=6 \
13
+ --new
14
+ # --resume=logs/bestmodel/model/best.pth
15
+ # activate new or resume