ZhengPeng7 commited on
Commit
a0f35d6
·
1 Parent(s): b6a0637

Turn back back to the latest version.

Browse files
Files changed (1) hide show
  1. app.py +17 -6
app.py CHANGED
@@ -10,7 +10,7 @@ from typing import Tuple
10
 
11
  from PIL import Image
12
  from gradio_imageslider import ImageSlider
13
- from transformers import AutoModelForImageSegmentation
14
  from torchvision import transforms
15
 
16
  import requests
@@ -18,6 +18,7 @@ from io import BytesIO
18
  import zipfile
19
 
20
 
 
21
  torch.set_float32_matmul_precision('high')
22
  torch.jit.script = lambda f: f
23
 
@@ -60,8 +61,9 @@ def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
60
 
61
  class ImagePreprocessor():
62
  def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
 
63
  self.transform_image = transforms.Compose([
64
- transforms.Resize(resolution),
65
  transforms.ToTensor(),
66
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
67
  ])
@@ -84,10 +86,11 @@ usage_to_weights_file = {
84
  'HRSOD': 'BiRefNet-HRSOD',
85
  'COD': 'BiRefNet-COD',
86
  'DIS-TR_TEs': 'BiRefNet-DIS5K-TR_TEs',
87
- 'General-legacy': 'BiRefNet-legacy'
 
88
  }
89
 
90
- birefnet = AutoModelForImageSegmentation.from_pretrained('/'.join(('zhengpeng7', usage_to_weights_file['General'])), trust_remote_code=True)
91
  birefnet.to(device)
92
  birefnet.eval(); birefnet.half()
93
 
@@ -100,7 +103,7 @@ def predict(images, resolution, weights_file):
100
  # Load BiRefNet with chosen weights
101
  _weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else usage_to_weights_file['General']))
102
  print('Using weights: {}.'.format(_weights_file))
103
- birefnet = AutoModelForImageSegmentation.from_pretrained(_weights_file, trust_remote_code=True)
104
  birefnet.to(device)
105
  birefnet.eval(); birefnet.half()
106
 
@@ -114,7 +117,11 @@ def predict(images, resolution, weights_file):
114
  elif weights_file in ['General-reso_512']:
115
  resolution = (512, 512)
116
  else:
117
- resolution = (1024, 1024)
 
 
 
 
118
  print('Invalid resolution input. Automatically changed to 1024x1024 / 2048x2048 / 2560x1440.')
119
 
120
  if isinstance(images, list):
@@ -141,6 +148,10 @@ def predict(images, resolution, weights_file):
141
 
142
  image = image_ori.convert('RGB')
143
  # Preprocess the image
 
 
 
 
144
  image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
145
  image_proc = image_preprocessor.proc(image)
146
  image_proc = image_proc.unsqueeze(0)
 
10
 
11
  from PIL import Image
12
  from gradio_imageslider import ImageSlider
13
+ import transformers
14
  from torchvision import transforms
15
 
16
  import requests
 
18
  import zipfile
19
 
20
 
21
+ transformers.utils.move_cache()
22
  torch.set_float32_matmul_precision('high')
23
  torch.jit.script = lambda f: f
24
 
 
61
 
62
  class ImagePreprocessor():
63
  def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
64
+ # Input resolution is on WxH.
65
  self.transform_image = transforms.Compose([
66
+ transforms.Resize(resolution[::-1]),
67
  transforms.ToTensor(),
68
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
69
  ])
 
86
  'HRSOD': 'BiRefNet-HRSOD',
87
  'COD': 'BiRefNet-COD',
88
  'DIS-TR_TEs': 'BiRefNet-DIS5K-TR_TEs',
89
+ 'General-legacy': 'BiRefNet-legacy',
90
+ 'General-dynamic': 'BiRefNet_dynamic',
91
  }
92
 
93
+ birefnet = transformers.AutoModelForImageSegmentation.from_pretrained('/'.join(('zhengpeng7', usage_to_weights_file['General'])), trust_remote_code=True)
94
  birefnet.to(device)
95
  birefnet.eval(); birefnet.half()
96
 
 
103
  # Load BiRefNet with chosen weights
104
  _weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else usage_to_weights_file['General']))
105
  print('Using weights: {}.'.format(_weights_file))
106
+ birefnet = transformers.AutoModelForImageSegmentation.from_pretrained(_weights_file, trust_remote_code=True)
107
  birefnet.to(device)
108
  birefnet.eval(); birefnet.half()
109
 
 
117
  elif weights_file in ['General-reso_512']:
118
  resolution = (512, 512)
119
  else:
120
+ if weights_file in ['General-dynamic']:
121
+ resolution = None
122
+ print('Using the original size (div by 32) for inference.')
123
+ else:
124
+ resolution = (1024, 1024)
125
  print('Invalid resolution input. Automatically changed to 1024x1024 / 2048x2048 / 2560x1440.')
126
 
127
  if isinstance(images, list):
 
148
 
149
  image = image_ori.convert('RGB')
150
  # Preprocess the image
151
+ if resolution is None:
152
+ resolution_div_by_32 = [int(int(reso)//32*32) for reso in image.size]
153
+ if resolution_div_by_32 != resolution:
154
+ resolution = resolution_div_by_32
155
  image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
156
  image_proc = image_preprocessor.proc(image)
157
  image_proc = image_proc.unsqueeze(0)