Zaiiida commited on
Commit
807b1e0
·
verified ·
1 Parent(s): 61badd0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -61
app.py CHANGED
@@ -14,20 +14,23 @@ from PIL import Image
14
  import sf3d.utils as sf3d_utils
15
  from sf3d.system import SF3D
16
 
 
17
  rembg_session = rembg.new_session()
18
 
 
19
  COND_WIDTH = 512
20
  COND_HEIGHT = 512
21
  COND_DISTANCE = 1.6
22
  COND_FOVY_DEG = 40
23
  BACKGROUND_COLOR = [0.5, 0.5, 0.5]
24
 
25
- # Cached. Doesn't change
26
  c2w_cond = sf3d_utils.default_cond_c2w(COND_DISTANCE)
27
  intrinsic, intrinsic_normed_cond = sf3d_utils.create_intrinsic_from_fov_deg(
28
  COND_FOVY_DEG, COND_HEIGHT, COND_WIDTH
29
  )
30
 
 
31
  model = SF3D.from_pretrained(
32
  "stabilityai/stable-fast-3d",
33
  config_name="config.yaml",
@@ -35,11 +38,12 @@ model = SF3D.from_pretrained(
35
  )
36
  model.eval().cuda()
37
 
 
38
  example_files = [
39
  os.path.join("demo_files/examples", f) for f in os.listdir("demo_files/examples")
40
  ]
41
 
42
-
43
  def run_model(input_image):
44
  start = time.time()
45
  with torch.no_grad():
@@ -49,16 +53,13 @@ def run_model(input_image):
49
  trimesh_mesh, _glob_dict = model.generate_mesh(model_batch, 1024)
50
  trimesh_mesh = trimesh_mesh[0]
51
 
52
- # Create new tmp file
53
  tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".glb")
54
-
55
  trimesh_mesh.export(tmp_file.name, file_type="glb", include_normals=True)
56
 
57
  print("Generation took:", time.time() - start, "s")
58
-
59
  return tmp_file.name
60
 
61
-
62
  def create_batch(input_image: Image) -> dict[str, Any]:
63
  img_cond = (
64
  torch.from_numpy(
@@ -80,11 +81,10 @@ def create_batch(input_image: Image) -> dict[str, Any]:
80
  "intrinsic_cond": intrinsic.unsqueeze(0),
81
  "intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0),
82
  }
83
- # Add batch dim
84
  batched = {k: v.unsqueeze(0) for k, v in batch_elem.items()}
85
  return batched
86
 
87
-
88
  @lru_cache
89
  def checkerboard(squares: int, size: int, min_value: float = 0.5):
90
  base = np.zeros((squares, squares)) + min_value
@@ -98,11 +98,9 @@ def checkerboard(squares: int, size: int, min_value: float = 0.5):
98
  .repeat(3, axis=-1)
99
  )
100
 
101
-
102
  def remove_background(input_image: Image) -> Image:
103
  return rembg.remove(input_image, session=rembg_session)
104
 
105
-
106
  def resize_foreground(
107
  image: Image,
108
  ratio: float,
@@ -126,26 +124,25 @@ def resize_foreground(
126
  fg,
127
  ((ph0, ph1), (pw0, pw1), (0, 0)),
128
  mode="constant",
129
- constant_values=((0, 0), (0, 0), (0, 0)),
130
  )
131
 
132
  # Compute padding according to the ratio
133
  new_size = int(new_image.shape[0] / ratio)
134
- # Pad to size, double side
135
  ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
136
  ph1, pw1 = new_size - size - ph0, new_size - size - pw0
137
  new_image = np.pad(
138
  new_image,
139
  ((ph0, ph1), (pw0, pw1), (0, 0)),
140
  mode="constant",
141
- constant_values=((0, 0), (0, 0), (0, 0)),
142
  )
143
  new_image = Image.fromarray(new_image, mode="RGBA").resize(
144
  (COND_WIDTH, COND_HEIGHT)
145
  )
146
  return new_image
147
 
148
-
149
  def square_crop(input_image: Image) -> Image:
150
  # Perform a center square crop
151
  min_size = min(input_image.size)
@@ -157,7 +154,6 @@ def square_crop(input_image: Image) -> Image:
157
  (COND_WIDTH, COND_HEIGHT)
158
  )
159
 
160
-
161
  def show_mask_img(input_image: Image) -> Image:
162
  img_numpy = np.array(input_image)
163
  alpha = img_numpy[:, :, 3] / 255.0
