File size: 1,626 Bytes
586d4f8
 
 
 
 
 
c1565a6
 
586d4f8
 
2485872
 
586d4f8
c1565a6
 
81a19fa
c1565a6
586d4f8
a541988
c1565a6
 
 
 
 
c3f5b96
 
c1565a6
 
 
 
 
8e17922
 
586d4f8
 
c1565a6
 
 
8e17922
c1565a6
 
 
 
 
8e17922
c1565a6
 
 
 
 
 
 
fbc5057
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from transformers import SamModel, SamConfig, SamProcessor
import torch
import numpy as np
import matplotlib.pyplot as plt
import app
import os
import json
from PIL import Image

def pred(src):
    # -- cache
    cache_dir = "/code/cache"

    # -- load model configuration
    MODEL_FILE = "sam_model.pth"
    model_config = SamConfig.from_pretrained("facebook/sam-vit-base", cache_dir=cache_dir)

    model = SamModel(config=model_config)
    model.load_state_dict(torch.load(MODEL_FILE, map_location=torch.device('cpu')))

    with open("sam-config.json", "r") as f: # modified config json file
        modified_config_dict = json.load(f)

    processor = SamProcessor.from_pretrained("facebook/sam-vit-base", 
                                            **modified_config_dict, 
                                            cache_dir=cache_dir)
    
    # -- process image
    image = Image.open(src)
    rgbim = image.convert("RGB")
    new_image = np.array(rgbim)
    print()
    print("image shape:",new_image.shape)

    inputs = processor(new_image, return_tensors="pt")
    model.eval()

    # forward pass
    print("predicting...")
    with torch.no_grad():
        outputs = model(pixel_values=inputs["pixel_values"], 
                        multimask_output=False)
    
    # apply sigmoid
    print("apply sigmoid...")
    pred_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))

    # convert soft mask to hard mask
    PROBABILITY_THRES = 0.30
    pred_prob = pred_prob.cpu().numpy().squeeze()
    pred_prediction = (pred_prob > PROBABILITY_THRES).astype(np.uint8)
    
    return pred_prob, pred_prediction