jamino30 commited on
Commit
fc92636
·
verified ·
1 Parent(s): 75dfe24

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. app.py +3 -6
  2. inference.py +4 -11
app.py CHANGED
@@ -33,11 +33,8 @@ def load_model_without_module(model, model_path):
33
  model = VGG_19().to(device).eval()
34
  for param in model.parameters():
35
  param.requires_grad = False
36
- # sod_model = U2Net().to(device).eval()
37
- # load_model_without_module(sod_model, 'u2net/saved_models/u2net-duts.pt')
38
- sod_model = models.segmentation.deeplabv3_resnet101(
39
- weights='DEFAULT'
40
- ).to(device).eval()
41
 
42
  style_files = os.listdir('./style_images')
43
  style_options = {' '.join(style_file.split('.')[0].split('_')): f'./style_images/{style_file}' for style_file in style_files}
@@ -166,4 +163,4 @@ with gr.Blocks(css=css) as demo:
166
 
167
  demo.queue = False
168
  demo.config['queue'] = False
169
- demo.launch(show_api=False)
 
33
  model = VGG_19().to(device).eval()
34
  for param in model.parameters():
35
  param.requires_grad = False
36
+ sod_model = U2Net().to(device).eval()
37
+ load_model_without_module(sod_model, 'u2net/saved_models/u2net-duts.pt')
 
 
 
38
 
39
  style_files = os.listdir('./style_images')
40
  style_options = {' '.join(style_file.split('.')[0].split('_')): f'./style_images/{style_file}' for style_file in style_files}
 
163
 
164
  demo.queue = False
165
  demo.config['queue'] = False
166
+ demo.launch(show_api=False)
inference.py CHANGED
@@ -56,19 +56,12 @@ def inference(
56
  resized_bg_masks = []
57
  salient_object_ratio = None
58
  if apply_to_background:
59
- # original
60
- segmentation_output = sod_model(content_image)['out'] # [1, 21, 512, 512]
61
- segmentation_mask = segmentation_output.argmax(dim=1) # [1, 512, 512]
62
  background_mask = (segmentation_mask == 0).float()
63
  foreground_mask = 1 - background_mask
64
 
65
- # new
66
- # segmentation_output = sod_model(content_image)[0]
67
- # segmentation_output = torch.sigmoid(segmentation_output)
68
- # segmentation_mask = (segmentation_output > 0.7).float()
69
- # background_mask = (segmentation_mask == 0).float()
70
- # foreground_mask = 1 - background_mask
71
-
72
  salient_object_pixel_count = foreground_mask.sum().item()
73
  total_pixel_count = segmentation_mask.numel()
74
  salient_object_ratio = salient_object_pixel_count / total_pixel_count
@@ -99,4 +92,4 @@ def inference(
99
  foreground_mask_resized = F.interpolate(foreground_mask.unsqueeze(1), size=generated_image.shape[2:], mode='nearest')
100
  generated_image.data = generated_image.data * (1 - foreground_mask_resized) + content_image.data * foreground_mask_resized
101
 
102
- return generated_image, salient_object_ratio
 
56
  resized_bg_masks = []
57
  salient_object_ratio = None
58
  if apply_to_background:
59
+ segmentation_output = sod_model(content_image)[0]
60
+ segmentation_output = torch.sigmoid(segmentation_output)
61
+ segmentation_mask = (segmentation_output > 0.7).float()
62
  background_mask = (segmentation_mask == 0).float()
63
  foreground_mask = 1 - background_mask
64
 
 
 
 
 
 
 
 
65
  salient_object_pixel_count = foreground_mask.sum().item()
66
  total_pixel_count = segmentation_mask.numel()
67
  salient_object_ratio = salient_object_pixel_count / total_pixel_count
 
92
  foreground_mask_resized = F.interpolate(foreground_mask.unsqueeze(1), size=generated_image.shape[2:], mode='nearest')
93
  generated_image.data = generated_image.data * (1 - foreground_mask_resized) + content_image.data * foreground_mask_resized
94
 
95
+ return generated_image, salient_object_ratio