junkmind veb-101 commited on
Commit
6bc5f9c
·
0 Parent(s):

Duplicate from veb-101/UWMGI_Medical_Image_Segmentation

Browse files

Co-authored-by: Vaibhav Singh <[email protected]>

Files changed (48) hide show
  1. .gitattributes +35 -0
  2. .gitignore +2 -0
  3. README.md +27 -0
  4. app.py +77 -0
  5. load_lightning_SD_to_Usual_SD.ipynb +335 -0
  6. requirements.txt +4 -0
  7. samples/case101_day26_slice_0096_266_266_1.50_1.50.png +0 -0
  8. samples/case107_day0_slice_0089_266_266_1.50_1.50.png +0 -0
  9. samples/case107_day21_slice_0069_266_266_1.50_1.50.png +0 -0
  10. samples/case113_day12_slice_0108_360_310_1.50_1.50.png +0 -0
  11. samples/case119_day20_slice_0063_266_266_1.50_1.50.png +0 -0
  12. samples/case119_day25_slice_0075_266_266_1.50_1.50.png +0 -0
  13. samples/case119_day25_slice_0095_266_266_1.50_1.50.png +0 -0
  14. samples/case121_day14_slice_0057_266_266_1.50_1.50.png +0 -0
  15. samples/case122_day25_slice_0087_266_266_1.50_1.50.png +0 -0
  16. samples/case124_day19_slice_0110_266_266_1.50_1.50.png +0 -0
  17. samples/case124_day20_slice_0110_266_266_1.50_1.50.png +0 -0
  18. samples/case130_day0_slice_0106_266_266_1.50_1.50.png +0 -0
  19. samples/case134_day21_slice_0085_360_310_1.50_1.50.png +0 -0
  20. samples/case139_day0_slice_0062_234_234_1.50_1.50.png +0 -0
  21. samples/case139_day18_slice_0094_266_266_1.50_1.50.png +0 -0
  22. samples/case146_day25_slice_0053_276_276_1.63_1.63.png +0 -0
  23. samples/case147_day0_slice_0085_360_310_1.50_1.50.png +0 -0
  24. samples/case148_day0_slice_0113_360_310_1.50_1.50.png +0 -0
  25. samples/case149_day15_slice_0057_266_266_1.50_1.50.png +0 -0
  26. samples/case29_day0_slice_0065_266_266_1.50_1.50.png +0 -0
  27. samples/case2_day1_slice_0054_266_266_1.50_1.50.png +0 -0
  28. samples/case2_day1_slice_0077_266_266_1.50_1.50.png +0 -0
  29. samples/case32_day19_slice_0091_266_266_1.50_1.50.png +0 -0
  30. samples/case32_day19_slice_0100_266_266_1.50_1.50.png +0 -0
  31. samples/case33_day21_slice_0114_266_266_1.50_1.50.png +0 -0
  32. samples/case36_day16_slice_0064_266_266_1.50_1.50.png +0 -0
  33. samples/case40_day0_slice_0094_266_266_1.50_1.50.png +0 -0
  34. samples/case41_day25_slice_0049_266_266_1.50_1.50.png +0 -0
  35. samples/case63_day22_slice_0076_266_266_1.50_1.50.png +0 -0
  36. samples/case63_day26_slice_0093_266_266_1.50_1.50.png +0 -0
  37. samples/case65_day28_slice_0133_266_266_1.50_1.50.png +0 -0
  38. samples/case66_day36_slice_0101_266_266_1.50_1.50.png +0 -0
  39. samples/case67_day0_slice_0049_266_266_1.50_1.50.png +0 -0
  40. samples/case67_day0_slice_0086_266_266_1.50_1.50.png +0 -0
  41. samples/case74_day18_slice_0101_266_266_1.50_1.50.png +0 -0
  42. samples/case74_day19_slice_0084_266_266_1.50_1.50.png +0 -0
  43. samples/case81_day28_slice_0066_266_266_1.50_1.50.png +0 -0
  44. samples/case85_day29_slice_0102_360_310_1.50_1.50.png +0 -0
  45. samples/case89_day19_slice_0082_360_310_1.50_1.50.png +0 -0
  46. samples/case89_day20_slice_0087_266_266_1.50_1.50.png +0 -0
  47. segformer_trained_weights/config.json +82 -0
  48. segformer_trained_weights/pytorch_model.bin +3 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ artifacts
