Spaces:
Running
on
L40S
Running
on
L40S
update
Browse files- infer_api.py +27 -31
infer_api.py
CHANGED
@@ -108,34 +108,32 @@ def set_seed(seed):
|
|
108 |
torch.manual_seed(seed)
|
109 |
torch.cuda.manual_seed_all(seed)
|
110 |
|
111 |
-
class BkgRemover:
|
112 |
-
def __init__(self, force_cpu: Optional[bool] = True):
|
113 |
-
session_infer_path = hf_hub_download(
|
114 |
-
repo_id="skytnt/anime-seg", filename="isnetis.onnx",
|
115 |
-
)
|
116 |
-
providers: list[str] = ["CPUExecutionProvider"]
|
117 |
-
if not force_cpu and "CUDAExecutionProvider" in rt.get_available_providers():
|
118 |
-
providers = ["CUDAExecutionProvider"]
|
119 |
|
120 |
-
|
121 |
-
|
122 |
-
|
|
|
|
|
|
|
123 |
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
|
|
|
|
|
|
139 |
|
140 |
|
141 |
def process_image(image, totensor, width, height):
|
@@ -168,7 +166,7 @@ def process_image(image, totensor, width, height):
|
|
168 |
|
169 |
@spaces.GPU
|
170 |
@torch.no_grad()
|
171 |
-
def inference(validation_pipeline,
|
172 |
text_encoder, pretrained_model_path, validation, val_width, val_height, unet_condition_type,
|
173 |
use_noise=True, noise_d=256, crop=False, seed=100, timestep=20):
|
174 |
set_seed(seed)
|
@@ -186,7 +184,7 @@ def inference(validation_pipeline, bkg_remover, input_image, vae, feature_extrac
|
|
186 |
B = 1
|
187 |
if input_image.mode != "RGBA":
|
188 |
# remove background
|
189 |
-
input_image =
|
190 |
imgs_in = process_image(input_image, totensor, val_width, val_height)
|
191 |
imgs_in = rearrange(imgs_in.unsqueeze(0).unsqueeze(0), "B Nv C H W -> (B Nv) C H W")
|
192 |
|
@@ -869,11 +867,9 @@ class InferCanonicalAPI:
|
|
869 |
)
|
870 |
self.validation_pipeline.set_progress_bar_config(disable=True)
|
871 |
|
872 |
-
self.bkg_remover = BkgRemover()
|
873 |
-
|
874 |
def canonicalize(self, image, seed):
|
875 |
return inference(
|
876 |
-
self.validation_pipeline,
|
877 |
self.pretrained_model_path, self.validation, self.width_input, self.height_input, self.unet_condition_type,
|
878 |
use_noise=self.use_noise, noise_d=self.noise_d, crop=True, seed=seed, timestep=self.timestep
|
879 |
)
|
|
|
108 |
torch.manual_seed(seed)
|
109 |
torch.cuda.manual_seed_all(seed)
|
110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
|
112 |
+
session_infer_path = hf_hub_download(
|
113 |
+
repo_id="skytnt/anime-seg", filename="isnetis.onnx",
|
114 |
+
)
|
115 |
+
providers: list[str] = ["CPUExecutionProvider"]
|
116 |
+
if "CUDAExecutionProvider" in rt.get_available_providers():
|
117 |
+
providers = ["CUDAExecutionProvider"]
|
118 |
|
119 |
+
bkg_remover_session_infer = rt.InferenceSession(
|
120 |
+
session_infer_path, providers=providers,
|
121 |
+
)
|
122 |
+
|
123 |
+
@spaces.GPU
|
124 |
+
def remove_background(
|
125 |
+
img: np.ndarray,
|
126 |
+
alpha_min: float,
|
127 |
+
alpha_max: float,
|
128 |
+
) -> list:
|
129 |
+
img = np.array(img)
|
130 |
+
mask = get_mask(bkg_remover_session_infer, img)
|
131 |
+
mask[mask < alpha_min] = 0.0
|
132 |
+
mask[mask > alpha_max] = 1.0
|
133 |
+
img_after = (mask * img).astype(np.uint8)
|
134 |
+
mask = (mask * SCALE).astype(np.uint8)
|
135 |
+
img_after = np.concatenate([img_after, mask], axis=2, dtype=np.uint8)
|
136 |
+
return Image.fromarray(img_after)
|
137 |
|
138 |
|
139 |
def process_image(image, totensor, width, height):
|
|
|
166 |
|
167 |
@spaces.GPU
|
168 |
@torch.no_grad()
|
169 |
+
def inference(validation_pipeline, input_image, vae, feature_extractor, image_encoder, unet, ref_unet, tokenizer,
|
170 |
text_encoder, pretrained_model_path, validation, val_width, val_height, unet_condition_type,
|
171 |
use_noise=True, noise_d=256, crop=False, seed=100, timestep=20):
|
172 |
set_seed(seed)
|
|
|
184 |
B = 1
|
185 |
if input_image.mode != "RGBA":
|
186 |
# remove background
|
187 |
+
input_image = remove_background(input_image, 0.1, 0.9)
|
188 |
imgs_in = process_image(input_image, totensor, val_width, val_height)
|
189 |
imgs_in = rearrange(imgs_in.unsqueeze(0).unsqueeze(0), "B Nv C H W -> (B Nv) C H W")
|
190 |
|
|
|
867 |
)
|
868 |
self.validation_pipeline.set_progress_bar_config(disable=True)
|
869 |
|
|
|
|
|
870 |
def canonicalize(self, image, seed):
|
871 |
return inference(
|
872 |
+
self.validation_pipeline, image, self.vae, self.feature_extractor, self.image_encoder, self.unet, self.ref_unet, self.tokenizer, self.text_encoder,
|
873 |
self.pretrained_model_path, self.validation, self.width_input, self.height_input, self.unet_condition_type,
|
874 |
use_noise=self.use_noise, noise_d=self.noise_d, crop=True, seed=seed, timestep=self.timestep
|
875 |
)
|