michaelj commited on
Commit
ac7cfc6
·
verified ·
1 Parent(s): 3aafe99
Files changed (1) hide show
  1. app.py +53 -54
app.py CHANGED
@@ -2,60 +2,37 @@ import logging
2
  import os
3
  import tempfile
4
  import time
5
-
6
  import gradio as gr
7
  import numpy as np
8
  import rembg
9
  import torch
10
  from PIL import Image
11
  from functools import partial
12
-
13
  from tsr.system import TSR
14
  from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation
15
-
16
- #HF_TOKEN = os.getenv("HF_TOKEN")
17
-
18
- HEADER = """
19
- **TripoSR** is a state-of-the-art open-source model for **fast** feedforward 3D reconstruction from a single image, developed in collaboration between [Tripo AI](https://www.tripo3d.ai/) and [Stability AI](https://stability.ai/).
20
- **Tips:**
21
- 1. If you find the result is unsatisfied, please try to change the foreground ratio. It might improve the results.
22
- 2. Please disable "Remove Background" option only if your input image is RGBA with transparent background, image contents are centered and occupy more than 70% of image width or height.
23
- """
24
-
25
-
26
  if torch.cuda.is_available():
27
  device = "cuda:0"
28
  else:
29
  device = "cpu"
30
-
31
- d = os.environ.get("DEVICE", None)
32
- if d != None:
33
- device = d
34
-
35
  model = TSR.from_pretrained(
36
  "stabilityai/TripoSR",
37
  config_name="config.yaml",
38
  weight_name="model.ckpt",
39
- # token=HF_TOKEN
40
  )
41
- model.renderer.set_chunk_size(131072)
 
42
  model.to(device)
43
-
44
  rembg_session = rembg.new_session()
45
-
46
-
47
  def check_input_image(input_image):
48
  if input_image is None:
49
  raise gr.Error("No image uploaded!")
50
-
51
-
52
  def preprocess(input_image, do_remove_background, foreground_ratio):
53
  def fill_background(image):
54
  image = np.array(image).astype(np.float32) / 255.0
55
  image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
56
  image = Image.fromarray((image * 255.0).astype(np.uint8))
57
  return image
58
-
59
  if do_remove_background:
60
  image = input_image.convert("RGB")
61
  image = remove_background(image, rembg_session)
@@ -66,25 +43,26 @@ def preprocess(input_image, do_remove_background, foreground_ratio):
66
  if image.mode == "RGBA":
67
  image = fill_background(image)
68
  return image
69
-
70
-
71
- def generate(image):
72
  scene_codes = model(image, device=device)
73
- mesh = model.extract_mesh(scene_codes)[0]
74
  mesh = to_gradio_3d_orientation(mesh)
75
- mesh_path = tempfile.NamedTemporaryFile(suffix=".obj", delete=False)
76
- mesh_path2 = tempfile.NamedTemporaryFile(suffix=".glb", delete=False)
77
- mesh.export(mesh_path.name)
78
- mesh.export(mesh_path2.name)
79
- return mesh_path.name, mesh_path2.name
80
-
81
  def run_example(image_pil):
82
  preprocessed = preprocess(image_pil, False, 0.9)
83
- mesh_name, mesn_name2 = generate(preprocessed)
84
- return preprocessed, mesh_name, mesh_name2
85
-
86
- with gr.Blocks() as demo:
87
- gr.Markdown(HEADER)
 
 
 
88
  with gr.Row(variant="panel"):
89
  with gr.Column():
90
  with gr.Row():
@@ -108,30 +86,52 @@ with gr.Blocks() as demo:
108
  value=0.85,
109
  step=0.05,
110
  )
 
 
 
 
 
 
 
 
111
  with gr.Row():
112
  submit = gr.Button("Generate", elem_id="generate", variant="primary")
113
  with gr.Column():
114
- with gr.Tab("obj"):
115
- output_model = gr.Model3D(
116
- label="Output Model",
117
  interactive=False,
118
  )
119
- with gr.Tab("glb"):
120
- output_model2 = gr.Model3D(
121
- label="Output Model",
 
122
  interactive=False,
123
  )
 
124
  with gr.Row(variant="panel"):
125
  gr.Examples(
126
  examples=[
127
- os.path.join("examples", img_name) for img_name in sorted(os.listdir("examples"))
 
 
 
 
 
 
 
 
 
 
 
 
128
  ],
129
  inputs=[input_image],
130
- outputs=[processed_image, output_model, output_model2],
131
- #cache_examples=True,
132
  fn=partial(run_example),
133
  label="Examples",
134
- examples_per_page=20
135
  )
136
  submit.click(fn=check_input_image, inputs=[input_image]).success(
137
  fn=preprocess,
@@ -139,9 +139,8 @@ with gr.Blocks() as demo:
139
  outputs=[processed_image],
140
  ).success(
141
  fn=generate,
142
- inputs=[processed_image],
143
- outputs=[output_model, output_model2],
144
  )
