Spaces:
Build error
Build error
from transformers import SamModel, SamConfig, SamProcessor | |
import torch | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import app | |
import os | |
import json | |
from PIL import Image | |
def pred(src): | |
# -- cache | |
cache_dir = "/code/cache" | |
# -- load model configuration | |
MODEL_FILE = "sam_model.pth" | |
model_config = SamConfig.from_pretrained("facebook/sam-vit-base", cache_dir=cache_dir) | |
model = SamModel(config=model_config) | |
model.load_state_dict(torch.load(MODEL_FILE, map_location=torch.device('cpu'))) | |
with open("sam-config.json", "r") as f: # modified config json file | |
modified_config_dict = json.load(f) | |
processor = SamProcessor.from_pretrained("facebook/sam-vit-base", | |
**modified_config_dict, | |
cache_dir=cache_dir) | |
# -- process image | |
image = Image.open(src) | |
rgbim = image.convert("RGB") | |
new_image = np.array(rgbim) | |
print() | |
print("image shape:",new_image.shape) | |
inputs = processor(new_image, return_tensors="pt") | |
model.eval() | |
# forward pass | |
print("predicting...") | |
with torch.no_grad(): | |
outputs = model(pixel_values=inputs["pixel_values"], | |
multimask_output=False) | |
# apply sigmoid | |
print("apply sigmoid...") | |
pred_prob = torch.sigmoid(outputs.pred_masks.squeeze(1)) | |
# convert soft mask to hard mask | |
PROBABILITY_THRES = 0.30 | |
pred_prob = pred_prob.cpu().numpy().squeeze() | |
pred_prediction = (pred_prob > PROBABILITY_THRES).astype(np.uint8) | |
return pred_prob, pred_prediction | |