Zaiiida commited on
Commit
94c8594
·
verified ·
1 Parent(s): 1aaf1b1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +163 -93
app.py CHANGED
@@ -13,124 +13,194 @@ from functools import partial
13
  from tsr.system import TSR
14
  from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation
15
 
16
- # HF_TOKEN = os.getenv("HF_TOKEN")
17
-
18
- if torch.cuda.is_available():
19
- device = "cuda:0"
20
- else:
21
- device = "cpu"
22
-
23
- d = os.environ.get("DEVICE", None)
24
- if d!= None:
25
- device = d
26
-
27
- model = TSR.from_pretrained(
28
- "stabilityai/TripoSR",
29
- config_name="config.yaml",
30
- weight_name="model.ckpt",
31
- # token=HF_TOKEN
32
- )
33
- model.renderer.set_chunk_size(131072)
34
- # Увеличение разрешения входного изображения
35
- model.to(device)
36
 
 
 
 
 
 
 
 
37
  rembg_session = rembg.new_session()
38
 
39
- def preprocess(input_image, do_remove_background, foreground_ratio):
40
- def fill_background(image):
41
- image = np.array(image).astype(np.float32) / 255.0
42
- image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
43
- image = Image.fromarray((image * 255.0).astype(np.uint8))
44
- return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
 
 
 
 
 
 
 
46
  if do_remove_background:
47
- image = image.convert("RGB")
48
  image = remove_background(image, rembg_session)
49
  image = resize_foreground(image, foreground_ratio)
50
  image = fill_background(image)
51
  else:
 
52
  if image.mode == "RGBA":
53
  image = fill_background(image)
54
  return image
55
 
56
  def generate(image):
57
- scene_codes = model(image, device=device)
58
- # Увеличение разрешения входного изображения
 
 
 
 
 
 
 
 
59
  mesh = model.extract_mesh(scene_codes)[0]
60
  mesh = to_gradio_3d_orientation(mesh)
61
  mesh_path = tempfile.NamedTemporaryFile(suffix=".obj", delete=False)
62
- mesh.export(mesh_path.name)
63
- # Увеличение разрешения входного изображения
64
  mesh_path2 = tempfile.NamedTemporaryFile(suffix=".glb", delete=False)
 
65
  mesh.export(mesh_path2.name)
66
- # Увеличение разрешения входного изображения
67
  return mesh_path.name, mesh_path2.name
68
 
69
  def run_example(image_pil):
70
- # Увеличение разрешения входного изображения
71
  preprocessed = preprocess(image_pil, False, 0.9)
72
- mesh_name, mesh_name2 = generate(preprocessed) # Увеличение разрешения входного изображения
73
- return image_pil, mesh_name, mesh_name2
74
-
75
- with gr.Blocks() as demo:
76
- gr.Markdown("# **Input Image**")
77
- with gr.Column():
78
- input_image = gr.Image(
79
- label="",
80
- image_mode="RGBA",
81
- sources="upload",
82
- type="pil",
83
- elem_id="content_image",
84
- ) # Увеличение разрешения входного изображения
85
  with gr.Column():
86
- do_remove_background = gr.Checkbox(
87
- label="Remove Background",
88
- value=True,
89
- elem_id="remove_background",
90
- )
91
- foreground_ratio = gr.Slider(
92
- label="Foreground Ratio",
93
- minimum=0.5,
94
- maximum=1.0,
95
- value=0.85,
96
- step=0.05,
97
- elem_id="foreground_ratio",
98
- ) # Увеличение разрешения входного изображения
99
- with gr.Row():
100
- submit = gr.Button("Generate")
101
- submit.click(
102
- fn=check_input_image,
103
- inputs=[input_image],
104
- outputs=[],
105
- ).click(
106
- fn=preprocess,
107
- inputs=[input_image, do_remove_background, foreground_ratio],
108
- outputs=[input_image],
109
- ).click(
110
- fn=generate,
111
- inputs=[input_image],
112
- outputs=["mesh_name", "mesh_name2"],
113
- )
114
- with gr.Tab("obj"):
115
- output_model = gr.Model3D(
116
- label="Output Model",
117
- interactive=False,
118
  )