@@ -165,11 +161,9 @@ def show_mask_img(input_image: Image) -> Image:
165
  new_img = img_numpy[..., :3] * alpha[:, :, None] + chkb * (1 - alpha[:, :, None])
166
  return Image.fromarray(new_img.astype(np.uint8), mode="RGB")
167
 
168
-
169
  def run_button(run_btn, input_image, background_state, foreground_ratio):
170
  if run_btn == "Run":
171
  glb_file: str = run_model(background_state)
172
-
173
  return (
174
  gr.update(),
175
  gr.update(),
@@ -180,10 +174,8 @@ def run_button(run_btn, input_image, background_state, foreground_ratio):
180
  )
181
  elif run_btn == "Remove Background":
182
  rem_removed = remove_background(input_image)
183
-
184
  sqr_crop = square_crop(rem_removed)
185
  fr_res = resize_foreground(sqr_crop, foreground_ratio)
186
-
187
  return (
188
  gr.update(value="Run", visible=True),
189
  sqr_crop,
@@ -193,7 +185,6 @@ def run_button(run_btn, input_image, background_state, foreground_ratio):
193
  gr.update(visible=False),
194
  )
195
 
196
-
197
  def requires_bg_remove(image, fr):
198
  if image is None:
199
  return (
@@ -228,7 +219,6 @@ def requires_bg_remove(image, fr):
228
  gr.update(visible=False),
229
  )
230
 
231
-
232
  def update_foreground_ratio(img_proc, fr):
233
  foreground_res = resize_foreground(img_proc, fr)
234
  return (
@@ -236,38 +226,27 @@ def update_foreground_ratio(img_proc, fr):
236
  gr.update(value=show_mask_img(foreground_res)),
237
  )
238
 
239
- # Generate color shades for the primary hue
240
- from gradio.themes.colors import Color
241
-
242
- primary_hue_color = Color(
243
- name="custom",
244
- c50="#e7e7e8",
245
- c100="#c2c2c4",
246
- c200="#9d9da0",
247
- c300="#78787b",
248
- c400="#525357",
249
- c500="#2d2e33",
250
- c600="#191a1e", # Base color
251
- c700="#141517",
252
- c800="#0f1012",
253
- c900="#0a0a0c",
254
- c950="#050506",
255
- )
256
-
257
- # Define the custom theme
258
  class CustomTheme(gr.themes.Base):
259
- primary_hue = primary_hue_color
260
- background_fill_primary = "#191a1e"
261
- background_fill_secondary = "#191a1e"
262
- background_fill_tertiary = "#191a1e"
263
- text_color_primary = "#FFFFFF"
264
- text_color_secondary = "#FFFFFF"
265
- text_color_tertiary = "#FFFFFF"
266
- input_background_fill = "#191a1e"
267
- input_text_color = "#FFFFFF"
 
 
 
 
 
 
268
 
 
269
  css = """
270
- /* Apply background color to the entire page */
271
  body {
272
  background-color: #191a1e !important;
273
  margin: 0;
@@ -275,43 +254,46 @@ body {
275
  }
276
 
277
  /* Apply fonts */
278
- @import url('https://fonts.googleapis.com/css2?family=Poppins:wght@400;500;700&display=swap');
279
  body, input, button, textarea, select, .gr-button {
280
  font-family: 'Poppins', sans-serif;
281
  color: #FFFFFF;
282
  }
283
 
284
- /* Header settings */
285
  h1, h2, h3, h4, h5, h6 {
286
  font-family: 'Poppins', sans-serif;
287
  font-weight: 700;
288
  color: #FFFFFF;
289
  }
290
 
291
- /* Button colors */
292
- .gr-button {
293
  background-color: #5271FF !important;
294
  color: #FFFFFF !important;
295
  border: none;
296
  font-weight: bold;
297
  }
298
 
299
- /* Container colors */
300
  .gr-block {
301
  background-color: #1c1c24 !important;
302
  border: 1px solid #5271FF !important;
303
  }
304
  """
305
 
 
306
  with gr.Blocks(theme=CustomTheme(), css=css) as demo:
307
  img_proc_state = gr.State()
308
  background_remove_state = gr.State()
309
-
310
  with gr.Row(variant="panel"):
311
  with gr.Column():
312
  with gr.Row():
313
  input_img = gr.Image(
314
- type="pil", label="Input Image", source="upload", image_mode="RGBA"
 
 
 
315
  )
316
  preview_removal = gr.Image(
317
  label="Preview Background Removal",
@@ -335,7 +317,12 @@ with gr.Blocks(theme=CustomTheme(), css=css) as demo:
335
  outputs=[background_remove_state, preview_removal],
336
  )
337
 
