File size: 2,552 Bytes
38c5a71
 
 
 
 
 
 
 
 
 
 
 
 
95e25bf
38c5a71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os
import cv2
import numpy as np
from PIL import Image
import gradio as gr
import json
import matplotlib.pyplot as plt
import subprocess

repo_url = "https://github.com/CASIA-IVA-Lab/FastSAM.git"
target_directory = "./FastSAM"
subprocess.run(['git', 'clone', repo_url, target_directory])
os.chdir('./FastSAM')
print('pwd: ', os.getcwd())

from fastsam import FastSAM, FastSAMPrompt 
import ast
import torch
from PIL import Image
from utils.tools import convert_box_xywh_to_xyxy

def gradio_fn(pil_input_img):
    # load model
    model = FastSAM('./weights/FastSAM.pt')
    args_point_prompt = ast.literal_eval("[[0,0]]")
    args_box_prompt = convert_box_xywh_to_xyxy(ast.literal_eval("[[0,0,0,0]]"))
    args_point_label = ast.literal_eval("[0]")
    args_text_prompt = None
    input = pil_input_img
    input = input.convert("RGB")
    everything_results = model(
        input,
        device="cpu",
        retina_masks=True,
        imgsz=1024,
        conf=0.4,
        iou=0.9    
        )
    bboxes = None
    points = None
    point_label = None
    prompt_process = FastSAMPrompt(input, everything_results, device="cpu")
    if args_box_prompt[0][2] != 0 and args_box_prompt[0][3] != 0:
            ann = prompt_process.box_prompt(bboxes=args_box_prompt)
            bboxes = args_box_prompt
    elif args_text_prompt != None:
        ann = prompt_process.text_prompt(text=args_text_prompt)
    elif args_point_prompt[0] != [0, 0]:
        ann = prompt_process.point_prompt(
            points=args_point_prompt, pointlabel=args_point_label
        )
        points = args_point_prompt
        point_label = args_point_label
    else:
        ann = prompt_process.everything_prompt()
    prompt_process.plot(
        annotations=ann,
        output_path="./output.jpg",
        bboxes = bboxes,
        points = points,
        point_label = point_label,
        withContours=False,
        better_quality=False,
    )
    pil_image_output = Image.open('./output.jpg')
    np_img_array = np.array(pil_image_output)
    return np_img_array

demo = gr.Interface(fn=gradio_fn, 
                    inputs=gr.Image(type="pil"), 
                    outputs="image", 
                    title="FAST-SAM Segment Everything",
                    description="- **FastSAM** model that returns segmented RGB image of given input image. \
                                 - **Credits** : \
                                    - https://huggingface.co/An-619 \
                                    - https://github.com/CASIA-IVA-Lab/FastSAM")