Henry Scheible commited on
Commit
fc1a0c8
·
1 Parent(s): 5afcf8b

add gpu support

Browse files
Files changed (2) hide show
  1. app.py +23 -14
  2. requirements.txt +1 -0
app.py CHANGED
@@ -10,22 +10,27 @@ torch.manual_seed(12345)
10
  random.seed(12345)
11
  np.random.seed(12345)
12
 
 
13
  def get_dataset_x(blank_image, filter_size=50, filter_stride=2):
14
- full_image_tensor = torch.tensor(blank_image).type(torch.FloatTensor).permute(2,0,1).unsqueeze(0)
15
- num_windows_h = math.floor((full_image_tensor.shape[2] - filter_size)/filter_stride) + 1
16
- num_windows_w = math.floor((full_image_tensor.shape[3] - filter_size)/filter_stride) + 1
17
- windows = torch.nn.functional.unfold(full_image_tensor, (filter_size, filter_size), stride=filter_stride).reshape([1, 3, 50, 50, num_windows_h * num_windows_w]).permute([0,4,1,2,3]).squeeze()
 
18
 
19
  dataset_images = [windows[idx] for idx in range(len(windows))]
20
  dataset = list(dataset_images)
21
  return dataset
22
 
 
23
  from torchvision.models.resnet import resnet50
24
  from torchvision.models.resnet import ResNet50_Weights
 
25
  print("Loading resnet...")
26
  model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
27
  hidden_state_size = model.fc.in_features
28
  model.fc = torch.nn.Linear(in_features=hidden_state_size, out_features=2, bias=True)
 
29
 
30
  import gradio as gr
31
 
@@ -33,25 +38,27 @@ import gradio as gr
33
  def count_barnacles(input_img, progress=gr.Progress()):
34
  progress(0, desc="Loading Image")
35
  test_dataset = get_dataset_x(input_img)
36
- test_dataloader = DataLoader(test_dataset, batch_size=256, shuffle=False)
37
  model.eval()
38
  predicted_labels_list = []
39
  for data in progress.tqdm(test_dataloader):
40
  with torch.no_grad():
 
41
  predicted_labels_list += [model(data)]
42
  predicted_labels = torch.cat(predicted_labels_list)
43
  x = int(math.sqrt(predicted_labels.shape[0]))
44
  predicted_labels = predicted_labels.reshape([x, x, 2]).detach()
45
- label_img = predicted_labels[:,:,:1].cpu().numpy()
46
  label_img -= label_img.min()
47
  label_img /= label_img.max()
48
  label_img = (label_img * 255).astype(np.uint8)
49
  mask = np.array(label_img > 180, np.uint8)
50
  contours, hierarchy = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
 
51
  def extract_contour_center(cnt):
52
  M = cv2.moments(cnt)
53
- cx = int(M['m10']/M['m00'])
54
- cy = int(M['m01']/M['m00'])
55
  return cx, cy
56
 
57
  filter_width = 50
@@ -59,8 +66,8 @@ def count_barnacles(input_img, progress=gr.Progress()):
59
 
60
  def rev_window_transform(point):
61
  wx, wy = point
62
- x = int(filter_width/2) + wx*filter_stride
63
- y = int(filter_width/2) + wy*filter_stride
64
  return x, y
65
 
66
  nonempty_contours = filter(lambda cnt: cv2.contourArea(cnt) != 0, contours)
@@ -69,8 +76,10 @@ def count_barnacles(input_img, progress=gr.Progress()):
69
 
70
  blank_img_copy = input_img.copy()
71
  for x, y in points:
72
- blank_img_copy = cv2.circle(blank_img_copy, (x,y), radius=4, color=(255, 0, 0), thickness=-1)
73
- return blank_img_copy
 
74
 
75
- demo = gr.Interface(count_barnacles, gr.Image(shape=(500, 500), type="numpy"), gr.Image(type="numpy"))
76
- demo.queue(concurrency_count=10).launch()
 
 
10
  random.seed(12345)
