YulianSa commited on
Commit
8d53de2
·
1 Parent(s): 0aa1fe9
Files changed (1) hide show
  1. infer_api.py +27 -31
infer_api.py CHANGED
@@ -108,34 +108,32 @@ def set_seed(seed):
108
  torch.manual_seed(seed)
109
  torch.cuda.manual_seed_all(seed)
110
 
111
- class BkgRemover:
112
- def __init__(self, force_cpu: Optional[bool] = True):
113
- session_infer_path = hf_hub_download(
114
- repo_id="skytnt/anime-seg", filename="isnetis.onnx",
115
- )
116
- providers: list[str] = ["CPUExecutionProvider"]
117
- if not force_cpu and "CUDAExecutionProvider" in rt.get_available_providers():
118
- providers = ["CUDAExecutionProvider"]
119
 
120
- self.session_infer = rt.InferenceSession(
121
- session_infer_path, providers=providers,
122
- )
 
 
 
123
 
124
- @spaces.GPU
125
- def remove_background(
126
- self,
127
- img: np.ndarray,
128
- alpha_min: float,
129
- alpha_max: float,
130
- ) -> list:
131
- img = np.array(img)
132
- mask = get_mask(self.session_infer, img)
133
- mask[mask < alpha_min] = 0.0
134
- mask[mask > alpha_max] = 1.0
135
- img_after = (mask * img).astype(np.uint8)
136
- mask = (mask * SCALE).astype(np.uint8)
137
- img_after = np.concatenate([img_after, mask], axis=2, dtype=np.uint8)
138
- return Image.fromarray(img_after)
 
 
 
139
 
140
 
141
  def process_image(image, totensor, width, height):
@@ -168,7 +166,7 @@ def process_image(image, totensor, width, height):
168
 
169
  @spaces.GPU
170
  @torch.no_grad()
171
- def inference(validation_pipeline, bkg_remover, input_image, vae, feature_extractor, image_encoder, unet, ref_unet, tokenizer,
172
  text_encoder, pretrained_model_path, validation, val_width, val_height, unet_condition_type,
173
  use_noise=True, noise_d=256, crop=False, seed=100, timestep=20):
174
  set_seed(seed)
@@ -186,7 +184,7 @@ def inference(validation_pipeline, bkg_remover, input_image, vae, feature_extrac
186
  B = 1
187
  if input_image.mode != "RGBA":
188
  # remove background
189
- input_image = bkg_remover.remove_background(input_image, 0.1, 0.9)
190
  imgs_in = process_image(input_image, totensor, val_width, val_height)
191
  imgs_in = rearrange(imgs_in.unsqueeze(0).unsqueeze(0), "B Nv C H W -> (B Nv) C H W")
192
 
@@ -869,11 +867,9 @@ class InferCanonicalAPI:
869
  )
870
  self.validation_pipeline.set_progress_bar_config(disable=True)
871
 
872
- self.bkg_remover = BkgRemover()
873
-
874
  def canonicalize(self, image, seed):
875
  return inference(
876
- self.validation_pipeline, self.bkg_remover, image, self.vae, self.feature_extractor, self.image_encoder, self.unet, self.ref_unet, self.tokenizer, self.text_encoder,
877
  self.pretrained_model_path, self.validation, self.width_input, self.height_input, self.unet_condition_type,
878
  use_noise=self.use_noise, noise_d=self.noise_d, crop=True, seed=seed, timestep=self.timestep
879
  )
 
108
  torch.manual_seed(seed)
109
  torch.cuda.manual_seed_all(seed)
110
 
 
 
 
 
 
 
 
 
111
 
112
+ session_infer_path = hf_hub_download(
113
+ repo_id="skytnt/anime-seg", filename="isnetis.onnx",
114
+ )
115
+ providers: list[str] = ["CPUExecutionProvider"]
116
+ if "CUDAExecutionProvider" in rt.get_available_providers():
117
+ providers = ["CUDAExecutionProvider"]
118
 
119
+ bkg_remover_session_infer = rt.InferenceSession(
120
+ session_infer_path, providers=providers,
121
+ )
122
+
123
+ @spaces.GPU
124
+ def remove_background(
125
+ img: np.ndarray,
126
+ alpha_min: float,
127
+ alpha_max: float,
128
+ ) -> list:
129
+ img = np.array(img)
130
+ mask = get_mask(bkg_remover_session_infer, img)
131
+ mask[mask < alpha_min] = 0.0
132
+ mask[mask > alpha_max] = 1.0
133
+ img_after = (mask * img).astype(np.uint8)
134
+ mask = (mask * SCALE).astype(np.uint8)
135
+ img_after = np.concatenate([img_after, mask], axis=2, dtype=np.uint8)
136
+ return Image.fromarray(img_after)
137
 
138
 
139
  def process_image(image, totensor, width, height):
 
166
 
167
  @spaces.GPU
168
  @torch.no_grad()
169
+ def inference(validation_pipeline, input_image, vae, feature_extractor, image_encoder, unet, ref_unet, tokenizer,
170
  text_encoder, pretrained_model_path, validation, val_width, val_height, unet_condition_type,
171
  use_noise=True, noise_d=256, crop=False, seed=100, timestep=20):
172
  set_seed(seed)
 
184
  B = 1
185
  if input_image.mode != "RGBA":
186
  # remove background
187
+ input_image = remove_background(input_image, 0.1, 0.9)
188
  imgs_in = process_image(input_image, totensor, val_width, val_height)
189
  imgs_in = rearrange(imgs_in.unsqueeze(0).unsqueeze(0), "B Nv C H W -> (B Nv) C H W")
190
 
 
867
  )
868
  self.validation_pipeline.set_progress_bar_config(disable=True)
869
 
 
 
870
  def canonicalize(self, image, seed):
871
  return inference(
872
+ self.validation_pipeline, image, self.vae, self.feature_extractor, self.image_encoder, self.unet, self.ref_unet, self.tokenizer, self.text_encoder,
873
  self.pretrained_model_path, self.validation, self.width_input, self.height_input, self.unet_condition_type,
874
  use_noise=self.use_noise, noise_d=self.noise_d, crop=True, seed=seed, timestep=self.timestep
875
  )