sam / run.py
Nguyen Thai Thao Uyen
run.py update device
59dad06
raw
history blame
1.55 kB
from transformers import SamModel, SamConfig, SamProcessor
import torch
import numpy as np
import matplotlib.pyplot as plt
import app
import os
import PIL
def pred(src):
# os.environ['HUGGINGFACE_HUB_HOME'] = './.cache'
# Load the model configuration
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
cache_dir = "/code/cache"
model_config = SamConfig.from_pretrained("facebook/sam-vit-base",
cache_dir=cache_dir)
processor = SamProcessor.from_pretrained("facebook/sam-vit-base",
cache_dir=cache_dir)
# Create an instance of the model architecture with the loaded configuration
model = SamModel(config=model_config)
#Update the model by loading the weights from saved file.
model.load_state_dict(torch.load("sam_model.pth",
map_location=torch.device('cpu')))
new_image = np.array(Image.open(src))
inputs = processor(new_image, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
x = 1
# model.eval()
# # forward pass
# with torch.no_grad():
# outputs = model(**inputs, multimask_output=False)
# # apply sigmoid
# single_patch_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
# # convert soft mask to hard mask
# single_patch_prob = single_patch_prob.cpu().numpy().squeeze()
# single_patch_prediction = (single_patch_prob > 0.5).astype(np.uint8)
return x