wahaha commited on
Commit
dfb75bd
·
1 Parent(s): 3cd36ea
Files changed (1) hide show
  1. app.py +110 -6
app.py CHANGED
@@ -6,15 +6,58 @@ import sys
6
  sys.path.insert(0, 'U-2-Net')
7
 
8
  from skimage import io, transform
 
 
 
 
 
 
 
 
9
 
10
  import numpy as np
11
  from PIL import Image
 
12
 
 
 
 
 
13
 
 
 
14
 
15
  from modnet import ModNet
16
  import huggingface_hub
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
 
20
 
@@ -23,15 +66,76 @@ modnet_path = huggingface_hub.hf_hub_download('hylee/apdrawing_model',
23
  force_filename='modnet.onnx')
24
  modnet = ModNet(modnet_path)
25
 
26
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def process(im):
28
- image = modnet.segment(im.name)
29
- im_path = os.path.abspath(os.path.basename(im.name))
30
- Image.fromarray(np.uint8(image)).save(im_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
-
33
 
34
- return Image.open(im_path)
35
 
36
  title = "U-2-Net"
37
  description = "Gradio demo for U-2-Net, https://github.com/xuebinqin/U-2-Net"
 
6
  sys.path.insert(0, 'U-2-Net')
7
 
8
  from skimage import io, transform
9
+ import torch
10
+ import torchvision
11
+ from torch.autograd import Variable
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from torch.utils.data import Dataset, DataLoader
15
+ from torchvision import transforms#, utils
16
+ # import torch.optim as optim
17
 
18
  import numpy as np
19
  from PIL import Image
20
+ import glob
21
 
22
+ from data_loader import RescaleT
23
+ from data_loader import ToTensor
24
+ from data_loader import ToTensorLab
25
+ from data_loader import SalObjDataset
26
 
27
+ from model import U2NET # full size version 173.6 MB
28
+ from model import U2NETP # small version u2net 4.7 MB
29
 
30
  from modnet import ModNet
31
  import huggingface_hub
32
 
33
+ # normalize the predicted SOD probability map
34
+ def normPRED(d):
35
+ ma = torch.max(d)
36
+ mi = torch.min(d)
37
+
38
+ dn = (d-mi)/(ma-mi)
39
+
40
+ return dn
41
+ def save_output(image_name,pred,d_dir):
42
+ predict = pred
43
+ predict = predict.squeeze()
44
+ predict_np = predict.cpu().data.numpy()
45
+
46
+ im = Image.fromarray(predict_np*255).convert('RGB')
47
+ img_name = image_name.split(os.sep)[-1]
48
+ image = io.imread(image_name)
49
+ imo = im.resize((image.shape[1],image.shape[0]),resample=Image.BILINEAR)
50
+
51
+ pb_np = np.array(imo)
52
+
53
+ aaa = img_name.split(".")
54
+ bbb = aaa[0:-1]
55
+ imidx = bbb[0]
56
+ for i in range(1,len(bbb)):
57
+ imidx = imidx + "." + bbb[i]
58
+
59
+ imo.save(d_dir+'/'+imidx+'.png')
60
+ return d_dir+'/'+imidx+'.png'
61
 
62
 
63
 
 
66
  force_filename='modnet.onnx')
67
  modnet = ModNet(modnet_path)
68
 
69
+ # --------- 1. get image path and name ---------
70
+ model_name='u2net_portrait'#u2netp
71
+
72
+
73
+ image_dir = 'portrait_im'
74
+ prediction_dir = 'portrait_results'
75
+ if(not os.path.exists(prediction_dir)):
76
+ os.mkdir(prediction_dir)
77
+
78
+ model_dir = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'U-2-Net/saved_models/u2net_portrait/u2net_portrait.pth')
79
+
80
+
81
+ # --------- 3. model define ---------
82
+
83
+ print("...load U2NET---173.6 MB")
84
+ net = U2NET(3,1)
85
+
86
+ net.load_state_dict(torch.load(model_dir, map_location='cpu'))
87
+ # if torch.cuda.is_available():
88
+ # net.cuda()
89
+ net.eval()
90
+
91
+
92
  def process(im):
93
+ #image = modnet.segment(im.name)
94
+ #im_path = os.path.abspath(os.path.basename(im.name))
95
+ #Image.fromarray(np.uint8(image)).save(im_path)
96
+
97
+ img_name_list = [im.name]
98
+ print("Number of images: ", len(img_name_list))
99
+ # --------- 2. dataloader ---------
100
+ # 1. dataloader
101
+ test_salobj_dataset = SalObjDataset(img_name_list=img_name_list,
102
+ lbl_name_list=[],
103
+ transform=transforms.Compose([RescaleT(512),
104
+ ToTensorLab(flag=0)])
105
+ )
106
+ test_salobj_dataloader = DataLoader(test_salobj_dataset,
107
+ batch_size=1,
108
+ shuffle=False,
109
+ num_workers=1)
110
+
111
+ results = []
112
+ # --------- 4. inference for each image ---------
113
+ for i_test, data_test in enumerate(test_salobj_dataloader):
114
+
115
+ print("inferencing:", img_name_list[i_test].split(os.sep)[-1])
116
+
117
+ inputs_test = data_test['image']
118
+ inputs_test = inputs_test.type(torch.FloatTensor)
119
+
120
+ # if torch.cuda.is_available():
121
+ # inputs_test = Variable(inputs_test.cuda())
122
+ # else:
123
+ inputs_test = Variable(inputs_test)
124
+
125
+ d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)
126
+
127
+ # normalization
128
+ pred = 1.0 - d1[:, 0, :, :]
129
+ pred = normPRED(pred)
130
+
131
+ # save results to test_results folder
132
+ results.append(save_output(img_name_list[i_test], pred, prediction_dir))
133
+
134
+ del d1, d2, d3, d4, d5, d6, d7
135
 
136
+ print(results)
137
 
138
+ return Image.open(results[0])
139
 
140
  title = "U-2-Net"
141
  description = "Gradio demo for U-2-Net, https://github.com/xuebinqin/U-2-Net"