hideosnes commited on
Commit
bd8ac0f
·
verified ·
1 Parent(s): 5c041f6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -28
app.py CHANGED
@@ -16,9 +16,8 @@ import gradio as gr
16
  from huggingface_hub import hf_hub_download, snapshot_download
17
  from ip_adapter import IPAdapterXL
18
  from safetensors.torch import load_file
19
- # remove the background
20
- from tsr.system import TSR
21
- from tsr.utils import remove_background, resize_foreground
22
 
23
  snapshot_download(
24
  repo_id="h94/IP-Adapter", allow_patterns="sdxl_models/*", local_dir="."
@@ -103,31 +102,8 @@ def resize_img(
103
  np.array(input_image)
104
  )
105
  input_image = Image.fromarray(res)
106
- # return input_image
107
- return resized_image
108
-
109
- #added
110
-
111
- def preprocess(resized_image, do_remove_background, foreground_ratio):
112
- def fill_background(image):
113
- image = np.array(image).astype(np.float32) / 255.0
114
- image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
115
- image = Image.fromarray((image * 255.0).astype(np.uint8))
116
- return image
117
-
118
- if do_remove_background:
119
- image = input_image.convert("RGB")
120
- image = remove_background(image, rembg_session)
121
- image = resize_foreground(image, foreground_ratio)
122
- image = fill_background(image)
123
- else:
124
- image = input_image
125
- if image.mode == "RGBA":
126
- image = fill_background(image)
127
  return input_image
128
 
129
- #/added
130
-
131
  examples = [
132
  [
133
  "./asset/0.jpg",
@@ -162,6 +138,43 @@ def run_for_examples(style_image, source_image, prompt, scale, control_scale):
162
  neg_content_scale=0,
163
  )
164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
  @spaces.GPU
167
  def create_image(
@@ -208,6 +221,17 @@ def create_image(
208
  cv_input_image = pil_to_cv2(input_image)
209
  detected_map = cv2.Canny(cv_input_image, 50, 200)
210
  canny_map = Image.fromarray(cv2.cvtColor(detected_map, cv2.COLOR_BGR2RGB))
 
 
 
 
 
 
 
 
 
 
 
211
  else:
212
  canny_map = Image.new("RGB", (1024, 1024), color=(255, 255, 255))
213
  control_scale = 0
@@ -215,7 +239,7 @@ def create_image(
215
  if float(control_scale) == 0:
216
  canny_map = canny_map.resize((1024, 1024))
217
 
218
- if len(neg_content_prompt) > 0 and neg_content_scale != 0:
219
  images = ip_model.generate(
220
  pil_image=image_pil,
221
  prompt=prompt,
@@ -282,7 +306,7 @@ with block:
282
  with gr.Row():
283
  with gr.Column():
284
  image_pil = gr.Image(label="Style Image", type="pil")
285
- processed_image = gr.Image(label="Preprocess uWu", interactive=False)
286
  with gr.Column():
287
  prompt = gr.Textbox(
288
  label="Prompt",
 
16
  from huggingface_hub import hf_hub_download, snapshot_download
17
  from ip_adapter import IPAdapterXL
18
  from safetensors.torch import load_file
19
+ from torchvision.models.detection import deeplabv3_resnet101
20
+ from torchvision.transforms import functional as F
 
21
 
22
  snapshot_download(
23
  repo_id="h94/IP-Adapter", allow_patterns="sdxl_models/*", local_dir="."
 
102
  np.array(input_image)
103
  )
104
  input_image = Image.fromarray(res)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  return input_image
106
 
 
 
107
  examples = [
108
  [
109
  "./asset/0.jpg",
 
138
  neg_content_scale=0,
139
  )
140
 
141
+ # Add the background removal function
142
+ def remove_background(input_image):
143
+ # Load the deep learning model
144
+ model = deeplabv3_resnet101(pretrained=True)
145
+ model.eval()
146
+
147
+ # Preprocess the image
148
+ preprocess = transforms.Compose([
149
+ transforms.ToTensor(),
150
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
151
+ ])
152
+ input_tensor = preprocess(input_image)
153
+ input_batch = input_tensor.unsqueeze(0) # Create a mini-batch as expected by the model
154
+
155
+ # Move the input and model to GPU for speed if available
156
+ if torch.cuda.is_available():
157
+ input_batch = input_batch.to('cuda')
158
+ model.to('cuda')
159
+
160
+ with torch.no_grad():
161
+ output = model(input_batch)['out'][0]
162
+ output_predictions = output.argmax(0)
163
+
164
+ # Create a binary (black and white) mask of the profile foreground
165
+ mask = output_predictions.byte().cpu().numpy()
166
+ background = np.zeros(mask.shape)
167
+ bin_mask = np.where(mask, 255, background).astype(np.uint8)
168
+
169
+ # Create a transparent foreground
170
+ b, g, r = cv2.split(np.array(input_image).astype('uint8'))
171
+ a = np.ones(bin_mask.shape, dtype='uint8') * 255
172
+ alpha_im = cv2.merge([b, g, r, a], 4)
173
+ bg = np.zeros(alpha_im.shape)
174
+ new_mask = np.stack([bin_mask, bin_mask, bin_mask, bin_mask], axis=2)
175
+ foreground = np.where(new_mask, alpha_im, bg).astype(np.uint8)
176
+
177
+ return foreground
178
 
179
  @spaces.GPU
180
  def create_image(
 
221
  cv_input_image = pil_to_cv2(input_image)
222
  detected_map = cv2.Canny(cv_input_image, 50, 200)
223
  canny_map = Image.fromarray(cv2.cvtColor(detected_map, cv2.COLOR_BGR2RGB))
224
+
225
+ # Remove background from the input image
226
+ foreground = remove_background(input_image)
227
+ # Convert the foreground back to a PIL image if necessary
228
+ foreground_pil = Image.fromarray(foreground)
229
+
230
+ # Use foreground_pil instead of input_image for further processing
231
+ # Note: You might need to adjust the following lines based on how you intend to use the foreground_pil
232
+ # For example, if you're passing it to the IP-Adapter, ensure it's in the correct format
233
+
234
+ # Continue with the existing logic for generating the image...
235
  else:
236
  canny_map = Image.new("RGB", (1024, 1024), color=(255, 255, 255))
237
  control_scale = 0
 
239
  if float(control_scale) == 0:
240
  canny_map = canny_map.resize((1024, 1024))
241
 
242
+ if len(neg_content_prompt) > 0 and neg_content_scale!= 0:
243
  images = ip_model.generate(
244
  pil_image=image_pil,
245
  prompt=prompt,
 
306
  with gr.Row():
307
  with gr.Column():
308
  image_pil = gr.Image(label="Style Image", type="pil")
309
+ # processed_image = gr.Image(label="Preprocess uWu", interactive=False)
310
  with gr.Column():
311
  prompt = gr.Textbox(
312
  label="Prompt",