Upload folder using huggingface_hub
Browse files- app.py +3 -6
- inference.py +4 -11
app.py
CHANGED
@@ -33,11 +33,8 @@ def load_model_without_module(model, model_path):
|
|
33 |
model = VGG_19().to(device).eval()
|
34 |
for param in model.parameters():
|
35 |
param.requires_grad = False
|
36 |
-
|
37 |
-
|
38 |
-
sod_model = models.segmentation.deeplabv3_resnet101(
|
39 |
-
weights='DEFAULT'
|
40 |
-
).to(device).eval()
|
41 |
|
42 |
style_files = os.listdir('./style_images')
|
43 |
style_options = {' '.join(style_file.split('.')[0].split('_')): f'./style_images/{style_file}' for style_file in style_files}
|
@@ -166,4 +163,4 @@ with gr.Blocks(css=css) as demo:
|
|
166 |
|
167 |
demo.queue = False
|
168 |
demo.config['queue'] = False
|
169 |
-
demo.launch(show_api=False)
|
|
|
33 |
model = VGG_19().to(device).eval()
|
34 |
for param in model.parameters():
|
35 |
param.requires_grad = False
|
36 |
+
sod_model = U2Net().to(device).eval()
|
37 |
+
load_model_without_module(sod_model, 'u2net/saved_models/u2net-duts.pt')
|
|
|
|
|
|
|
38 |
|
39 |
style_files = os.listdir('./style_images')
|
40 |
style_options = {' '.join(style_file.split('.')[0].split('_')): f'./style_images/{style_file}' for style_file in style_files}
|
|
|
163 |
|
164 |
demo.queue = False
|
165 |
demo.config['queue'] = False
|
166 |
+
demo.launch(show_api=False)
|
inference.py
CHANGED
@@ -56,19 +56,12 @@ def inference(
|
|
56 |
resized_bg_masks = []
|
57 |
salient_object_ratio = None
|
58 |
if apply_to_background:
|
59 |
-
|
60 |
-
segmentation_output =
|
61 |
-
segmentation_mask = segmentation_output.
|
62 |
background_mask = (segmentation_mask == 0).float()
|
63 |
foreground_mask = 1 - background_mask
|
64 |
|
65 |
-
# new
|
66 |
-
# segmentation_output = sod_model(content_image)[0]
|
67 |
-
# segmentation_output = torch.sigmoid(segmentation_output)
|
68 |
-
# segmentation_mask = (segmentation_output > 0.7).float()
|
69 |
-
# background_mask = (segmentation_mask == 0).float()
|
70 |
-
# foreground_mask = 1 - background_mask
|
71 |
-
|
72 |
salient_object_pixel_count = foreground_mask.sum().item()
|
73 |
total_pixel_count = segmentation_mask.numel()
|
74 |
salient_object_ratio = salient_object_pixel_count / total_pixel_count
|
@@ -99,4 +92,4 @@ def inference(
|
|
99 |
foreground_mask_resized = F.interpolate(foreground_mask.unsqueeze(1), size=generated_image.shape[2:], mode='nearest')
|
100 |
generated_image.data = generated_image.data * (1 - foreground_mask_resized) + content_image.data * foreground_mask_resized
|
101 |
|
102 |
-
return generated_image, salient_object_ratio
|
|
|
56 |
resized_bg_masks = []
|
57 |
salient_object_ratio = None
|
58 |
if apply_to_background:
|
59 |
+
segmentation_output = sod_model(content_image)[0]
|
60 |
+
segmentation_output = torch.sigmoid(segmentation_output)
|
61 |
+
segmentation_mask = (segmentation_output > 0.7).float()
|
62 |
background_mask = (segmentation_mask == 0).float()
|
63 |
foreground_mask = 1 - background_mask
|
64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
salient_object_pixel_count = foreground_mask.sum().item()
|
66 |
total_pixel_count = segmentation_mask.numel()
|
67 |
salient_object_ratio = salient_object_pixel_count / total_pixel_count
|
|
|
92 |
foreground_mask_resized = F.interpolate(foreground_mask.unsqueeze(1), size=generated_image.shape[2:], mode='nearest')
|
93 |
generated_image.data = generated_image.data * (1 - foreground_mask_resized) + content_image.data * foreground_mask_resized
|
94 |
|
95 |
+
return generated_image, salient_object_ratio
|