ganteng88 commited on
Commit
ce7c2ce
·
1 Parent(s): b689dd0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -40
app.py CHANGED
@@ -10,7 +10,6 @@ from utils import page_utils
10
 
11
  class BasicBlock(nn.Module):
12
  """ResNet Basic Block.
13
-
14
  Parameters
15
  ----------
16
  in_channels : int
@@ -63,7 +62,6 @@ class BasicBlock(nn.Module):
63
 
64
  class ResNet18(nn.Module):
65
  """Construct ResNet-18 Model.
66
-
67
  Parameters
68
  ----------
69
  input_channels : int
@@ -137,8 +135,8 @@ model = ResNet18(3, 7)
137
 
138
  checkpoint = torch.load('ham10000.ckpt', map_location=torch.device('cpu'))
139
 
140
- # The state dict will contains net.layer_name
141
- # Our model doesn't contains `net.` so we have to rename it
142
  state_dict = checkpoint['state_dict']
143
  for key in list(state_dict.keys()):
144
  if 'net.' in key:
@@ -147,7 +145,8 @@ for key in list(state_dict.keys()):
147
 
148
  model.load_state_dict(state_dict)
149
  model.eval()
150
- """
 
151
  class_names = {
152
  'akk': 'Actinic Keratosis',
153
  'bcc': 'Basal Cell Carcinoma',
@@ -157,13 +156,8 @@ class_names = {
157
  'nv': 'Melanocytic Nevi',
158
  'vasc': 'Vascular Lesion'
159
  }
160
- """
161
- class_names = ['akk', 'bcc', 'bkl', 'df', 'mel', 'nv', 'vasc']
162
- class_names.sort()
163
- examples_dir = "sample"
164
-
165
-
166
 
 
167
 
