Spaces:
Build error
Build error
File size: 1,438 Bytes
586d4f8 c1565a6 586d4f8 c1565a6 586d4f8 c1565a6 586d4f8 c1565a6 586d4f8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
from transformers import SamModel, SamConfig, SamProcessor
import torch
import numpy as np
import matplotlib.pyplot as plt
import app
import os
import json
from PIL import Image
def pred(src):
# -- load model configuration
MODEL_FILE = "sam_model.pth"
model_config = SamConfig.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
model = SamModel(config=model_config)
model.load_state_dict(torch.load(MODEL_FILE))
with open("sam-config.json", "r") as f: # modified config json file
modified_config_dict = json.load(f)
processor = SamProcessor.from_pretrained("facebook/sam-vit-base",
**modified_config_dict)
# -- process image
image = Image.open(src)
rgbim = image.convert("RGB")
new_image = np.array(rgbim)
print("Shape:",new_image.shape)
inputs = processor(new_image, return_tensors="pt")
model.eval()
# forward pass
with torch.no_grad():
outputs = model(pixel_values=inputs["pixel_values"],
multimask_output=False)
# apply sigmoid
pred_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
# convert soft mask to hard mask
PROBABILITY_THRES = 0.30
pred_prob = pred_prob.cpu().numpy().squeeze()
pred_prediction = (pred_prob > PROBABILITY_THRES).astype(np.uint8)
x=1
return x
|