JiantaoLin commited on
Commit
230d5cd
·
1 Parent(s): afe40e2
Files changed (3) hide show
  1. app.py +327 -433
  2. app_demo.py +0 -385
  3. app_demo_.py +491 -0
app.py CHANGED
@@ -1,16 +1,10 @@
1
- import os
2
  import gradio as gr
 
3
  import subprocess
4
- import spaces
5
- import ctypes
6
  import shlex
 
7
  import torch
8
-
9
- subprocess.run(
10
- shlex.split(
11
- "pip install ./custom_diffusers --force-reinstall --no-deps"
12
- )
13
- )
14
  subprocess.run(
15
  shlex.split(
16
  "pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt240/download.html"
@@ -28,7 +22,6 @@ subprocess.run(
28
  "pip install ./extension/renderutils_plugin-0.1.0-cp310-cp310-linux_x86_64.whl --force-reinstall --no-deps"
29
  )
30
  )
31
- # download cudatoolkit
32
  def install_cuda_toolkit():
33
  # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run"
34
  # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.2.0/local_installers/cuda_12.2.0_535.54.03_linux.run"
@@ -48,51 +41,271 @@ def install_cuda_toolkit():
48
  os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
49
  print("==> finfish install")
50
  install_cuda_toolkit()
 
 
 
 
 
 
 
 
 
51
 
52
-
53
- import base64
54
- import re
55
- import sys
56
-
57
- sys.path.append(os.path.abspath(os.path.join(__file__, '../')))
58
- if 'OMP_NUM_THREADS' not in os.environ:
59
- os.environ['OMP_NUM_THREADS'] = '32'
60
-
61
- import shutil
62
- import json
63
- import requests
64
- import shutil
65
- import threading
66
  from PIL import Image
67
- import time
68
- import trimesh
69
-
70
- import random
71
- import time
 
 
 
 
 
 
 
 
72
  import numpy as np
73
- from video_render import render_video_from_obj
74
-
75
- access_token = os.getenv("HUGGINGFACE_TOKEN")
76
- from pipeline.kiss3d_wrapper import init_wrapper_from_config, run_text_to_3d, run_image_to_3d, image2mesh_preprocess, image2mesh_main
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
 
78
 
79
- # Add logo file path and hyperlinks
80
- LOGO_PATH = "app_assets/logo_temp_.png" # Update this to the actual path of your logo
81
- ARXIV_LINK = "https://arxiv.org/abs/example"
82
- GITHUB_LINK = "https://github.com/example"
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
- k3d_wrapper = init_wrapper_from_config('./pipeline/pipeline_config/default.yaml')
86
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
- from models.ISOMER.scripts.utils import fix_vert_color_glb
89
- torch.backends.cuda.matmul.allow_tf32 = True
90
 
91
- def check_gpu():
92
- os.environ['CUDA_HOME'] = '/usr/local/cuda-12.1'
93
- os.environ['PATH'] += ':/usr/local/cuda-12.1/bin'
94
- # os.environ['LD_LIBRARY_PATH'] += ':/usr/local/cuda-12.1/lib64'
95
- os.environ['LD_LIBRARY_PATH'] = "/usr/local/cuda-12.1/lib64:" + os.environ.get('LD_LIBRARY_PATH', '')
96
  # 显式加载 libnvrtc.so.12
97
  cuda_lib_path = "/usr/local/cuda-12.1/lib64/libnvrtc.so.12"
98
  try:
@@ -100,392 +313,73 @@ def check_gpu():
100
  print(f"Successfully preloaded {cuda_lib_path}")
101
  except OSError as e:
102
  print(f"Failed to preload {cuda_lib_path}: {e}")
103
- check_gpu()
104
- print(f"GPU: {torch.cuda.is_available()}")
105
- subprocess.run(['nvidia-smi'])
106
-
107
- TEMP_MESH_ADDRESS=''
108
-
109
- mesh_cache = None
110
- preprocessed_input_image = None
111
-
112
- def save_cached_mesh():
113
- global mesh_cache
114
- return mesh_cache
115
- # if mesh_cache is None:
116
- # return None
117
- # return save_py3dmesh_with_trimesh_fast(mesh_cache)
118
-
119
- def save_py3dmesh_with_trimesh_fast(meshes, save_glb_path=TEMP_MESH_ADDRESS, apply_sRGB_to_LinearRGB=True):
120
- from pytorch3d.structures import Meshes
121
- import trimesh
122
-
123
- # convert from pytorch3d meshes to trimesh mesh
124
- vertices = meshes.verts_packed().cpu().float().numpy()
125
- triangles = meshes.faces_packed().cpu().long().numpy()
126
- np_color = meshes.textures.verts_features_packed().cpu().float().numpy()
127
- if save_glb_path.endswith(".glb"):
128
- # rotate 180 along +Y
129
- vertices[:, [0, 2]] = -vertices[:, [0, 2]]
130
-
131
- def srgb_to_linear(c_srgb):
132
- c_linear = np.where(c_srgb <= 0.04045, c_srgb / 12.92, ((c_srgb + 0.055) / 1.055) ** 2.4)
133
- return c_linear.clip(0, 1.)
134
- if apply_sRGB_to_LinearRGB:
135
- np_color = srgb_to_linear(np_color)
136
- assert vertices.shape[0] == np_color.shape[0]
137
- assert np_color.shape[1] == 3
138
- assert 0 <= np_color.min() and np_color.max() <= 1, f"min={np_color.min()}, max={np_color.max()}"
139
- mesh = trimesh.Trimesh(vertices=vertices, faces=triangles, vertex_colors=np_color)
140
- mesh.remove_unreferenced_vertices()
141
- # save mesh
142
- mesh.export(save_glb_path)
143
- if save_glb_path.endswith(".glb"):
144
- fix_vert_color_glb(save_glb_path)
145
- print(f"saving to {save_glb_path}")
146
- #
147
- #
148
- # @spaces.GPU
149
- def text_to_detailed(prompt, seed=None):
150
- # print(torch.cuda.is_available())
151
- # print(f"Before text_to_detailed: {torch.cuda.memory_allocated() / 1024**3} GB")
152
- return k3d_wrapper.get_detailed_prompt(prompt, seed)
153
-
154
- def text_to_image(prompt, seed=None, strength=1.0,lora_scale=1.0, num_inference_steps=30, redux_hparam=None, init_image=None, **kwargs):
155
- # print(f"Before text_to_image: {torch.cuda.memory_allocated() / 1024**3} GB")
156
- k3d_wrapper.renew_uuid()
157
- init_image = None
158
- # if init_image_path is not None:
159
- # init_image = Image.open(init_image_path)
160
- result = k3d_wrapper.generate_3d_bundle_image_text(
161
- prompt,
162
- image=init_image,
163
- strength=strength,
164
- lora_scale=lora_scale,
165
- num_inference_steps=num_inference_steps,
166
- seed=int(seed) if seed is not None else None,
167
- redux_hparam=redux_hparam,
168
- save_intermediate_results=True,
169
- **kwargs)
170
- return result[-1]
171
-
172
- def image2mesh_preprocess_(input_image_, seed, use_mv_rgb=True):
173
- global preprocessed_input_image
174
-
175
- seed = int(seed) if seed is not None else None
176
-
177
- # TODO: delete this later
178
- k3d_wrapper.del_llm_model()
179
-
180
- input_image_save_path, reference_save_path, caption = image2mesh_preprocess(k3d_wrapper, input_image_, seed, use_mv_rgb)
181
-
182
- preprocessed_input_image = Image.open(input_image_save_path)
183
- return reference_save_path, caption
184
-
185
- @spaces.GPU
186
- def image2mesh_main_(reference_3d_bundle_image, caption, seed, strength1=0.5, strength2=0.95, enable_redux=True, use_controlnet=True, if_video=True):
187
- global mesh_cache
188
- seed = int(seed) if seed is not None else None
189
-
190
-
191
- # TODO: delete this later
192
- k3d_wrapper.del_llm_model()
193
-
194
- input_image = preprocessed_input_image
195
-
196
- reference_3d_bundle_image = torch.tensor(reference_3d_bundle_image).permute(2,0,1)/255
197
-
198
- gen_save_path, recon_mesh_path = image2mesh_main(k3d_wrapper, input_image, reference_3d_bundle_image, caption=caption, seed=seed, strength1=strength1, strength2=strength2, enable_redux=enable_redux, use_controlnet=use_controlnet)
199
- mesh_cache = recon_mesh_path
200
-
201
-
202
- # gen_save_ = Image.open(gen_save_path)
203
-
204
- if if_video:
205
- video_path = recon_mesh_path.replace('.obj','.mp4').replace('.glb','.mp4')
206
- render_video_from_obj(recon_mesh_path, video_path)
207
- print(f"After bundle_image_to_mesh: {torch.cuda.memory_allocated() / 1024**3} GB")
208
- return gen_save_path, video_path
209
- else:
210
- return gen_save_path, recon_mesh_path
211
- # return gen_save_path, recon_mesh_path
212
-
213
- @spaces.GPU
214
- def bundle_image_to_mesh(
215
- gen_3d_bundle_image,
216
- lrm_radius = 4.15,
217
- isomer_radius = 4.5,
218
- reconstruction_stage1_steps = 10,
219
- reconstruction_stage2_steps = 50,
220
- save_intermediate_results=True,
221
- if_video=True
222
- ):
223
- global mesh_cache
224
- print(f"Before bundle_image_to_mesh: {torch.cuda.memory_allocated() / 1024**3} GB")
225
- k3d_wrapper.recon_model.init_flexicubes_geometry("cuda:0", fovy=50.0)
226
- # TODO: delete this later
227
- k3d_wrapper.del_llm_model()
228
-
229
- print(f"Before bundle_image_to_mesh after deleting llm model: {torch.cuda.memory_allocated() / 1024**3} GB")
230
-
231
- gen_3d_bundle_image = torch.tensor(gen_3d_bundle_image).permute(2,0,1)/255
232
- # recon from 3D Bundle image
233
- recon_mesh_path = k3d_wrapper.reconstruct_3d_bundle_image(gen_3d_bundle_image, lrm_render_radius=lrm_radius, isomer_radius=isomer_radius, save_intermediate_results=save_intermediate_results, reconstruction_stage1_steps=int(reconstruction_stage1_steps), reconstruction_stage2_steps=int(reconstruction_stage2_steps))
234
- mesh_cache = recon_mesh_path
235
-
236
- if if_video:
237
- video_path = recon_mesh_path.replace('.obj','.mp4').replace('.glb','.mp4')
238
- # # 检查这个video_path文件大小是是否超过50KB,不超过的话就认为是空文件,需要重新渲染
239
- # if os.path.exists(video_path):
240
- # print(f"file size:{os.path.getsize(video_path)}")
241
- # if os.path.getsize(video_path) > 50*1024:
242
- # print(f"video path:{video_path}")
243
- # return video_path
244
- render_video_from_obj(recon_mesh_path, video_path)
245
- print(f"After bundle_image_to_mesh: {torch.cuda.memory_allocated() / 1024**3} GB")
246
- return video_path
247
- else:
248
- return recon_mesh_path
249
-
250
- _HEADER_=f"""
251
- <img src="{LOGO_PATH}">
252
- <h2><b>Official 🤗 Gradio Demo</b></h2><h2>
253
- <b>Kiss3DGen: Repurposing Image Diffusion Models for 3D Asset Generation</b></a></h2>
254
-
255
- <p>**Kiss3DGen** is xxxxxxxxx</p>
256
-
257
- [![arXiv](https://img.shields.io/badge/arXiv-Link-red)]({ARXIV_LINK}) [![GitHub](https://img.shields.io/badge/GitHub-Repo-blue)]({GITHUB_LINK})
258
- """
259
-
260
- _CITE_ = r"""
261
- <h2>If Kiss3DGen is helpful, please help to ⭐ the <a href='{""" + GITHUB_LINK + r"""}' target='_blank'>Github Repo</a>. Thanks!</h2>
262
-
263
- 📝 **Citation**
264
-
265
- If you find our work useful for your research or applications, please cite using this bibtex:
266
- ```bibtex
267
- @article{xxxx,
268
- title={xxxx},
269
- author={xxxx},
270
- journal={xxxx},
271
- year={xxxx}
272
- }
273
- ```
274
-
275
- 📋 **License**
276
-
277
- Apache-2.0 LICENSE. Please refer to the [LICENSE file](https://huggingface.co/spaces/TencentARC/InstantMesh/blob/main/LICENSE) for details.
278
-
279
- 📧 **Contact**
280
-
281
- If you have any questions, feel free to open a discussion or contact us at <b>xxx@xxxx</b>.
282
- """
283
-
284
- def image_to_base64(image_path):
285
- """Converts an image file to a base64-encoded string."""
286
- with open(image_path, "rb") as img_file:
287
- return base64.b64encode(img_file.read()).decode('utf-8')
288
-
289
- def main():
290
-
291
- torch.set_grad_enabled(False)
292
-
293
- # Convert the logo image to base64
294
- logo_base64 = image_to_base64(LOGO_PATH)
295
- # with gr.Blocks() as demo:
296
- with gr.Blocks(css="""
297
- body {
298
- display: flex;
299
- justify-content: center;
300
- align-items: center;
301
- min-height: 100vh;
302
- margin: 0;
303
- padding: 0;
304
- }
305
- #col-container { margin: 0px auto; max-width: 200px; }
306
-
307
-
308
- .gradio-container {
309
- max-width: 1000px;
310
- margin: auto;
311
- width: 100%;
312
- }
313
- #center-align-column {
314
- display: flex;
315
- justify-content: center;
316
- align-items: center;
317
- }
318
- #right-align-column {
319
- display: flex;
320
- justify-content: flex-end;
321
- align-items: center;
322
- }
323
- h1 {text-align: center;}
324
- h2 {text-align: center;}
325
- h3 {text-align: center;}
326
- p {text-align: center;}
327
- img {text-align: right;}
328
- .right {
329
- display: block;
330
- margin-left: auto;
331
- }
332
- .center {
333
- display: block;
334
- margin-left: auto;
335
- margin-right: auto;
336
- width: 50%;
337
-
338
- #content-container {
339
- max-width: 1200px;
340
- margin: 0 auto;
341
- }
342
- #example-container {
343
- max-width: 300px;
344
- margin: 0 auto;
345
- }
346
- """,elem_id="col-container") as demo:
347
- # Header Section
348
- # gr.Image(value=LOGO_PATH, width=64, height=64)
349
- # gr.Markdown(_HEADER_)
350
- with gr.Row(elem_id="content-container"):
351
- # with gr.Column(scale=1):
352
- # pass
353
- # with gr.Column(scale=1, elem_id="right-align-column"):
354
- # # gr.Image(value=LOGO_PATH, interactive=False, show_label=False, width=64, height=64, elem_id="logo-image")
355
- # # gr.Markdown(f"<img src='{LOGO_PATH}' alt='Logo' style='width:64px;height:64px;border:0;'>")
356
- # # gr.HTML(f"<img src='data:image/png;base64,{logo_base64}' alt='Logo' class='right' style='width:64px;height:64px;border:0;text-align:right;'>")
357
- # pass
358
- with gr.Column(scale=7, elem_id="center-align-column"):
359
- gr.Markdown(f"""
360
- ## Official 🤗 Gradio Demo
361
- # Kiss3DGen: Repurposing Image Diffusion Models for 3D Asset Generation""")
362
- gr.HTML(f"<img src='data:image/png;base64,{logo_base64}' alt='Logo' class='center' style='width:64px;height:64px;border:0;text-align:center;'>")
363
-
364
- gr.HTML(f"""
365
- <div style="display: flex; justify-content: center; align-items: center; gap: 10px;">
366
- <a href="{ARXIV_LINK}" target="_blank">
367
- <img src="https://img.shields.io/badge/arXiv-Link-red" alt="arXiv">
368
- </a>
369
- <a href="{GITHUB_LINK}" target="_blank">
370
- <img src="https://img.shields.io/badge/GitHub-Repo-blue" alt="GitHub">
371
- </a>
372
- </div>
373
-
374
- """)
375
-
376
-
377
- # gr.HTML(f"""
378
- # <div style="display: flex; gap: 10px; align-items: center;"><a href="{ARXIV_LINK}" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/arXiv-Link-red" alt="arXiv"></a> <a href="{GITHUB_LINK}" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/GitHub-Repo-blue" alt="GitHub"></a></div>
379
- # """)
380
-
381
- # gr.Markdown(f"""
382
- # [![arXiv](https://img.shields.io/badge/arXiv-Link-red)]({ARXIV_LINK}) [![GitHub](https://img.shields.io/badge/GitHub-Repo-blue)]({GITHUB_LINK})
383
- # """, elem_id="title")
384
- # with gr.Column(scale=1):
385
- # pass
386
- # with gr.Row():
387
- # gr.Markdown(f"[![arXiv](https://img.shields.io/badge/arXiv-Link-red)]({ARXIV_LINK})")
388
- # gr.Markdown(f"[![GitHub](https://img.shields.io/badge/GitHub-Repo-blue)]({GITHUB_LINK})")
389
-
390
- # Tabs Section
391
- with gr.Tabs(selected='tab_text_to_3d', elem_id="content-container") as main_tabs:
392
- with gr.TabItem('Text-to-3D', id='tab_text_to_3d'):
393
- with gr.Row():
394
- with gr.Column(scale=1):
395
- prompt = gr.Textbox(value="", label="Input Prompt", lines=4)
396
- seed1 = gr.Number(value=10, label="Seed")
397
-
398
- with gr.Row(elem_id="example-container"):
399
- gr.Examples(
400
- examples=[
401
- # ["A tree with red leaves"],
402
- # ["A dragon with black texture"],
403
- ["A girl with pink hair"],
404
- ["A boy playing guitar"],
405
-
406
-
407
- ["A dog wearing a hat"],
408
- ["A boy playing basketball"],
409
- # [""],
410
- # [""],
411
- # [""],
412
-
413
- ],
414
- inputs=[prompt], # 将选中的示例填入 prompt 文本框
415
- label="Example Prompts"
416
- )
417
- btn_text2detailed = gr.Button("Refine to detailed prompt")
418
- detailed_prompt = gr.Textbox(value="", label="Detailed Prompt", placeholder="detailed prompt will be generated here base on your input prompt. You can also edit this prompt", lines=4, interactive=True)
419
- btn_text2img = gr.Button("Generate Images")
420
-
421
- with gr.Column(scale=1):
422
- output_image1 = gr.Image(label="Generated image", interactive=False)
423
-
424
-
425
- # lrm_radius = gr.Number(value=4.15, label="lrm_radius")
426
- # isomer_radius = gr.Number(value=4.5, label="isomer_radius")
427
- # reconstruction_stage1_steps = gr.Number(value=10, label="reconstruction_stage1_steps")
428
- # reconstruction_stage2_steps = gr.Number(value=50, label="reconstruction_stage2_steps")
429
-
430
- btn_gen_mesh = gr.Button("Generate Mesh")
431
- output_video1 = gr.Video(label="Generated Video", interactive=False, loop=True, autoplay=True)
432
- btn_download1 = gr.Button("Download Mesh")
433
-
434
- file_output1 = gr.File()
435
-
436
- with gr.TabItem('Image-to-3D', id='tab_image_to_3d'):
437
- with gr.Row():
438
- with gr.Column(scale=1):
439
- image = gr.Image(label="Input Image", type="pil")
440
-
441
- seed2 = gr.Number(value=10, label="Seed (0 for random)")
442
-
443
- btn_img2mesh_preprocess = gr.Button("Preprocess Image")
444
-
445
- image_caption = gr.Textbox(value="", label="Image Caption", placeholder="caption will be generated here base on your input image. You can also edit this caption", lines=4, interactive=True)
446
-
447
- output_image2 = gr.Image(label="Generated image", interactive=False)
448
- strength1 = gr.Slider(minimum=0, maximum=1.0, step=0.01, value=0.5, label="strength1")
449
- strength2 = gr.Slider(minimum=0, maximum=1.0, step=0.01, value=0.95, label="strength2")
450
- enable_redux = gr.Checkbox(label="enable redux", value=True)
451
- use_controlnet = gr.Checkbox(label="use controlnet", value=True)
452
-
453
- btn_img2mesh_main = gr.Button("Generate Mesh")
454
-
455
- with gr.Column(scale=1):
456
-
457
- # output_mesh2 = gr.Model3D(label="Generated Mesh", interactive=False)
458
- output_image3 = gr.Image(label="gen save image", interactive=False)
459
- output_video2 = gr.Video(label="Generated Video", interactive=False, loop=True, autoplay=True)
460
- btn_download2 = gr.Button("Download Mesh")
461
- file_output2 = gr.File()
462
-
463
- # Image2
464
- btn_img2mesh_preprocess.click(fn=image2mesh_preprocess_, inputs=[image, seed2], outputs=[output_image2, image_caption])
465
-
466
- btn_img2mesh_main.click(fn=image2mesh_main_, inputs=[output_image2, image_caption, seed2, strength1, strength2, enable_redux, use_controlnet], outputs=[output_image3, output_video2])
467
-
468
-
469
- btn_download2.click(fn=save_cached_mesh, inputs=[], outputs=file_output2)
470
-
471
-
472
- # Button Click Events
473
- # Text2
474
- btn_text2detailed.click(fn=text_to_detailed, inputs=[prompt, seed1], outputs=detailed_prompt)
475
- btn_text2img.click(fn=text_to_image, inputs=[detailed_prompt, seed1], outputs=output_image1)
476
- btn_gen_mesh.click(fn=bundle_image_to_mesh, inputs=[output_image1,], outputs=output_video1)
477
- # btn_gen_mesh.click(fn=bundle_image_to_mesh, inputs=[output_image1, lrm_radius, isomer_radius, reconstruction_stage1_steps, reconstruction_stage2_steps], outputs=output_video1)
478
-
479
- with gr.Row():
480
- pass
481
- with gr.Row():
482
- gr.Markdown(_CITE_)
483
-
484
- # demo.queue(default_concurrency_limit=1)
485
- # demo.launch(server_name="0.0.0.0", server_port=9239)
486
- # subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
487
- demo.launch()
488
-
489
 
490
- if __name__ == "__main__":
491
- main()
 
 
 
1
  import gradio as gr
2
+ import os
3
  import subprocess
 
 
4
  import shlex
5
+ import spaces
6
  import torch
7
+ access_token = os.getenv("HUGGINGFACE_TOKEN")
 
 
 
 
 
8
  subprocess.run(
9
  shlex.split(
10
  "pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt240/download.html"
 
22
  "pip install ./extension/renderutils_plugin-0.1.0-cp310-cp310-linux_x86_64.whl --force-reinstall --no-deps"
23
  )
24
  )
 
25
  def install_cuda_toolkit():
26
  # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run"
27
  # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.2.0/local_installers/cuda_12.2.0_535.54.03_linux.run"
 
41
  os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
42
  print("==> finfish install")
43
  install_cuda_toolkit()
44
+ @spaces.GPU
45
+ def check_gpu():
46
+ os.environ['CUDA_HOME'] = '/usr/local/cuda-12.1'
47
+ os.environ['PATH'] += ':/usr/local/cuda-12.1/bin'
48
+ # os.environ['LD_LIBRARY_PATH'] += ':/usr/local/cuda-12.1/lib64'
49
+ os.environ['LD_LIBRARY_PATH'] = "/usr/local/cuda-12.1/lib64:" + os.environ.get('LD_LIBRARY_PATH', '')
50
+ subprocess.run(['nvidia-smi']) # 测试 CUDA 是否可用
51
+ print(f"torch.cuda.is_available:{torch.cuda.is_available()}")
52
+ check_gpu()
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  from PIL import Image
55
+ from einops import rearrange
56
+ from diffusers import FluxPipeline
57
+ from models.lrm.utils.camera_util import get_flux_input_cameras
58
+ from models.lrm.utils.infer_util import save_video
59
+ from models.lrm.utils.mesh_util import save_obj, save_obj_with_mtl
60
+ from models.lrm.utils.render_utils import rotate_x, rotate_y
61
+ from models.lrm.utils.train_util import instantiate_from_config
62
+ from models.ISOMER.reconstruction_func import reconstruction
63
+ from models.ISOMER.projection_func import projection
64
+ import os
65
+ from einops import rearrange
66
+ from omegaconf import OmegaConf
67
+ import torch
68
  import numpy as np
69
+ import trimesh
70
+ import torchvision
71
+ import torch.nn.functional as F
72
+ from PIL import Image
73
+ from torchvision import transforms
74
+ from torchvision.transforms import v2
75
+ from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, AutoencoderTiny, AutoencoderKL
76
+ from transformers import CLIPTextModel, CLIPTokenizer,T5EncoderModel, T5TokenizerFast
77
+ from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
78
+ from diffusers import FluxPipeline
79
+ from pytorch_lightning import seed_everything
80
+ import os
81
+ from huggingface_hub import hf_hub_download
82
+
83
+
84
+ from utils.tool import NormalTransfer, get_background, get_render_cameras_video, load_mipmap, render_frames
85
+
86
+ device_0 = "cuda"
87
+ device_1 = "cuda"
88
+ resolution = 512
89
+ save_dir = "./outputs"
90
+ normal_transfer = NormalTransfer()
91
+ isomer_azimuths = torch.from_numpy(np.array([0, 90, 180, 270])).float().to(device_1)
92
+ isomer_elevations = torch.from_numpy(np.array([5, 5, 5, 5])).float().to(device_1)
93
+ isomer_radius = 4.5
94
+ isomer_geo_weights = torch.from_numpy(np.array([1, 0.9, 1, 0.9])).float().to(device_1)
95
+ isomer_color_weights = torch.from_numpy(np.array([1, 0.5, 1, 0.5])).float().to(device_1)
96
+
97
+ # model initialization and loading
98
+ # flux
99
+ # # taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=torch.bfloat16).to(device_0)
100
+ # # good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16, token=access_token).to(device_0)
101
+ # flux_pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, token=access_token).to(device=device_0, dtype=torch.bfloat16)
102
+ # # flux_pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, vae=taef1, token=access_token).to(device_0)
103
+ # flux_lora_ckpt_path = hf_hub_download(repo_id="LTT/xxx-ckpt", filename="rgb_normal_large.safetensors", repo_type="model", token=access_token)
104
+ # flux_pipe.load_lora_weights(flux_lora_ckpt_path)
105
+ # flux_pipe.to(device=device_0, dtype=torch.bfloat16)
106
+ # torch.cuda.empty_cache()
107
+ # flux_pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(flux_pipe)
108
+
109
+
110
+ # lrm
111
+ config = OmegaConf.load("./models/lrm/config/PRM_inference.yaml")
112
+ model_config = config.model_config
113
+ infer_config = config.infer_config
114
+ model = instantiate_from_config(model_config)
115
+ model_ckpt_path = hf_hub_download(repo_id="LTT/PRM", filename="final_ckpt.ckpt", repo_type="model")
116
+ state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
117
+ state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
118
+ model.load_state_dict(state_dict, strict=True)
119
+ model = model.to(device_1)
120
+ torch.cuda.empty_cache()
121
+ @spaces.GPU
122
+ def lrm_reconstructions(image, input_cameras, save_path=None, name="temp", export_texmap=False, if_save_video=False):
123
+ images = image.unsqueeze(0).to(device_1)
124
+ images = v2.functional.resize(images, 512, interpolation=3, antialias=True).clamp(0, 1)
125
+ # breakpoint()
126
+ with torch.no_grad():
127
+ # get triplane
128
+ planes = model.forward_planes(images, input_cameras)
129
+
130
+ mesh_path_idx = os.path.join(save_path, f'{name}.obj')
131
+
132
+ mesh_out = model.extract_mesh(
133
+ planes,
134
+ use_texture_map=export_texmap,
135
+ **infer_config,
136
+ )
137
+ if export_texmap:
138
+ vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out
139
+ save_obj_with_mtl(
140
+ vertices.data.cpu().numpy(),
141
+ uvs.data.cpu().numpy(),
142
+ faces.data.cpu().numpy(),
143
+ mesh_tex_idx.data.cpu().numpy(),
144
+ tex_map.permute(1, 2, 0).data.cpu().numpy(),
145
+ mesh_path_idx,
146
+ )
147
+ else:
148
+ vertices, faces, vertex_colors = mesh_out
149
+ save_obj(vertices, faces, vertex_colors, mesh_path_idx)
150
+ print(f"Mesh saved to {mesh_path_idx}")
151
+
152
+ render_size = 512
153
+ if if_save_video:
154
+ video_path_idx = os.path.join(save_path, f'{name}.mp4')
155
+ render_size = infer_config.render_resolution
156
+ ENV = load_mipmap("models/lrm/env_mipmap/6")
157
+ materials = (0.0,0.9)
158
+
159
+ all_mv, all_mvp, all_campos = get_render_cameras_video(
160
+ batch_size=1,
161
+ M=24,
162
+ radius=4.5,
163
+ elevation=(90, 60.0),
164
+ is_flexicubes=True,
165
+ fov=30
166
+ )
167
+
168
+ frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals, alphas = render_frames(
169
+ model,
170
+ planes,
171
+ render_cameras=all_mvp,
172
+ camera_pos=all_campos,
173
+ env=ENV,
174
+ materials=materials,
175
+ render_size=render_size,
176
+ chunk_size=20,
177
+ is_flexicubes=True,
178
+ )
179
+ normals = (torch.nn.functional.normalize(normals) + 1) / 2
180
+ normals = normals * alphas + (1-alphas)
181
+ all_frames = torch.cat([frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals], dim=3)
182
+
183
+ save_video(
184
+ all_frames,
185
+ video_path_idx,
186
+ fps=30,
187
+ )
188
+ print(f"Video saved to {video_path_idx}")
189
 
190
+ return vertices, faces
191
 
 
 
 
 
192
 
193
+ def local_normal_global_transform(local_normal_images, azimuths_deg, elevations_deg):
194
+ if local_normal_images.min() >= 0:
195
+ local_normal = local_normal_images.float() * 2 - 1
196
+ else:
197
+ local_normal = local_normal_images.float()
198
+ global_normal = normal_transfer.trans_local_2_global(local_normal, azimuths_deg, elevations_deg, radius=4.5, for_lotus=False)
199
+ global_normal[...,0] *= -1
200
+ global_normal = (global_normal + 1) / 2
201
+ global_normal = global_normal.permute(0, 3, 1, 2)
202
+ return global_normal
203
+
204
+ # 生成多视图图像
205
+ @spaces.GPU(duration=120)
206
+ def generate_multi_view_images(prompt, seed):
207
+ # torch.cuda.empty_cache()
208
+ # generator = torch.manual_seed(seed)
209
+ generator = torch.Generator().manual_seed(seed)
210
+ with torch.no_grad():
211
+ img = flux_pipe(
212
+ prompt=prompt,
213
+ num_inference_steps=5,
214
+ guidance_scale=3.5,
215
+ num_images_per_prompt=1,
216
+ width=resolution * 2,
217
+ height=resolution * 1,
218
+ output_type='np',
219
+ generator=generator,
220
+ ).images
221
+ # for img in flux_pipe.flux_pipe_call_that_returns_an_iterable_of_images(
222
+ # prompt=prompt,
223
+ # guidance_scale=3.5,
224
+ # num_inference_steps=4,
225
+ # width=resolution * 4,
226
+ # height=resolution * 2,
227
+ # generator=generator,
228
+ # output_type="np",
229
+ # good_vae=good_vae,
230
+ # ):
231
+ # pass
232
+ # 返回最终的图像和种子(通过外部调用处理)
233
+ return img
234
+
235
+ # 重建 3D 模型
236
+ @spaces.GPU
237
+ def reconstruct_3d_model(images, prompt):
238
+ global model
239
+ model.init_flexicubes_geometry(device_1, fovy=50.0)
240
+ model = model.eval()
241
+ rgb_normal_grid = images
242
+ save_dir_path = os.path.join(save_dir, prompt.replace(" ", "_"))
243
+ os.makedirs(save_dir_path, exist_ok=True)
244
+
245
+ images = torch.from_numpy(rgb_normal_grid).squeeze(0).permute(2, 0, 1).contiguous().float() # (3, 1024, 2048)
246
+ images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=2, m=4) # (8, 3, 512, 512)
247
+ rgb_multi_view = images[:4, :3, :, :]
248
+ normal_multi_view = images[4:, :3, :, :]
249
+ multi_view_mask = get_background(normal_multi_view)
250
+ rgb_multi_view = rgb_multi_view * rgb_multi_view + (1-multi_view_mask)
251
+ input_cameras = get_flux_input_cameras(batch_size=1, radius=4.2, fov=30).to(device_1)
252
+ vertices, faces = lrm_reconstructions(rgb_multi_view, input_cameras, save_path=save_dir_path, name='lrm', export_texmap=False, if_save_video=True)
253
+ # local normal to global normal
254
+
255
+ global_normal = local_normal_global_transform(normal_multi_view.permute(0, 2, 3, 1), isomer_azimuths, isomer_elevations)
256
+ global_normal = global_normal * multi_view_mask + (1-multi_view_mask)
257
+
258
+ global_normal = global_normal.permute(0,2,3,1)
259
+ rgb_multi_view = rgb_multi_view.permute(0,2,3,1)
260
+ multi_view_mask = multi_view_mask.permute(0,2,3,1).squeeze(-1)
261
+ vertices = torch.from_numpy(vertices).to(device_1)
262
+ faces = torch.from_numpy(faces).to(device_1)
263
+ vertices = vertices @ rotate_x(np.pi / 2, device=vertices.device)[:3, :3]
264
+ vertices = vertices @ rotate_y(np.pi / 2, device=vertices.device)[:3, :3]
265
+
266
+ # global_normal: B,H,W,3
267
+ # multi_view_mask: B,H,W
268
+ # rgb_multi_view: B,H,W,3
269
+
270
+ meshes = reconstruction(
271
+ normal_pils=global_normal,
272
+ masks=multi_view_mask,
273
+ weights=isomer_geo_weights,
274
+ fov=30,
275
+ radius=isomer_radius,
276
+ camera_angles_azi=isomer_azimuths,
277
+ camera_angles_ele=isomer_elevations,
278
+ expansion_weight_stage1=0.1,
279
+ init_type="file",
280
+ init_verts=vertices,
281
+ init_faces=faces,
282
+ stage1_steps=0,
283
+ stage2_steps=50,
284
+ start_edge_len_stage1=0.1,
285
+ end_edge_len_stage1=0.02,
286
+ start_edge_len_stage2=0.02,
287
+ end_edge_len_stage2=0.005,
288
+ )
289
 
 
290
 