338
- run_btn = gr.Button("Run", variant="primary", visible=False)
 
 
 
 
 
339
 
340
  with gr.Column():
341
  output_3d = LitModel3D(
@@ -347,10 +334,13 @@ with gr.Blocks(theme=CustomTheme(), css=css) as demo:
347
  scale=1.0,
348
  )
349
  with gr.Column(visible=False, scale=1.0) as hdr_row:
350
- gr.Markdown("""## HDR Environment Map
 
 
351
 
352
- Select an HDR environment map to light the 3D model. You can also upload your own HDR environment maps.
353
- """)
 
354
 
355
  with gr.Row():
356
  hdr_illumination_file = gr.File(
@@ -407,4 +397,5 @@ Select an HDR environment map to light the 3D model. You can also upload your ow
407
  ],
408
  )
409
 
 
410
  demo.launch()
 
14
  import sf3d.utils as sf3d_utils
15
  from sf3d.system import SF3D
16
 
17
+ # Initialize the rembg session
18
  rembg_session = rembg.new_session()
19
 
20
+ # Constants
21
  COND_WIDTH = 512
22
  COND_HEIGHT = 512
23
  COND_DISTANCE = 1.6
24
  COND_FOVY_DEG = 40
25
  BACKGROUND_COLOR = [0.5, 0.5, 0.5]
26
 
27
+ # Cached camera parameters
28
  c2w_cond = sf3d_utils.default_cond_c2w(COND_DISTANCE)
29
  intrinsic, intrinsic_normed_cond = sf3d_utils.create_intrinsic_from_fov_deg(
30
  COND_FOVY_DEG, COND_HEIGHT, COND_WIDTH
31
  )
32
 
33
+ # Load the model
34
  model = SF3D.from_pretrained(
35
  "stabilityai/stable-fast-3d",
36
  config_name="config.yaml",
 
38
  )
39
  model.eval().cuda()
40
 
41
+ # Load example files
42
  example_files = [
43
  os.path.join("demo_files/examples", f) for f in os.listdir("demo_files/examples")
44
  ]
45
 
46
+ # Define functions
47
  def run_model(input_image):
48
  start = time.time()
49
  with torch.no_grad():
 
53
  trimesh_mesh, _glob_dict = model.generate_mesh(model_batch, 1024)
54
  trimesh_mesh = trimesh_mesh[0]
55
 
56
+ # Create new temporary file
57
  tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".glb")
 
58
  trimesh_mesh.export(tmp_file.name, file_type="glb", include_normals=True)
59
 
60
  print("Generation took:", time.time() - start, "s")
 
61
  return tmp_file.name
62
 
 
63
  def create_batch(input_image: Image) -> dict[str, Any]:
64
  img_cond = (
65
  torch.from_numpy(
 
81
  "intrinsic_cond": intrinsic.unsqueeze(0),
82
  "intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0),
83
  }
84
+ # Add batch dimension
85
  batched = {k: v.unsqueeze(0) for k, v in batch_elem.items()}
86
  return batched
87
 
 
88
  @lru_cache
89
  def checkerboard(squares: int, size: int, min_value: float = 0.5):
90
  base = np.zeros((squares, squares)) + min_value
 
98
  .repeat(3, axis=-1)
99
  )
100
 
 
101
  def remove_background(input_image: Image) -> Image:
102
  return rembg.remove(input_image, session=rembg_session)
103
 
 
104
  def resize_foreground(
105
  image: Image,
106
  ratio: float,
 
124
  fg,
125
  ((ph0, ph1), (pw0, pw1), (0, 0)),
126
  mode="constant",
127
+ constant_values=0,
128
  )
129
 
130
  # Compute padding according to the ratio
131
  new_size = int(new_image.shape[0] / ratio)
132
+ # Pad to new size
133
  ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
134
  ph1, pw1 = new_size - size - ph0, new_size - size - pw0
135
  new_image = np.pad(
136
  new_image,
137
  ((ph0, ph1), (pw0, pw1), (0, 0)),
138
  mode="constant",
139
+ constant_values=0,
140
  )
141
  new_image = Image.fromarray(new_image, mode="RGBA").resize(
142
  (COND_WIDTH, COND_HEIGHT)
143
  )
144
  return new_image
145
 
 
146
  def square_crop(input_image: Image) -> Image:
147
  # Perform a center square crop
148
  min_size = min(input_image.size)
 
154
  (COND_WIDTH, COND_HEIGHT)
155
  )
156
 
 
157
  def show_mask_img(input_image: Image) -> Image:
158
  img_numpy = np.array(input_image)
