Mehmet Batuhan Duman commited on
Commit
2b06c57
·
1 Parent(s): e9be013

Changed scan func

Browse files
Files changed (2) hide show
  1. .idea/workspace.xml +14 -14
  2. app.py +92 -31
.idea/workspace.xml CHANGED
@@ -34,20 +34,20 @@
34
  <option name="hideEmptyMiddlePackages" value="true" />
35
  <option name="showLibraryContents" value="true" />
36
  </component>
37
- <component name="PropertiesComponent">{
38
- &quot;keyToString&quot;: {
39
- &quot;RunOnceActivity.ShowReadmeOnStart&quot;: &quot;true&quot;,
40
- &quot;WebServerToolWindowFactoryState&quot;: &quot;false&quot;,
41
- &quot;git-widget-placeholder&quot;: &quot;main&quot;,
42
- &quot;last_opened_file_path&quot;: &quot;D:/fastApiProjects/pythonProject1/shipnet&quot;,
43
- &quot;node.js.detected.package.eslint&quot;: &quot;true&quot;,
44
- &quot;node.js.detected.package.tslint&quot;: &quot;true&quot;,
45
- &quot;node.js.selected.package.eslint&quot;: &quot;(autodetect)&quot;,
46
- &quot;node.js.selected.package.tslint&quot;: &quot;(autodetect)&quot;,
47
- &quot;settings.editor.selected.configurable&quot;: &quot;shared-indexes&quot;,
48
- &quot;vue.rearranger.settings.migration&quot;: &quot;true&quot;
49
  }
50
- }</component>
51
  <component name="RecentsManager">
52
  <key name="CopyFile.RECENT_KEYS">
53
  <recent name="D:\fastApiProjects\pythonProject1" />
@@ -65,7 +65,7 @@
65
  <workItem from="1683665300392" duration="7649000" />
66
  <workItem from="1683708398011" duration="1235000" />
67
  <workItem from="1684437905081" duration="110000" />
68
- <workItem from="1686602174110" duration="1036000" />
69
  </task>
70
  <servers />
71
  </component>
 
34
  <option name="hideEmptyMiddlePackages" value="true" />
35
  <option name="showLibraryContents" value="true" />
36
  </component>
37
+ <component name="PropertiesComponent"><![CDATA[{
38
+ "keyToString": {
39
+ "RunOnceActivity.ShowReadmeOnStart": "true",
40
+ "WebServerToolWindowFactoryState": "false",
41
+ "git-widget-placeholder": "main",
42
+ "last_opened_file_path": "D:/fastApiProjects/pythonProject1",
43
+ "node.js.detected.package.eslint": "true",
44
+ "node.js.detected.package.tslint": "true",
45
+ "node.js.selected.package.eslint": "(autodetect)",
46
+ "node.js.selected.package.tslint": "(autodetect)",
47
+ "settings.editor.selected.configurable": "shared-indexes",
48
+ "vue.rearranger.settings.migration": "true"
49
  }
50
+ }]]></component>
51
  <component name="RecentsManager">
52
  <key name="CopyFile.RECENT_KEYS">
53
  <recent name="D:\fastApiProjects\pythonProject1" />
 
65
  <workItem from="1683665300392" duration="7649000" />
66
  <workItem from="1683708398011" duration="1235000" />
67
  <workItem from="1684437905081" duration="110000" />
68
+ <workItem from="1686602174110" duration="2045000" />
69
  </task>
70
  <servers />
71
  </component>
app.py CHANGED
@@ -2,7 +2,6 @@ import cv2
2
  import numpy as np
3
  import gradio as gr
4
  from PIL import Image, ImageOps
5
- import matplotlib.pyplot as plt
6
  import torch
7
  import torch.nn as nn
8
  import torch.nn.functional as F
@@ -11,7 +10,10 @@ import os
11
  import time
12
  import io
13
  import base64
14
-
 
 
 
15
 
16
  class Net2(nn.Module):
17
  def __init__(self):
@@ -137,53 +139,112 @@ else:
137
  print("Model file not found at", model_path)
138
 
139
 
140
- def process_image(input_image):
141
- image = Image.fromarray(input_image).convert("RGB")
142
-
143
- start_time = time.time()
144
- heatmap = scanmap(np.array(image), model)
145
- elapsed_time = time.time() - start_time
146
- heatmap_img = Image.fromarray(np.uint8(plt.cm.hot(heatmap) * 255)).convert('RGB')
147
-
148
- heatmap_img = heatmap_img.resize(image.size)
149
-
150
- return image, heatmap_img, int(elapsed_time)
151
-
152
-
153
- def scanmap(image_np, model):
154
- image_np = image_np.astype(np.float32) / 255.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
  window_size = (80, 80)
157
  stride = 10
158
 
159
- height, width, channels = image_np.shape
 
 
 
 
 
160
 
161
- probabilities_map = []
162
 
163
  for y in range(0, height - window_size[1] + 1, stride):
164
- row_probabilities = []
165
  for x in range(0, width - window_size[0] + 1, stride):
166
- cropped_window = image_np[y:y + window_size[1], x:x + window_size[0]]
167
- cropped_window_torch = transforms.ToTensor()(cropped_window).unsqueeze(0)
 
168
 
169
  with torch.no_grad():
170
  probabilities = model(cropped_window_torch)
171
 
172
- row_probabilities.append(probabilities[0, 1].item())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
- probabilities_map.append(row_probabilities)
 
175
 
176
- probabilities_map = np.array(probabilities_map)
177
- return probabilities_map
178
 