168
  transformation_pipeline = transforms.Compose([
169
  transforms.ToPILImage(),
@@ -173,43 +167,40 @@ transformation_pipeline = transforms.Compose([
173
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
174
  ])
175
 
176
-
177
  def preprocess_image(image: np.ndarray):
178
  """Preprocess the input image.
179
-
180
  Note that the input image is in RGB mode.
181
-
182
  Parameters
183
  ----------
184
  image: np.ndarray
185
  Input image from callback.
186
  """
187
-
188
  image = transformation_pipeline(image)
189
  image = torch.unsqueeze(image, 0)
190
-
191
  return image
192
 
193
-
194
  def image_classifier(inp):
195
  """Image Classifier Function.
196
-
197
  Parameters
198
  ----------
199
  inp: Optional[np.ndarray] = None
200
  Input image from callback
201
-
202
  Returns
203
  -------
204
  Dict
205
  A dictionary class names and its probability
206
  """
207
-
208
- # If input not valid, return dummy data or raise error
209
  if inp is None:
210
- if inp is None:
211
- return {'cat': 0.3, 'dog': 0.7}
212
- #return {'Actinic Keratosis': 0.0, 'Basal Cell Carcinoma': 0.0, 'Benign Keratosis': 0.0, 'Dermatofibroma': 0.0, 'Melanoma': 0.0, 'Melanocytic Nevi': 0.0, 'Vascular Lesion': 0.0}
 
 
 
 
 
 
213
  # preprocess
214
  image = preprocess_image(inp)
215
  image = image.to(dtype=torch.float32)
@@ -218,21 +209,12 @@ def image_classifier(inp):
218
  result = model(image)
219
 
220
  # postprocess
221
- result = torch.nn.functional.softmax(result, dim=1) # apply softmax
222
- result = result[0].detach().numpy().tolist() # take the first batch
223
- labeled_result = {name:score for name, score in zip(class_names, result)}
224
 
225
  return labeled_result
226
 
227
- # gradio code block for input and output
228
- with gr.Blocks() as app:
229
- gr.Markdown("# Skin Cancer Classification")
230
-
231
- with open('index.html', encoding="utf-8") as f:
232
- description = f.read()
233
-
234
-
235
-
236
  # gradio code block for input and output
237
  with gr.Blocks(theme=gr.themes.Default(primary_hue=page_utils.KALBE_THEME_COLOR, secondary_hue=page_utils.KALBE_THEME_COLOR).set(
238
  button_primary_background_fill="*primary_600",
@@ -252,10 +234,10 @@ with gr.Blocks(theme=gr.themes.Default(primary_hue=page_utils.KALBE_THEME_COLOR,
252
  out_txt = gr.Label(label="Probabilities", num_top_classes=3)
253
 
254
  process_btn.click(image_classifier, inputs=inp_img, outputs=out_txt)
255
- clear_btn.click(lambda:(
256
  gr.update(value=None),
257
  gr.update(value=None)
258
- ),
259
  inputs=None,
260
  outputs=[inp_img, out_txt])
261
 
@@ -266,7 +248,7 @@ with gr.Blocks(theme=gr.themes.Default(primary_hue=page_utils.KALBE_THEME_COLOR,
266
  os.path.join(examples_dir, "bkl_1.jpeg"),
267
  os.path.join(examples_dir, "akk.jpeg"),
268
  os.path.join(examples_dir, "mel-_3_.jpeg"),
269
- ],
270
  inputs=inp_img,
271
  outputs=out_txt,
272
  fn=image_classifier,
@@ -275,4 +257,5 @@ with gr.Blocks(theme=gr.themes.Default(primary_hue=page_utils.KALBE_THEME_COLOR,
275
  gr.Markdown(line_breaks=True, value='Author: M HAIKAL FEBRIAN P ([email protected]) <div class="row"><a href="https://github.com/HAikalfebrianp96?tab=repositories"><img alt="GitHub" src="https://img.shields.io/badge/haikal%20phona-000000?logo=github"> </div>')
276
 
277
  # demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label")
278
- app.launch(share=True)
 
 
10
 
11
  class BasicBlock(nn.Module):
12
  """ResNet Basic Block.
 
13
  Parameters
14
  ----------
15
  in_channels : int
 
62
 
63
  class ResNet18(nn.Module):
64
  """Construct ResNet-18 Model.
 
65
  Parameters
66
  ----------
67
  input_channels : int
 
135
 
136
  checkpoint = torch.load('ham10000.ckpt', map_location=torch.device('cpu'))
137
 
138
+ # The state dict will contain net.layer_name
139
+ # Our model doesn't contain `net.` so we have to rename it
140
  state_dict = checkpoint['state_dict']
141
  for key in list(state_dict.keys()):
142
  if 'net.' in key:
 
145
 
146
  model.load_state_dict(state_dict)
147
  model.eval()
148
+
149
+ # Updated class names
150
  class_names = {
151
  'akk': 'Actinic Keratosis',
152
  'bcc': 'Basal Cell Carcinoma',
 
156
  'nv': 'Melanocytic Nevi',
157
  'vasc': 'Vascular Lesion'
158
  }
 
 
 
 
 
 
159
 
160
+ examples_dir = "sample"
161
 
162
  transformation_pipeline = transforms.Compose([
163
  transforms.ToPILImage(),
 
167
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
168
  ])
169
 
 
170
  def preprocess_image(image: np.ndarray):
171
  """Preprocess the input image.
 
172
  Note that the input image is in RGB mode.
 
173
  Parameters
174
  ----------
175
  image: np.ndarray
176
  Input image from callback.
177
  """
 
178
  image = transformation_pipeline(image)
179
  image = torch.unsqueeze(image, 0)
 
180
  return image
181
 
 
182
  def image_classifier(inp):
183
  """Image Classifier Function.
 
184
  Parameters
185
  ----------
186
  inp: Optional[np.ndarray] = None
187
  Input image from callback
 
188
  Returns
189
  -------
190
  Dict
191
  A dictionary class names and its probability
192
  """
193
+ # If input not valid, return dummy data or raise an error
 
194
  if inp is None:
195
+ return {
196
+ 'Actinic Keratosis': 0.0,
197
+ 'Basal Cell Carcinoma': 0.0,
198
+ 'Benign Keratosis': 0.0,
199
+ 'Dermatofibroma': 0.0,
200
+ 'Melanoma': 0.0,
201
+ 'Melanocytic Nevi': 0.0,
202
+ 'Vascular Lesion': 0.0
203
+ }
204
  # preprocess
205
  image = preprocess_image(inp)
206
  image = image.to(dtype=torch.float32)
 
209
  result = model(image)
210
 
211
  # postprocess
212
+ result = torch.nn.functional.softmax(result, dim=1) # apply softmax
213
+ result = result[0].detach().numpy().tolist() # take the first batch
214
+ labeled_result = {class_names[name]: score for name, score in zip(class_names, result)}
215
 
216
  return labeled_result
217
 
 
 
 
 
 
 
 
 
 
218
  # gradio code block for input and output
219
  with gr.Blocks(theme=gr.themes.Default(primary_hue=page_utils.KALBE_THEME_COLOR, secondary_hue=page_utils.KALBE_THEME_COLOR).set(
220
  button_primary_background_fill="*primary_600",
 
234
  out_txt = gr.Label(label="Probabilities", num_top_classes=3)
235
 
236
  process_btn.click(image_classifier, inputs=inp_img, outputs=out_txt)
237
+ clear_btn.click(lambda: (
238
  gr.update(value=None),
239
  gr.update(value=None)
240
+ ),
241
  inputs=None,
242
  outputs=[inp_img, out_txt])
243
 
 
248
  os.path.join(examples_dir, "bkl_1.jpeg"),
249
  os.path.join(examples_dir, "akk.jpeg"),
250
  os.path.join(examples_dir, "mel-_3_.jpeg"),
251
+ ],
252
  inputs=inp_img,
253
  outputs=out_txt,
254
  fn=image_classifier,
 
257
  gr.Markdown(line_breaks=True, value='Author: M HAIKAL FEBRIAN P ([email protected]) <div class="row"><a href="https://github.com/HAikalfebrianp96?tab=repositories"><img alt="GitHub" src="https://img.shields.io/badge/haikal%20phona-000000?logo=github"> </div>')
258
 
259
  # demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label")
260
+ app.launch(share=True)
261
+