from transformers import SamModel, SamConfig, SamProcessor import torch import numpy as np import matplotlib.pyplot as plt import app import os import json from PIL import Image def pred(src): # -- cache cache_dir = "/code/cache" # -- load model configuration MODEL_FILE = "sam_model.pth" model_config = SamConfig.from_pretrained("facebook/sam-vit-base", cache_dir=cache_dir) model = SamModel(config=model_config) model.load_state_dict(torch.load(MODEL_FILE, map_location=torch.device('cpu'))) with open("sam-config.json", "r") as f: # modified config json file modified_config_dict = json.load(f) processor = SamProcessor.from_pretrained("facebook/sam-vit-base", **modified_config_dict, cache_dir=cache_dir) # -- process image image = Image.open(src) rgbim = image.convert("RGB") new_image = np.array(rgbim) print() print("image shape:",new_image.shape) inputs = processor(new_image, return_tensors="pt") model.eval() # forward pass print("predicting...") with torch.no_grad(): outputs = model(pixel_values=inputs["pixel_values"], multimask_output=False) # apply sigmoid print("apply sigmoid...") pred_prob = torch.sigmoid(outputs.pred_masks.squeeze(1)) # convert soft mask to hard mask PROBABILITY_THRES = 0.30 pred_prob = pred_prob.cpu().numpy().squeeze() pred_prediction = (pred_prob > PROBABILITY_THRES).astype(np.uint8) return pred_prob, pred_prediction