145
-
146
  demo.queue(max_size=10)
147
  demo.launch()
 
2
  import os
3
  import tempfile
4
  import time
 
5
  import gradio as gr
6
  import numpy as np
7
  import rembg
8
  import torch
9
  from PIL import Image
10
  from functools import partial
 
11
  from tsr.system import TSR
12
  from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation
13
+ import argparse
 
 
 
 
 
 
 
 
 
 
14
  if torch.cuda.is_available():
15
  device = "cuda:0"
16
  else:
17
  device = "cpu"
 
 
 
 
 
18
  model = TSR.from_pretrained(
19
  "stabilityai/TripoSR",
20
  config_name="config.yaml",
21
  weight_name="model.ckpt",
 
22
  )
23
+ # adjust the chunk size to balance between speed and memory usage
24
+ model.renderer.set_chunk_size(8192)
25
  model.to(device)
 
26
  rembg_session = rembg.new_session()
 
 
27
  def check_input_image(input_image):
28
  if input_image is None:
29
  raise gr.Error("No image uploaded!")
 
 
30
  def preprocess(input_image, do_remove_background, foreground_ratio):
31
  def fill_background(image):
32
  image = np.array(image).astype(np.float32) / 255.0
33
  image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
34
  image = Image.fromarray((image * 255.0).astype(np.uint8))
35
  return image
 
36
  if do_remove_background:
37
  image = input_image.convert("RGB")
38
  image = remove_background(image, rembg_session)
 
43
  if image.mode == "RGBA":
44
  image = fill_background(image)
45
  return image
46
+ def generate(image, mc_resolution, formats=["obj", "glb"]):
 
 
47
  scene_codes = model(image, device=device)
48
+ mesh = model.extract_mesh(scene_codes, resolution=mc_resolution)[0]
49
  mesh = to_gradio_3d_orientation(mesh)
50
+ rv = []
51
+ for format in formats:
52
+ mesh_path = tempfile.NamedTemporaryFile(suffix=f".{format}", delete=False)
53
+ mesh.export(mesh_path.name)
54
+ rv.append(mesh_path.name)
55
+ return rv
56
  def run_example(image_pil):
57
  preprocessed = preprocess(image_pil, False, 0.9)
58
+ mesh_name_obj, mesh_name_glb = generate(preprocessed, 256, ["obj", "glb"])
59
+ return preprocessed, mesh_name_obj, mesh_name_glb
60
+ with gr.Blocks(title="TripoSR") as demo:
61
+ gr.Markdown(
62
+ """
63
+ 图像生成3d模型
64
+ """
65
+ )
66
  with gr.Row(variant="panel"):
67
  with gr.Column():
68
  with gr.Row():
 
86
  value=0.85,
87
  step=0.05,
88
  )
89
+ mc_resolution = gr.Slider(
90
+ label="Marching Cubes Resolution",
91
+ minimum=32,
92
+ maximum=320,
93
+ maximum=1024,
94
+ value=256,
95
+ step=32
96
+ )
97
  with gr.Row():
98
  submit = gr.Button("Generate", elem_id="generate", variant="primary")
99
  with gr.Column():
100
+ with gr.Tab("OBJ"):
101
+ output_model_obj = gr.Model3D(
102
+ label="Output Model (OBJ Format)",
103
  interactive=False,
104
  )
105
+ gr.Markdown("Note: The model shown here is flipped. Download to get correct results.")
106
+ with gr.Tab("GLB"):
107
+ output_model_glb = gr.Model3D(
108
+ label="Output Model (GLB Format)",
109
  interactive=False,
110
  )
111
+ gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.")
112
  with gr.Row(variant="panel"):
113
  gr.Examples(
114
  examples=[
115
+ "examples/hamburger.png",
116
+ "examples/poly_fox.png",
117
+ "examples/robot.png",
118
+ "examples/teapot.png",
119
+ "examples/tiger_girl.png",
120
+ "examples/horse.png",
121
+ "examples/flamingo.png",
122
+ "examples/unicorn.png",
123
+ "examples/chair.png",
124
+ "examples/iso_house.png",
125
+ "examples/marble.png",
126
+ "examples/police_woman.png",
127
+ "examples/captured.jpeg",
128
  ],
129
  inputs=[input_image],
130
+ outputs=[processed_image, output_model_obj, output_model_glb],
131
+ cache_examples=False,
132
  fn=partial(run_example),
133
  label="Examples",
134
+ examples_per_page=20,
135
  )
136
  submit.click(fn=check_input_image, inputs=[input_image]).success(
137
  fn=preprocess,
 
139
  outputs=[processed_image],
140
  ).success(
141
  fn=generate,
142
+ inputs=[processed_image, mc_resolution],
143
+ outputs=[output_model_obj, output_model_glb],
144
  )
 
145
  demo.queue(max_size=10)
146
  demo.launch()