291
+ save_glb_addr = projection(
292
+ meshes,
293
+ masks=multi_view_mask,
294
+ images=rgb_multi_view,
295
+ azimuths=isomer_azimuths,
296
+ elevations=isomer_elevations,
297
+ weights=isomer_color_weights,
298
+ fov=30,
299
+ radius=isomer_radius,
300
+ save_dir=f"{save_dir_path}/ISOMER/",
301
+ )
302
 
303
+ return save_glb_addr
 
304
 
305
+ # Gradio 接口函数
306
+ @spaces.GPU
307
+ def gradio_pipeline(prompt, seed):
308
+ import ctypes
 
309
  # 显式加载 libnvrtc.so.12
310
  cuda_lib_path = "/usr/local/cuda-12.1/lib64/libnvrtc.so.12"
311
  try:
 
313
  print(f"Successfully preloaded {cuda_lib_path}")
314
  except OSError as e:
315
  print(f"Failed to preload {cuda_lib_path}: {e}")
316
+ # 生成多视图图像
317
+ # rgb_normal_grid = generate_multi_view_images(prompt, seed)
318
+ rgb_normal_grid = np.load("rgb_normal_grid.npy")
319
+ image_preview = Image.fromarray((rgb_normal_grid[0] * 255).astype(np.uint8))
320
+
321
+ # 3d reconstruction
322
+
323
+
324
+ # 重建 3D 模型并返回 glb 路径
325
+ save_glb_addr = reconstruct_3d_model(rgb_normal_grid, prompt)
326
+ # save_glb_addr = None
327
+ return image_preview, save_glb_addr
328
+
329
+ # Gradio Blocks 应用
330
+ with gr.Blocks() as demo:
331
+ with gr.Row(variant="panel"):
332
+ # 左侧输入区域
333
+ with gr.Column():
334
+ with gr.Row():
335
+ prompt_input = gr.Textbox(
336
+ label="Enter Prompt",
337
+ placeholder="Describe your 3D model...",
338
+ lines=2,
339
+ elem_id="prompt_input"
340
+ )
341
+
342
+ with gr.Row():
343
+ sample_seed = gr.Number(value=42, label="Seed Value", precision=0)
344
+
345
+ with gr.Row():
346
+ submit = gr.Button("Generate", elem_id="generate", variant="primary")
347
+
348
+ with gr.Row(variant="panel"):
349
+ gr.Markdown("Examples:")
350
+ gr.Examples(
351
+ examples=[
352
+ ["a castle on a hill"],
353
+ ["an owl wearing a hat"],
354
+ ["a futuristic car"]
355
+ ],
356
+ inputs=[prompt_input],
357
+ label="Prompt Examples"
358
+ )
359
+
360
+ # 右侧输出区域
361
+ with gr.Column():
362
+ with gr.Row():
363
+ rgb_normal_grid_image = gr.Image(
364
+ label="RGB Normal Grid",
365
+ type="pil",
366
+ interactive=False
367
+ )
368
+
369
+ with gr.Row():
370
+ with gr.Tab("GLB"):
371
+ output_glb_model = gr.Model3D(
372
+ label="Generated 3D Model (GLB Format)",
373
+ interactive=False
374
+ )
375
+ gr.Markdown("Download the model for proper visualization.")
376
+
377
+ # 处理逻辑
378
+ submit.click(
379
+ fn=gradio_pipeline, inputs=[prompt_input, sample_seed],
380
+ outputs=[rgb_normal_grid_image, output_glb_model]
381
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
 
383
+ # 启动应用
384
+ # demo.queue(max_size=10)
385
+ demo.launch()
app_demo.py DELETED
@@ -1,385 +0,0 @@
1
- import gradio as gr
2
- import os
3
- import subprocess
4
- import shlex
5
- import spaces
6
- import torch
7
- access_token = os.getenv("HUGGINGFACE_TOKEN")
8
- subprocess.run(
9
- shlex.split(
10
- "pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt240/download.html"
11
- )
12
- )
13
-
14
- subprocess.run(
15
- shlex.split(
16
- "pip install ./extension/nvdiffrast-0.3.1+torch-py3-none-any.whl --force-reinstall --no-deps"
17
- )
18
- )
19
-
20
- subprocess.run(
21
- shlex.split(
22
- "pip install ./extension/renderutils_plugin-0.1.0-cp310-cp310-linux_x86_64.whl --force-reinstall --no-deps"
23
- )
24
- )
25
- def install_cuda_toolkit():
26
- # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run"
27
- # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.2.0/local_installers/cuda_12.2.0_535.54.03_linux.run"
28
- CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.1.0/local_installers/cuda_12.1.0_530.30.02_linux.run"
29
- CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL)
30
- subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE])
31
- subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE])
32
- subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"])
33
-
34
- os.environ["CUDA_HOME"] = "/usr/local/cuda"
35
- os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"])
36
- os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % (
37
- os.environ["CUDA_HOME"],
38
- "" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"],
39
- )
40
- # Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range
41
- os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
42
- print("==> finfish install")
43
- install_cuda_toolkit()
44
- @spaces.GPU
45
- def check_gpu():
46
- os.environ['CUDA_HOME'] = '/usr/local/cuda-12.1'
47
- os.environ['PATH'] += ':/usr/local/cuda-12.1/bin'
48
- # os.environ['LD_LIBRARY_PATH'] += ':/usr/local/cuda-12.1/lib64'
49
- os.environ['LD_LIBRARY_PATH'] = "/usr/local/cuda-12.1/lib64:" + os.environ.get('LD_LIBRARY_PATH', '')
50
- subprocess.run(['nvidia-smi']) # 测试 CUDA 是否可用
51
- print(f"torch.cuda.is_available:{torch.cuda.is_available()}")
52
- check_gpu()
53
-
54
- from PIL import Image
55
- from einops import rearrange
56
- from diffusers import FluxPipeline
57
- from models.lrm.utils.camera_util import get_flux_input_cameras
58
- from models.lrm.utils.infer_util import save_video
59
- from models.lrm.utils.mesh_util import save_obj, save_obj_with_mtl
60
- from models.lrm.utils.render_utils import rotate_x, rotate_y
61
- from models.lrm.utils.train_util import instantiate_from_config
62
- from models.ISOMER.reconstruction_func import reconstruction
63
- from models.ISOMER.projection_func import projection
64
- import os
65
- from einops import rearrange
66
- from omegaconf import OmegaConf
67
- import torch
68
- import numpy as np
69
- import trimesh
70
- import torchvision
71
- import torch.nn.functional as F
72
- from PIL import Image
73
- from torchvision import transforms
74
- from torchvision.transforms import v2
75
- from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, AutoencoderTiny, AutoencoderKL
76
- from transformers import CLIPTextModel, CLIPTokenizer,T5EncoderModel, T5TokenizerFast
77
- from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
78
- from diffusers import FluxPipeline
79
- from pytorch_lightning import seed_everything
80
- import os
81
- from huggingface_hub import hf_hub_download
82
-
83
-
84
- from utils.tool import NormalTransfer, get_background, get_render_cameras_video, load_mipmap, render_frames
85
-
86
- device_0 = "cuda"
87
- device_1 = "cuda"
88
- resolution = 512
89
- save_dir = "./outputs"
90
- normal_transfer = NormalTransfer()
91
- isomer_azimuths = torch.from_numpy(np.array([0, 90, 180, 270])).float().to(device_1)
92
- isomer_elevations = torch.from_numpy(np.array([5, 5, 5, 5])).float().to(device_1)
93
- isomer_radius = 4.5
94
- isomer_geo_weights = torch.from_numpy(np.array([1, 0.9, 1, 0.9])).float().to(device_1)
95
- isomer_color_weights = torch.from_numpy(np.array([1, 0.5, 1, 0.5])).float().to(device_1)
96
-
97
- # model initialization and loading
98
- # flux
99
- # # taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=torch.bfloat16).to(device_0)
100
- # # good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16, token=access_token).to(device_0)
101
- # flux_pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, token=access_token).to(device=device_0, dtype=torch.bfloat16)
102
- # # flux_pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, vae=taef1, token=access_token).to(device_0)
103
- # flux_lora_ckpt_path = hf_hub_download(repo_id="LTT/xxx-ckpt", filename="rgb_normal_large.safetensors", repo_type="model", token=access_token)
104
- # flux_pipe.load_lora_weights(flux_lora_ckpt_path)
105
- # flux_pipe.to(device=device_0, dtype=torch.bfloat16)
106
- # torch.cuda.empty_cache()
107
- # flux_pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(flux_pipe)
108
-
109
-
110
- # lrm
111
- config = OmegaConf.load("./models/lrm/config/PRM_inference.yaml")
112
- model_config = config.model_config
113
- infer_config = config.infer_config
114
- model = instantiate_from_config(model_config)
115
- model_ckpt_path = hf_hub_download(repo_id="LTT/PRM", filename="final_ckpt.ckpt", repo_type="model")
116
- state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
117
- state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
118
- model.load_state_dict(state_dict, strict=True)
119
- model = model.to(device_1)
120
- torch.cuda.empty_cache()
121
- @spaces.GPU
122
- def lrm_reconstructions(image, input_cameras, save_path=None, name="temp", export_texmap=False, if_save_video=False):
123
- images = image.unsqueeze(0).to(device_1)
124
- images = v2.functional.resize(images, 512, interpolation=3, antialias=True).clamp(0, 1)
125
- # breakpoint()
126
- with torch.no_grad():
127
- # get triplane
128
- planes = model.forward_planes(images, input_cameras)
129
-
130
- mesh_path_idx = os.path.join(save_path, f'{name}.obj')
131
-
132
- mesh_out = model.extract_mesh(
133
- planes,
134
- use_texture_map=export_texmap,
135
- **infer_config,
136
- )
137
- if export_texmap:
138
- vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out
139
- save_obj_with_mtl(
140
- vertices.data.cpu().numpy(),
141
- uvs.data.cpu().numpy(),
142
- faces.data.cpu().numpy(),
143
- mesh_tex_idx.data.cpu().numpy(),
144
- tex_map.permute(1, 2, 0).data.cpu().numpy(),
145
- mesh_path_idx,
146
- )
147
- else:
148
- vertices, faces, vertex_colors = mesh_out
149
- save_obj(vertices, faces, vertex_colors, mesh_path_idx)
150
- print(f"Mesh saved to {mesh_path_idx}")
151
-
152
- render_size = 512
153
- if if_save_video:
154
- video_path_idx = os.path.join(save_path, f'{name}.mp4')
155
- render_size = infer_config.render_resolution
156
- ENV = load_mipmap("models/lrm/env_mipmap/6")
157
- materials = (0.0,0.9)
158
-
159
- all_mv, all_mvp, all_campos = get_render_cameras_video(
160
- batch_size=1,
161
- M=24,
162
- radius=4.5,
163
- elevation=(90, 60.0),
164
- is_flexicubes=True,
165
- fov=30
166
- )
167
-
168
- frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals, alphas = render_frames(
169
- model,
170
- planes,
171
- render_cameras=all_mvp,
172
- camera_pos=all_campos,
173
- env=ENV,
174
- materials=materials,
175
- render_size=render_size,
176
- chunk_size=20,
177
- is_flexicubes=True,
178
- )
179
- normals = (torch.nn.functional.normalize(normals) + 1) / 2
180
- normals = normals * alphas + (1-alphas)
181
- all_frames = torch.cat([frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals], dim=3)
182
-
183
- save_video(
184
- all_frames,
185
- video_path_idx,
186
- fps=30,
187
- )
188
- print(f"Video saved to {video_path_idx}")
189
-
190
- return vertices, faces
191
-
192
-
193
- def local_normal_global_transform(local_normal_images, azimuths_deg, elevations_deg):
194
- if local_normal_images.min() >= 0:
195
- local_normal = local_normal_images.float() * 2 - 1
196
- else:
197
- local_normal = local_normal_images.float()
198
- global_normal = normal_transfer.trans_local_2_global(local_normal, azimuths_deg, elevations_deg, radius=4.5, for_lotus=False)
199
- global_normal[...,0] *= -1
200
- global_normal = (global_normal + 1) / 2
201
- global_normal = global_normal.permute(0, 3, 1, 2)
202
- return global_normal
203
-
204
- # 生成多视图图像
205
- @spaces.GPU(duration=120)
206
- def generate_multi_view_images(prompt, seed):
207
- # torch.cuda.empty_cache()
208
- # generator = torch.manual_seed(seed)
209
- generator = torch.Generator().manual_seed(seed)
210
- with torch.no_grad():
211
- img = flux_pipe(
212
- prompt=prompt,
213
- num_inference_steps=5,
214
- guidance_scale=3.5,
215
- num_images_per_prompt=1,
216
- width=resolution * 2,
217
- height=resolution * 1,
218
- output_type='np',
219
- generator=generator,
220
- ).images
221
- # for img in flux_pipe.flux_pipe_call_that_returns_an_iterable_of_images(
222
- # prompt=prompt,
223
- # guidance_scale=3.5,
224
- # num_inference_steps=4,
225
- # width=resolution * 4,
226
- # height=resolution * 2,
227
- # generator=generator,
228
- # output_type="np",
229
- # good_vae=good_vae,
230
- # ):
231
- # pass
232
- # 返回最终的图像和种子(通过外部调用处理)
233
- return img
234
-
235
- # 重建 3D 模型
236
- @spaces.GPU
237
- def reconstruct_3d_model(images, prompt):
238
- global model
239
- model.init_flexicubes_geometry(device_1, fovy=50.0)
240
- model = model.eval()
241
- rgb_normal_grid = images
242
- save_dir_path = os.path.join(save_dir, prompt.replace(" ", "_"))
243
- os.makedirs(save_dir_path, exist_ok=True)
244
-
245
- images = torch.from_numpy(rgb_normal_grid).squeeze(0).permute(2, 0, 1).contiguous().float() # (3, 1024, 2048)
246
- images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=2, m=4) # (8, 3, 512, 512)
247
- rgb_multi_view = images[:4, :3, :, :]
248
- normal_multi_view = images[4:, :3, :, :]
249
- multi_view_mask = get_background(normal_multi_view)
250
- rgb_multi_view = rgb_multi_view * rgb_multi_view + (1-multi_view_mask)
251
- input_cameras = get_flux_input_cameras(batch_size=1, radius=4.2, fov=30).to(device_1)
252
- vertices, faces = lrm_reconstructions(rgb_multi_view, input_cameras, save_path=save_dir_path, name='lrm', export_texmap=False, if_save_video=True)
253
- # local normal to global normal
254
-
255
- global_normal = local_normal_global_transform(normal_multi_view.permute(0, 2, 3, 1), isomer_azimuths, isomer_elevations)
256
- global_normal = global_normal * multi_view_mask + (1-multi_view_mask)
257
-
258
- global_normal = global_normal.permute(0,2,3,1)
259
- rgb_multi_view = rgb_multi_view.permute(0,2,3,1)
260
- multi_view_mask = multi_view_mask.permute(0,2,3,1).squeeze(-1)
261
- vertices = torch.from_numpy(vertices).to(device_1)
262
- faces = torch.from_numpy(faces).to(device_1)
263
- vertices = vertices @ rotate_x(np.pi / 2, device=vertices.device)[:3, :3]
264
- vertices = vertices @ rotate_y(np.pi / 2, device=vertices.device)[:3, :3]
265
-
266
- # global_normal: B,H,W,3
267
- # multi_view_mask: B,H,W
268
- # rgb_multi_view: B,H,W,3
269
-
270
- meshes = reconstruction(
271
- normal_pils=global_normal,
272
- masks=multi_view_mask,
273
- weights=isomer_geo_weights,
274
- fov=30,
275
- radius=isomer_radius,
276
- camera_angles_azi=isomer_azimuths,
277
- camera_angles_ele=isomer_elevations,
278
- expansion_weight_stage1=0.1,
279
- init_type="file",
280
- init_verts=vertices,
281
- init_faces=faces,
282
- stage1_steps=0,
283
- stage2_steps=50,
284
- start_edge_len_stage1=0.1,
285
- end_edge_len_stage1=0.02,
286
- start_edge_len_stage2=0.02,
287
- end_edge_len_stage2=0.005,
288
- )
289
-
290
-
291
- save_glb_addr = projection(
292
- meshes,
293
- masks=multi_view_mask,
294
- images=rgb_multi_view,
295
- azimuths=isomer_azimuths,
296
- elevations=isomer_elevations,
297
- weights=isomer_color_weights,
298
- fov=30,
299
- radius=isomer_radius,
300
- save_dir=f"{save_dir_path}/ISOMER/",
301
- )
302
-
303
- return save_glb_addr
304
-
305
- # Gradio 接口函数
306
- @spaces.GPU
307
- def gradio_pipeline(prompt, seed):
308
- import ctypes
309
- # 显式加载 libnvrtc.so.12
310
- cuda_lib_path = "/usr/local/cuda-12.1/lib64/libnvrtc.so.12"
311
- try:
312
- ctypes.CDLL(cuda_lib_path, mode=ctypes.RTLD_GLOBAL)
313
- print(f"Successfully preloaded {cuda_lib_path}")
314
- except OSError as e:
315
- print(f"Failed to preload {cuda_lib_path}: {e}")
316
- # 生成多视图图像
317
- # rgb_normal_grid = generate_multi_view_images(prompt, seed)
318
- rgb_normal_grid = np.load("rgb_normal_grid.npy")
319
- image_preview = Image.fromarray((rgb_normal_grid[0] * 255).astype(np.uint8))
320
-
321
- # 3d reconstruction
322
-
323
-
324
- # 重建 3D 模型并返回 glb 路径
325
- save_glb_addr = reconstruct_3d_model(rgb_normal_grid, prompt)
326
- # save_glb_addr = None
327
- return image_preview, save_glb_addr
328
-
329
- # Gradio Blocks 应用
330
- with gr.Blocks() as demo:
331
- with gr.Row(variant="panel"):
332
- # 左侧输入区域
333
- with gr.Column():
334
- with gr.Row():
335
- prompt_input = gr.Textbox(
336
- label="Enter Prompt",
337
- placeholder="Describe your 3D model...",
338
- lines=2,
339
- elem_id="prompt_input"
340
- )
341
-
342
- with gr.Row():
343
- sample_seed = gr.Number(value=42, label="Seed Value", precision=0)
344
-
345
- with gr.Row():
346
- submit = gr.Button("Generate", elem_id="generate", variant="primary")
347
-
348
- with gr.Row(variant="panel"):
349
- gr.Markdown("Examples:")
350
- gr.Examples(
351
- examples=[
352
- ["a castle on a hill"],
353
- ["an owl wearing a hat"],
354
- ["a futuristic car"]
355
- ],
356
- inputs=[prompt_input],
357
- label="Prompt Examples"
358
- )
359
-
360
- # 右侧输出区域
361
- with gr.Column():
362
- with gr.Row():
363
- rgb_normal_grid_image = gr.Image(
364
- label="RGB Normal Grid",
365
- type="pil",
366
- interactive=False
367
- )
368
-
369
- with gr.Row():
370
- with gr.Tab("GLB"):
371
- output_glb_model = gr.Model3D(
372
- label="Generated 3D Model (GLB Format)",
373
- interactive=False
374
- )
375
- gr.Markdown("Download the model for proper visualization.")
376
-
377
- # 处理逻辑
378
- submit.click(
379
- fn=gradio_pipeline, inputs=[prompt_input, sample_seed],
380
- outputs=[rgb_normal_grid_image, output_glb_model]
381
- )
382
-
383
- # 启动应用
384
- # demo.queue(max_size=10)
385
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app_demo_.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import subprocess
4
+ import spaces
5
+ import ctypes
6
+ import shlex
7
+ import torch
8
+
9
+ subprocess.run(
10
+ shlex.split(
11
+ "pip install ./custom_diffusers --force-reinstall --no-deps"
12
+ )
13
+ )
14
+ subprocess.run(
15
+ shlex.split(
16
+ "pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt240/download.html"
17
+ )
18
+ )
19
+
20
+ subprocess.run(
21
+ shlex.split(
22
+ "pip install ./extension/nvdiffrast-0.3.1+torch-py3-none-any.whl --force-reinstall --no-deps"
23
+ )
24
+ )
25
+
26
+ subprocess.run(
27
+ shlex.split(
28
+ "pip install ./extension/renderutils_plugin-0.1.0-cp310-cp310-linux_x86_64.whl --force-reinstall --no-deps"
29
+ )
30
+ )
31
+ # download cudatoolkit
32
+ def install_cuda_toolkit():
33
+ # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run"
34
+ # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.2.0/local_installers/cuda_12.2.0_535.54.03_linux.run"
35
+ CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.1.0/local_installers/cuda_12.1.0_530.30.02_linux.run"
36
+ CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL)
37
+ subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE])
38
+ subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE])
39
+ subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"])
40
+
41
+ os.environ["CUDA_HOME"] = "/usr/local/cuda"
42
+ os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"])
43
+ os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % (
44
+ os.environ["CUDA_HOME"],
45
+ "" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"],
46
+ )
47
+ # Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range
48
+ os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
49
+ print("==> finfish install")
50
+ install_cuda_toolkit()
51
+
52
+
53
+ import base64
54
+ import re
55
+ import sys
56
+
57
+ sys.path.append(os.path.abspath(os.path.join(__file__, '../')))
58
+ if 'OMP_NUM_THREADS' not in os.environ:
59
+ os.environ['OMP_NUM_THREADS'] = '32'
60
+
61
+ import shutil
62
+ import json
63
+ import requests
64
+ import shutil
65
+ import threading
66
+ from PIL import Image
67
+ import time
68
+ import trimesh
69
+
70
+ import random
71
+ import time
72
+ import numpy as np
73
+ from video_render import render_video_from_obj
74
+
75
+ access_token = os.getenv("HUGGINGFACE_TOKEN")
76
+ from pipeline.kiss3d_wrapper import init_wrapper_from_config, run_text_to_3d, run_image_to_3d, image2mesh_preprocess, image2mesh_main
77
+
78
+
79
+ # Add logo file path and hyperlinks
80
+ LOGO_PATH = "app_assets/logo_temp_.png" # Update this to the actual path of your logo
81
+ ARXIV_LINK = "https://arxiv.org/abs/example"
82
+ GITHUB_LINK = "https://github.com/example"
83
+
84
+
85
+ k3d_wrapper = init_wrapper_from_config('./pipeline/pipeline_config/default.yaml')
86
+
87
+
88
+ from models.ISOMER.scripts.utils import fix_vert_color_glb
89
+ torch.backends.cuda.matmul.allow_tf32 = True
90
+
91
+ def check_gpu():
92
+ os.environ['CUDA_HOME'] = '/usr/local/cuda-12.1'
93
+ os.environ['PATH'] += ':/usr/local/cuda-12.1/bin'
94
+ # os.environ['LD_LIBRARY_PATH'] += ':/usr/local/cuda-12.1/lib64'
95
+ os.environ['LD_LIBRARY_PATH'] = "/usr/local/cuda-12.1/lib64:" + os.environ.get('LD_LIBRARY_PATH', '')
96
+ # 显式加载 libnvrtc.so.12
97
+ cuda_lib_path = "/usr/local/cuda-12.1/lib64/libnvrtc.so.12"
98
+ try:
99
+ ctypes.CDLL(cuda_lib_path, mode=ctypes.RTLD_GLOBAL)
100
+ print(f"Successfully preloaded {cuda_lib_path}")
101
+ except OSError as e:
102
+ print(f"Failed to preload {cuda_lib_path}: {e}")
103
+ check_gpu()
104
+ print(f"GPU: {torch.cuda.is_available()}")
105
+ subprocess.run(['nvidia-smi'])
106
+
107
+ TEMP_MESH_ADDRESS=''
108
+
109
+ mesh_cache = None
110
+ preprocessed_input_image = None
111
+
112
+ def save_cached_mesh():
113
+ global mesh_cache
114
+ return mesh_cache
115
+ # if mesh_cache is None:
116
+ # return None
117
+ # return save_py3dmesh_with_trimesh_fast(mesh_cache)
118
+
119
+ def save_py3dmesh_with_trimesh_fast(meshes, save_glb_path=TEMP_MESH_ADDRESS, apply_sRGB_to_LinearRGB=True):
120
+ from pytorch3d.structures import Meshes
121
+ import trimesh
122
+
123
+ # convert from pytorch3d meshes to trimesh mesh
124
+ vertices = meshes.verts_packed().cpu().float().numpy()
125
+ triangles = meshes.faces_packed().cpu().long().numpy()
126
+ np_color = meshes.textures.verts_features_packed().cpu().float().numpy()
127
+ if save_glb_path.endswith(".glb"):
128
+ # rotate 180 along +Y
129
+ vertices[:, [0, 2]] = -vertices[:, [0, 2]]
130
+
131
+ def srgb_to_linear(c_srgb):
132
+ c_linear = np.where(c_srgb <= 0.04045, c_srgb / 12.92, ((c_srgb + 0.055) / 1.055) ** 2.4)
133
+ return c_linear.clip(0, 1.)
134
+ if apply_sRGB_to_LinearRGB:
135
+ np_color = srgb_to_linear(np_color)
136
+ assert vertices.shape[0] == np_color.shape[0]
137
+ assert np_color.shape[1] == 3
138
+ assert 0 <= np_color.min() and np_color.max() <= 1, f"min={np_color.min()}, max={np_color.max()}"
139
+ mesh = trimesh.Trimesh(vertices=vertices, faces=triangles, vertex_colors=np_color)
140
+ mesh.remove_unreferenced_vertices()
141
+ # save mesh
142
+ mesh.export(save_glb_path)
143
+ if save_glb_path.endswith(".glb"):
144
+ fix_vert_color_glb(save_glb_path)
145
+ print(f"saving to {save_glb_path}")
146
+ #
147
+ #
148
+ # @spaces.GPU
149
+ def text_to_detailed(prompt, seed=None):
150
+ # print(torch.cuda.is_available())
151
+ # print(f"Before text_to_detailed: {torch.cuda.memory_allocated() / 1024**3} GB")
152
+ return k3d_wrapper.get_detailed_prompt(prompt, seed)
153
+
154
+ def text_to_image(prompt, seed=None, strength=1.0,lora_scale=1.0, num_inference_steps=30, redux_hparam=None, init_image=None, **kwargs):
155
+ # print(f"Before text_to_image: {torch.cuda.memory_allocated() / 1024**3} GB")
156
+ k3d_wrapper.renew_uuid()
157
+ init_image = None
158
+ # if init_image_path is not None:
159
+ # init_image = Image.open(init_image_path)
160
+ result = k3d_wrapper.generate_3d_bundle_image_text(
161
+ prompt,
162
+ image=init_image,
163
+ strength=strength,
164
+ lora_scale=lora_scale,
165
+ num_inference_steps=num_inference_steps,
166
+ seed=int(seed) if seed is not None else None,
167
+ redux_hparam=redux_hparam,
168
+ save_intermediate_results=True,
169
+ **kwargs)
170
+ return result[-1]
171
+
172
+ def image2mesh_preprocess_(input_image_, seed, use_mv_rgb=True):
173
+ global preprocessed_input_image
174
+
175
+ seed = int(seed) if seed is not None else None
176
+
177
+ # TODO: delete this later
178
+ k3d_wrapper.del_llm_model()
179
+
180
+ input_image_save_path, reference_save_path, caption = image2mesh_preprocess(k3d_wrapper, input_image_, seed, use_mv_rgb)
181
+
182
+ preprocessed_input_image = Image.open(input_image_save_path)
183
+ return reference_save_path, caption
184
+
185
+ @spaces.GPU
186
+ def image2mesh_main_(reference_3d_bundle_image, caption, seed, strength1=0.5, strength2=0.95, enable_redux=True, use_controlnet=True, if_video=True):
187
+ global mesh_cache
188
+ seed = int(seed) if seed is not None else None
189
+
190
+
191
+ # TODO: delete this later
192
+ k3d_wrapper.del_llm_model()
193
+
194
+ input_image = preprocessed_input_image
195
+
196
+ reference_3d_bundle_image = torch.tensor(reference_3d_bundle_image).permute(2,0,1)/255
197
+
198
+ gen_save_path, recon_mesh_path = image2mesh_main(k3d_wrapper, input_image, reference_3d_bundle_image, caption=caption, seed=seed, strength1=strength1, strength2=strength2, enable_redux=enable_redux, use_controlnet=use_controlnet)
199
+ mesh_cache = recon_mesh_path
200
+
201
+
202
+ # gen_save_ = Image.open(gen_save_path)
203
+
204
+ if if_video:
205
+ video_path = recon_mesh_path.replace('.obj','.mp4').replace('.glb','.mp4')
206
+ render_video_from_obj(recon_mesh_path, video_path)
207
+ print(f"After bundle_image_to_mesh: {torch.cuda.memory_allocated() / 1024**3} GB")
208
+ return gen_save_path, video_path
209
+ else:
210
+ return gen_save_path, recon_mesh_path
211
+ # return gen_save_path, recon_mesh_path
212
+
213
+ @spaces.GPU
214
+ def bundle_image_to_mesh(
215
+ gen_3d_bundle_image,
216
+ lrm_radius = 4.15,
217
+ isomer_radius = 4.5,
218
+ reconstruction_stage1_steps = 10,
219
+ reconstruction_stage2_steps = 50,
220
+ save_intermediate_results=True,
221
+ if_video=True
222
+ ):
223
+ global mesh_cache
224
+ print(f"Before bundle_image_to_mesh: {torch.cuda.memory_allocated() / 1024**3} GB")
225
+ k3d_wrapper.recon_model.init_flexicubes_geometry("cuda:0", fovy=50.0)
226
+ # TODO: delete this later
227
+ k3d_wrapper.del_llm_model()
228
+
229
+ print(f"Before bundle_image_to_mesh after deleting llm model: {torch.cuda.memory_allocated() / 1024**3} GB")
230
+
231
+ gen_3d_bundle_image = torch.tensor(gen_3d_bundle_image).permute(2,0,1)/255
232
+ # recon from 3D Bundle image
233
+ recon_mesh_path = k3d_wrapper.reconstruct_3d_bundle_image(gen_3d_bundle_image, lrm_render_radius=lrm_radius, isomer_radius=isomer_radius, save_intermediate_results=save_intermediate_results, reconstruction_stage1_steps=int(reconstruction_stage1_steps), reconstruction_stage2_steps=int(reconstruction_stage2_steps))
234
+ mesh_cache = recon_mesh_path
235
+
236
+ if if_video:
237
+ video_path = recon_mesh_path.replace('.obj','.mp4').replace('.glb','.mp4')
238
+ # # 检查这个video_path文件大小是是否超过50KB,不超过的话就认为是空文件,需要重新渲染
239
+ # if os.path.exists(video_path):
240
+ # print(f"file size:{os.path.getsize(video_path)}")
241
+ # if os.path.getsize(video_path) > 50*1024:
242
+ # print(f"video path:{video_path}")
243
+ # return video_path
244
+ render_video_from_obj(recon_mesh_path, video_path)
245
+ print(f"After bundle_image_to_mesh: {torch.cuda.memory_allocated() / 1024**3} GB")
246
+ return video_path
247
+ else:
248
+ return recon_mesh_path
249
+
250
+ _HEADER_=f"""
251
+ <img src="{LOGO_PATH}">
252
+ <h2><b>Official 🤗 Gradio Demo</b></h2><h2>
253
+ <b>Kiss3DGen: Repurposing Image Diffusion Models for 3D Asset Generation</b></a></h2>
254
+
255
+ <p>**Kiss3DGen** is xxxxxxxxx</p>
256
+
257
+ [![arXiv](https://img.shields.io/badge/arXiv-Link-red)]({ARXIV_LINK}) [![GitHub](https://img.shields.io/badge/GitHub-Repo-blue)]({GITHUB_LINK})
258
+ """
259
+
260
+ _CITE_ = r"""
261
+ <h2>If Kiss3DGen is helpful, please help to ⭐ the <a href='{""" + GITHUB_LINK + r"""}' target='_blank'>Github Repo</a>. Thanks!</h2>
262
+
263
+ 📝 **Citation**
264
+
265
+ If you find our work useful for your research or applications, please cite using this bibtex:
266
+ ```bibtex
267
+ @article{xxxx,
268
+ title={xxxx},
269
+ author={xxxx},
270
+ journal={xxxx},
271
+ year={xxxx}
272
+ }
273
+ ```
274
+
275
+ 📋 **License**
276
+
277
+ Apache-2.0 LICENSE. Please refer to the [LICENSE file](https://huggingface.co/spaces/TencentARC/InstantMesh/blob/main/LICENSE) for details.
278
+
279
+ 📧 **Contact**
280
+
281
+ If you have any questions, feel free to open a discussion or contact us at <b>xxx@xxxx</b>.
282
+ """
283
+
284
+ def image_to_base64(image_path):
285
+ """Converts an image file to a base64-encoded string."""
286
+ with open(image_path, "rb") as img_file:
287
+ return base64.b64encode(img_file.read()).decode('utf-8')
288
+
289
+ def main():
290
+
291
+ torch.set_grad_enabled(False)
292
+
293
+ # Convert the logo image to base64
294
+ logo_base64 = image_to_base64(LOGO_PATH)
295
+ # with gr.Blocks() as demo:
296
+ with gr.Blocks(css="""
297
+ body {
298
+ display: flex;
299
+ justify-content: center;
300
+ align-items: center;
301
+ min-height: 100vh;
302
+ margin: 0;
303
+ padding: 0;
304
+ }
305
+ #col-container { margin: 0px auto; max-width: 200px; }
306
+
307
+
308
+ .gradio-container {
309
+ max-width: 1000px;
310
+ margin: auto;
311
+ width: 100%;
312
+ }
313
+ #center-align-column {
314
+ display: flex;
315
+ justify-content: center;
316
+ align-items: center;
317
+ }
318
+ #right-align-column {
319
+ display: flex;
320
+ justify-content: flex-end;
321
+ align-items: center;
322
+ }
323
+ h1 {text-align: center;}
324
+ h2 {text-align: center;}
325
+ h3 {text-align: center;}
326
+ p {text-align: center;}
327
+ img {text-align: right;}
328
+ .right {
329
+ display: block;
330
+ margin-left: auto;
331
+ }
332
+ .center {
333
+ display: block;
334
+ margin-left: auto;
335
+ margin-right: auto;
336
+ width: 50%;
337
+
338
+ #content-container {
339
+ max-width: 1200px;
340
+ margin: 0 auto;
341
+ }
342
+ #example-container {
343
+ max-width: 300px;
344
+ margin: 0 auto;
345
+ }
346
+ """,elem_id="col-container") as demo:
347
+ # Header Section
348
+ # gr.Image(value=LOGO_PATH, width=64, height=64)
349
+ # gr.Markdown(_HEADER_)
350
+ with gr.Row(elem_id="content-container"):
351
+ # with gr.Column(scale=1):
352
+ # pass
353
+ # with gr.Column(scale=1, elem_id="right-align-column"):
354
+ # # gr.Image(value=LOGO_PATH, interactive=False, show_label=False, width=64, height=64, elem_id="logo-image")
355
+ # # gr.Markdown(f"<img src='{LOGO_PATH}' alt='Logo' style='width:64px;height:64px;border:0;'>")
356
+ # # gr.HTML(f"<img src='data:image/png;base64,{logo_base64}' alt='Logo' class='right' style='width:64px;height:64px;border:0;text-align:right;'>")
357
+ # pass
358
+ with gr.Column(scale=7, elem_id="center-align-column"):
359
+ gr.Markdown(f"""
360
+ ## Official 🤗 Gradio Demo
361
+ # Kiss3DGen: Repurposing Image Diffusion Models for 3D Asset Generation""")
362
+ gr.HTML(f"<img src='data:image/png;base64,{logo_base64}' alt='Logo' class='center' style='width:64px;height:64px;border:0;text-align:center;'>")
363
+
364
+ gr.HTML(f"""
365
+ <div style="display: flex; justify-content: center; align-items: center; gap: 10px;">
366
+ <a href="{ARXIV_LINK}" target="_blank">
367
+ <img src="https://img.shields.io/badge/arXiv-Link-red" alt="arXiv">
368
+ </a>
369
+ <a href="{GITHUB_LINK}" target="_blank">
370
+ <img src="https://img.shields.io/badge/GitHub-Repo-blue" alt="GitHub">
371
+ </a>
372
+ </div>
373
+
374
+ """)
375
+
376
+
377
+ # gr.HTML(f"""
378
+ # <div style="display: flex; gap: 10px; align-items: center;"><a href="{ARXIV_LINK}" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/arXiv-Link-red" alt="arXiv"></a> <a href="{GITHUB_LINK}" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/GitHub-Repo-blue" alt="GitHub"></a></div>
379
+ # """)
380
+
381
+ # gr.Markdown(f"""
382
+ # [![arXiv](https://img.shields.io/badge/arXiv-Link-red)]({ARXIV_LINK}) [![GitHub](https://img.shields.io/badge/GitHub-Repo-blue)]({GITHUB_LINK})
383
+ # """, elem_id="title")
384
+ # with gr.Column(scale=1):
385
+ # pass
386
+ # with gr.Row():
387
+ # gr.Markdown(f"[![arXiv](https://img.shields.io/badge/arXiv-Link-red)]({ARXIV_LINK})")
388
+ # gr.Markdown(f"[![GitHub](https://img.shields.io/badge/GitHub-Repo-blue)]({GITHUB_LINK})")
389
+
390
+ # Tabs Section
391
+ with gr.Tabs(selected='tab_text_to_3d', elem_id="content-container") as main_tabs:
392
+ with gr.TabItem('Text-to-3D', id='tab_text_to_3d'):
393
+ with gr.Row():
394
+ with gr.Column(scale=1):
395
+ prompt = gr.Textbox(value="", label="Input Prompt", lines=4)
396
+ seed1 = gr.Number(value=10, label="Seed")
397
+
398
+ with gr.Row(elem_id="example-container"):
399
+ gr.Examples(
400
+ examples=[
401
+ # ["A tree with red leaves"],
402
+ # ["A dragon with black texture"],
403
+ ["A girl with pink hair"],
404
+ ["A boy playing guitar"],
405
+
406
+
407
+ ["A dog wearing a hat"],
408
+ ["A boy playing basketball"],
409
+ # [""],
410
+ # [""],
411
+ # [""],
412
+
413
+ ],
414
+ inputs=[prompt], # 将选中的示例填入 prompt 文本框
415
+ label="Example Prompts"
416
+ )
417
+ btn_text2detailed = gr.Button("Refine to detailed prompt")
418
+ detailed_prompt = gr.Textbox(value="", label="Detailed Prompt", placeholder="detailed prompt will be generated here base on your input prompt. You can also edit this prompt", lines=4, interactive=True)
419
+ btn_text2img = gr.Button("Generate Images")
420
+
421
+ with gr.Column(scale=1):
422
+ output_image1 = gr.Image(label="Generated image", interactive=False)
423
+
424
+
425
+ # lrm_radius = gr.Number(value=4.15, label="lrm_radius")
426
+ # isomer_radius = gr.Number(value=4.5, label="isomer_radius")
427
+ # reconstruction_stage1_steps = gr.Number(value=10, label="reconstruction_stage1_steps")
428
+ # reconstruction_stage2_steps = gr.Number(value=50, label="reconstruction_stage2_steps")
429
+
430
+ btn_gen_mesh = gr.Button("Generate Mesh")
431
+ output_video1 = gr.Video(label="Generated Video", interactive=False, loop=True, autoplay=True)
432
+ btn_download1 = gr.Button("Download Mesh")
433
+
434
+ file_output1 = gr.File()
435
+
436
+ with gr.TabItem('Image-to-3D', id='tab_image_to_3d'):
437
+ with gr.Row():
438
+ with gr.Column(scale=1):
439
+ image = gr.Image(label="Input Image", type="pil")
440
+
441
+ seed2 = gr.Number(value=10, label="Seed (0 for random)")
442
+
443
+ btn_img2mesh_preprocess = gr.Button("Preprocess Image")
444
+
445
+ image_caption = gr.Textbox(value="", label="Image Caption", placeholder="caption will be generated here base on your input image. You can also edit this caption", lines=4, interactive=True)
446
+
447
+ output_image2 = gr.Image(label="Generated image", interactive=False)
448
+ strength1 = gr.Slider(minimum=0, maximum=1.0, step=0.01, value=0.5, label="strength1")
449
+ strength2 = gr.Slider(minimum=0, maximum=1.0, step=0.01, value=0.95, label="strength2")
450
+ enable_redux = gr.Checkbox(label="enable redux", value=True)
451
+ use_controlnet = gr.Checkbox(label="use controlnet", value=True)
452
+
453
+ btn_img2mesh_main = gr.Button("Generate Mesh")
454
+
455
+ with gr.Column(scale=1):
456
+
457
+ # output_mesh2 = gr.Model3D(label="Generated Mesh", interactive=False)
458
+ output_image3 = gr.Image(label="gen save image", interactive=False)
459
+ output_video2 = gr.Video(label="Generated Video", interactive=False, loop=True, autoplay=True)
460
+ btn_download2 = gr.Button("Download Mesh")
461
+ file_output2 = gr.File()
462
+
463
+ # Image2
464
+ btn_img2mesh_preprocess.click(fn=image2mesh_preprocess_, inputs=[image, seed2], outputs=[output_image2, image_caption])
465
+
466
+ btn_img2mesh_main.click(fn=image2mesh_main_, inputs=[output_image2, image_caption, seed2, strength1, strength2, enable_redux, use_controlnet], outputs=[output_image3, output_video2])
467
+
468
+
469
+ btn_download2.click(fn=save_cached_mesh, inputs=[], outputs=file_output2)
470
+
471
+
472
+ # Button Click Events
473
+ # Text2
474
+ btn_text2detailed.click(fn=text_to_detailed, inputs=[prompt, seed1], outputs=detailed_prompt)
475
+ btn_text2img.click(fn=text_to_image, inputs=[detailed_prompt, seed1], outputs=output_image1)
476
+ btn_gen_mesh.click(fn=bundle_image_to_mesh, inputs=[output_image1,], outputs=output_video1)
477
+ # btn_gen_mesh.click(fn=bundle_image_to_mesh, inputs=[output_image1, lrm_radius, isomer_radius, reconstruction_stage1_steps, reconstruction_stage2_steps], outputs=output_video1)
478
+
479
+ with gr.Row():
480
+ pass
481
+ with gr.Row():
482
+ gr.Markdown(_CITE_)
483
+
484
+ # demo.queue(default_concurrency_limit=1)
485
+ # demo.launch(server_name="0.0.0.0", server_port=9239)
486
+ # subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
487
+ demo.launch()
488
+
489
+
490
+ if __name__ == "__main__":
491
+ main()