119
- with gr.Tab("glb"):
120
- output_model2 = gr.Model3D(
121
- label="Output Model",
122
- interactive=False,
123
- )
124
- #...
125
-
126
- with gr.Column():
127
- gr.Examples(
128
- examples=["examples/1.png"],
129
- inputs=[input_image],
130
- outputs=[input_image, output_model, output_model2],
131
- fn=run_example,
132
- cache_examples=True,
133
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
  demo.queue(max_size=10)
136
  demo.launch()
 
13
  from tsr.system import TSR
14
  from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation
15
 
16
+ # DEVICE
17
+ DEVICE_ENV = os.environ.get("DEVICE", None)
18
+ DEVICE = DEVICE_ENV if DEVICE_ENV else ("cuda:0" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
+ # MODEL
21
+ MODEL_NAME = "stabilityai/TripoSR"
22
+ CONFIG_NAME = "config.yaml"
23
+ WEIGHT_NAME = "model.ckpt"
24
+ # HF_TOKEN = os.getenv("HF_TOKEN") # Если нужен токен
25
+
26
+ # REMBG
27
  rembg_session = rembg.new_session()
28
 
29
+ # THEME
30
+ HEADER = """
31
+ """
32
+
33
+ class CustomTheme(gr.themes.Base):
34
+ def __init__(self):
35
+ super().__init__()
36
+ self.primary_hue = "#191a1e"
37
+ self.background_fill_primary = "#191a1e"
38
+ self.background_fill_secondary = "#191a1e"
39
+ self.background_fill_tertiary = "#191a1e"
40
+ self.text_color_primary = "#FFFFFF"
41
+ self.text_color_secondary = "#FFFFFF"
42
+ self.text_color_tertiary = "#FFFFFF"
43
+ self.input_background_fill = "#191a1e"
44
+ self.input_text_color = "#FFFFFF"
45
+
46
+ css = """
47
+ /* Скрываем нижний колонтитул */
48
+ footer {
49
+ visibility: hidden;
50
+ height: 0;
51
+ margin: 0;
52
+ padding: 0;
53
+ overflow: hidden;
54
+ }
55
+ @import url('https://fonts.googleapis.com/css2?family=Poppins:wght@400;500;700&display=swap');
56
+ /* Применяем шрифты */
57
+ body, input, button, textarea, select,.gr-button {
58
+ font-family: 'Poppins', sans-serif;
59
+ background-color: #191a1e!important;
60
+ color: #FFFFFF;
61
+ }
62
+ /* Настройки заголовков */
63
+ h1, h2, h3, h4, h5, h6 {
64
+ font-family: 'Poppins', sans-serif;
65
+ font-weight: 700;
66
+ color: #FFFFFF;
67
+ }
68
+ /* Стиль для текстовых полей и кнопок */
69
+ input[type="text"], textarea {
70
+ background-color: #191a1e!important;
71
+ color: #FFFFFF;
72
+ border: 1px solid #FFFFFF;
73
+ }
74
+ /* Цвет кнопки Generate */
75
+ .generate-button {
76
+ background-color: #5271FF!important;
77
+ color: #FFFFFF!important;
78
+ border: none;
79
+ font-weight: bold;
80
+ }
81
+ .generate-button:hover {
82
+ background-color: #405BBF!important; /* Цвет при наведении */
83
+ }
84
+ /* Выравнивание элементов */
85
+ .drop-image-container {
86
+ display: flex;
87
+ flex-direction: column;
88
+ align-items: center;
89
+ }
90
+ .drop-image,.processed-image {
91
+ margin-bottom: 20px;
92
+ }
93
+ .foreground-ratio-container {
94
+ margin-top: 20px;
95
+ margin-bottom: 20px;
96
+ }
97
+ .generate-button {
98
+ margin-top: 20px;
99
+ margin-left: auto;
100
+ margin-right: auto;
101
+ }
102
+ """
103
+
104
+ def check_input_image(input_image):
105
+ if input_image is None:
106
+ raise gr.Error("No image uploaded!")
107
 
108
+ def fill_background(image):
109
+ image = np.array(image).astype(np.float32) / 255.0
110
+ image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
111
+ image = Image.fromarray((image * 255.0).astype(np.uint8))
112
+ return image
113
+
114
+ def preprocess(input_image, do_remove_background, foreground_ratio):
115
  if do_remove_background:
116
+ image = input_image.convert("RGB")
117
  image = remove_background(image, rembg_session)
118
  image = resize_foreground(image, foreground_ratio)
119
  image = fill_background(image)
120
  else:
121
+ image = input_image
122
  if image.mode == "RGBA":
123
  image = fill_background(image)
124
  return image
125
 
126
  def generate(image):
127
+ model = TSR.from_pretrained(
128
+ MODEL_NAME,
129
+ config_name=CONFIG_NAME,
130
+ weight_name=WEIGHT_NAME,
131
+ # token=HF_TOKEN
132
+ )
133
+ model.renderer.set_chunk_size(131072)
134
+ model.to(DEVICE)
135
+
136
+ scene_codes = model(image, device=DEVICE)
137
  mesh = model.extract_mesh(scene_codes)[0]
138
  mesh = to_gradio_3d_orientation(mesh)
139
  mesh_path = tempfile.NamedTemporaryFile(suffix=".obj", delete=False)
 
 
140
  mesh_path2 = tempfile.NamedTemporaryFile(suffix=".glb", delete=False)
141
+ mesh.export(mesh_path.name)
142
  mesh.export(mesh_path2.name)
 
143
  return mesh_path.name, mesh_path2.name
144
 
145
  def run_example(image_pil):
 
146
  preprocessed = preprocess(image_pil, False, 0.9)
147
+ mesh_name, mesh_name2 = generate(preprocessed)
148
+ return preprocessed, mesh_name, mesh_name2
149
+
150
+ with gr.Blocks(theme=CustomTheme(), css=css) as demo:
151
+ # **Header**
152
+ with gr.Row():
153
+ gr.Markdown("# 3D Model Generator", elem_id="title")
154
+
155
+ # **Input Section**
156
+ with gr.Accordion("Input Image", open=False):
 
 
 
157
  with gr.Column():
158
+ input_image = gr.Image(
159
+ label="Upload Image",
160
+ image_mode="RGBA",
161
+ sources="upload",
162
+ type="pil",
163
+ elem_id="content_image",
164
+ width=500,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  )
166
+ with gr.Row():
167
+ do_remove_background = gr.Checkbox(
168
+ label="Remove Background",
169
+ value=True,
170
+ )
171
+ foreground_ratio = gr.Slider(
172
+ label="Foreground Ratio",
173
+ minimum=0.5,
174
+ maximum=1.0,
175
+ value=0.85,
176
+ step=0.05,
177
+ )
178
+
179
+ # **Processing and Generation**
180
+ with gr.Accordion("Generate 3D Model", open=False):
181
+ with gr.Column():
182
+ submit = gr.Button("Generate", elem_classes="generate-button")
183
+ processed_image = gr.Image(label="Processed Image", interactive=False)
184
+ with gr.Tabs():
185
+ with gr.Tab("OBJ Model"):
186
+ output_model = gr.Model3D(label="Output Model", interactive=False)
187
+ with gr.Tab("GLB Model"):
188
+ output_model2 = gr.Model3D(label="Output Model", interactive=False)
189
+
190
+ # **Event Triggers**
191
+ submit.click(
192
+ fn=check_input_image,
193
+ inputs=[input_image],
194
+ outputs=[gr.update(label="Processing...", interactive=False)]
195
+ ).then(
196
+ fn=preprocess,
197
+ inputs=[input_image, do_remove_background, foreground_ratio],
198
+ outputs=[processed_image]
199
+ ).then(
200
+ fn=generate,
201
+ inputs=[processed_image],
202
+ outputs=[output_model, output_model2]
203
+ )
204
 
205
  demo.queue(max_size=10)
206
  demo.launch()