sam / run.py
Nguyen Thai Thao Uyen
Update run.py
c1565a6
raw
history blame
1.44 kB
from transformers import SamModel, SamConfig, SamProcessor
import torch
import numpy as np
import matplotlib.pyplot as plt
import app
import os
import json
from PIL import Image
def pred(src):
# -- load model configuration
MODEL_FILE = "sam_model.pth"
model_config = SamConfig.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
model = SamModel(config=model_config)
model.load_state_dict(torch.load(MODEL_FILE))
with open("sam-config.json", "r") as f: # modified config json file
modified_config_dict = json.load(f)
processor = SamProcessor.from_pretrained("facebook/sam-vit-base",
**modified_config_dict)
# -- process image
image = Image.open(src)
rgbim = image.convert("RGB")
new_image = np.array(rgbim)
print("Shape:",new_image.shape)
inputs = processor(new_image, return_tensors="pt")
model.eval()
# forward pass
with torch.no_grad():
outputs = model(pixel_values=inputs["pixel_values"],
multimask_output=False)
# apply sigmoid
pred_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
# convert soft mask to hard mask
PROBABILITY_THRES = 0.30
pred_prob = pred_prob.cpu().numpy().squeeze()
pred_prediction = (pred_prob > PROBABILITY_THRES).astype(np.uint8)
x=1
return x