hasibzunair commited on
Commit
76f797b
ยท
1 Parent(s): 8a3583d

add app.py

Browse files
Files changed (2) hide show
  1. app.py +77 -0
  2. nyu.ipynb +0 -165
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import cv2
4
+ import codecs
5
+ import torch
6
+ import torchvision.transforms as transforms
7
+ import gradio as gr
8
+
9
+ from PIL import Image
10
+
11
+ from unetplusplus import NestedUNet
12
+
13
+ torch.manual_seed(0)
14
+
15
+ if torch.cuda.is_available():
16
+ torch.backends.cudnn.deterministic = True
17
+
18
+ # Device
19
+ DEVICE = "cpu"
20
+ print(DEVICE)
21
+
22
+ # Load color map
23
+ cmap = np.load('cmap.npy')
24
+
25
+ # Make directories
26
+ os.system("mkdir ./models")
27
+
28
+ # Get model weights
29
+ if not os.path.exists("./models/masksupnyu39.31d.pth"):
30
+ os.system("wget -O ./models/masksupnyu39.31d.pth https://github.com/hasibzunair/masksup-segmentation/releases/download/v0.1/masksupnyu39.31iou.pth")
31
+
32
+ # Load model
33
+ model = NestedUNet(num_classes=40)
34
+ checkpoint = torch.load("./models/masksupnyu39.31d.pth")
35
+ model.load_state_dict(checkpoint)
36
+ model = model.to(DEVICE)
37
+ model.eval()
38
+
39
+
40
+ # Main inference function
41
+ def inference(img_path):
42
+ image = Image.open(img_path).convert("RGB")
43
+ transforms_image = transforms.Compose(
44
+ [
45
+ transforms.Resize((224, 224)),
46
+ transforms.CenterCrop((224, 224)),
47
+ transforms.ToTensor(),
48
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
49
+ ]
50
+ )
51
+
52
+ image = transforms_image(image)
53
+ image = image[None, :]
54
+ # Predict
55
+ with torch.no_grad():
56
+ output = torch.sigmoid(model(image.to(DEVICE).float()))
57
+ output = torch.softmax(output, dim=1).argmax(dim=1)[0].float().cpu().numpy().astype(np.uint8)
58
+ pred = cmap[output]
59
+ return pred
60
+
61
+ # App
62
+ title = "Masked Supervised Learning for Semantic Segmentation"
63
+ description = codecs.open("description.html", "r", "utf-8").read()
64
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2210.00923' target='_blank'>Masked Supervised Learning for Semantic Segmentation</a> | <a href='https://github.com/hasibzunair/masksup-segmentation' target='_blank'>Github</a></p>"
65
+
66
+ gr.Interface(
67
+ inference,
68
+ gr.inputs.Image(type='file', label="Input Image"),
69
+ gr.outputs.Image(type="file", label="Predicted Output"),
70
+ examples=["./sample_images/a.png", "./sample_images/b.png",
71
+ "./sample_images/c.png", "./sample_images/d.png"],
72
+ title=title,
73
+ description=description,
74
+ article=article,
75
+ allow_flagging=False,
76
+ analytics_enabled=False,
77
+ ).launch(debug=True, enable_queue=True)
nyu.ipynb DELETED
@@ -1,165 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": null,
6
- "metadata": {},
7
- "outputs": [],
8
- "source": [
9
- "import os\n",
10
- "import numpy as np\n",
11
- "import cv2\n",
12
- "import codecs\n",
13
- "import torch\n",
14
- "import torchvision.transforms as transforms\n",
15
- "import gradio as gr\n",
16
- "\n",
17
- "from PIL import Image\n",
18
- "\n",
19
- "from unetplusplus import NestedUNet\n",
20
- "\n",
21
- "torch.manual_seed(0)\n",
22
- "\n",
23
- "if torch.cuda.is_available():\n",
24
- " torch.backends.cudnn.deterministic = True\n",
25
- "\n",
26
- "# Device\n",
27
- "DEVICE = \"cpu\"\n",
28
- "print(DEVICE)\n",
29
- "\n",
30
- "# Load color map\n",
31
- "cmap = np.load('cmap.npy')\n",
32
- "\n",
33
- "# Make directories\n",
34
- "os.system(\"mkdir ./models\")\n",
35
- "\n",
36
- "# Get model weights\n",
37
- "if not os.path.exists(\"./models/masksupnyu39.31d.pth\"):\n",
38
- " os.system(\"wget -O ./models/masksupnyu39.31d.pth https://github.com/hasibzunair/masksup-segmentation/releases/download/v0.1/masksupnyu39.31iou.pth\")\n",
39
- "\n",
40
- "# Load model\n",
41
- "model = NestedUNet(num_classes=40)\n",
42
- "checkpoint = torch.load(\"./models/masksupnyu39.31d.pth\")\n",
43
- "model.load_state_dict(checkpoint)\n",
44
- "model = model.to(DEVICE)\n",
45
- "model.eval()\n",
46
- "\n",
47
- "\n",
48
- "# Main inference function\n",
49
- "def inference(img_path):\n",
50
- " image = Image.open(img_path).convert(\"RGB\")\n",
51
- " transforms_image = transforms.Compose(\n",
52
- " [\n",
53
- " transforms.Resize((224, 224)),\n",
54
- " transforms.CenterCrop((224, 224)),\n",
55
- " transforms.ToTensor(),\n",
56
- " transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n",
57
- " ]\n",
58
- " )\n",
59
- "\n",
60
- " image = transforms_image(image)\n",
61
- " image = image[None, :]\n",
62
- " # Predict\n",
63
- " with torch.no_grad():\n",
64
- " output = torch.sigmoid(model(image.to(DEVICE).float()))\n",
65
- " output = torch.softmax(output, dim=1).argmax(dim=1)[0].float().cpu().numpy().astype(np.uint8)\n",
66
- " pred = cmap[output]\n",
67
- " return pred\n",
68
- "\n",
69
- "# App\n",
70
- "title = \"Masked Supervised Learning for Semantic Segmentation\"\n",
71
- "description = codecs.open(\"description.html\", \"r\", \"utf-8\").read()\n",
72
- "article = \"<p style='text-align: center'><a href='https://arxiv.org/abs/2210.00923' target='_blank'>Masked Supervised Learning for Semantic Segmentation</a> | <a href='https://github.com/hasibzunair/masksup-segmentation' target='_blank'>Github</a></p>\"\n",
73
- "\n",
74
- "gr.Interface(\n",
75
- " inference,\n",
76
- " gr.inputs.Image(type='file', label=\"Input Image\"),\n",
77
- " gr.outputs.Image(type=\"file\", label=\"Predicted Output\"),\n",
78
- " examples=[\"./sample_images/a.png\", \"./sample_images/b.png\", \n",
79
- " \"./sample_images/c.png\", \"./sample_images/d.png\"],\n",
80
- " title=title,\n",
81
- " description=description,\n",
82
- " article=article,\n",
83
- " allow_flagging=False,\n",
84
- " analytics_enabled=False,\n",
85
- " ).launch(debug=True, enable_queue=True)"
86
- ]
87
- },
88
- {
89
- "cell_type": "code",
90
- "execution_count": null,
91
- "metadata": {},
92
- "outputs": [],
93
- "source": []
94
- },
95
- {
96
- "cell_type": "code",
97
- "execution_count": null,
98
- "metadata": {},
99
- "outputs": [],
100
- "source": []
101
- },
102
- {
103
- "cell_type": "code",
104
- "execution_count": null,
105
- "metadata": {},
106
- "outputs": [],
107
- "source": []
108
- },
109
- {
110
- "cell_type": "code",
111
- "execution_count": null,
112
- "metadata": {},
113
- "outputs": [],
114
- "source": []
115
- },
116
- {
117
- "cell_type": "code",
118
- "execution_count": null,
119
- "metadata": {},
120
- "outputs": [],
121
- "source": []
122
- },
123
- {
124
- "cell_type": "code",
125
- "execution_count": null,
126
- "metadata": {},
127
- "outputs": [],
128
- "source": []
129
- },
130
- {
131
- "cell_type": "code",
132
- "execution_count": null,
133
- "metadata": {},
134
- "outputs": [],
135
- "source": []
136
- }
137
- ],
138
- "metadata": {
139
- "kernelspec": {
140
- "display_name": "Python 3.8.12 ('fifa')",
141
- "language": "python",
142
- "name": "python3"
143
- },
144
- "language_info": {
145
- "codemirror_mode": {
146
- "name": "ipython",
147
- "version": 3
148
- },
149
- "file_extension": ".py",
150
- "mimetype": "text/x-python",
151
- "name": "python",
152
- "nbconvert_exporter": "python",
153
- "pygments_lexer": "ipython3",
154
- "version": "3.8.12"
155
- },
156
- "orig_nbformat": 4,
157
- "vscode": {
158
- "interpreter": {
159
- "hash": "5a4cff4f724f20f3784f32e905011239b516be3fadafd59414871df18d0dad63"
160
- }
161
- }
162
- },
163
- "nbformat": 4,
164
- "nbformat_minor": 2
165
- }