File size: 2,090 Bytes
429f61d
 
 
 
 
 
9a86e9d
849c851
c24777f
975ae87
429f61d
c24777f
 
429f61d
9a86e9d
 
c24777f
 
 
 
 
9a86e9d
c24777f
9a86e9d
 
c24777f
9a86e9d
c24777f
 
429f61d
c24777f
 
 
 
9a86e9d
 
c24777f
 
 
9a86e9d
429f61d
9a86e9d
 
c24777f
429f61d
9a86e9d
c24777f
 
9a86e9d
 
 
429f61d
c24777f
429f61d
9a86e9d
c47cea1
 
429f61d
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import cv2
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
import matplotlib.pyplot as plt
import gradio as gr

# import segmentation_models_pytorch as smp

##set the device to cuda for sam model
# device = torch.device('cuda')


# image= cv2.imread('image_4.png', cv2.IMREAD_COLOR)
def get_masks( image, model_type):
    print(image)
    # image_pil = Image.fromarray(image.astype('uint8'), 'RGB')
    # print(image_pil)
    if model_type == 'vit_h':
        sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
    if model_type == 'vit_b':
        sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
        
    if model_type == 'vit_l':
        sam = sam_model_registry["vit_l"](checkpoint="sam_vit_l_0b3195.pth")
    else:
        sam=  sam_model_registry["vit_l"](checkpoint="sam_vit_l_0b3195.pth")
    
    # print(image.shape)
    #set the device to cuda for sam model
    # sam.to(device= device)

    mask_generator = SamAutomaticMaskGenerator(sam)
    masks = mask_generator.generate(image)
    composite_image = np.zeros_like(image)
    colors = plt.cm.jet(np.linspace(0, 1, len(masks)))  # Generate distinct colors

    for i, mask_data in enumerate(masks):
        mask = mask_data['segmentation']
        color = colors[i]
        composite_image[mask] = (color[:3] * 255).astype(np.uint8)  # Apply color to mask
    print(composite_image.shape, image.shape)
    
    # Combine original image with the composite mask image
    overlayed_image = (composite_image * 0.5 + torch.from_numpy(image).resize(738, 1200, 3).cpu().numpy() * 0.5).astype(np.uint8)

    return overlayed_image




iface = gr.Interface(
    fn=get_masks,
    inputs=["image", gr.components.Dropdown(choices=['vit_h', 'vit_b', 'vit_l'], label="Model Type")],
    outputs="image",
    title="SAM Model Segmentation and Classification",
    description="Upload an image, select a model type, and receive the segmented and classified parts."
)

iface.launch()