Spaces:
Build error
Build error
from transformers import SamModel, SamConfig, SamProcessor | |
import torch | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import app | |
import os | |
import json | |
from PIL import Image | |
def pred(src): | |
# -- load model configuration | |
MODEL_FILE = "sam_model.pth" | |
model_config = SamConfig.from_pretrained("facebook/sam-vit-base") | |
processor = SamProcessor.from_pretrained("facebook/sam-vit-base") | |
model = SamModel(config=model_config) | |
model.load_state_dict(torch.load(MODEL_FILE)) | |
with open("sam-config.json", "r") as f: # modified config json file | |
modified_config_dict = json.load(f) | |
processor = SamProcessor.from_pretrained("facebook/sam-vit-base", | |
**modified_config_dict) | |
# -- process image | |
image = Image.open(src) | |
rgbim = image.convert("RGB") | |
new_image = np.array(rgbim) | |
print("Shape:",new_image.shape) | |
inputs = processor(new_image, return_tensors="pt") | |
model.eval() | |
# forward pass | |
with torch.no_grad(): | |
outputs = model(pixel_values=inputs["pixel_values"], | |
multimask_output=False) | |
# apply sigmoid | |
pred_prob = torch.sigmoid(outputs.pred_masks.squeeze(1)) | |
# convert soft mask to hard mask | |
PROBABILITY_THRES = 0.30 | |
pred_prob = pred_prob.cpu().numpy().squeeze() | |
pred_prediction = (pred_prob > PROBABILITY_THRES).astype(np.uint8) | |
x=1 | |
return x | |