File size: 2,476 Bytes
f6ed1c1
 
 
 
 
 
 
 
 
e3fa94a
 
 
 
 
 
 
f6ed1c1
 
 
e3fa94a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6ed1c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# IMPORTS
import torch
import numpy as np
from PIL import Image
from lang_sam import LangSAM
import gradio as gr


def run_lang_sam(input_image, text_prompt, model):
    height = width = 256
    image = input_image.convert("RGB").resize((height, width))

    # Get the mask using the model
    masks, _, _, _ = model.predict(image, text_prompt)

    # Convert masks to integer format and find the maximum mask
    masks_int = masks.to(torch.uint8)
    masks_max, _ = masks_int.max(dim=0, keepdim=True)
    unified_mask = masks_max.squeeze(0).to(torch.bool)

    # Create a colored layer for the mask (choose your color in RGB format)
    color = (255, 0, 0)  # Red color, for example
    colored_mask = np.zeros((256, 256, 3), dtype=np.uint8)
    colored_mask[unified_mask] = color  # Apply the color to the mask area

    # Convert the colored mask to PIL for blending
    colored_mask_pil = Image.fromarray(colored_mask)

    # Blend the colored mask with the original image
    # You can adjust the alpha to change the transparency of the colored mask
    alpha = 0.5  # Transparency factor (between 0 and 1)
    blended_image = Image.blend(image, colored_mask_pil, alpha=alpha)

    return blended_image


def setup_gradio_interface(model):
    block = gr.Blocks()

    with block:
        gr.Markdown("<h1><center>Lang SAM<h1><center>")

        with gr.Row():
            with gr.Column():
                input_image = gr.Image(type="pil", label="Input Image")
                text_prompt = gr.Textbox(label="Enter what you want to segment")
                run_button = gr.Button(value="Run")

            with gr.Column():
                output_mask = gr.Image(type="numpy", label="Segmentation Mask")

        run_button.click(
            fn=lambda image, prompt: run_lang_sam(
                image, prompt, model,
            ),
            inputs=[input_image, text_prompt],
            outputs=[output_mask],
        )

        gr.Examples(
            examples=[["bw-image.jpeg", "road"]],
            inputs=[input_image, text_prompt],
            outputs=[output_mask],
            fn=lambda image, prompt: run_lang_sam(
                image, prompt, model,
            ),
            cache_examples=True,
            label="Try this example input!",
        )

    return block


if __name__ == "__main__":
    model = LangSAM()
    gradio_interface = setup_gradio_interface(model)
    gradio_interface.launch(share=False, show_api=False, show_error=True)