2
+ wandb
README.md ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Medical Image Segmentation Gradio App
3
+ emoji: 🏥🩺
4
+ colorFrom: gray
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 3.37.0
8
+ app_file: app.py
9
+ pinned: false
10
+ duplicated_from: veb-101/UWMGI_Medical_Image_Segmentation
11
+ ---
12
+
13
+ # Medical Image Segmentation Gradio App
14
+
15
+ For the Gradio app we've removed the dependency on pytorch-lightning otherwise used in the project.
16
+ The `load_lightning_SD_to_Usual_SD.ipynb` notebook contains the steps used to convert pytorch-lightning checkpoint to a regular model checkpoint. This was mainly done to reduce the file size (977 MB --> 244 MB).
17
+
18
+ You can download the original saved checkpoint from over here: [wandb artifact](https://wandb.ai/veb-101/UM_medical_segmentation/artifacts/model/model-jsr2fn8v/v0/files)
19
+
20
+ Or via Python:
21
+
22
+ ```python
23
+ import wandb
24
+ run = wandb.init()
25
+ artifact = run.use_artifact('veb-101/UM_medical_segmentation/model-jsr2fn8v:v0', type='model')
26
+ artifact_dir = artifact.download()
27
+ ```
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import gradio as gr
4
+ from glob import glob
5
+ from functools import partial
6
+ from dataclasses import dataclass
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import torchvision.transforms as TF
11
+ from transformers import SegformerForSemanticSegmentation
12
+
13
+
14
+ @dataclass
15
+ class Configs:
16
+ NUM_CLASSES: int = 4 # including background.
17
+ CLASSES: tuple = ("Large bowel", "Small bowel", "Stomach")
18
+ IMAGE_SIZE: tuple[int, int] = (288, 288) # W, H
19
+ MEAN: tuple = (0.485, 0.456, 0.406)
20
+ STD: tuple = (0.229, 0.224, 0.225)
21
+ MODEL_PATH: str = os.path.join(os.getcwd(), "segformer_trained_weights")
22
+
23
+
24
+ def get_model(*, model_path, num_classes):
25
+ model = SegformerForSemanticSegmentation.from_pretrained(model_path, num_labels=num_classes, ignore_mismatched_sizes=True)
26
+ return model
27
+
28
+
29
+ @torch.inference_mode()
30
+ def predict(input_image, model=None, preprocess_fn=None, device="cpu"):
31
+ shape_H_W = input_image.size[::-1]
32
+ input_tensor = preprocess_fn(input_image)
33
+ input_tensor = input_tensor.unsqueeze(0).to(device)
34
+
35
+ # Generate predictions
36
+ outputs = model(pixel_values=input_tensor.to(device), return_dict=True)
37
+ predictions = F.interpolate(outputs["logits"], size=shape_H_W, mode="bilinear", align_corners=False)
38
+
39
+ preds_argmax = predictions.argmax(dim=1).cpu().squeeze().numpy()
40
+
41
+ seg_info = [(preds_argmax == idx, class_name) for idx, class_name in enumerate(Configs.CLASSES, 1)]
42
+
43
+ return (input_image, seg_info)
44
+
45
+
46
+ if __name__ == "__main__":
47
+ class2hexcolor = {"Stomach": "#007fff", "Small bowel": "#009A17", "Large bowel": "#FF0000"}
48
+
49
+ DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
50
+
51
+ model = get_model(model_path=Configs.MODEL_PATH, num_classes=Configs.NUM_CLASSES)
52
+ model.to(DEVICE)
53
+ model.eval()
54
+ _ = model(torch.randn(1, 3, *Configs.IMAGE_SIZE[::-1], device=DEVICE))
55
+
56
+ preprocess = TF.Compose(
57
+ [
58
+ TF.Resize(size=Configs.IMAGE_SIZE[::-1]),
59
+ TF.ToTensor(),
60
+ TF.Normalize(Configs.MEAN, Configs.STD, inplace=True),
61
+ ]
62
+ )
63
+
64
+ with gr.Blocks(title="Medical Image Segmentation") as demo:
65
+ gr.Markdown("""<h1><center>Medical Image Segmentation with UW-Madison GI Tract Dataset</center></h1>""")
66
+ with gr.Row():
67
+ img_input = gr.Image(type="pil", height=360, width=360, label="Input image")
68
+ img_output = gr.AnnotatedImage(label="Predictions", height=360, width=360, color_map=class2hexcolor)
69
+
70
+ section_btn = gr.Button("Generate Predictions")
71
+ section_btn.click(partial(predict, model=model, preprocess_fn=preprocess, device=DEVICE), img_input, img_output)
72
+
73
+ images_dir = glob(os.path.join(os.getcwd(), "samples") + os.sep + "*.png")
74
+ examples = [i for i in np.random.choice(images_dir, size=10, replace=False)]
75
+ gr.Examples(examples=examples, inputs=img_input, outputs=img_output)
76
+
77
+ demo.launch()
load_lightning_SD_to_Usual_SD.ipynb ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "## Base Configurations"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": 1,
13
+ "metadata": {},
14
+ "outputs": [],
15
+ "source": [
16
+ "import os\n",
17
+ "import torch\n",
18
+ "from transformers import SegformerForSemanticSegmentation\n",
19
+ "from dataclasses import dataclass\n",
20
+ "\n",
21
+ "\n",
22
+ "@dataclass\n",
23
+ "class Configs:\n",
24
+ " NUM_CLASSES = 4\n",
25
+ " MODEL_PATH: str = \"nvidia/segformer-b4-finetuned-ade-512-512\""
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "markdown",
30
+ "metadata": {},
31
+ "source": [
32
+ "## Load Model To Inspect Parameter Names"
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "code",
37
+ "execution_count": 2,
38
+ "metadata": {},
39
+ "outputs": [],
40
+ "source": [
41
+ "\n",
42
+ "\n",
43
+ "def get_model(*, model_path, num_classes):\n",
44
+ " model = SegformerForSemanticSegmentation.from_pretrained(\n",
45
+ " model_path,\n",
46
+ " num_labels=num_classes,\n",
47
+ " ignore_mismatched_sizes=True,\n",
48
+ " )\n",
49
+ " return model"
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "code",
54
+ "execution_count": 3,
55
+ "metadata": {},
56
+ "outputs": [
57
+ {
58
+ "name": "stderr",
59
+ "output_type": "stream",
60
+ "text": [
61
+ "Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/segformer-b4-finetuned-ade-512-512 and are newly initialized because the shapes did not match:\n",
62
+ "- decode_head.classifier.weight: found shape torch.Size([150, 768, 1, 1]) in the checkpoint and torch.Size([4, 768, 1, 1]) in the model instantiated\n",
63
+ "- decode_head.classifier.bias: found shape torch.Size([150]) in the checkpoint and torch.Size([4]) in the model instantiated\n",
64
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
65
+ ]
66
+ },
67
+ {
68
+ "name": "stdout",
69
+ "output_type": "stream",
70
+ "text": [
71
+ "\n",
72
+ "segformer.encoder.patch_embeddings.0.proj.weight\n",
73
+ "segformer.encoder.patch_embeddings.0.proj.bias\n",
74
+ "segformer.encoder.patch_embeddings.0.layer_norm.weight\n",
75
+ "segformer.encoder.patch_embeddings.0.layer_norm.bias\n",
76
+ "segformer.encoder.patch_embeddings.1.proj.weight\n",
77
+ "segformer.encoder.patch_embeddings.1.proj.bias\n"
78
+ ]
79
+ }
80
+ ],
81
+ "source": [
82
+ "model = get_model(model_path=Configs.MODEL_PATH, num_classes=Configs.NUM_CLASSES)\n",
83
+ "model_state_dict = model.state_dict()\n",
84
+ "\n",
85
+ "print()\n",
86
+ "for i, (key, val) in enumerate(model_state_dict.items()):\n",
87
+ " print(key)\n",
88
+ " if i == 5:\n",
89
+ " break"
90
+ ]
91
+ },
92
+ {
93
+ "cell_type": "markdown",
94
+ "metadata": {},
95
+ "source": [
96
+ "## Download & load PyTorch-Lightning Checkpoint and Inspect Parameter Names"
97
+ ]
98
+ },
99
+ {
100
+ "cell_type": "code",
101
+ "execution_count": 4,
102
+ "metadata": {},
103
+ "outputs": [
104
+ {
105
+ "name": "stderr",
106
+ "output_type": "stream",
107
+ "text": [
108
+ "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n",
109
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mveb-101\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
110
+ ]
111
+ },
112
+ {
113
+ "data": {
114
+ "text/html": [
115
+ "Tracking run with wandb version 0.15.5"
116
+ ],
117
+ "text/plain": [
118
+ "<IPython.core.display.HTML object>"
119
+ ]
120
+ },
121
+ "metadata": {},
122
+ "output_type": "display_data"
123
+ },
124
+ {
125
+ "data": {
126
+ "text/html": [
127
+ "Run data is saved locally in <code>c:\\Users\\vaibh\\OneDrive\\Desktop\\Work\\BigVision\\BLOG_POSTS\\Medical_segmentation\\GRADIO_APP\\UWMGI_Medical_Image_Segmentation\\wandb\\run-20230719_204221-w5qu5rqw</code>"
128
+ ],
129
+ "text/plain": [
130
+ "<IPython.core.display.HTML object>"
131
+ ]
132
+ },
133
+ "metadata": {},
134
+ "output_type": "display_data"
135
+ },
136
+ {
137
+ "data": {
138
+ "text/html": [
139
+ "Syncing run <strong><a href='https://wandb.ai/veb-101/UWMGI_Medical_Image_Segmentation/runs/w5qu5rqw' target=\"_blank\">ethereal-bush-2</a></strong> to <a href='https://wandb.ai/veb-101/UWMGI_Medical_Image_Segmentation' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
140
+ ],
141
+ "text/plain": [
142
+ "<IPython.core.display.HTML object>"
143
+ ]
144
+ },
145
+ "metadata": {},
146
+ "output_type": "display_data"
147
+ },
148
+ {
149
+ "data": {
150
+ "text/html": [
151
+ " View project at <a href='https://wandb.ai/veb-101/UWMGI_Medical_Image_Segmentation' target=\"_blank\">https://wandb.ai/veb-101/UWMGI_Medical_Image_Segmentation</a>"
152
+ ],
153
+ "text/plain": [
154
+ "<IPython.core.display.HTML object>"
155
+ ]
156
+ },
157
+ "metadata": {},
158
+ "output_type": "display_data"
159
+ },
160
+ {
161
+ "data": {
162
+ "text/html": [
163
+ " View run at <a href='https://wandb.ai/veb-101/UWMGI_Medical_Image_Segmentation/runs/w5qu5rqw' target=\"_blank\">https://wandb.ai/veb-101/UWMGI_Medical_Image_Segmentation/runs/w5qu5rqw</a>"
164
+ ],
165
+ "text/plain": [
166
+ "<IPython.core.display.HTML object>"
167
+ ]
168
+ },
169
+ "metadata": {},
170
+ "output_type": "display_data"
171
+ },
172
+ {
173
+ "name": "stderr",
174
+ "output_type": "stream",
175
+ "text": [
176
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Downloading large artifact model-fpgquxev:v0, 977.89MB. 1 files... \n",
177
+ "\u001b[34m\u001b[1mwandb\u001b[0m: 1 of 1 files downloaded. \n",
178
+ "Done. 0:1:5.3\n"
179
+ ]
180
+ }
181
+ ],
182
+ "source": [
183
+ "import wandb\n",
184
+ "\n",
185
+ "run = wandb.init()\n",
186
+ "artifact = run.use_artifact(\"veb-101/UM_medical_segmentation/model-fpgquxev:v0\", type=\"model\")\n",
187
+ "artifact_dir = artifact.download()"
188
+ ]
189
+ },
190
+ {
191
+ "cell_type": "code",
192
+ "execution_count": 5,
193
+ "metadata": {},
194
+ "outputs": [
195
+ {
196
+ "name": "stdout",
197
+ "output_type": "stream",
198
+ "text": [
199
+ "dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'loops', 'callbacks', 'optimizer_states', 'lr_schedulers', 'MixedPrecisionPlugin', 'hparams_name', 'hyper_parameters'])\n"
200
+ ]
201
+ }
202
+ ],
203
+ "source": [
204
+ "CKPT = torch.load(os.path.join(artifact_dir, \"model.ckpt\"), map_location=\"cpu\")\n",
205
+ "print(CKPT.keys())"
206
+ ]
207
+ },
208
+ {
209
+ "cell_type": "code",
210
+ "execution_count": 6,
211
+ "metadata": {},
212
+ "outputs": [
213
+ {
214
+ "name": "stdout",
215
+ "output_type": "stream",
216
+ "text": [
217
+ "model.segformer.encoder.patch_embeddings.0.proj.weight\n",
218
+ "model.segformer.encoder.patch_embeddings.0.proj.bias\n",
219
+ "model.segformer.encoder.patch_embeddings.0.layer_norm.weight\n",
220
+ "model.segformer.encoder.patch_embeddings.0.layer_norm.bias\n",
221
+ "model.segformer.encoder.patch_embeddings.1.proj.weight\n",
222
+ "model.segformer.encoder.patch_embeddings.1.proj.bias\n"
223
+ ]
224
+ }
225
+ ],
226
+ "source": [
227
+ "TRAINED_CKPT_state_dict = CKPT[\"state_dict\"]\n",
228
+ "\n",
229
+ "for i, (key, val) in enumerate(TRAINED_CKPT_state_dict.items()):\n",
230
+ " print(key)\n",
231
+ " if i == 5:\n",
232
+ " break"
233
+ ]
234
+ },
235
+ {
236
+ "cell_type": "markdown",
237
+ "metadata": {},
238
+ "source": [
239
+ "**The pytorch-lightning `state_dict()` has an extra `model.` string at the front that refers to the object/variable name that was holding the model in the `LightningModule` class.**\n",
240
+ "\n",
241
+ "We can simply iterate over the parameters and change the parameter key name. We'll create a new `OrderedDict` for it."
242
+ ]
243
+ },
244
+ {
245
+ "cell_type": "code",
246
+ "execution_count": 7,
247
+ "metadata": {},
248
+ "outputs": [],
249
+ "source": [
250
+ "from collections import OrderedDict\n",
251
+ "\n",
252
+ "new_state_dict = OrderedDict()\n",
253
+ "\n",
254
+ "for key_name, value in CKPT[\"state_dict\"].items():\n",
255
+ " new_state_dict[key_name.replace(\"model.\", \"\")] = value"
256
+ ]
257
+ },
258
+ {
259
+ "cell_type": "code",
260
+ "execution_count": 8,
261
+ "metadata": {},
262
+ "outputs": [
263
+ {
264
+ "data": {
265
+ "text/plain": [
266
+ "<All keys matched successfully>"
267
+ ]
268
+ },
269
+ "execution_count": 8,
270
+ "metadata": {},
271
+ "output_type": "execute_result"
272
+ }
273
+ ],
274
+ "source": [
275
+ "model.load_state_dict(new_state_dict)"
276
+ ]
277
+ },
278
+ {
279
+ "cell_type": "code",
280
+ "execution_count": 9,
281
+ "metadata": {},
282
+ "outputs": [],
283
+ "source": [
284
+ "# torch.save(model.state_dict(), \"Segformer_best_state_dict.ckpt\")"
285
+ ]
286
+ },
287
+ {
288
+ "cell_type": "code",
289
+ "execution_count": 10,
290
+ "metadata": {},
291
+ "outputs": [],
292
+ "source": [
293
+ "model.save_pretrained(\"segformer_trained_weights\")"
294
+ ]
295
+ },
296
+ {
297
+ "cell_type": "markdown",
298
+ "metadata": {},
299
+ "source": [
300
+ "To load the saved model, we simply need to pass the path to the directory \"segformer_trained_weights\"."
301
+ ]
302
+ },
303
+ {
304
+ "cell_type": "code",
305
+ "execution_count": null,
306
+ "metadata": {},
307
+ "outputs": [],
308
+ "source": [
309
+ "# model = get_model(model_path=os.path.join(os.getcwd(), \"segformer_trained_weights\"), num_classes=Configs.NUM_CLASSES)"
310
+ ]
311
+ }
312
+ ],
313
+ "metadata": {
314
+ "kernelspec": {
315
+ "display_name": "pytorchx",
316
+ "language": "python",
317
+ "name": "python3"
318
+ },
319
+ "language_info": {
320
+ "codemirror_mode": {
321
+ "name": "ipython",
322
+ "version": 3
323
+ },
324
+ "file_extension": ".py",
325
+ "mimetype": "text/x-python",
326
+ "name": "python",
327
+ "nbconvert_exporter": "python",
328
+ "pygments_lexer": "ipython3",
329
+ "version": "3.10.12"
330
+ },
331
+ "orig_nbformat": 4
332
+ },
333
+ "nbformat": 4,
334
+ "nbformat_minor": 2
335
+ }
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ --find-links https://download.pytorch.org/whl/torch_stable.html
2
+ torch==2.0.0+cpu
3
+ torchvision==0.15.0
4
+ transformers==4.30.2
samples/case101_day26_slice_0096_266_266_1.50_1.50.png ADDED
samples/case107_day0_slice_0089_266_266_1.50_1.50.png ADDED
samples/case107_day21_slice_0069_266_266_1.50_1.50.png ADDED
samples/case113_day12_slice_0108_360_310_1.50_1.50.png ADDED
samples/case119_day20_slice_0063_266_266_1.50_1.50.png ADDED
samples/case119_day25_slice_0075_266_266_1.50_1.50.png ADDED
samples/case119_day25_slice_0095_266_266_1.50_1.50.png ADDED
samples/case121_day14_slice_0057_266_266_1.50_1.50.png ADDED
samples/case122_day25_slice_0087_266_266_1.50_1.50.png ADDED
samples/case124_day19_slice_0110_266_266_1.50_1.50.png ADDED
samples/case124_day20_slice_0110_266_266_1.50_1.50.png ADDED
samples/case130_day0_slice_0106_266_266_1.50_1.50.png ADDED
samples/case134_day21_slice_0085_360_310_1.50_1.50.png ADDED
samples/case139_day0_slice_0062_234_234_1.50_1.50.png ADDED
samples/case139_day18_slice_0094_266_266_1.50_1.50.png ADDED
samples/case146_day25_slice_0053_276_276_1.63_1.63.png ADDED
samples/case147_day0_slice_0085_360_310_1.50_1.50.png ADDED
samples/case148_day0_slice_0113_360_310_1.50_1.50.png ADDED
samples/case149_day15_slice_0057_266_266_1.50_1.50.png ADDED
samples/case29_day0_slice_0065_266_266_1.50_1.50.png ADDED
samples/case2_day1_slice_0054_266_266_1.50_1.50.png ADDED
samples/case2_day1_slice_0077_266_266_1.50_1.50.png ADDED
samples/case32_day19_slice_0091_266_266_1.50_1.50.png ADDED
samples/case32_day19_slice_0100_266_266_1.50_1.50.png ADDED
samples/case33_day21_slice_0114_266_266_1.50_1.50.png ADDED
samples/case36_day16_slice_0064_266_266_1.50_1.50.png ADDED
samples/case40_day0_slice_0094_266_266_1.50_1.50.png ADDED
samples/case41_day25_slice_0049_266_266_1.50_1.50.png ADDED
samples/case63_day22_slice_0076_266_266_1.50_1.50.png ADDED
samples/case63_day26_slice_0093_266_266_1.50_1.50.png ADDED
samples/case65_day28_slice_0133_266_266_1.50_1.50.png ADDED
samples/case66_day36_slice_0101_266_266_1.50_1.50.png ADDED
samples/case67_day0_slice_0049_266_266_1.50_1.50.png ADDED
samples/case67_day0_slice_0086_266_266_1.50_1.50.png ADDED
samples/case74_day18_slice_0101_266_266_1.50_1.50.png ADDED
samples/case74_day19_slice_0084_266_266_1.50_1.50.png ADDED
samples/case81_day28_slice_0066_266_266_1.50_1.50.png ADDED
samples/case85_day29_slice_0102_360_310_1.50_1.50.png ADDED
samples/case89_day19_slice_0082_360_310_1.50_1.50.png ADDED
samples/case89_day20_slice_0087_266_266_1.50_1.50.png ADDED
segformer_trained_weights/config.json ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "nvidia/segformer-b4-finetuned-ade-512-512",
3
+ "architectures": [
4
+ "SegformerForSemanticSegmentation"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.0,
7
+ "classifier_dropout_prob": 0.1,
8
+ "decoder_hidden_size": 768,
9
+ "depths": [
10
+ 3,
11
+ 8,
12
+ 27,
13
+ 3
14
+ ],
15
+ "downsampling_rates": [
16
+ 1,
17
+ 4,
18
+ 8,
19
+ 16
20
+ ],
21
+ "drop_path_rate": 0.1,
22
+ "hidden_act": "gelu",
23
+ "hidden_dropout_prob": 0.0,
24
+ "hidden_sizes": [
25
+ 64,
26
+ 128,
27
+ 320,
28
+ 512
29
+ ],
30
+ "id2label": {
31
+ "0": "LABEL_0",
32
+ "1": "LABEL_1",
33
+ "2": "LABEL_2",
34
+ "3": "LABEL_3"
35
+ },
36
+ "image_size": 224,
37
+ "initializer_range": 0.02,
38
+ "label2id": {
39
+ "LABEL_0": 0,
40
+ "LABEL_1": 1,
41
+ "LABEL_2": 2,
42
+ "LABEL_3": 3
43
+ },
44
+ "layer_norm_eps": 1e-06,
45
+ "mlp_ratios": [
46
+ 4,
47
+ 4,
48
+ 4,
49
+ 4
50
+ ],
51
+ "model_type": "segformer",
52
+ "num_attention_heads": [
53
+ 1,
54
+ 2,
55
+ 5,
56
+ 8
57
+ ],
58
+ "num_channels": 3,
59
+ "num_encoder_blocks": 4,
60
+ "patch_sizes": [
61
+ 7,
62
+ 3,
63
+ 3,
64
+ 3
65
+ ],
66
+ "reshape_last_stage": true,
67
+ "semantic_loss_ignore_index": 255,
68
+ "sr_ratios": [
69
+ 8,
70
+ 4,
71
+ 2,
72
+ 1
73
+ ],
74
+ "strides": [
75
+ 4,
76
+ 2,
77
+ 2,
78
+ 2
79
+ ],
80
+ "torch_dtype": "float32",
81
+ "transformers_version": "4.30.2"
82
+ }
segformer_trained_weights/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:423ff60b52bdbc5c0ea00f1a5648c42eccf2bdfbab550304bc95e28eb594cf0e
3
+ size 256300245