File size: 1,873 Bytes
0bfc0a1
0a65b5f
 
 
 
7bc59eb
0bfc0a1
 
0a65b5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0bfc0a1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import gradio as gr
import torch
import numpy as np
from transformers import AutoModel
from transformers import SamModel, SamConfig, SamProcessor
from PIL import Image


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_config = SamConfig.from_pretrained("./checkpoint",local_files_only=True)
processor = SamProcessor.from_pretrained("./checkpoint",local_files_only=True)
model = SamModel.from_pretrained("./checkpoint",local_files_only=True)



def get_bbox(gt_map):
    
    if np.sum(gt_map) == 0:
        return [0, 0, gt_map.shape[1], gt_map.shape[0]]

    y_indices, x_indices = np.where(gt_map > 0)
    x_min, x_max = np.min(x_indices), np.max(x_indices)
    y_min, y_max = np.min(y_indices), np.max(y_indices)

    H, W = gt_map.shape
    x_min = max(0, x_min - np.random.randint(0, 20))
    x_max = min(W, x_max + np.random.randint(0, 20))
    y_min = max(0, y_min - np.random.randint(0, 20))
    y_max = min(H, y_max + np.random.randint(0, 20))

    bbox = [x_min,y_min,x_max,y_max]

    return bbox



def greet(image):
    
    image = Image.open(image)
    image = image.resize((256, 256)) 
    
    gt_mask = np.array(image)
    prompt = get_bbox(gt_mask)

    inputs = processor(images=image, input_boxes=[[prompt]], return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}

    model.eval()
    with torch.no_grad():
        outputs = model(**inputs,multimask_outputs=False)
    
    # outputs = outputs.logits[0].cpu().numpy()
    # outputs = np.argmax(outputs, axis=0)
    # outputs = Image.fromarray(outputs)
    # return outputs
    seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(0))
    seg_prob = seg_prob.cpu().numpy().squeeze()
    seg_prob = (seg_prob > 0.5).astype(np.uint8)

    return seg_prob


iface = gr.Interface(fn= greet, inputs="image", outputs="image", title="Greeter")
iface.launch()