ZhengPeng7 commited on
Commit
8af980d
·
1 Parent(s): 5b26e24

Add BiRefNet_dynamic for test and arbitary input size for it.

Browse files
Files changed (1) hide show
  1. app.py +13 -3
app.py CHANGED
@@ -60,8 +60,9 @@ def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
60
 
61
  class ImagePreprocessor():
62
  def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
 
63
  self.transform_image = transforms.Compose([
64
- transforms.Resize(resolution),
65
  transforms.ToTensor(),
66
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
67
  ])
@@ -84,7 +85,8 @@ usage_to_weights_file = {
84
  'HRSOD': 'BiRefNet-HRSOD',
85
  'COD': 'BiRefNet-COD',
86
  'DIS-TR_TEs': 'BiRefNet-DIS5K-TR_TEs',
87
- 'General-legacy': 'BiRefNet-legacy'
 
88
  }
89
 
90
  birefnet = AutoModelForImageSegmentation.from_pretrained('/'.join(('zhengpeng7', usage_to_weights_file['General'])), trust_remote_code=True)
@@ -114,7 +116,11 @@ def predict(images, resolution, weights_file):
114
  elif weights_file in ['General-reso_512']:
115
  resolution = (512, 512)
116
  else:
117
- resolution = (1024, 1024)
 
 
 
 
118
  print('Invalid resolution input. Automatically changed to 1024x1024 / 2048x2048 / 2560x1440.')
119
 
120
  if isinstance(images, list):
@@ -141,6 +147,10 @@ def predict(images, resolution, weights_file):
141
 
142
  image = image_ori.convert('RGB')
143
  # Preprocess the image
 
 
 
 
144
  image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
145
  image_proc = image_preprocessor.proc(image)
146
  image_proc = image_proc.unsqueeze(0)
 
60
 
61
  class ImagePreprocessor():
62
  def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
63
+ # Input resolution is on WxH.
64
  self.transform_image = transforms.Compose([
65
+ transforms.Resize(resolution[::-1]),
66
  transforms.ToTensor(),
67
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
68
  ])
 
85
  'HRSOD': 'BiRefNet-HRSOD',
86
  'COD': 'BiRefNet-COD',
87
  'DIS-TR_TEs': 'BiRefNet-DIS5K-TR_TEs',
88
+ 'General-legacy': 'BiRefNet-legacy',
89
+ 'General-dynamic': 'BiRefNet_dynamic',
90
  }
91
 
92
  birefnet = AutoModelForImageSegmentation.from_pretrained('/'.join(('zhengpeng7', usage_to_weights_file['General'])), trust_remote_code=True)
 
116
  elif weights_file in ['General-reso_512']:
117
  resolution = (512, 512)
118
  else:
119
+ if weights_file in ['General-dynamic']:
120
+ resolution = None
121
+ print('Using the original size (div by 32) for inference.')
122
+ else:
123
+ resolution = (1024, 1024)
124
  print('Invalid resolution input. Automatically changed to 1024x1024 / 2048x2048 / 2560x1440.')
125
 
126
  if isinstance(images, list):
 
147
 
148
  image = image_ori.convert('RGB')
149
  # Preprocess the image
150
+ if resolution is None:
151
+ resolution_div_by_32 = [int(int(reso)//32*32) for reso in image.size]
152
+ if resolution_div_by_32 != resolution:
153
+ resolution = resolution_div_by_32
154
  image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
155
  image_proc = image_preprocessor.proc(image)
156
  image_proc = image_proc.unsqueeze(0)