179
- def gradio_process_image(input_image):
180
- original, heatmap, elapsed_time = process_image(input_image)
181
- return original, heatmap, f"Elapsed Time (seconds): {elapsed_time}"
182
 
183
  inputs = gr.Image(label="Upload Image")
184
  outputs = [
185
- gr.Image(label="Original Image"),
186
- gr.Image(label="Heatmap"),
187
  gr.Textbox(label="Elapsed Time")
188
  ]
189
 
 
2
  import numpy as np
3
  import gradio as gr
4
  from PIL import Image, ImageOps
 
5
  import torch
6
  import torch.nn as nn
7
  import torch.nn.functional as F
 
10
  import time
11
  import io
12
  import base64
13
+ import torch
14
+ import cv2
15
+ import matplotlib.pyplot as plt
16
+ import matplotlib.patches as patches
17
 
18
  class Net2(nn.Module):
19
  def __init__(self):
 
139
  print("Model file not found at", model_path)
140
 
141
 
142
+ # def process_image(input_image):
143
+ # image = Image.fromarray(input_image).convert("RGB")
144
+ #
145
+ # start_time = time.time()
146
+ # heatmap = scanmap(np.array(image), model)
147
+ # elapsed_time = time.time() - start_time
148
+ # heatmap_img = Image.fromarray(np.uint8(plt.cm.hot(heatmap) * 255)).convert('RGB')
149
+ #
150
+ # heatmap_img = heatmap_img.resize(image.size)
151
+ #
152
+ # return image, heatmap_img, int(elapsed_time)
153
+ #
154
+ #
155
+ # def scanmap(image_np, model):
156
+ # image_np = image_np.astype(np.float32) / 255.0
157
+ #
158
+ # window_size = (80, 80)
159
+ # stride = 10
160
+ #
161
+ # height, width, channels = image_np.shape
162
+ #
163
+ # probabilities_map = []
164
+ #
165
+ # for y in range(0, height - window_size[1] + 1, stride):
166
+ # row_probabilities = []
167
+ # for x in range(0, width - window_size[0] + 1, stride):
168
+ # cropped_window = image_np[y:y + window_size[1], x:x + window_size[0]]
169
+ # cropped_window_torch = transforms.ToTensor()(cropped_window).unsqueeze(0)
170
+ #
171
+ # with torch.no_grad():
172
+ # probabilities = model(cropped_window_torch)
173
+ #
174
+ # row_probabilities.append(probabilities[0, 1].item())
175
+ #
176
+ # probabilities_map.append(row_probabilities)
177
+ #
178
+ # probabilities_map = np.array(probabilities_map)
179
+ # return probabilities_map
180
+ #
181
+ # def gradio_process_image(input_image):
182
+ # original, heatmap, elapsed_time = process_image(input_image)
183
+ # return original, heatmap, f"Elapsed Time (seconds): {elapsed_time}"
184
+ #
185
+ # inputs = gr.Image(label="Upload Image")
186
+ # outputs = [
187
+ # gr.Image(label="Original Image"),
188
+ # gr.Image(label="Heatmap"),
189
+ # gr.Textbox(label="Elapsed Time")
190
+ # ]
191
+ #
192
+ # iface = gr.Interface(fn=gradio_process_image, inputs=inputs, outputs=outputs)
193
+ # iface.launch()
194
+ def scanmap(image_path, model, device, threshold=0.5):
195
+ satellite_image = cv2.imread(image_path)
196
+ satellite_image = satellite_image.astype(np.float32) / 255.0
197
 
198
  window_size = (80, 80)
199
  stride = 10
200
 
201
+ height, width, channels = satellite_image.shape
202
+
203
+ model.to(device) # ensure model is on correct device
204
+
205
+ fig, ax = plt.subplots(1)
206
+ ax.imshow(satellite_image)
207
 
208
+ ship_images = []
209
 
210
  for y in range(0, height - window_size[1] + 1, stride):
 
211
  for x in range(0, width - window_size[0] + 1, stride):
212
+ cropped_window = satellite_image[y:y + window_size[1], x:x + window_size[0]]
213
+ cropped_window_torch = torch.tensor(cropped_window.transpose(2, 0, 1), dtype=torch.float32).unsqueeze(0)
214
+ cropped_window_torch = cropped_window_torch.to(device) # move data to the same device as model
215
 
216
  with torch.no_grad():
217
  probabilities = model(cropped_window_torch)
218
 
219
+ # if probability is greater than threshold, draw a bounding box and add to ship_images
220
+ if probabilities[0, 1].item() > threshold:
221
+ rect = patches.Rectangle((x, y), window_size[0], window_size[1], linewidth=1, edgecolor='r',
222
+ facecolor='none')
223
+ ax.add_patch(rect)
224
+ ship_images.append(cropped_window)
225
+
226
+ plt.show()
227
+
228
+ return ship_images
229
+
230
+
231
+ def process_image(input_image, model, device, threshold=0.5):
232
+ start_time = time.time()
233
+ ship_images = scanmap(input_image, model, device, threshold)
234
+ elapsed_time = time.time() - start_time
235
+
236
+ return ship_images, int(elapsed_time)
237
+
238
 
239
+ def gradio_process_image(input_image, model, device, threshold=0.5):
240
+ ship_images, elapsed_time = process_image(input_image, model, device, threshold)
241
 
242
+ return ship_images, f"Elapsed Time (seconds): {elapsed_time}"
 
243
 
 
 
 
244
 
245
  inputs = gr.Image(label="Upload Image")
246
  outputs = [
247
+ gr.Image(label="Detected Ships"),
 
248
  gr.Textbox(label="Elapsed Time")
249
  ]
250