Mehmet Batuhan Duman commited on
Commit
f746c21
·
1 Parent(s): f40de53

Changed scan func

Browse files
Files changed (2) hide show
  1. .idea/workspace.xml +1 -1
  2. app.py +11 -21
.idea/workspace.xml CHANGED
@@ -65,7 +65,7 @@
65
  <workItem from="1683665300392" duration="7649000" />
66
  <workItem from="1683708398011" duration="1235000" />
67
  <workItem from="1684437905081" duration="110000" />
68
- <workItem from="1686602174110" duration="2679000" />
69
  </task>
70
  <servers />
71
  </component>
 
65
  <workItem from="1683665300392" duration="7649000" />
66
  <workItem from="1683708398011" duration="1235000" />
67
  <workItem from="1684437905081" duration="110000" />
68
+ <workItem from="1686602174110" duration="3352000" />
69
  </task>
70
  <servers />
71
  </component>
app.py CHANGED
@@ -122,8 +122,8 @@ class Net(nn.Module):
122
  model = None
123
  model_path = "models1.pth"
124
 
125
- model2 = None
126
- model2_path = "model4.pth"
127
 
128
  if os.path.exists(model_path):
129
  state_dict = torch.load(model_path, map_location=torch.device('cpu'))
@@ -192,9 +192,7 @@ else:
192
  #
193
  # iface = gr.Interface(fn=gradio_process_image, inputs=inputs, outputs=outputs)
194
  # iface.launch()
195
- def scanmap(satellite_image, model, device, threshold=0.5):
196
- # No need to read the image, you already have it as a NumPy array.
197
- # Just normalize it.
198
  satellite_image = satellite_image.astype(np.float32) / 255.0
199
 
200
  window_size = (80, 80)
@@ -202,8 +200,6 @@ def scanmap(satellite_image, model, device, threshold=0.5):
202
 
203
  height, width, channels = satellite_image.shape
204
 
205
- model.to(device) # ensure model is on correct device
206
-
207
  fig, ax = plt.subplots(1)
208
  ax.imshow(satellite_image)
209
 
@@ -213,15 +209,12 @@ def scanmap(satellite_image, model, device, threshold=0.5):
213
  for x in range(0, width - window_size[0] + 1, stride):
214
  cropped_window = satellite_image[y:y + window_size[1], x:x + window_size[0]]
215
  cropped_window_torch = torch.tensor(cropped_window.transpose(2, 0, 1), dtype=torch.float32).unsqueeze(0)
216
- cropped_window_torch = cropped_window_torch.to(device) # move data to the same device as model
217
 
218
  with torch.no_grad():
219
  probabilities = model(cropped_window_torch)
220
 
221
- # if probability is greater than threshold, draw a bounding box and add to ship_images
222
  if probabilities[0, 1].item() > threshold:
223
- rect = patches.Rectangle((x, y), window_size[0], window_size[1], linewidth=1, edgecolor='r',
224
- facecolor='none')
225
  ax.add_patch(rect)
226
  ship_images.append(cropped_window)
227
 
@@ -230,16 +223,16 @@ def scanmap(satellite_image, model, device, threshold=0.5):
230
  return ship_images
231
 
232
 
233
- def process_image(input_image, model, device, threshold=0.5):
234
  start_time = time.time()
235
- ship_images = scanmap(input_image, model, device, threshold)
236
  elapsed_time = time.time() - start_time
237
 
238
  return ship_images, int(elapsed_time)
239
 
240
 
241
- def gradio_process_image(input_image, model, device, threshold=0.5):
242
- ship_images, elapsed_time = process_image(input_image, model, device, threshold)
243
 
244
  return ship_images, f"Elapsed Time (seconds): {elapsed_time}"
245
 
@@ -249,13 +242,10 @@ outputs = [
249
  gr.Image(label="Detected Ships"),
250
  gr.Textbox(label="Elapsed Time")
251
  ]
252
- model = None # TODO: initialize your model
253
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
254
- # Here I'm using 0.5 as the threshold, but adjust according to your needs
255
- gradio_process_image_partial = partial(gradio_process_image, model=model, device=device, threshold=0.5)
256
 
257
  iface = gr.Interface(fn=gradio_process_image_partial, inputs=inputs, outputs=outputs)
258
  iface.launch()
259
 
260
-
261
-
 
122
  model = None
123
  model_path = "models1.pth"
124
 
125
+ # model2 = None
126
+ # model2_path = "model4.pth"
127
 
128
  if os.path.exists(model_path):
129
  state_dict = torch.load(model_path, map_location=torch.device('cpu'))
 
192
  #
193
  # iface = gr.Interface(fn=gradio_process_image, inputs=inputs, outputs=outputs)
194
  # iface.launch()
195
+ def scanmap(satellite_image, model, threshold=0.5):
 
 
196
  satellite_image = satellite_image.astype(np.float32) / 255.0
197
 
198
  window_size = (80, 80)
 
200
 
201
  height, width, channels = satellite_image.shape
202
 
 
 
203
  fig, ax = plt.subplots(1)
204
  ax.imshow(satellite_image)
205
 
 
209
  for x in range(0, width - window_size[0] + 1, stride):
210
  cropped_window = satellite_image[y:y + window_size[1], x:x + window_size[0]]
211
  cropped_window_torch = torch.tensor(cropped_window.transpose(2, 0, 1), dtype=torch.float32).unsqueeze(0)
 
212
 
213
  with torch.no_grad():
214
  probabilities = model(cropped_window_torch)
215
 
 
216
  if probabilities[0, 1].item() > threshold:
217
+ rect = patches.Rectangle((x, y), window_size[0], window_size[1], linewidth=1, edgecolor='r', facecolor='none')
 
218
  ax.add_patch(rect)
219
  ship_images.append(cropped_window)
220
 
 
223
  return ship_images
224
 
225
 
226
+ def process_image(input_image, model, threshold=0.5):
227
  start_time = time.time()
228
+ ship_images = scanmap(input_image, model, threshold)
229
  elapsed_time = time.time() - start_time
230
 
231
  return ship_images, int(elapsed_time)
232
 
233
 
234
+ def gradio_process_image(input_image, model, threshold=0.5):
235
+ ship_images, elapsed_time = process_image(input_image, model, threshold)
236
 
237
  return ship_images, f"Elapsed Time (seconds): {elapsed_time}"
238
 
 
242
  gr.Image(label="Detected Ships"),
243
  gr.Textbox(label="Elapsed Time")
244
  ]
245
+
246
+ # Use 0.5 as the threshold, but adjust according to your needs
247
+ gradio_process_image_partial = partial(gradio_process_image, model=model, threshold=0.5)
 
248
 
249
  iface = gr.Interface(fn=gradio_process_image_partial, inputs=inputs, outputs=outputs)
250
  iface.launch()
251