wahaha commited on
Commit
3cd36ea
·
1 Parent(s): e6eaa82
Files changed (2) hide show
  1. README.md +0 -1
  2. app.py +3 -107
README.md CHANGED
@@ -1,5 +1,4 @@
1
  ---
2
- python_version: 3.7
3
  title: U2net_portrait
4
  emoji: 🦀
5
  colorFrom: indigo
 
1
  ---
 
2
  title: U2net_portrait
3
  emoji: 🦀
4
  colorFrom: indigo
app.py CHANGED
@@ -6,58 +6,15 @@ import sys
6
  sys.path.insert(0, 'U-2-Net')
7
 
8
  from skimage import io, transform
9
- import torch
10
- import torchvision
11
- from torch.autograd import Variable
12
- import torch.nn as nn
13
- import torch.nn.functional as F
14
- from torch.utils.data import Dataset, DataLoader
15
- from torchvision import transforms#, utils
16
- # import torch.optim as optim
17
 
18
  import numpy as np
19
  from PIL import Image
20
- import glob
21
 
22
- from data_loader import RescaleT
23
- from data_loader import ToTensor
24
- from data_loader import ToTensorLab
25
- from data_loader import SalObjDataset
26
 
27
- from model import U2NET # full size version 173.6 MB
28
- from model import U2NETP # small version u2net 4.7 MB
29
 
30
  from modnet import ModNet
31
  import huggingface_hub
32
 
33
- # normalize the predicted SOD probability map
34
- def normPRED(d):
35
- ma = torch.max(d)
36
- mi = torch.min(d)
37
-
38
- dn = (d-mi)/(ma-mi)
39
-
40
- return dn
41
- def save_output(image_name,pred,d_dir):
42
- predict = pred
43
- predict = predict.squeeze()
44
- predict_np = predict.cpu().data.numpy()
45
-
46
- im = Image.fromarray(predict_np*255).convert('RGB')
47
- img_name = image_name.split(os.sep)[-1]
48
- image = io.imread(image_name)
49
- imo = im.resize((image.shape[1],image.shape[0]),resample=Image.BILINEAR)
50
-
51
- pb_np = np.array(imo)
52
-
53
- aaa = img_name.split(".")
54
- bbb = aaa[0:-1]
55
- imidx = bbb[0]
56
- for i in range(1,len(bbb)):
57
- imidx = imidx + "." + bbb[i]
58
-
59
- imo.save(d_dir+'/'+imidx+'.png')
60
- return d_dir+'/'+imidx+'.png'
61
 
62
 
63
 
@@ -66,76 +23,15 @@ modnet_path = huggingface_hub.hf_hub_download('hylee/apdrawing_model',
66
  force_filename='modnet.onnx')
67
  modnet = ModNet(modnet_path)
68
 
69
- # --------- 1. get image path and name ---------
70
- model_name='u2net_portrait'#u2netp
71
-
72
-
73
- image_dir = 'portrait_im'
74
- prediction_dir = 'portrait_results'
75
- if(not os.path.exists(prediction_dir)):
76
- os.mkdir(prediction_dir)
77
-
78
- model_dir = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'U-2-Net/saved_models/u2net_portrait/u2net_portrait.pth')
79
-
80
-
81
- # --------- 3. model define ---------
82
-
83
- print("...load U2NET---173.6 MB")
84
- net = U2NET(3,1)
85
-
86
- net.load_state_dict(torch.load(model_dir, map_location='cpu'))
87
- # if torch.cuda.is_available():
88
- # net.cuda()
89
- net.eval()
90
-
91
-
92
  def process(im):
93
  image = modnet.segment(im.name)
94
  im_path = os.path.abspath(os.path.basename(im.name))
95
  Image.fromarray(np.uint8(image)).save(im_path)
96
 
97
- img_name_list = [im_path]
98
- print("Number of images: ", len(img_name_list))
99
- # --------- 2. dataloader ---------
100
- # 1. dataloader
101
- test_salobj_dataset = SalObjDataset(img_name_list=img_name_list,
102
- lbl_name_list=[],
103
- transform=transforms.Compose([RescaleT(512),
104
- ToTensorLab(flag=0)])
105
- )
106
- test_salobj_dataloader = DataLoader(test_salobj_dataset,
107
- batch_size=1,
108
- shuffle=False,
109
- num_workers=1)
110
-
111
- results = []
112
- # --------- 4. inference for each image ---------
113
- for i_test, data_test in enumerate(test_salobj_dataloader):
114
-
115
- print("inferencing:", img_name_list[i_test].split(os.sep)[-1])
116
-
117
- inputs_test = data_test['image']
118
- inputs_test = inputs_test.type(torch.FloatTensor)
119
-
120
- # if torch.cuda.is_available():
121
- # inputs_test = Variable(inputs_test.cuda())
122
- # else:
123
- inputs_test = Variable(inputs_test)
124
-
125
- d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)
126
-
127
- # normalization
128
- pred = 1.0 - d1[:, 0, :, :]
129
- pred = normPRED(pred)
130
-
131
- # save results to test_results folder
132
- results.append(save_output(img_name_list[i_test], pred, prediction_dir))
133
-
134
- del d1, d2, d3, d4, d5, d6, d7
135
-
136
- print(results)
137
 
138
- return Image.open(results[0])
139
 
140
  title = "U-2-Net"
141
  description = "Gradio demo for U-2-Net, https://github.com/xuebinqin/U-2-Net"
 
6
  sys.path.insert(0, 'U-2-Net')
7
 
8
  from skimage import io, transform
 
 
 
 
 
 
 
 
9
 
10
  import numpy as np
11
  from PIL import Image
 
12
 
 
 
 
 
13
 
 
 
14
 
15
  from modnet import ModNet
16
  import huggingface_hub
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
 
20
 
 
23
  force_filename='modnet.onnx')
24
  modnet = ModNet(modnet_path)
25
 
26
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def process(im):
28
  image = modnet.segment(im.name)
29
  im_path = os.path.abspath(os.path.basename(im.name))
30
  Image.fromarray(np.uint8(image)).save(im_path)
31
 
32
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
+ return Image.open(im_path)
35
 
36
  title = "U-2-Net"
37
  description = "Gradio demo for U-2-Net, https://github.com/xuebinqin/U-2-Net"