ZhengPeng7 commited on
Commit
b6a0637
·
1 Parent(s): 5f799ae
Files changed (1) hide show
  1. app.py +6 -17
app.py CHANGED
@@ -10,7 +10,7 @@ from typing import Tuple
10
 
11
  from PIL import Image
12
  from gradio_imageslider import ImageSlider
13
- import transformers
14
  from torchvision import transforms
15
 
16
  import requests
@@ -18,7 +18,6 @@ from io import BytesIO
18
  import zipfile
19
 
20
 
21
- transformers.utils.move_cache()
22
  torch.set_float32_matmul_precision('high')
23
  torch.jit.script = lambda f: f
24
 
@@ -61,9 +60,8 @@ def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
61
 
62
  class ImagePreprocessor():
63
  def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
64
- # Input resolution is on WxH.
65
  self.transform_image = transforms.Compose([
66
- transforms.Resize(resolution[::-1]),
67
  transforms.ToTensor(),
68
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
69
  ])
@@ -86,11 +84,10 @@ usage_to_weights_file = {
86
  'HRSOD': 'BiRefNet-HRSOD',
87
  'COD': 'BiRefNet-COD',
88
  'DIS-TR_TEs': 'BiRefNet-DIS5K-TR_TEs',
89
- 'General-legacy': 'BiRefNet-legacy',
90
- 'General-dynamic': 'BiRefNet_dynamic',
91
  }
92
 
93
- birefnet = transformers.AutoModelForImageSegmentation.from_pretrained('/'.join(('zhengpeng7', usage_to_weights_file['General'])), trust_remote_code=True)
94
  birefnet.to(device)
95
  birefnet.eval(); birefnet.half()
96
 
@@ -103,7 +100,7 @@ def predict(images, resolution, weights_file):
103
  # Load BiRefNet with chosen weights
104
  _weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else usage_to_weights_file['General']))
105
  print('Using weights: {}.'.format(_weights_file))
106
- birefnet = transformers.AutoModelForImageSegmentation.from_pretrained(_weights_file, trust_remote_code=True)
107
  birefnet.to(device)
108
  birefnet.eval(); birefnet.half()
109
 
@@ -117,11 +114,7 @@ def predict(images, resolution, weights_file):
117
  elif weights_file in ['General-reso_512']:
118
  resolution = (512, 512)
119
  else:
120
- if weights_file in ['General-dynamic']:
121
- resolution = None
122
- print('Using the original size (div by 32) for inference.')
123
- else:
124
- resolution = (1024, 1024)
125
  print('Invalid resolution input. Automatically changed to 1024x1024 / 2048x2048 / 2560x1440.')
126
 
127
  if isinstance(images, list):
@@ -148,10 +141,6 @@ def predict(images, resolution, weights_file):
148
 
149
  image = image_ori.convert('RGB')
150
  # Preprocess the image
151
- if resolution is None:
152
- resolution_div_by_32 = [int(int(reso)//32*32) for reso in image.size]
153
- if resolution_div_by_32 != resolution:
154
- resolution = resolution_div_by_32
155
  image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
156
  image_proc = image_preprocessor.proc(image)
157
  image_proc = image_proc.unsqueeze(0)
 
10
 
11
  from PIL import Image
12
  from gradio_imageslider import ImageSlider
13
+ from transformers import AutoModelForImageSegmentation
14
  from torchvision import transforms
15
 
16
  import requests
 
18
  import zipfile
19
 
20
 
 
21
  torch.set_float32_matmul_precision('high')
22
  torch.jit.script = lambda f: f
23
 
 
60
 
61
  class ImagePreprocessor():
62
  def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
 
63
  self.transform_image = transforms.Compose([
64
+ transforms.Resize(resolution),
65
  transforms.ToTensor(),
66
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
67
  ])
 
84
  'HRSOD': 'BiRefNet-HRSOD',
85
  'COD': 'BiRefNet-COD',
86
  'DIS-TR_TEs': 'BiRefNet-DIS5K-TR_TEs',
87
+ 'General-legacy': 'BiRefNet-legacy'
 
88
  }
89
 
90
+ birefnet = AutoModelForImageSegmentation.from_pretrained('/'.join(('zhengpeng7', usage_to_weights_file['General'])), trust_remote_code=True)
91
  birefnet.to(device)
92
  birefnet.eval(); birefnet.half()
93
 
 
100
  # Load BiRefNet with chosen weights
101
  _weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else usage_to_weights_file['General']))
102
  print('Using weights: {}.'.format(_weights_file))
103
+ birefnet = AutoModelForImageSegmentation.from_pretrained(_weights_file, trust_remote_code=True)
104
  birefnet.to(device)
105
  birefnet.eval(); birefnet.half()
106
 
 
114
  elif weights_file in ['General-reso_512']:
115
  resolution = (512, 512)
116
  else:
117
+ resolution = (1024, 1024)
 
 
 
 
118
  print('Invalid resolution input. Automatically changed to 1024x1024 / 2048x2048 / 2560x1440.')
119
 
120
  if isinstance(images, list):
 
141
 
142
  image = image_ori.convert('RGB')
143
  # Preprocess the image
 
 
 
 
144
  image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
145
  image_proc = image_preprocessor.proc(image)
146
  image_proc = image_proc.unsqueeze(0)