sam / run.py
Nguyen Thai Thao Uyen
UI
586d4f8
raw
history blame
1.45 kB
from transformers import SamModel, SamConfig, SamProcessor
import torch
import numpy as np
import matplotlib.pyplot as plt
import app
import os
def pred(src):
# os.environ['HUGGINGFACE_HUB_HOME'] = './.cache'
# Load the model configuration
cache_dir = "/code/cache"
model_config = SamConfig.from_pretrained("facebook/sam-vit-base",
cache_dir=cache_dir)
processor = SamProcessor.from_pretrained("facebook/sam-vit-base",
cache_dir=cache_dir)
# Create an instance of the model architecture with the loaded configuration
model = SamModel(config=model_config)
#Update the model by loading the weights from saved file.
model.load_state_dict(torch.load("sam_model.pth",
map_location=torch.device('cpu')))
new_image = np.array(Image.open(src))
inputs = processor(new_image, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
x = 1
# model.eval()
# # forward pass
# with torch.no_grad():
# outputs = model(**inputs, multimask_output=False)
# # apply sigmoid
# single_patch_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
# # convert soft mask to hard mask
# single_patch_prob = single_patch_prob.cpu().numpy().squeeze()
# single_patch_prediction = (single_patch_prob > 0.5).astype(np.uint8)
return x