JiantaoLin commited on
Commit
23ca1bc
Β·
1 Parent(s): 3c4968d
Files changed (2) hide show
  1. app.py +26 -26
  2. app_demo.py +0 -384
app.py CHANGED
@@ -429,42 +429,42 @@ with gr.Blocks(css="""
429
 
430
  # file_output1 = gr.File()
431
 
432
- with gr.TabItem('Image-to-3D', id='tab_image_to_3d'):
433
- with gr.Row():
434
- with gr.Column(scale=1):
435
- image = gr.Image(label="Input Image", type="pil")
436
 
437
- seed2 = gr.Number(value=10, label="Seed (0 for random)")
438
 
439
- btn_img2mesh_preprocess = gr.Button("Preprocess Image")
440
 
441
- image_caption = gr.Textbox(value="", label="Image Caption", placeholder="caption will be generated here base on your input image. You can also edit this caption", lines=4, interactive=True)
442
 
443
- with gr.Accordion(label="Extra Settings", open=False):
444
- output_image2 = gr.Image(label="Generated image", interactive=False)
445
- strength1 = gr.Slider(minimum=0, maximum=1.0, step=0.01, value=0.5, label="redux strength")
446
- strength2 = gr.Slider(minimum=0, maximum=1.0, step=0.01, value=0.95, label="denoise strength")
447
- enable_redux = gr.Checkbox(label="enable redux", value=True)
448
- use_controlnet = gr.Checkbox(label="enable controlnet", value=True)
449
 
450
- btn_img2mesh_main = gr.Button("Generate Mesh")
451
 
452
- with gr.Column(scale=1):
453
 
454
- # output_mesh2 = gr.Model3D(label="Generated Mesh", interactive=False)
455
- output_image3 = gr.Image(label="Final Bundle Image", interactive=False)
456
- output_video2 = gr.Video(label="Generated Video", interactive=False, loop=True, autoplay=True)
457
- # btn_download2 = gr.Button("Download Mesh")
458
- download_2 = gr.DownloadButton(label="Download mesh", interactive=False)
459
- # file_output2 = gr.File()
460
 
461
  # Image2
462
- btn_img2mesh_preprocess.click(fn=image2mesh_preprocess_, inputs=[image, seed2], outputs=[output_image2, image_caption])
463
 
464
- btn_img2mesh_main.click(fn=image2mesh_main_, inputs=[output_image2, image_caption, seed2, strength1, strength2, enable_redux, use_controlnet], outputs=[output_image3, output_video2, download_2]).then(
465
- lambda: gr.Button(interactive=True),
466
- outputs=[download_2],
467
- )
468
 
469
 
470
  # btn_download1.click(fn=save_cached_mesh, inputs=[], outputs=file_output1)
 
429
 
430
  # file_output1 = gr.File()
431
 
432
+ # with gr.TabItem('Image-to-3D', id='tab_image_to_3d'):
433
+ # with gr.Row():
434
+ # with gr.Column(scale=1):
435
+ # image = gr.Image(label="Input Image", type="pil")
436
 
437
+ # seed2 = gr.Number(value=10, label="Seed (0 for random)")
438
 
439
+ # btn_img2mesh_preprocess = gr.Button("Preprocess Image")
440
 
441
+ # image_caption = gr.Textbox(value="", label="Image Caption", placeholder="caption will be generated here base on your input image. You can also edit this caption", lines=4, interactive=True)
442
 
443
+ # with gr.Accordion(label="Extra Settings", open=False):
444
+ # output_image2 = gr.Image(label="Generated image", interactive=False)
445
+ # strength1 = gr.Slider(minimum=0, maximum=1.0, step=0.01, value=0.5, label="redux strength")
446
+ # strength2 = gr.Slider(minimum=0, maximum=1.0, step=0.01, value=0.95, label="denoise strength")
447
+ # enable_redux = gr.Checkbox(label="enable redux", value=True)
448
+ # use_controlnet = gr.Checkbox(label="enable controlnet", value=True)
449
 
450
+ # btn_img2mesh_main = gr.Button("Generate Mesh")
451
 
452
+ # with gr.Column(scale=1):
453
 
454
+ # # output_mesh2 = gr.Model3D(label="Generated Mesh", interactive=False)
455
+ # output_image3 = gr.Image(label="Final Bundle Image", interactive=False)
456
+ # output_video2 = gr.Video(label="Generated Video", interactive=False, loop=True, autoplay=True)
457
+ # # btn_download2 = gr.Button("Download Mesh")
458
+ # download_2 = gr.DownloadButton(label="Download mesh", interactive=False)
459
+ # # file_output2 = gr.File()
460
 
461
  # Image2
462
+ # btn_img2mesh_preprocess.click(fn=image2mesh_preprocess_, inputs=[image, seed2], outputs=[output_image2, image_caption])
463
 
464
+ # btn_img2mesh_main.click(fn=image2mesh_main_, inputs=[output_image2, image_caption, seed2, strength1, strength2, enable_redux, use_controlnet], outputs=[output_image3, output_video2, download_2]).then(
465
+ # lambda: gr.Button(interactive=True),
466
+ # outputs=[download_2],
467
+ # )
468
 
469
 
470
  # btn_download1.click(fn=save_cached_mesh, inputs=[], outputs=file_output1)
app_demo.py DELETED
@@ -1,384 +0,0 @@
1
- import gradio as gr
2
- import os
3
- import subprocess
4
- import shlex
5
- import spaces
6
- import torch
7
- access_token = os.getenv("HUGGINGFACE_TOKEN")
8
- subprocess.run(
9
- shlex.split(
10
- "pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt240/download.html"
11
- )
12
- )
13
-
14
- subprocess.run(
15
- shlex.split(
16
- "pip install ./extension/nvdiffrast-0.3.1+torch-py3-none-any.whl --force-reinstall --no-deps"
17
- )
18
- )
19
-
20
- subprocess.run(
21
- shlex.split(
22
- "pip install ./extension/renderutils_plugin-0.1.0-cp310-cp310-linux_x86_64.whl --force-reinstall --no-deps"
23
- )
24
- )
25
- def install_cuda_toolkit():
26
- # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run"
27
- # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.2.0/local_installers/cuda_12.2.0_535.54.03_linux.run"
28
- CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.1.0/local_installers/cuda_12.1.0_530.30.02_linux.run"
29
- CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL)
30
- subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE])
31
- subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE])
32
- subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"])
33
-
34
- os.environ["CUDA_HOME"] = "/usr/local/cuda"
35
- os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"])
36
- os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % (
37
- os.environ["CUDA_HOME"],
38
- "" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"],
39
- )
40
- # Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range
41
- os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
42
- print("==> finfish install")
43
- install_cuda_toolkit()
44
- @spaces.GPU
45
- def check_gpu():
46
- os.environ['CUDA_HOME'] = '/usr/local/cuda-12.1'
47
- os.environ['PATH'] += ':/usr/local/cuda-12.1/bin'
48
- # os.environ['LD_LIBRARY_PATH'] += ':/usr/local/cuda-12.1/lib64'
49
- os.environ['LD_LIBRARY_PATH'] = "/usr/local/cuda-12.1/lib64:" + os.environ.get('LD_LIBRARY_PATH', '')
50
- subprocess.run(['nvidia-smi']) # ζ΅‹θ―• CUDA ζ˜―ε¦ε―η”¨
51
- print(f"torch.cuda.is_available:{torch.cuda.is_available()}")
52
- check_gpu()
53
-
54
- from PIL import Image
55
- from einops import rearrange
56
- from diffusers import FluxPipeline
57
- from models.lrm.utils.camera_util import get_flux_input_cameras
58
- from models.lrm.utils.infer_util import save_video
59
- from models.lrm.utils.mesh_util import save_obj, save_obj_with_mtl
60
- from models.lrm.utils.render_utils import rotate_x, rotate_y
61
- from models.lrm.utils.train_util import instantiate_from_config
62
- from models.ISOMER.reconstruction_func import reconstruction
63
- from models.ISOMER.projection_func import projection
64
- import os
65
- from einops import rearrange
66
- from omegaconf import OmegaConf
67
- import torch
68
- import numpy as np
69
- import trimesh
70
- import torchvision
71
- import torch.nn.functional as F
72
- from PIL import Image
73
- from torchvision import transforms
74
- from torchvision.transforms import v2
75
- from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, AutoencoderTiny, AutoencoderKL
76
- from transformers import CLIPTextModel, CLIPTokenizer,T5EncoderModel, T5TokenizerFast
77
- from diffusers import FluxPipeline
78
- from pytorch_lightning import seed_everything
79
- import os
80
- from huggingface_hub import hf_hub_download
81
-
82
-
83
- from utils.tool import NormalTransfer, get_background, get_render_cameras_video, load_mipmap, render_frames
84
-
85
- device_0 = "cuda"
86
- device_1 = "cuda"
87
- resolution = 512
88
- save_dir = "./outputs"
89
- normal_transfer = NormalTransfer()
90
- isomer_azimuths = torch.from_numpy(np.array([0, 90, 180, 270])).float().to(device_1)
91
- isomer_elevations = torch.from_numpy(np.array([5, 5, 5, 5])).float().to(device_1)
92
- isomer_radius = 4.5
93
- isomer_geo_weights = torch.from_numpy(np.array([1, 0.9, 1, 0.9])).float().to(device_1)
94
- isomer_color_weights = torch.from_numpy(np.array([1, 0.5, 1, 0.5])).float().to(device_1)
95
-
96
- # model initialization and loading
97
- # flux
98
- # # taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=torch.bfloat16).to(device_0)
99
- # # good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16, token=access_token).to(device_0)
100
- # flux_pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, token=access_token).to(device=device_0, dtype=torch.bfloat16)
101
- # # flux_pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, vae=taef1, token=access_token).to(device_0)
102
- # flux_lora_ckpt_path = hf_hub_download(repo_id="LTT/xxx-ckpt", filename="rgb_normal_large.safetensors", repo_type="model", token=access_token)
103
- # flux_pipe.load_lora_weights(flux_lora_ckpt_path)
104
- # flux_pipe.to(device=device_0, dtype=torch.bfloat16)
105
- # torch.cuda.empty_cache()
106
- # flux_pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(flux_pipe)
107
-
108
-
109
- # lrm
110
- config = OmegaConf.load("./models/lrm/config/PRM_inference.yaml")
111
- model_config = config.model_config
112
- infer_config = config.infer_config
113
- model = instantiate_from_config(model_config)
114
- model_ckpt_path = hf_hub_download(repo_id="LTT/PRM", filename="final_ckpt.ckpt", repo_type="model")
115
- state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
116
- state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
117
- model.load_state_dict(state_dict, strict=True)
118
- model = model.to(device_1)
119
- torch.cuda.empty_cache()
120
- @spaces.GPU
121
- def lrm_reconstructions(image, input_cameras, save_path=None, name="temp", export_texmap=False, if_save_video=False):
122
- images = image.unsqueeze(0).to(device_1)
123
- images = v2.functional.resize(images, 512, interpolation=3, antialias=True).clamp(0, 1)
124
- # breakpoint()
125
- with torch.no_grad():
126
- # get triplane
127
- planes = model.forward_planes(images, input_cameras)
128
-
129
- mesh_path_idx = os.path.join(save_path, f'{name}.obj')
130
-
131
- mesh_out = model.extract_mesh(
132
- planes,
133
- use_texture_map=export_texmap,
134
- **infer_config,
135
- )
136
- if export_texmap:
137
- vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out
138
- save_obj_with_mtl(
139
- vertices.data.cpu().numpy(),
140
- uvs.data.cpu().numpy(),
141
- faces.data.cpu().numpy(),
142
- mesh_tex_idx.data.cpu().numpy(),
143
- tex_map.permute(1, 2, 0).data.cpu().numpy(),
144
- mesh_path_idx,
145
- )
146
- else:
147
- vertices, faces, vertex_colors = mesh_out
148
- save_obj(vertices, faces, vertex_colors, mesh_path_idx)
149
- print(f"Mesh saved to {mesh_path_idx}")
150
-
151
- render_size = 512
152
- if if_save_video:
153
- video_path_idx = os.path.join(save_path, f'{name}.mp4')
154
- render_size = infer_config.render_resolution
155
- ENV = load_mipmap("models/lrm/env_mipmap/6")
156
- materials = (0.0,0.9)
157
-
158
- all_mv, all_mvp, all_campos = get_render_cameras_video(
159
- batch_size=1,
160
- M=24,
161
- radius=4.5,
162
- elevation=(90, 60.0),
163
- is_flexicubes=True,
164
- fov=30
165
- )
166
-
167
- frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals, alphas = render_frames(
168
- model,
169
- planes,
170
- render_cameras=all_mvp,
171
- camera_pos=all_campos,
172
- env=ENV,
173
- materials=materials,
174
- render_size=render_size,
175
- chunk_size=20,
176
- is_flexicubes=True,
177
- )
178
- normals = (torch.nn.functional.normalize(normals) + 1) / 2
179
- normals = normals * alphas + (1-alphas)
180
- all_frames = torch.cat([frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals], dim=3)
181
-
182
- save_video(
183
- all_frames,
184
- video_path_idx,
185
- fps=30,
186
- )
187
- print(f"Video saved to {video_path_idx}")
188
-
189
- return vertices, faces
190
-
191
-
192
- def local_normal_global_transform(local_normal_images, azimuths_deg, elevations_deg):
193
- if local_normal_images.min() >= 0:
194
- local_normal = local_normal_images.float() * 2 - 1
195
- else:
196
- local_normal = local_normal_images.float()
197
- global_normal = normal_transfer.trans_local_2_global(local_normal, azimuths_deg, elevations_deg, radius=4.5, for_lotus=False)
198
- global_normal[...,0] *= -1
199
- global_normal = (global_normal + 1) / 2
200
- global_normal = global_normal.permute(0, 3, 1, 2)
201
- return global_normal
202
-
203
- # η”Ÿζˆε€šθ§†ε›Ύε›Ύεƒ
204
- @spaces.GPU(duration=120)
205
- def generate_multi_view_images(prompt, seed):
206
- # torch.cuda.empty_cache()
207
- # generator = torch.manual_seed(seed)
208
- generator = torch.Generator().manual_seed(seed)
209
- with torch.no_grad():
210
- img = flux_pipe(
211
- prompt=prompt,
212
- num_inference_steps=5,
213
- guidance_scale=3.5,
214
- num_images_per_prompt=1,
215
- width=resolution * 2,
216
- height=resolution * 1,
217
- output_type='np',
218
- generator=generator,
219
- ).images
220
- # for img in flux_pipe.flux_pipe_call_that_returns_an_iterable_of_images(
221
- # prompt=prompt,
222
- # guidance_scale=3.5,
223
- # num_inference_steps=4,
224
- # width=resolution * 4,
225
- # height=resolution * 2,
226
- # generator=generator,
227
- # output_type="np",
228
- # good_vae=good_vae,
229
- # ):
230
- # pass
231
- # θΏ”ε›žζœ€η»ˆηš„ε›Ύεƒε’Œη§ε­οΌˆι€šθΏ‡ε€–ιƒ¨θ°ƒη”¨ε€„η†οΌ‰
232
- return img
233
-
234
- # 重建 3D ζ¨‘εž‹
235
- @spaces.GPU
236
- def reconstruct_3d_model(images, prompt):
237
- global model
238
- model.init_flexicubes_geometry(device_1, fovy=50.0)
239
- model = model.eval()
240
- rgb_normal_grid = images
241
- save_dir_path = os.path.join(save_dir, prompt.replace(" ", "_"))
242
- os.makedirs(save_dir_path, exist_ok=True)
243
-
244
- images = torch.from_numpy(rgb_normal_grid).squeeze(0).permute(2, 0, 1).contiguous().float() # (3, 1024, 2048)
245
- images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=2, m=4) # (8, 3, 512, 512)
246
- rgb_multi_view = images[:4, :3, :, :]
247
- normal_multi_view = images[4:, :3, :, :]
248
- multi_view_mask = get_background(normal_multi_view)
249
- rgb_multi_view = rgb_multi_view * rgb_multi_view + (1-multi_view_mask)
250
- input_cameras = get_flux_input_cameras(batch_size=1, radius=4.2, fov=30).to(device_1)
251
- vertices, faces = lrm_reconstructions(rgb_multi_view, input_cameras, save_path=save_dir_path, name='lrm', export_texmap=False, if_save_video=True)
252
- # local normal to global normal
253
-
254
- global_normal = local_normal_global_transform(normal_multi_view.permute(0, 2, 3, 1), isomer_azimuths, isomer_elevations)
255
- global_normal = global_normal * multi_view_mask + (1-multi_view_mask)
256
-
257
- global_normal = global_normal.permute(0,2,3,1)
258
- rgb_multi_view = rgb_multi_view.permute(0,2,3,1)
259
- multi_view_mask = multi_view_mask.permute(0,2,3,1).squeeze(-1)
260
- vertices = torch.from_numpy(vertices).to(device_1)
261
- faces = torch.from_numpy(faces).to(device_1)
262
- vertices = vertices @ rotate_x(np.pi / 2, device=vertices.device)[:3, :3]
263
- vertices = vertices @ rotate_y(np.pi / 2, device=vertices.device)[:3, :3]
264
-
265
- # global_normal: B,H,W,3
266
- # multi_view_mask: B,H,W
267
- # rgb_multi_view: B,H,W,3
268
-
269
- meshes = reconstruction(
270
- normal_pils=global_normal,
271
- masks=multi_view_mask,
272
- weights=isomer_geo_weights,
273
- fov=30,
274
- radius=isomer_radius,
275
- camera_angles_azi=isomer_azimuths,
276
- camera_angles_ele=isomer_elevations,
277
- expansion_weight_stage1=0.1,
278
- init_type="file",
279
- init_verts=vertices,
280
- init_faces=faces,
281
- stage1_steps=0,
282
- stage2_steps=50,
283
- start_edge_len_stage1=0.1,
284
- end_edge_len_stage1=0.02,
285
- start_edge_len_stage2=0.02,
286
- end_edge_len_stage2=0.005,
287
- )
288
-
289
-
290
- save_glb_addr = projection(
291
- meshes,
292
- masks=multi_view_mask,
293
- images=rgb_multi_view,
294
- azimuths=isomer_azimuths,
295
- elevations=isomer_elevations,
296
- weights=isomer_color_weights,
297
- fov=30,
298
- radius=isomer_radius,
299
- save_dir=f"{save_dir_path}/ISOMER/",
300
- )
301
-
302
- return save_glb_addr
303
-
304
- # Gradio ζŽ₯口函数
305
- @spaces.GPU
306
- def gradio_pipeline(prompt, seed):
307
- import ctypes
308
- # 显式加载 libnvrtc.so.12
309
- cuda_lib_path = "/usr/local/cuda-12.1/lib64/libnvrtc.so.12"
310
- try:
311
- ctypes.CDLL(cuda_lib_path, mode=ctypes.RTLD_GLOBAL)
312
- print(f"Successfully preloaded {cuda_lib_path}")
313
- except OSError as e:
314
- print(f"Failed to preload {cuda_lib_path}: {e}")
315
- # η”Ÿζˆε€šθ§†ε›Ύε›Ύεƒ
316
- # rgb_normal_grid = generate_multi_view_images(prompt, seed)
317
- rgb_normal_grid = np.load("rgb_normal_grid.npy")
318
- image_preview = Image.fromarray((rgb_normal_grid[0] * 255).astype(np.uint8))
319
-
320
- # 3d reconstruction
321
-
322
-
323
- # 重建 3D ζ¨‘εž‹εΉΆθΏ”ε›ž glb θ·―εΎ„
324
- save_glb_addr = reconstruct_3d_model(rgb_normal_grid, prompt)
325
- # save_glb_addr = None
326
- return image_preview, save_glb_addr
327
-
328
- # Gradio Blocks 应用
329
- with gr.Blocks() as demo:
330
- with gr.Row(variant="panel"):
331
- # ε·¦δΎ§θΎ“ε…₯区域
332
- with gr.Column():
333
- with gr.Row():
334
- prompt_input = gr.Textbox(
335
- label="Enter Prompt",
336
- placeholder="Describe your 3D model...",
337
- lines=2,
338
- elem_id="prompt_input"
339
- )
340
-
341
- with gr.Row():
342
- sample_seed = gr.Number(value=42, label="Seed Value", precision=0)
343
-
344
- with gr.Row():
345
- submit = gr.Button("Generate", elem_id="generate", variant="primary")
346
-
347
- with gr.Row(variant="panel"):
348
- gr.Markdown("Examples:")
349
- gr.Examples(
350
- examples=[
351
- ["a castle on a hill"],
352
- ["an owl wearing a hat"],
353
- ["a futuristic car"]
354
- ],
355
- inputs=[prompt_input],
356
- label="Prompt Examples"
357
- )
358
-
359
- # ε³δΎ§θΎ“ε‡ΊεŒΊεŸŸ
360
- with gr.Column():
361
- with gr.Row():
362
- rgb_normal_grid_image = gr.Image(
363
- label="RGB Normal Grid",
364
- type="pil",
365
- interactive=False
366
- )
367
-
368
- with gr.Row():
369
- with gr.Tab("GLB"):
370
- output_glb_model = gr.Model3D(
371
- label="Generated 3D Model (GLB Format)",
372
- interactive=False
373
- )
374
- gr.Markdown("Download the model for proper visualization.")
375
-
376
- # 倄理逻辑
377
- submit.click(
378
- fn=gradio_pipeline, inputs=[prompt_input, sample_seed],
379
- outputs=[rgb_normal_grid_image, output_glb_model]
380
- )
381
-
382
- # ε―εŠ¨εΊ”η”¨
383
- # demo.queue(max_size=10)
384
- demo.launch()