aryanxxvii commited on
Commit
0c0c137
1 Parent(s): 99c494e
Files changed (1) hide show
  1. inference.py +34 -0
inference.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from u2net import U2NET
3
+ from torchvision import transforms
4
+ import numpy as np
5
+ from PIL import Image
6
+ import torch.nn.functional as F
7
+ import data_transforms
8
+
9
+ # Load the model
10
+ def load_model():
11
+ model = U2NET(3, 1)
12
+ model.load_state_dict(torch.load("u2net.pth", map_location="cpu"))
13
+ model.eval()
14
+ return model
15
+
16
+ # Preprocessing function (same as you defined locally)
17
+ def preprocess(image):
18
+ transform = transforms.Compose([data_transforms.RescaleT(320), data_transforms.ToTensorLab(flag=0)])
19
+ label_3 = np.zeros(image.shape)
20
+ label = np.zeros(label_3.shape[0:2])
21
+ sample = transform({"imidx": np.array([0]), "image": image, "label": label})
22
+ return sample
23
+
24
+ # Inference function
25
+ def infer(model, image):
26
+ input_size = [1024, 1024]
27
+ im_shp = image.shape[0:2]
28
+ im_tensor = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1)
29
+ im_tensor = F.upsample(torch.unsqueeze(im_tensor, 0), input_size, mode="bilinear").type(torch.uint8)
30
+ image = torch.divide(im_tensor, 255.0)
31
+ result = model(image)
32
+ result = torch.squeeze(F.upsample(result[0][0], im_shp, mode='bilinear'), 0)
33
+ result = (result - result.min()) / (result.max() - result.min())
34
+ return result.numpy()