sam / run.py
Nguyen Thai Thao Uyen
Success!
8e17922
from transformers import SamModel, SamConfig, SamProcessor
import torch
import numpy as np
import matplotlib.pyplot as plt
import app
import os
import json
from PIL import Image
def pred(src):
# -- cache
cache_dir = "/code/cache"
# -- load model configuration
MODEL_FILE = "sam_model.pth"
model_config = SamConfig.from_pretrained("facebook/sam-vit-base", cache_dir=cache_dir)
model = SamModel(config=model_config)
model.load_state_dict(torch.load(MODEL_FILE, map_location=torch.device('cpu')))
with open("sam-config.json", "r") as f: # modified config json file
modified_config_dict = json.load(f)
processor = SamProcessor.from_pretrained("facebook/sam-vit-base",
**modified_config_dict,
cache_dir=cache_dir)
# -- process image
image = Image.open(src)
rgbim = image.convert("RGB")
new_image = np.array(rgbim)
print()
print("image shape:",new_image.shape)
inputs = processor(new_image, return_tensors="pt")
model.eval()
# forward pass
print("predicting...")
with torch.no_grad():
outputs = model(pixel_values=inputs["pixel_values"],
multimask_output=False)
# apply sigmoid
print("apply sigmoid...")
pred_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
# convert soft mask to hard mask
PROBABILITY_THRES = 0.30
pred_prob = pred_prob.cpu().numpy().squeeze()
pred_prediction = (pred_prob > PROBABILITY_THRES).astype(np.uint8)
return pred_prob, pred_prediction