abreza commited on
Commit
07ace78
·
1 Parent(s): 4d6f443
Files changed (2) hide show
  1. app.py +91 -27
  2. requirements.txt +8 -3
app.py CHANGED
@@ -1,18 +1,21 @@
1
  import os
2
  import shutil
3
  import tempfile
 
 
4
 
5
  import gradio as gr
6
  import numpy as np
7
  import rembg
8
  import spaces
9
  import torch
10
- from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
11
  from einops import rearrange
12
  from huggingface_hub import hf_hub_download
13
  from omegaconf import OmegaConf
14
  from PIL import Image
15
  from pytorch_lightning import seed_everything
 
16
  from torchvision.transforms import v2
17
  from tqdm import tqdm
18
 
@@ -22,6 +25,26 @@ from src.utils.infer_util import (remove_background, resize_foreground)
22
  from src.utils.mesh_util import save_glb, save_obj
23
  from src.utils.train_util import instantiate_from_config
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  def find_cuda():
27
  cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
@@ -52,7 +75,7 @@ def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexi
52
 
53
  def check_input_image(input_image):
54
  if input_image is None:
55
- raise gr.Error("No image uploaded!")
56
 
57
 
58
  def preprocess(input_image, do_remove_background):
@@ -125,6 +148,21 @@ def make3d(images):
125
  return mesh_fpath, mesh_glb_fpath
126
 
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  # Configuration
129
  cuda_path = find_cuda()
130
  config_path = 'configs/instant-mesh-large.yaml'
@@ -166,6 +204,21 @@ model.load_state_dict(state_dict, strict=True)
166
 
167
  model = model.to(device)
168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  print('Loading Finished!')
170
 
171
  # Gradio UI
@@ -173,19 +226,28 @@ with gr.Blocks() as demo:
173
  with gr.Row(variant="panel"):
174
  with gr.Column():
175
  with gr.Row():
176
- input_image = gr.Image(
177
- label="Input Image",
178
- image_mode="RGBA",
179
- sources="upload",
180
- type="pil",
181
- elem_id="content_image",
182
- )
183
- processed_image = gr.Image(
184
- label="Processed Image",
 
 
 
 
 
 
 
 
185
  image_mode="RGBA",
186
  type="pil",
187
  interactive=False
188
  )
 
189
  with gr.Row():
190
  with gr.Group():
191
  do_remove_background = gr.Checkbox(
@@ -196,18 +258,8 @@ with gr.Blocks() as demo:
196
  label="Sample Steps", minimum=30, maximum=75, value=75, step=5)
197
 
198
  with gr.Row():
199
- submit = gr.Button(
200
- "Generate", elem_id="generate", variant="primary")
201
-
202
- with gr.Row(variant="panel"):
203
- gr.Examples(
204
- examples=[os.path.join("examples", img_name)
205
- for img_name in sorted(os.listdir("examples"))],
206
- inputs=[input_image],
207
- label="Examples",
208
- cache_examples=False,
209
- examples_per_page=16
210
- )
211
 
212
  with gr.Column():
213
  with gr.Row():
@@ -241,13 +293,25 @@ with gr.Blocks() as demo:
241
 
242
  mv_images = gr.State()
243
 
