Spaces:
Runtime error
Runtime error
File size: 1,035 Bytes
fe70fd4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
import torch
import torchvision
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from model import DoubleConv,UNET
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
convert_tensor = transforms.ToTensor()
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print(device)
model = UNET(in_channels=3, out_channels=1).to(device)
model=torch.load("Unet_acc_94.pth",map_location=torch.device('cpu'))
# test_img=np.array(Image.open("profilepic - Copy.jpeg").resize((160,240)))
test_img=Image.open("104.jpg").resize((240,160))
# test_img=torch.tensor(test_img).permute(2,1,0)
# test_img=test_img.unsqueeze(0)
test_img=convert_tensor(test_img).unsqueeze(0)
print(test_img.shape)
preds=model(test_img.float())
preds=torch.sigmoid(preds)
preds=(preds > 0.5).float()
print(preds.shape)
im=preds.squeeze(0).permute(1,2,0).detach()
print(im.shape)
fig,axs=plt.subplots(1,2)
axs[0].imshow(im)
axs[1].imshow(test_img.squeeze(0).permute(1,2,0).detach())
plt.show()
|