gracewanggw commited on
Commit
a34b545
·
1 Parent(s): a291c56

everything so far SAM model

Browse files
.gitattributes ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ sam_vit_h_4b8939.pth filter=lfs diff=lfs merge=lfs -text
2
+ model_best_epoch_4_59.62.pth filter=lfs diff=lfs merge=lfs -text
.gitignore CHANGED
@@ -1 +1,3 @@
1
-
 
 
 
1
+ .DS_Store
2
+ venv/
3
+ __pycache__/
README.md CHANGED
@@ -1,3 +1,10 @@
 
 
 
 
 
 
 
1
  # 👩🏾‍💻 Project Starter Template
2
 
3
  [Project Description]
 
1
+ ---
2
+ title: Barnacle Counter
3
+ sdk: gradio
4
+ app_file: app.py
5
+ pinned: false
6
+ ---
7
+
8
  # 👩🏾‍💻 Project Starter Template
9
 
10
  [Project Description]
annotated.png ADDED
app.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import math
4
+ import torch
5
+ import random
6
+ from PIL import Image
7
+
8
+ from torch.utils.data import DataLoader
9
+ from torchvision.transforms import Resize
10
+
11
+ torch.manual_seed(12345)
12
+ random.seed(12345)
13
+ np.random.seed(12345)
14
+
15
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
16
+
17
+ class WireframeExtractor:
18
+
19
+ def __call__(self, image: np.ndarray):
20
+ """
21
+ Extract corners of wireframe from a barnacle image
22
+ :param image: Numpy RGB image of shape (W, H, 3)
23
+ :return [x1, y1, x2, y2]
24
+ """
25
+ h, w = image.shape[:2]
26
+ imghsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
27
+ hsvblur = cv2.GaussianBlur(imghsv, (9, 9), 0)
28
+
29
+ lower = np.array([70, 20, 20])
30
+ upper = np.array([130, 255, 255])
31
+
32
+ color_mask = cv2.inRange(hsvblur, lower, upper)
33
+
34
+ invert = cv2.bitwise_not(color_mask)
35
+
36
+ contours, _ = cv2.findContours(invert, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
37
+
38
+ max_contour = contours[0]
39
+ largest_area = 0
40
+ for index, contour in enumerate(contours):
41
+ area = cv2.contourArea(contour)
42
+ if area > largest_area:
43
+ if cv2.pointPolygonTest(contour, (w / 2, h / 2), False) == 1:
44
+ largest_area = area
45
+ max_contour = contour
46
+
47
+ x, y, w, h = cv2.boundingRect(max_contour)
48
+ # return [x, y, x + w, y + h]
49
+ return x,y,w,h
50
+
51
+ wireframe_extractor = WireframeExtractor()
52
+
53
+ def show_anns(anns):
54
+ if len(anns) == 0:
55
+ return
56
+ sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
57
+ ax = plt.gca()
58
+ ax.set_autoscale_on(False)
59
+ polygons = []
60
+ color = []
61
+ for ann in sorted_anns:
62
+ m = ann['segmentation']
63
+ img = np.ones((m.shape[0], m.shape[1], 3))
64
+ color_mask = np.random.random((1, 3)).tolist()[0]
65
+ for i in range(3):
66
+ img[:,:,i] = color_mask[i]
67
+ ax.imshow(np.dstack((img, m*0.35)))
68
+
69
+
70
+ # def find_contours(img, color):
71
+ # low = color - 10
72
+ # high = color + 10
73
+
74
+ # mask = cv2.inRange(img, low, high)
75
+ # contours, hierarchy = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
76
+
77
+ # print(f"Total Contours: {len(contours)}")
78
+ # nonempty_contours = list()
79
+ # for i in range(len(contours)):
80
+ # if hierarchy[0,i,3] == -1 and cv2.contourArea(contours[i]) > cv2.arcLength(contours[i], True):
81
+ # nonempty_contours += [contours[i]]
82
+ # print(f"Nonempty Contours: {len(nonempty_contours)}")
83
+ # contour_plot = img.copy()
84
+ # contour_plot = cv2.drawContours(contour_plot, nonempty_contours, -1, (0,255,0), -1)
85
+
86
+ # sorted_contours = sorted(nonempty_contours, key=cv2.contourArea, reverse= True)
87
+
88
+ # bounding_rects = [cv2.boundingRect(cnt) for cnt in contours]
89
+
90
+ # for (i,c) in enumerate(sorted_contours):
91
+ # M= cv2.moments(c)
92
+ # cx= int(M['m10']/M['m00'])
93
+ # cy= int(M['m01']/M['m00'])
94
+ # cv2.putText(contour_plot, text= str(i), org=(cx,cy),
95
+ # fontFace= cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.25, color=(255,255,255),
96
+ # thickness=1, lineType=cv2.LINE_AA)
97
+
98
+ # N = len(sorted_contours)
99
+ # H, W, C = img.shape
100
+ # boxes_array_xywh = [cv2.boundingRect(cnt) for cnt in sorted_contours]
101
+ # boxes_array_corners = [[x, y, x+w, y+h] for x, y, w, h in boxes_array_xywh]
102
+ # boxes = torch.tensor(boxes_array_corners)
103
+
104
+ # labels = torch.ones(N)
105
+ # masks = np.zeros([N, H, W])
106
+ # for idx in range(len(sorted_contours)):
107
+ # cnt = sorted_contours[idx]
108
+ # cv2.drawContours(masks[idx,:,:], [cnt], 0, (255), -1)
109
+ # masks = masks / 255.0
110
+ # masks = torch.tensor(masks)
111
+
112
+ # # for box in boxes:
113
+ # # cv2.rectangle(contour_plot, (box[0].item(), box[1].item()), (box[2].item(), box[3].item()), (255,0,0), 2)
114
+
115
+ # return contour_plot, (boxes, masks)
116
+
117
+
118
+ # def get_dataset_x(blank_image, filter_size=50, filter_stride=2):
119
+ # full_image_tensor = torch.tensor(blank_image).type(torch.FloatTensor).permute(2, 0, 1).unsqueeze(0)
120
+ # num_windows_h = math.floor((full_image_tensor.shape[2] - filter_size) / filter_stride) + 1
121
+ # num_windows_w = math.floor((full_image_tensor.shape[3] - filter_size) / filter_stride) + 1
122
+ # windows = torch.nn.functional.unfold(full_image_tensor, (filter_size, filter_size), stride=filter_stride).reshape(
123
+ # [1, 3, 50, 50, num_windows_h * num_windows_w]).permute([0, 4, 1, 2, 3]).squeeze()
124
+
125
+ # dataset_images = [windows[idx] for idx in range(len(windows))]
126
+ # dataset = list(dataset_images)
127
+ # return dataset
128
+
129
+
130
+ # def get_dataset(labeled_image, blank_image, color, filter_size=50, filter_stride=2, label_size=5):
131
+ # contour_plot, (blue_boxes, blue_masks) = find_contours(labeled_image, color)
132
+
133
+ # mask = torch.sum(blue_masks, 0)
134
+
135
+ # label_dim = int((labeled_image.shape[0] - filter_size) / filter_stride + 1)
136
+ # labels = torch.zeros(label_dim, label_dim)
137
+ # mask_labels = torch.zeros(label_dim, label_dim, filter_size, filter_size)
138
+
139
+ # for lx in range(label_dim):
140
+ # for ly in range(label_dim):
141
+ # mask_labels[lx, ly, :, :] = mask[
142
+ # lx * filter_stride: lx * filter_stride + filter_size,
143
+ # ly * filter_stride: ly * filter_stride + filter_size
144
+ # ]
145
+
146
+ # print(labels.shape)
147
+ # for box in blue_boxes:
148
+ # x = int((box[0] + box[2]) / 2)
149
+ # y = int((box[1] + box[3]) / 2)
150
+
151
+ # window_x = int((x - int(filter_size / 2)) / filter_stride)
152
+ # window_y = int((y - int(filter_size / 2)) / filter_stride)
153
+
154
+ # clamp = lambda n, minn, maxn: max(min(maxn, n), minn)
155
+
156
+ # labels[
157
+ # clamp(window_y - label_size, 0, labels.shape[0] - 1):clamp(window_y + label_size, 0, labels.shape[0] - 1),
158
+ # clamp(window_x - label_size, 0, labels.shape[0] - 1):clamp(window_x + label_size, 0, labels.shape[0] - 1),
159
+ # ] = 1
160
+
161
+ # positive_labels = labels.flatten() / labels.max()
162
+ # negative_labels = 1 - positive_labels
163
+ # pos_mask_labels = torch.flatten(mask_labels, end_dim=1)
164
+ # neg_mask_labels = 1 - pos_mask_labels
165
+ # mask_labels = torch.stack([pos_mask_labels, neg_mask_labels], dim=1)
166
+ # dataset_labels = torch.tensor(list(zip(positive_labels, negative_labels)))
167
+ # dataset = list(zip(
168
+ # get_dataset_x(blank_image, filter_size=filter_size, filter_stride=filter_stride),
169
+ # dataset_labels,
170
+ # mask_labels
171
+ # ))
172
+ # return dataset, (labels, mask_labels)
173
+
174
+
175
+ # from torchvision.models.resnet import resnet50
176
+ # from torchvision.models.resnet import ResNet50_Weights
177
+
178
+ # print("Loading resnet...")
179
+ # model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
180
+ # hidden_state_size = model.fc.in_features
181
+ # model.fc = torch.nn.Linear(in_features=hidden_state_size, out_features=2, bias=True)
182
+ # model.to(device)
183
+ # model.load_state_dict(torch.load("model_best_epoch_4_59.62.pth", map_location=torch.device(device)))
184
+ # model.to(device)
185
+ from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
186
+
187
+ model = sam_model_registry["default"](checkpoint="./sam_vit_h_4b8939.pth")
188
+ model.to(device)
189
+
190
+ predictor = SamPredictor(model)
191
+
192
+ mask_generator = SamAutomaticMaskGenerator(model)
193
+
194
+ import gradio as gr
195
+
196
+ import matplotlib.pyplot as plt
197
+ import io
198
+
199
+ def check_circularity(segmentation):
200
+ img_u8 = segmentation.astype(np.uint8)
201
+ im_gauss = cv2.GaussianBlur(img_u8, (5, 5), 0)
202
+ ret, thresh = cv2.threshold(im_gauss, 0, 255, cv2.THRESH_BINARY)
203
+ contours, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
204
+
205
+ con = contours[0]
206
+ perimeter = cv2.arcLength(con, True)
207
+ area = cv2.contourArea(con)
208
+ if perimeter != 0:
209
+ circularity = 4*math.pi*(area/(perimeter*perimeter))
210
+ if 0.8 < circularity < 1.2:
211
+ return True
212
+ else:
213
+ return circularity
214
+
215
+ def count_barnacles(image_raw, split_num, progress=gr.Progress()):
216
+ progress(0, desc="Finding bounding wire")
217
+
218
+ # crop image
219
+ # h, w = raw_input_img.shape[:2]
220
+ # imghsv = cv2.cvtColor(raw_input_img, cv2.COLOR_RGB2HSV)
221
+ # hsvblur = cv2.GaussianBlur(imghsv, (9, 9), 0)
222
+
223
+ # lower = np.array([70, 20, 20])
224
+ # upper = np.array([130, 255, 255])
225
+
226
+ # color_mask = cv2.inRange(hsvblur, lower, upper)
227
+
228
+ # invert = cv2.bitwise_not(color_mask)
229
+
230
+ # contours, _ = cv2.findContours(invert, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
231
+
232
+ # max_contour = contours[0]
233
+ # largest_area = 0
234
+ # for index, contour in enumerate(contours):
235
+ # area = cv2.contourArea(contour)
236
+ # if area > largest_area:
237
+ # if cv2.pointPolygonTest(contour, (w / 2, h / 2), False) == 1:
238
+ # largest_area = area
239
+ # max_contour = contour
240
+
241
+ # x, y, w, h = cv2.boundingRect(max_contour)
242
+
243
+
244
+ # image = cv2.cvtColor(image_raw, cv2.COLOR_BGR2RGB)
245
+ # image = Image.fromarray(image_raw)
246
+ # image = image[:,:,::-1]
247
+ # image = image_raw
248
+ # print(image.shape)
249
+ # print(type(image))
250
+ # print(image.dtype)
251
+ # print(image)
252
+ corners = wireframe_extractor(image_raw)
253
+ print(corners) # (0, 0, 1254, 1152)
254
+
255
+ cropped_image = image_raw[corners[1]:corners[3]+corners[1], corners[0]:corners[2]+corners[0], :]
256
+
257
+ print(cropped_image.shape)
258
+ # cropped_image = cropped_image[100:400, 100:400]
259
+ # print(cropped_image)
260
+
261
+
262
+ # progress(0, desc="Generating Masks by point in window")
263
+
264
+ # # get center point of windows
265
+ # predictor.set_image(image)
266
+ # mask_counter = 0
267
+ # masks = []
268
+
269
+ # for x in range(1,20, 2):
270
+ # for y in range(1,20, 2):
271
+ # point = np.array([[x*25, y*25]])
272
+ # input_label = np.array([1])
273
+ # mask, score, logit = predictor.predict(
274
+ # point_coords=point,
275
+ # point_labels=input_label,
276
+ # multimask_output=False,
277
+ # )
278
+ # if score[0] > 0.8:
279
+ # mask_counter += 1
280
+ # masks.append(mask)
281
+
282
+ # return mask_counter
283
+ split_num = 2
284
+
285
+ x_inc = int(cropped_image.shape[0]/split_num)
286
+ y_inc = int(cropped_image.shape[1]/split_num)
287
+ startx = -x_inc
288
+
289
+ mask_counter = 0
290
+ good_masks = []
291
+ centers = []
292
+
293
+ for r in range(0, split_num):
294
+ startx += x_inc
295
+ starty = -y_inc
296
+ for c in range(0, split_num):
297
+ starty += y_inc
298
+
299
+ small_image = cropped_image[starty:starty+y_inc, startx:startx+x_inc, :]
300
+
301
+ # plt.figure()
302
+ # plt.imshow(small_image)
303
+ # plt.axis('on')
304
+
305
+ masks = mask_generator.generate(small_image)
306
+
307
+
308
+ for mask in masks:
309
+ circular = check_circularity(mask['segmentation'])
310
+ if circular and mask['area']>500 and mask['area'] < 10000:
311
+ mask_counter += 1
312
+ # if cropped_image.shape != image_raw.shape:
313
+ # add_to_row = [False] * corners[0]
314
+ # temp = [False]*(corners[2]+corners[0])
315
+ # temp = [temp]*corners[1]
316
+ # new_seg = np.array(temp)
317
+ # for row in mask['segmentation']:
318
+ # row = np.append(add_to_row, row)
319
+ # new_seg = np.vstack([new_seg, row])
320
+ # mask['segmentation'] = new_seg
321
+ good_masks.append(mask)
322
+ box = mask['bbox']
323
+ centers.append((box[0] + box[2]/2 + corners[0] + startx, box[1] + box[3]/2 + corners[1] + starty))
324
+
325
+
326
+ progress(0, desc="Generating Plot")
327
+ # Create a figure with a size of 10 inches by 10 inches
328
+ fig = plt.figure(figsize=(10, 10))
329
+
330
+ # Display the image using the imshow() function
331
+ # plt.imshow(cropped_image)
332
+ plt.imshow(image_raw)
333
+
334
+ # Call the custom function show_anns() to plot annotations on top of the image
335
+ # show_anns(good_masks)
336
+
337
+ for coord in centers:
338
+ plt.scatter(coord[0], coord[1], marker="x", color="red", s=32)
339
+
340
+ # Turn off the axis
341
+ plt.axis('off')
342
+
343
+ # Get the plot as a numpy array
344
+ # buf = io.BytesIO()
345
+ # plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
346
+ # buf.seek(0)
347
+ # img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8)
348
+ # buf.close()
349
+
350
+ # # Decode the numpy array to an image
351
+ # annotated = cv2.imdecode(img_arr, 1)
352
+ # annotated = cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB)
353
+
354
+ # # Close the figure
355
+ # plt.close(fig)
356
+
357
+
358
+ # return annotated, mask_counter, centers
359
+ return fig, mask_counter, centers
360
+
361
+
362
+ # return len(masks)
363
+
364
+ # progress(0, desc="Resizing Image")
365
+ # cropped_img = raw_input_img[x:x+w, y:y+h]
366
+ # cropped_image_tensor = torch.transpose(torch.tensor(cropped_img).to(device), 0, 2)
367
+ # resize = Resize((1500, 1500))
368
+ # input_img = cropped_image_tensor
369
+ # blank_img_copy = torch.transpose(input_img, 0, 2).to("cpu").detach().numpy().copy()
370
+
371
+ # progress(0, desc="Generating Windows")
372
+ # test_dataset = get_dataset_x(input_img)
373
+ # test_dataloader = DataLoader(test_dataset, batch_size=1024, shuffle=False)
374
+ # model.eval()
375
+ # predicted_labels_list = []
376
+ # for data in progress.tqdm(test_dataloader):
377
+ # with torch.no_grad():
378
+ # data = data.to(device)
379
+ # predicted_labels_list += [model(data)]
380
+ # predicted_labels = torch.cat(predicted_labels_list)
381
+ # x = int(math.sqrt(predicted_labels.shape[0]))
382
+ # predicted_labels = predicted_labels.reshape([x, x, 2]).detach()
383
+ # label_img = predicted_labels[:, :, :1].cpu().numpy()
384
+ # label_img -= label_img.min()
385
+ # label_img /= label_img.max()
386
+ # label_img = (label_img * 255).astype(np.uint8)
387
+ # mask = np.array(label_img > 180, np.uint8)
388
+ # contours, hierarchy = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)\
389
+
390
+ # gt_contours = find_contours(labeled_input_img[x:x+w, y:y+h], cropped_img, np.array([59, 76, 160]))
391
+
392
+
393
+
394
+ # def extract_contour_center(cnt):
395
+ # M = cv2.moments(cnt)
396
+ # cx = int(M['m10'] / M['m00'])
397
+ # cy = int(M['m01'] / M['m00'])
398
+ # return cx, cy
399
+
400
+ # filter_width = 50
401
+ # filter_stride = 2
402
+
403
+ # def rev_window_transform(point):
404
+ # wx, wy = point
405
+ # x = int(filter_width / 2) + wx * filter_stride
406
+ # y = int(filter_width / 2) + wy * filter_stride
407
+ # return x, y
408
+
409
+ # nonempty_contours = filter(lambda cnt: cv2.contourArea(cnt) != 0, contours)
410
+ # windows = map(extract_contour_center, nonempty_contours)
411
+ # points = list(map(rev_window_transform, windows))
412
+ # for x, y in points:
413
+ # blank_img_copy = cv2.circle(blank_img_copy, (x, y), radius=4, color=(255, 0, 0), thickness=-1)
414
+ # print(f"pointlist: {len(points)}")
415
+ # return blank_img_copy, len(points)
416
+
417
+
418
+ demo = gr.Interface(count_barnacles,
419
+ inputs=[
420
+ gr.Image(type="numpy", label="Input Image"),
421
+ ],
422
+ outputs=[
423
+ # gr.Image(type="numpy", label="Annotated Image"),
424
+ gr.Plot(label="Annotated Image"),
425
+ gr.Number(label="Predicted Number of Barnacles"),
426
+ gr.Dataframe(type="array", headers=["x", "y"], label="Mask centers")
427
+ # gr.Number(label="Actual Number of Barnacles"),
428
+ # gr.Number(label="Custom Metric")
429
+ ])
430
+ # examples="examples")
431
+ demo.queue(concurrency_count=10).launch()
examples/new_blank_image.png ADDED
examples/without_crop.png ADDED
examples/without_crop2.png ADDED
flagged/log.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ input_img,output 0,output 1,flag,username,timestamp
2
+ ,,0,,,2023-02-22 15:46:27.797108
model_best_epoch_4_59.62.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f8ff81d32b5d8e4d9776386e6cbbe6baada9ea7ad95584d871bac1fea0a843cd
3
+ size 94371235
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ opencv-python
2
+ numpy
3
+ --extra-index-url https://download.pytorch.org/whl/cu113
4
+ torch
5
+ torchvision
6
+ gradio
7
+ git+https://github.com/facebookresearch/segment-anything.git
sam_vit_h_4b8939.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e
3
+ size 2564550879