11
  np.random.seed(12345)
12
 
13
+
14
  def get_dataset_x(blank_image, filter_size=50, filter_stride=2):
15
+ full_image_tensor = torch.tensor(blank_image).type(torch.FloatTensor).permute(2, 0, 1).unsqueeze(0)
16
+ num_windows_h = math.floor((full_image_tensor.shape[2] - filter_size) / filter_stride) + 1
17
+ num_windows_w = math.floor((full_image_tensor.shape[3] - filter_size) / filter_stride) + 1
18
+ windows = torch.nn.functional.unfold(full_image_tensor, (filter_size, filter_size), stride=filter_stride).reshape(
19
+ [1, 3, 50, 50, num_windows_h * num_windows_w]).permute([0, 4, 1, 2, 3]).squeeze()
20
 
21
  dataset_images = [windows[idx] for idx in range(len(windows))]
22
  dataset = list(dataset_images)
23
  return dataset
24
 
25
+
26
  from torchvision.models.resnet import resnet50
27
  from torchvision.models.resnet import ResNet50_Weights
28
+
29
  print("Loading resnet...")
30
  model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
31
  hidden_state_size = model.fc.in_features
32
  model.fc = torch.nn.Linear(in_features=hidden_state_size, out_features=2, bias=True)
33
+ model.to("cuda")
34
 
35
  import gradio as gr
36
 
 
38
  def count_barnacles(input_img, progress=gr.Progress()):
39
  progress(0, desc="Loading Image")
40
  test_dataset = get_dataset_x(input_img)
41
+ test_dataloader = DataLoader(test_dataset, batch_size=1024, shuffle=False)
42
  model.eval()
43
  predicted_labels_list = []
44
  for data in progress.tqdm(test_dataloader):
45
  with torch.no_grad():
46
+ data.to("cuda")
47
  predicted_labels_list += [model(data)]
48
  predicted_labels = torch.cat(predicted_labels_list)
49
  x = int(math.sqrt(predicted_labels.shape[0]))
50
  predicted_labels = predicted_labels.reshape([x, x, 2]).detach()
51
+ label_img = predicted_labels[:, :, :1].cpu().numpy()
52
  label_img -= label_img.min()
53
  label_img /= label_img.max()
54
  label_img = (label_img * 255).astype(np.uint8)
55
  mask = np.array(label_img > 180, np.uint8)
56
  contours, hierarchy = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
57
+
58
  def extract_contour_center(cnt):
59
  M = cv2.moments(cnt)
60
+ cx = int(M['m10'] / M['m00'])
61
+ cy = int(M['m01'] / M['m00'])
62
  return cx, cy
63
 
64
  filter_width = 50
 
66
 
67
  def rev_window_transform(point):
68
  wx, wy = point
69
+ x = int(filter_width / 2) + wx * filter_stride
70
+ y = int(filter_width / 2) + wy * filter_stride
71
  return x, y
72
 
73
  nonempty_contours = filter(lambda cnt: cv2.contourArea(cnt) != 0, contours)
 
76
 
77
  blank_img_copy = input_img.copy()
78
  for x, y in points:
79
+ blank_img_copy = cv2.circle(blank_img_copy, (x, y), radius=4, color=(255, 0, 0), thickness=-1)
80
+ return blank_img_copy, len(list(points))
81
+
82
 
83
+ demo = gr.Interface(count_barnacles, gr.Image(shape=(500, 500), type="numpy"),
84
+ outputs=[gr.Image(type="numpy"), "number"])
85
+ demo.queue(concurrency_count=10).launch()
requirements.txt CHANGED
@@ -1,5 +1,6 @@
1
  opencv-python
2
  numpy
 
3
  torch
4
  torchvision
5
  gradio
 
1
  opencv-python
2
  numpy
3
+ --extra-index-url https://download.pytorch.org/whl/cu113
4
  torch
5
  torchvision
6
  gradio