File size: 2,853 Bytes
acb3eab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import os
import PIL
from functools import lru_cache

from random import randint
import gradio as gr
import cv2
import torch
import numpy as np
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
from typing import List

CHECKPOINT_PATH = "sam_vit_h_4b8939.pth"
MODEL_TYPE = "default"
MAX_WIDTH = MAX_HEIGHT = 800
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


@lru_cache
def load_mask_generator(model_size: str = "large") -> SamAutomaticMaskGenerator:
    sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device)
    mask_generator = SamAutomaticMaskGenerator(sam)
    return mask_generator


def adjust_image_size(image: np.ndarray) -> np.ndarray:
    height, width = image.shape[:2]
    if height > width:
        if height > MAX_HEIGHT:
            height, width = MAX_HEIGHT, int(MAX_HEIGHT / height * width)
    else:
        if width > MAX_WIDTH:
            height, width = int(MAX_WIDTH / width * height), MAX_WIDTH
    image = cv2.resize(image, (width, height))
    print(image.shape)
    return image


def draw_masks(
    image: np.ndarray, masks: List[np.ndarray], alpha: float = 0.7
) -> np.ndarray:
    for mask in masks:
        color = [randint(127, 255) for _ in range(3)]
        segmentation = mask["segmentation"]

        # draw mask overlay
        colored_seg = np.expand_dims(segmentation, 0).repeat(3, axis=0)
        colored_seg = np.moveaxis(colored_seg, 0, -1)
        masked = np.ma.MaskedArray(image, mask=colored_seg, fill_value=color)
        image_overlay = masked.filled()
        image = cv2.addWeighted(image, 1 - alpha, image_overlay, alpha, 0)

        # draw contour
        contours, _ = cv2.findContours(
            np.uint8(segmentation), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
        )
        cv2.drawContours(image, contours, -1, (255, 0, 0), 2)
    return image


def segment(image_path: str, query: str) -> PIL.ImageFile.ImageFile:
    mask_generator = load_mask_generator()
    # reduce the size to save gpu memory
    image = adjust_image_size(cv2.imread(image_path))
    masks = mask_generator.generate(image)
    image = draw_masks(image, masks)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = PIL.Image.fromarray(np.uint8(image)).convert("RGB")
    return image


demo = gr.Interface(
    fn=segment,
    inputs=[gr.Image(type="filepath"), "text"],
    outputs="image",
    allow_flagging="never",
    title="Segment Anything with CLIP",
    examples=[
        [os.path.join(os.path.dirname(__file__), "examples/dog.jpg"), ""],
        [os.path.join(os.path.dirname(__file__), "examples/city.jpg"), ""],
        [os.path.join(os.path.dirname(__file__), "examples/food.jpg"), ""],
        [os.path.join(os.path.dirname(__file__), "examples/horse.jpg"), ""],
    ],
)

if __name__ == "__main__":
    demo.launch()