Spaces:
Build error
Build error
File size: 1,626 Bytes
586d4f8 c1565a6 586d4f8 2485872 586d4f8 c1565a6 81a19fa c1565a6 586d4f8 a541988 c1565a6 c3f5b96 c1565a6 8e17922 586d4f8 c1565a6 8e17922 c1565a6 8e17922 c1565a6 fbc5057 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
from transformers import SamModel, SamConfig, SamProcessor
import torch
import numpy as np
import matplotlib.pyplot as plt
import app
import os
import json
from PIL import Image
def pred(src):
# -- cache
cache_dir = "/code/cache"
# -- load model configuration
MODEL_FILE = "sam_model.pth"
model_config = SamConfig.from_pretrained("facebook/sam-vit-base", cache_dir=cache_dir)
model = SamModel(config=model_config)
model.load_state_dict(torch.load(MODEL_FILE, map_location=torch.device('cpu')))
with open("sam-config.json", "r") as f: # modified config json file
modified_config_dict = json.load(f)
processor = SamProcessor.from_pretrained("facebook/sam-vit-base",
**modified_config_dict,
cache_dir=cache_dir)
# -- process image
image = Image.open(src)
rgbim = image.convert("RGB")
new_image = np.array(rgbim)
print()
print("image shape:",new_image.shape)
inputs = processor(new_image, return_tensors="pt")
model.eval()
# forward pass
print("predicting...")
with torch.no_grad():
outputs = model(pixel_values=inputs["pixel_values"],
multimask_output=False)
# apply sigmoid
print("apply sigmoid...")
pred_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
# convert soft mask to hard mask
PROBABILITY_THRES = 0.30
pred_prob = pred_prob.cpu().numpy().squeeze()
pred_prediction = (pred_prob > PROBABILITY_THRES).astype(np.uint8)
return pred_prob, pred_prediction
|