File size: 1,438 Bytes
586d4f8
 
 
 
 
 
c1565a6
 
586d4f8
 
 
c1565a6
 
 
 
 
586d4f8
c1565a6
 
 
 
 
 
 
 
 
 
 
 
 
586d4f8
 
c1565a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
586d4f8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from transformers import SamModel, SamConfig, SamProcessor
import torch
import numpy as np
import matplotlib.pyplot as plt
import app
import os
import json
from PIL import Image

def pred(src):

    # -- load model configuration
    MODEL_FILE = "sam_model.pth"
    model_config = SamConfig.from_pretrained("facebook/sam-vit-base")
    processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

    model = SamModel(config=model_config)
    model.load_state_dict(torch.load(MODEL_FILE))

    with open("sam-config.json", "r") as f: # modified config json file
        modified_config_dict = json.load(f)

    processor = SamProcessor.from_pretrained("facebook/sam-vit-base", 
                                            **modified_config_dict)
    
    # -- process image
    image = Image.open(src)
    rgbim = image.convert("RGB")
    new_image = np.array(rgbim)
    print("Shape:",new_image.shape)

    inputs = processor(new_image, return_tensors="pt")
    model.eval()

    # forward pass
    with torch.no_grad():
        outputs = model(pixel_values=inputs["pixel_values"], 
                        multimask_output=False)
    
    # apply sigmoid
    pred_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))

    # convert soft mask to hard mask
    PROBABILITY_THRES = 0.30
    pred_prob = pred_prob.cpu().numpy().squeeze()
    pred_prediction = (pred_prob > PROBABILITY_THRES).astype(np.uint8)
    
    x=1
    return x