244
- submit.click(fn=check_input_image, inputs=[input_image]).success(
 
 
 
 
 
 
 
 
 
 
 
 
245
  fn=preprocess,
246
- inputs=[input_image, do_remove_background],
247
- outputs=[processed_image],
248
  ).success(
249
  fn=generate_mvs,
250
- inputs=[processed_image, sample_steps, sample_seed],
251
  outputs=[mv_images, mv_show_images]
252
  ).success(
253
  fn=make3d,
 
1
  import os
2
  import shutil
3
  import tempfile
4
+ import time
5
+ from os import path
6
 
7
  import gradio as gr
8
  import numpy as np
9
  import rembg
10
  import spaces
11
  import torch
12
+ from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler, StableDiffusionXLPipeline, LCMScheduler
13
  from einops import rearrange
14
  from huggingface_hub import hf_hub_download
15
  from omegaconf import OmegaConf
16
  from PIL import Image
17
  from pytorch_lightning import seed_everything
18
+ from safetensors.torch import load_file
19
  from torchvision.transforms import v2
20
  from tqdm import tqdm
21
 
 
25
  from src.utils.mesh_util import save_glb, save_obj
26
  from src.utils.train_util import instantiate_from_config
27
 
28
+ cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
29
+ os.environ["TRANSFORMERS_CACHE"] = cache_path
30
+ os.environ["HF_HUB_CACHE"] = cache_path
31
+ os.environ["HF_HOME"] = cache_path
32
+
33
+ torch.backends.cuda.matmul.allow_tf32 = True
34
+
35
+
36
+ class timer:
37
+ def __init__(self, method_name="timed process"):
38
+ self.method = method_name
39
+
40
+ def __enter__(self):
41
+ self.start = time.time()
42
+ print(f"{self.method} starts")
43
+
44
+ def __exit__(self, exc_type, exc_val, exc_tb):
45
+ end = time.time()
46
+ print(f"{self.method} took {str(round(end - self.start, 2))}s")
47
+
48
 
49
  def find_cuda():
50
  cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
 
75
 
76
  def check_input_image(input_image):
77
  if input_image is None:
78
+ raise gr.Error("No image selected!")
79
 
80
 
81
  def preprocess(input_image, do_remove_background):
 
148
  return mesh_fpath, mesh_glb_fpath
149
 
150
 
151
+ @spaces.GPU
152
+ def process_image(num_images, height, width, prompt, seed):
153
+ global pipe
154
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"):
155
+ return pipe(
156
+ prompt=[prompt]*num_images,
157
+ generator=torch.Generator().manual_seed(int(seed)),
158
+ num_inference_steps=1,
159
+ guidance_scale=0.,
160
+ height=int(height),
161
+ width=int(width),
162
+ timesteps=[800]
163
+ ).images
164
+
165
+
166
  # Configuration
167
  cuda_path = find_cuda()
168
  config_path = 'configs/instant-mesh-large.yaml'
 
204
 
205
  model = model.to(device)
206
 
207
+ # Load text-to-image model
208
+ print('Loading text-to-image model ...')
209
+ if not path.exists(cache_path):
210
+ os.makedirs(cache_path, exist_ok=True)
211
+
212
+ pipe = StableDiffusionXLPipeline.from_pretrained(
213
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16)
214
+ pipe.to(device="cuda", dtype=torch.bfloat16)
215
+
216
+ unet_state = load_file(hf_hub_download(
217
+ "ByteDance/Hyper-SD", "Hyper-SDXL-1step-Unet.safetensors"), device="cuda")
218
+ pipe.unet.load_state_dict(unet_state)
219
+ pipe.scheduler = LCMScheduler.from_config(
220
+ pipe.scheduler.config, timestep_spacing="trailing")
221
+
222
  print('Loading Finished!')
223
 
224
  # Gradio UI
 
226
  with gr.Row(variant="panel"):
227
  with gr.Column():
228
  with gr.Row():
229
+ num_images = gr.Slider(
230
+ label="Number of Images", minimum=1, maximum=8, step=1, value=4, interactive=True)
231
+ height = gr.Number(label="Image Height",
232
+ value=1024, interactive=True)
233
+ width = gr.Number(label="Image Width",
234
+ value=1024, interactive=True)
235
+ prompt = gr.Text(
236
+ label="Prompt", value="a photo of a cat", interactive=True)
237
+ seed = gr.Number(label="Seed", value=3413, interactive=True)
238
+ generate_2d_btn = gr.Button(value="Generate 2D Images")
239
+
240
+ with gr.Row():
241
+ generated_images = gr.Gallery(height=1024)
242
+
243
+ with gr.Row():
244
+ selected_image = gr.Image(
245
+ label="Selected Image",
246
  image_mode="RGBA",
247
  type="pil",
248
  interactive=False
249
  )
250
+
251
  with gr.Row():
252
  with gr.Group():
253
  do_remove_background = gr.Checkbox(
 
258
  label="Sample Steps", minimum=30, maximum=75, value=75, step=5)
259
 
260
  with gr.Row():
261
+ generate_3d_btn = gr.Button(
262
+ "Generate 3D Model", elem_id="generate", variant="primary")
 
 
 
 
 
 
 
 
 
 
263
 
264
  with gr.Column():
265
  with gr.Row():
 
293
 
294
  mv_images = gr.State()
295
 
296
+ generate_2d_btn.click(
297
+ fn=process_image,
298
+ inputs=[num_images, height, width, prompt, seed],
299
+ outputs=[generated_images]
300
+ )
301
+
302
+ generated_images.select(
303
+ fn=lambda x: x,
304
+ inputs=[generated_images],
305
+ outputs=[selected_image]
306
+ )
307
+
308
+ generate_3d_btn.click(fn=check_input_image, inputs=[selected_image]).success(
309
  fn=preprocess,
310
+ inputs=[selected_image, do_remove_background],
311
+ outputs=[selected_image],
312
  ).success(
313
  fn=generate_mvs,
314
+ inputs=[selected_image, sample_steps, sample_seed],
315
  outputs=[mv_images, mv_show_images]
316
  ).success(
317
  fn=make3d,
requirements.txt CHANGED
@@ -12,12 +12,17 @@ tensorboard
12
  PyMCubes
13
  trimesh
14
  rembg
15
- transformers==4.34.1
16
- diffusers==0.19.3
17
  bitsandbytes
18
  imageio[ffmpeg]
19
  xatlas
20
  plyfile
21
  xformers==0.0.22.post7
22
  git+https://github.com/NVlabs/nvdiffrast/
23
- huggingface-hub
 
 
 
 
 
 
12
  PyMCubes
13
  trimesh
14
  rembg
15
+ transformers==4.38.2
16
+ diffusers==0.25.0
17
  bitsandbytes
18
  imageio[ffmpeg]
19
  xatlas
20
  plyfile
21
  xformers==0.0.22.post7
22
  git+https://github.com/NVlabs/nvdiffrast/
23
+ huggingface-hub
24
+
25
+ httpx==0.23.0
26
+ flask
27
+ pillow
28
+ safetensors