Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -10,7 +10,6 @@ from utils import page_utils
|
|
10 |
|
11 |
class BasicBlock(nn.Module):
|
12 |
"""ResNet Basic Block.
|
13 |
-
|
14 |
Parameters
|
15 |
----------
|
16 |
in_channels : int
|
@@ -63,7 +62,6 @@ class BasicBlock(nn.Module):
|
|
63 |
|
64 |
class ResNet18(nn.Module):
|
65 |
"""Construct ResNet-18 Model.
|
66 |
-
|
67 |
Parameters
|
68 |
----------
|
69 |
input_channels : int
|
@@ -137,8 +135,8 @@ model = ResNet18(3, 7)
|
|
137 |
|
138 |
checkpoint = torch.load('ham10000.ckpt', map_location=torch.device('cpu'))
|
139 |
|
140 |
-
# The state dict will
|
141 |
-
# Our model doesn't
|
142 |
state_dict = checkpoint['state_dict']
|
143 |
for key in list(state_dict.keys()):
|
144 |
if 'net.' in key:
|
@@ -147,7 +145,8 @@ for key in list(state_dict.keys()):
|
|
147 |
|
148 |
model.load_state_dict(state_dict)
|
149 |
model.eval()
|
150 |
-
|
|
|
151 |
class_names = {
|
152 |
'akk': 'Actinic Keratosis',
|
153 |
'bcc': 'Basal Cell Carcinoma',
|
@@ -157,13 +156,8 @@ class_names = {
|
|
157 |
'nv': 'Melanocytic Nevi',
|
158 |
'vasc': 'Vascular Lesion'
|
159 |
}
|
160 |
-
"""
|
161 |
-
class_names = ['akk', 'bcc', 'bkl', 'df', 'mel', 'nv', 'vasc']
|
162 |
-
class_names.sort()
|
163 |
-
examples_dir = "sample"
|
164 |
-
|
165 |
-
|
166 |
|
|
|
167 |
|
168 |
transformation_pipeline = transforms.Compose([
|
169 |
transforms.ToPILImage(),
|
@@ -173,43 +167,40 @@ transformation_pipeline = transforms.Compose([
|
|
173 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
174 |
])
|
175 |
|
176 |
-
|
177 |
def preprocess_image(image: np.ndarray):
|
178 |
"""Preprocess the input image.
|
179 |
-
|
180 |
Note that the input image is in RGB mode.
|
181 |
-
|
182 |
Parameters
|
183 |
----------
|
184 |
image: np.ndarray
|
185 |
Input image from callback.
|
186 |
"""
|
187 |
-
|
188 |
image = transformation_pipeline(image)
|
189 |
image = torch.unsqueeze(image, 0)
|
190 |
-
|
191 |
return image
|
192 |
|
193 |
-
|
194 |
def image_classifier(inp):
|
195 |
"""Image Classifier Function.
|
196 |
-
|
197 |
Parameters
|
198 |
----------
|
199 |
inp: Optional[np.ndarray] = None
|
200 |
Input image from callback
|
201 |
-
|
202 |
Returns
|
203 |
-------
|
204 |
Dict
|
205 |
A dictionary class names and its probability
|
206 |
"""
|
207 |
-
|
208 |
-
# If input not valid, return dummy data or raise error
|
209 |
if inp is None:
|
210 |
-
|
211 |
-
|
212 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
213 |
# preprocess
|
214 |
image = preprocess_image(inp)
|
215 |
image = image.to(dtype=torch.float32)
|
@@ -218,21 +209,12 @@ def image_classifier(inp):
|
|
218 |
result = model(image)
|
219 |
|
220 |
# postprocess
|
221 |
-
result = torch.nn.functional.softmax(result, dim=1)
|
222 |
-
result = result[0].detach().numpy().tolist()
|
223 |
-
labeled_result = {name:score for name, score in zip(class_names, result)}
|
224 |
|
225 |
return labeled_result
|
226 |
|
227 |
-
# gradio code block for input and output
|
228 |
-
with gr.Blocks() as app:
|
229 |
-
gr.Markdown("# Skin Cancer Classification")
|
230 |
-
|
231 |
-
with open('index.html', encoding="utf-8") as f:
|
232 |
-
description = f.read()
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
# gradio code block for input and output
|
237 |
with gr.Blocks(theme=gr.themes.Default(primary_hue=page_utils.KALBE_THEME_COLOR, secondary_hue=page_utils.KALBE_THEME_COLOR).set(
|
238 |
button_primary_background_fill="*primary_600",
|
@@ -252,10 +234,10 @@ with gr.Blocks(theme=gr.themes.Default(primary_hue=page_utils.KALBE_THEME_COLOR,
|
|
252 |
out_txt = gr.Label(label="Probabilities", num_top_classes=3)
|
253 |
|
254 |
process_btn.click(image_classifier, inputs=inp_img, outputs=out_txt)
|
255 |
-
clear_btn.click(lambda:(
|
256 |
gr.update(value=None),
|
257 |
gr.update(value=None)
|
258 |
-
|
259 |
inputs=None,
|
260 |
outputs=[inp_img, out_txt])
|
261 |
|
@@ -266,7 +248,7 @@ with gr.Blocks(theme=gr.themes.Default(primary_hue=page_utils.KALBE_THEME_COLOR,
|
|
266 |
os.path.join(examples_dir, "bkl_1.jpeg"),
|
267 |
os.path.join(examples_dir, "akk.jpeg"),
|
268 |
os.path.join(examples_dir, "mel-_3_.jpeg"),
|
269 |
-
|
270 |
inputs=inp_img,
|
271 |
outputs=out_txt,
|
272 |
fn=image_classifier,
|
@@ -275,4 +257,5 @@ with gr.Blocks(theme=gr.themes.Default(primary_hue=page_utils.KALBE_THEME_COLOR,
|
|
275 |
gr.Markdown(line_breaks=True, value='Author: M HAIKAL FEBRIAN P ([email protected]) <div class="row"><a href="https://github.com/HAikalfebrianp96?tab=repositories"><img alt="GitHub" src="https://img.shields.io/badge/haikal%20phona-000000?logo=github"> </div>')
|
276 |
|
277 |
# demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label")
|
278 |
-
app.launch(share=True)
|
|
|
|
10 |
|
11 |
class BasicBlock(nn.Module):
|
12 |
"""ResNet Basic Block.
|
|
|
13 |
Parameters
|
14 |
----------
|
15 |
in_channels : int
|
|
|
62 |
|
63 |
class ResNet18(nn.Module):
|
64 |
"""Construct ResNet-18 Model.
|
|
|
65 |
Parameters
|
66 |
----------
|
67 |
input_channels : int
|
|
|
135 |
|
136 |
checkpoint = torch.load('ham10000.ckpt', map_location=torch.device('cpu'))
|
137 |
|
138 |
+
# The state dict will contain net.layer_name
|
139 |
+
# Our model doesn't contain `net.` so we have to rename it
|
140 |
state_dict = checkpoint['state_dict']
|
141 |
for key in list(state_dict.keys()):
|
142 |
if 'net.' in key:
|
|
|
145 |
|
146 |
model.load_state_dict(state_dict)
|
147 |
model.eval()
|
148 |
+
|
149 |
+
# Updated class names
|
150 |
class_names = {
|
151 |
'akk': 'Actinic Keratosis',
|
152 |
'bcc': 'Basal Cell Carcinoma',
|
|
|
156 |
'nv': 'Melanocytic Nevi',
|
157 |
'vasc': 'Vascular Lesion'
|
158 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
|
160 |
+
examples_dir = "sample"
|
161 |
|
162 |
transformation_pipeline = transforms.Compose([
|
163 |
transforms.ToPILImage(),
|
|
|
167 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
168 |
])
|
169 |
|
|
|
170 |
def preprocess_image(image: np.ndarray):
|
171 |
"""Preprocess the input image.
|
|
|
172 |
Note that the input image is in RGB mode.
|
|
|
173 |
Parameters
|
174 |
----------
|
175 |
image: np.ndarray
|
176 |
Input image from callback.
|
177 |
"""
|
|
|
178 |
image = transformation_pipeline(image)
|
179 |
image = torch.unsqueeze(image, 0)
|
|
|
180 |
return image
|
181 |
|
|
|
182 |
def image_classifier(inp):
|
183 |
"""Image Classifier Function.
|
|
|
184 |
Parameters
|
185 |
----------
|
186 |
inp: Optional[np.ndarray] = None
|
187 |
Input image from callback
|
|
|
188 |
Returns
|
189 |
-------
|
190 |
Dict
|
191 |
A dictionary class names and its probability
|
192 |
"""
|
193 |
+
# If input not valid, return dummy data or raise an error
|
|
|
194 |
if inp is None:
|
195 |
+
return {
|
196 |
+
'Actinic Keratosis': 0.0,
|
197 |
+
'Basal Cell Carcinoma': 0.0,
|
198 |
+
'Benign Keratosis': 0.0,
|
199 |
+
'Dermatofibroma': 0.0,
|
200 |
+
'Melanoma': 0.0,
|
201 |
+
'Melanocytic Nevi': 0.0,
|
202 |
+
'Vascular Lesion': 0.0
|
203 |
+
}
|
204 |
# preprocess
|
205 |
image = preprocess_image(inp)
|
206 |
image = image.to(dtype=torch.float32)
|
|
|
209 |
result = model(image)
|
210 |
|
211 |
# postprocess
|
212 |
+
result = torch.nn.functional.softmax(result, dim=1) # apply softmax
|
213 |
+
result = result[0].detach().numpy().tolist() # take the first batch
|
214 |
+
labeled_result = {class_names[name]: score for name, score in zip(class_names, result)}
|
215 |
|
216 |
return labeled_result
|
217 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
# gradio code block for input and output
|
219 |
with gr.Blocks(theme=gr.themes.Default(primary_hue=page_utils.KALBE_THEME_COLOR, secondary_hue=page_utils.KALBE_THEME_COLOR).set(
|
220 |
button_primary_background_fill="*primary_600",
|
|
|
234 |
out_txt = gr.Label(label="Probabilities", num_top_classes=3)
|
235 |
|
236 |
process_btn.click(image_classifier, inputs=inp_img, outputs=out_txt)
|
237 |
+
clear_btn.click(lambda: (
|
238 |
gr.update(value=None),
|
239 |
gr.update(value=None)
|
240 |
+
),
|
241 |
inputs=None,
|
242 |
outputs=[inp_img, out_txt])
|
243 |
|
|
|
248 |
os.path.join(examples_dir, "bkl_1.jpeg"),
|
249 |
os.path.join(examples_dir, "akk.jpeg"),
|
250 |
os.path.join(examples_dir, "mel-_3_.jpeg"),
|
251 |
+
],
|
252 |
inputs=inp_img,
|
253 |
outputs=out_txt,
|
254 |
fn=image_classifier,
|
|
|
257 |
gr.Markdown(line_breaks=True, value='Author: M HAIKAL FEBRIAN P ([email protected]) <div class="row"><a href="https://github.com/HAikalfebrianp96?tab=repositories"><img alt="GitHub" src="https://img.shields.io/badge/haikal%20phona-000000?logo=github"> </div>')
|
258 |
|
259 |
# demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label")
|
260 |
+
app.launch(share=True)
|
261 |
+
|