159
  alpha = img_numpy[:, :, 3] / 255.0
 
161
  new_img = img_numpy[..., :3] * alpha[:, :, None] + chkb * (1 - alpha[:, :, None])
162
  return Image.fromarray(new_img.astype(np.uint8), mode="RGB")
163
 
 
164
  def run_button(run_btn, input_image, background_state, foreground_ratio):
165
  if run_btn == "Run":
166
  glb_file: str = run_model(background_state)
 
167
  return (
168
  gr.update(),
169
  gr.update(),
 
174
  )
175
  elif run_btn == "Remove Background":
176
  rem_removed = remove_background(input_image)
 
177
  sqr_crop = square_crop(rem_removed)
178
  fr_res = resize_foreground(sqr_crop, foreground_ratio)
 
179
  return (
180
  gr.update(value="Run", visible=True),
181
  sqr_crop,
 
185
  gr.update(visible=False),
186
  )
187
 
 
188
  def requires_bg_remove(image, fr):
189
  if image is None:
190
  return (
 
219
  gr.update(visible=False),
220
  )
221
 
 
222
  def update_foreground_ratio(img_proc, fr):
223
  foreground_res = resize_foreground(img_proc, fr)
224
  return (
 
226
  gr.update(value=show_mask_img(foreground_res)),
227
  )
228
 
229
+ # Define custom theme
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  class CustomTheme(gr.themes.Base):
231
+ def __init__(self):
232
+ super().__init__()
233
+ self.primary_hue = "#191a1e"
234
+ self.background_fill_primary = "#191a1e"
235
+ self.background_fill_secondary = "#191a1e"
236
+ self.background_fill_tertiary = "#191a1e"
237
+ self.text_color_primary = "#FFFFFF"
238
+ self.text_color_secondary = "#FFFFFF"
239
+ self.text_color_tertiary = "#FFFFFF"
240
+ self.input_background_fill = "#191a1e"
241
+ self.input_text_color = "#FFFFFF"
242
+ self.font = (
243
+ "Poppins",
244
+ "https://fonts.googleapis.com/css2?family=Poppins:wght@400;500;700&display=swap",
245
+ )
246
 
247
+ # Custom CSS
248
  css = """
249
+ /* Set background color for the entire page */
250
  body {
251
  background-color: #191a1e !important;
252
  margin: 0;
 
254
  }
255
 
256
  /* Apply fonts */
 
257
  body, input, button, textarea, select, .gr-button {
258
  font-family: 'Poppins', sans-serif;
259
  color: #FFFFFF;
260
  }
261
 
262
+ /* Header styles */
263
  h1, h2, h3, h4, h5, h6 {
264
  font-family: 'Poppins', sans-serif;
265
  font-weight: 700;
266
  color: #FFFFFF;
267
  }
268
 
269
+ /* Button styles */
270
+ .generate-button {
271
  background-color: #5271FF !important;
272
  color: #FFFFFF !important;
273
  border: none;
274
  font-weight: bold;
275
  }
276
 
277
+ /* Container styles */
278
  .gr-block {
279
  background-color: #1c1c24 !important;
280
  border: 1px solid #5271FF !important;
281
  }
282
  """
283
 
284
+ # Build the Gradio interface
285
  with gr.Blocks(theme=CustomTheme(), css=css) as demo:
286
  img_proc_state = gr.State()
287
  background_remove_state = gr.State()
288
+
289
  with gr.Row(variant="panel"):
290
  with gr.Column():
291
  with gr.Row():
292
  input_img = gr.Image(
293
+ type="pil",
294
+ label="Input Image",
295
+ sources="upload",
296
+ image_mode="RGBA",
297
  )
298
  preview_removal = gr.Image(
299
  label="Preview Background Removal",
 
317
  outputs=[background_remove_state, preview_removal],
318
  )
319
 
320
+ run_btn = gr.Button(
321
+ "Run",
322
+ variant="primary",
323
+ visible=False,
324
+ elem_classes="generate-button",
325
+ )
326
 
327
  with gr.Column():
328
  output_3d = LitModel3D(
 
334
  scale=1.0,
335
  )
336
  with gr.Column(visible=False, scale=1.0) as hdr_row:
337
+ gr.Markdown(
338
+ """
339
+ ## HDR Environment Map
340
 
341
+ Select an HDR environment map to light the 3D model. You can also upload your own HDR environment maps.
342
+ """
343
+ )
344
 
345
  with gr.Row():
346
  hdr_illumination_file = gr.File(
 
397
  ],
398
  )
399
 
400
+ # Launch the interface
401
  demo.launch()