Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
a0f35d6
1
Parent(s):
b6a0637
Turn back back to the latest version.
Browse files
app.py
CHANGED
@@ -10,7 +10,7 @@ from typing import Tuple
|
|
10 |
|
11 |
from PIL import Image
|
12 |
from gradio_imageslider import ImageSlider
|
13 |
-
|
14 |
from torchvision import transforms
|
15 |
|
16 |
import requests
|
@@ -18,6 +18,7 @@ from io import BytesIO
|
|
18 |
import zipfile
|
19 |
|
20 |
|
|
|
21 |
torch.set_float32_matmul_precision('high')
|
22 |
torch.jit.script = lambda f: f
|
23 |
|
@@ -60,8 +61,9 @@ def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
|
|
60 |
|
61 |
class ImagePreprocessor():
|
62 |
def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
|
|
|
63 |
self.transform_image = transforms.Compose([
|
64 |
-
transforms.Resize(resolution),
|
65 |
transforms.ToTensor(),
|
66 |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
67 |
])
|
@@ -84,10 +86,11 @@ usage_to_weights_file = {
|
|
84 |
'HRSOD': 'BiRefNet-HRSOD',
|
85 |
'COD': 'BiRefNet-COD',
|
86 |
'DIS-TR_TEs': 'BiRefNet-DIS5K-TR_TEs',
|
87 |
-
'General-legacy': 'BiRefNet-legacy'
|
|
|
88 |
}
|
89 |
|
90 |
-
birefnet = AutoModelForImageSegmentation.from_pretrained('/'.join(('zhengpeng7', usage_to_weights_file['General'])), trust_remote_code=True)
|
91 |
birefnet.to(device)
|
92 |
birefnet.eval(); birefnet.half()
|
93 |
|
@@ -100,7 +103,7 @@ def predict(images, resolution, weights_file):
|
|
100 |
# Load BiRefNet with chosen weights
|
101 |
_weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else usage_to_weights_file['General']))
|
102 |
print('Using weights: {}.'.format(_weights_file))
|
103 |
-
birefnet = AutoModelForImageSegmentation.from_pretrained(_weights_file, trust_remote_code=True)
|
104 |
birefnet.to(device)
|
105 |
birefnet.eval(); birefnet.half()
|
106 |
|
@@ -114,7 +117,11 @@ def predict(images, resolution, weights_file):
|
|
114 |
elif weights_file in ['General-reso_512']:
|
115 |
resolution = (512, 512)
|
116 |
else:
|
117 |
-
|
|
|
|
|
|
|
|
|
118 |
print('Invalid resolution input. Automatically changed to 1024x1024 / 2048x2048 / 2560x1440.')
|
119 |
|
120 |
if isinstance(images, list):
|
@@ -141,6 +148,10 @@ def predict(images, resolution, weights_file):
|
|
141 |
|
142 |
image = image_ori.convert('RGB')
|
143 |
# Preprocess the image
|
|
|
|
|
|
|
|
|
144 |
image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
|
145 |
image_proc = image_preprocessor.proc(image)
|
146 |
image_proc = image_proc.unsqueeze(0)
|
|
|
10 |
|
11 |
from PIL import Image
|
12 |
from gradio_imageslider import ImageSlider
|
13 |
+
import transformers
|
14 |
from torchvision import transforms
|
15 |
|
16 |
import requests
|
|
|
18 |
import zipfile
|
19 |
|
20 |
|
21 |
+
transformers.utils.move_cache()
|
22 |
torch.set_float32_matmul_precision('high')
|
23 |
torch.jit.script = lambda f: f
|
24 |
|
|
|
61 |
|
62 |
class ImagePreprocessor():
|
63 |
def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
|
64 |
+
# Input resolution is on WxH.
|
65 |
self.transform_image = transforms.Compose([
|
66 |
+
transforms.Resize(resolution[::-1]),
|
67 |
transforms.ToTensor(),
|
68 |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
69 |
])
|
|
|
86 |
'HRSOD': 'BiRefNet-HRSOD',
|
87 |
'COD': 'BiRefNet-COD',
|
88 |
'DIS-TR_TEs': 'BiRefNet-DIS5K-TR_TEs',
|
89 |
+
'General-legacy': 'BiRefNet-legacy',
|
90 |
+
'General-dynamic': 'BiRefNet_dynamic',
|
91 |
}
|
92 |
|
93 |
+
birefnet = transformers.AutoModelForImageSegmentation.from_pretrained('/'.join(('zhengpeng7', usage_to_weights_file['General'])), trust_remote_code=True)
|
94 |
birefnet.to(device)
|
95 |
birefnet.eval(); birefnet.half()
|
96 |
|
|
|
103 |
# Load BiRefNet with chosen weights
|
104 |
_weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else usage_to_weights_file['General']))
|
105 |
print('Using weights: {}.'.format(_weights_file))
|
106 |
+
birefnet = transformers.AutoModelForImageSegmentation.from_pretrained(_weights_file, trust_remote_code=True)
|
107 |
birefnet.to(device)
|
108 |
birefnet.eval(); birefnet.half()
|
109 |
|
|
|
117 |
elif weights_file in ['General-reso_512']:
|
118 |
resolution = (512, 512)
|
119 |
else:
|
120 |
+
if weights_file in ['General-dynamic']:
|
121 |
+
resolution = None
|
122 |
+
print('Using the original size (div by 32) for inference.')
|
123 |
+
else:
|
124 |
+
resolution = (1024, 1024)
|
125 |
print('Invalid resolution input. Automatically changed to 1024x1024 / 2048x2048 / 2560x1440.')
|
126 |
|
127 |
if isinstance(images, list):
|
|
|
148 |
|
149 |
image = image_ori.convert('RGB')
|
150 |
# Preprocess the image
|
151 |
+
if resolution is None:
|
152 |
+
resolution_div_by_32 = [int(int(reso)//32*32) for reso in image.size]
|
153 |
+
if resolution_div_by_32 != resolution:
|
154 |
+
resolution = resolution_div_by_32
|
155 |
image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
|
156 |
image_proc = image_preprocessor.proc(image)
|
157 |
image_proc = image_proc.unsqueeze(0)
|