Bono93 commited on
Commit
19d592d
·
1 Parent(s): 354c411

feat: sam inference example

Browse files
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *.npy
2
+ mask.png
Makefile ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PYTHON=3.9
2
+ BASENAME=$(shell basename $(CURDIR))
3
+ CURRENT_DIR = $(shell pwd)
4
+
5
+ env:
6
+ conda create -n $(BASENAME) -y python=$(PYTHON)
7
+
8
+ setup:
9
+ pip install -r requirements.txt
10
+ pip install git+https://github.com/facebookresearch/segment-anything.git
11
+
12
+ load-model:
13
+ mkdir -p models
14
+ curl -O https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth && mv sam_vit_h_4b8939.pth models/sam_vit_h_4b8939.pth
README.md CHANGED
@@ -10,4 +10,18 @@ pinned: false
10
  license: apache-2.0
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  license: apache-2.0
11
  ---
12
 
13
+ ## Before you started
14
+ - requirements: Conda
15
+ ```
16
+ make env
17
+ conda activate sam-inference
18
+ make setup
19
+ make load-model
20
+
21
+ ```
22
+
23
+ ## Example inference script
24
+ ```
25
+
26
+ python example.py --image samples/bears.jpg
27
+ ```
models/sam_vit_h_4b8939.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e
3
+ size 2564550879
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ opencv-python
2
+ matplotlib
3
+ gradio
4
+ torch
5
+ torchvision
samples/bears.jpg ADDED
scripts/example.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import cv2
3
+ import numpy as np
4
+ from segment_anything import SamPredictor, sam_model_registry
5
+
6
+ # Argument parser
7
+ parser = argparse.ArgumentParser()
8
+ parser.add_argument("-i", "--image", required=True, help="Path to the image")
9
+ args = parser.parse_args()
10
+
11
+ # Set hyperparameters
12
+ sam_checkpoint = "./models/sam_vit_h_4b8939.pth"
13
+ model_type = "vit_h"
14
+ device = "cpu"
15
+
16
+ # Load model
17
+ sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
18
+ sam.to(device=device)
19
+ predictor = SamPredictor(sam)
20
+
21
+ # Preprocessing the image
22
+ image = cv2.imread(args.image)
23
+ predictor.set_image(image)
24
+
25
+ # SAM Encoder for embedding
26
+ embedding = predictor.get_image_embedding()
27
+ np.save("models/embedding.npy", embedding)
28
+
29
+
30
+ # SAM Decoder for segmentation
31
+ input_point = np.array([[1300, 950]])
32
+ input_label = np.array([1])
33
+ mask, score, logit = predictor.predict(
34
+ point_coords=input_point,
35
+ point_labels=input_label,
36
+ multimask_output=False,
37
+ )
38
+
39
+ # Save output
40
+ h, w = mask.shape[-2:]
41
+ mask = mask.reshape(h, w, 1)
42
+
43
+ ## Mask has a 255 or 0 value
44
+ mask = (mask * 255).astype(np.uint8)
45
+
46
+ ## Save mask image
47
+ cv2.imwrite("mask.png", mask[:, :])