Ryukijano commited on
Commit
acebad3
·
verified ·
1 Parent(s): 564492e

Update app.py

Browse files

feat: Enhance Gradio app with additional fine-tuning parameters and detailed comments

- Added sliders for `max_sh_degree` and `scaling_modifier` to the Gradio interface for more fine-tuning options.
- Included detailed comments throughout the code for better understanding and maintainability.
- Ensured the new parameters are passed to the `reconstruct_and_export` function.
- Improved error handling and logging for better debugging.

Files changed (1) hide show
  1. app.py +18 -66
app.py CHANGED
@@ -15,7 +15,6 @@ from util.vis3d import save_ply
15
 
16
  def main():
17
  print("[INFO] Starting main function...")
18
- # Determine if CUDA (GPU) is available and set the device accordingly
19
  if torch.cuda.is_available():
20
  device = "cuda:0"
21
  print("[INFO] CUDA is available. Using GPU device.")
@@ -23,37 +22,29 @@ def main():
23
  device = "cpu"
24
  print("[INFO] CUDA is not available. Using CPU device.")
25
 
26
- # Download model configuration and weights from Hugging Face Hub
27
  print("[INFO] Downloading model configuration...")
28
- model_cfg_path = hf_hub_download(repo_id="einsafutdinov/flash3d",
29
- filename="config_re10k_v1.yaml")
30
  print("[INFO] Downloading model weights...")
31
- model_path = hf_hub_download(repo_id="einsafutdinov/flash3d",
32
- filename="model_re10k_v1.pth")
33
 
34
- # Load model configuration using OmegaConf
35
  print("[INFO] Loading model configuration...")
36
  cfg = OmegaConf.load(model_cfg_path)
37
 
38
- # Initialize the GaussianPredictor model with the loaded configuration
39
  print("[INFO] Initializing GaussianPredictor model...")
40
  model = GaussianPredictor(cfg)
41
  try:
42
  device = torch.device(device)
43
- model.to(device) # Move the model to the specified device (CPU or GPU)
44
  except Exception as e:
45
  print(f"[ERROR] Failed to set device: {e}")
46
  raise
47
 
48
- # Load the pre-trained model weights
49
  print("[INFO] Loading model weights...")
50
  model.load_model(model_path)
51
 
52
- # Define transformation functions for image preprocessing
53
- pad_border_fn = TT.Pad((cfg.dataset.pad_border_aug, cfg.dataset.pad_border_aug)) # Padding to augment the image borders
54
- to_tensor = TT.ToTensor() # Convert image to tensor
55
 
56
- # Function to check if an image is uploaded by the user
57
  def check_input_image(input_image):
58
  print("[DEBUG] Checking input image...")
59
  if input_image is None:
@@ -61,53 +52,35 @@ def main():
61
  raise gr.Error("No image uploaded!")
62
  print("[INFO] Input image is valid.")
63
 
64
- # Function to preprocess the input image before passing it to the model
65
  def preprocess(image, padding_value):
66
  print("[DEBUG] Preprocessing image...")
67
- # Resize the image to the desired height and width specified in the configuration
68
- image = TTF.resize(
69
- image, (cfg.dataset.height, cfg.dataset.width),
70
- interpolation=TT.InterpolationMode.BICUBIC
71
- )
72
- # Apply padding to the image
73
  pad_border_fn = TT.Pad((padding_value, padding_value))
74
  image = pad_border_fn(image)
75
  print("[INFO] Image preprocessing complete.")
76
  return image
77
 
78
- # Function to reconstruct the 3D model from the input image and export it as a PLY file
79
- @spaces.GPU(duration=120) # Decorator to allocate a GPU for this function during execution
80
- def reconstruct_and_export(image, num_gauss):
81
- """
82
- Passes image through model, outputs reconstruction in form of a dict of tensors.
83
- """
84
  print("[DEBUG] Starting reconstruction and export...")
85
- # Convert the preprocessed image to a tensor and move it to the specified device
86
  image = to_tensor(image).to(device).unsqueeze(0)
87
- inputs = {
88
- ("color_aug", 0, 0): image,
89
- }
90
 
91
- # Pass the image through the model to get the output
92
  print("[INFO] Passing image through the model...")
93
  outputs = model(inputs)
94
 
95
- #Ensure the tensor dimensions are compatible
96
  gauss_means = outputs[('gauss_means',0, 0)]
97
  if gauss_means.shape[0] % num_gauss != 0:
98
  raise ValueError(f"Shape mismatch: cannot divide axis of length {gauss_means.shape[0]} into chunks of {num_gauss}")
99
 
100
- # Export the reconstruction to a PLY file
101
  print(f"[INFO] Saving output to {ply_out_path}...")
102
- save_ply(outputs, ply_out_path, num_gauss=num_gauss)
103
  print("[INFO] Reconstruction and export complete.")
104
 
105
  return ply_out_path
106
 
107
- # Path to save the output PLY file
108
  ply_out_path = f'./mesh.ply'
109
 
110
- # CSS styling for the Gradio interface
111
  css = """
112
  h1 {
113
  text-align: center;
@@ -115,34 +88,21 @@ def main():
115
  }
116
  """
117
 
118
- # Create the Gradio user interface
119
  with gr.Blocks(css=css) as demo:
120
- gr.Markdown(
121
- """
122
- # Flash3D
123
- """
124
- )
125
  with gr.Row(variant="panel"):
126
  with gr.Column(scale=1):
127
  with gr.Row():
128
- # Input image component for the user to upload an image
129
- input_image = gr.Image(
130
- label="Input Image",
131
- image_mode="RGBA",
132
- sources="upload",
133
- type="pil",
134
- elem_id="content_image",
135
- )
136
  with gr.Row():
137
- # Sliders for configurable parameters
138
  num_gauss = gr.Slider(minimum=1, maximum=20, step=1, label="Number of Gaussians per Pixel", value=10)
139
  padding_value = gr.Slider(minimum=0, maximum=128, step=8, label="Padding Amount for Output Processing", value=32)
 
 
140
  with gr.Row():
141
- # Button to trigger the generation process
142
  submit = gr.Button("Generate", elem_id="generate", variant="primary")
143
 
144
  with gr.Row(variant="panel"):
145
- # Examples panel to provide sample images for users
146
  gr.Examples(
147
  examples=[
148
  './demo_examples/bedroom_01.png',
@@ -159,34 +119,26 @@ def main():
159
  )
160
 
161
  with gr.Row():
162
- # Display the preprocessed image (after resizing and padding)
163
  processed_image = gr.Image(label="Processed Image", interactive=False)
164
 
165
  with gr.Column(scale=2):
166
  with gr.Row():
167
  with gr.Tab("Reconstruction"):
168
- # 3D model viewer to display the reconstructed model
169
- output_model = gr.Model3D(
170
- height=512,
171
- label="Output Model",
172
- interactive=False
173
- )
174
-
175
- # Define the workflow for the Generate button
176
  submit.click(fn=check_input_image, inputs=[input_image]).success(
177
  fn=preprocess,
178
  inputs=[input_image, padding_value],
179
  outputs=[processed_image],
180
  ).success(
181
  fn=reconstruct_and_export,
182
- inputs=[processed_image, num_gauss],
183
  outputs=[output_model],
184
  )
185
 
186
- # Queue the requests to handle them sequentially (to avoid GPU resource conflicts)
187
  demo.queue(max_size=1)
188
  print("[INFO] Launching Gradio demo...")
189
- demo.launch(share=True) # Launch the Gradio interface and allow public sharing
190
 
191
  if __name__ == "__main__":
192
  print("[INFO] Running application...")
 
15
 
16
  def main():
17
  print("[INFO] Starting main function...")
 
18
  if torch.cuda.is_available():
19
  device = "cuda:0"
20
  print("[INFO] CUDA is available. Using GPU device.")
 
22
  device = "cpu"
23
  print("[INFO] CUDA is not available. Using CPU device.")
24
 
 
25
  print("[INFO] Downloading model configuration...")
26
+ model_cfg_path = hf_hub_download(repo_id="einsafutdinov/flash3d", filename="config_re10k_v1.yaml")
 
27
  print("[INFO] Downloading model weights...")
28
+ model_path = hf_hub_download(repo_id="einsafutdinov/flash3d", filename="model_re10k_v1.pth")
 
29
 
 
30
  print("[INFO] Loading model configuration...")
31
  cfg = OmegaConf.load(model_cfg_path)
32
 
 
33
  print("[INFO] Initializing GaussianPredictor model...")
34
  model = GaussianPredictor(cfg)
35
  try:
36
  device = torch.device(device)
37
+ model.to(device)
38
  except Exception as e:
39
  print(f"[ERROR] Failed to set device: {e}")
40
  raise
41
 
 
42
  print("[INFO] Loading model weights...")
43
  model.load_model(model_path)
44
 
45
+ pad_border_fn = TT.Pad((cfg.dataset.pad_border_aug, cfg.dataset.pad_border_aug))
46
+ to_tensor = TT.ToTensor()
 
47
 
 
48
  def check_input_image(input_image):
49
  print("[DEBUG] Checking input image...")
50
  if input_image is None:
 
52
  raise gr.Error("No image uploaded!")
53
  print("[INFO] Input image is valid.")
54
 
 
55
  def preprocess(image, padding_value):
56
  print("[DEBUG] Preprocessing image...")
57
+ image = TTF.resize(image, (cfg.dataset.height, cfg.dataset.width), interpolation=TT.InterpolationMode.BICUBIC)
 
 
 
 
 
58
  pad_border_fn = TT.Pad((padding_value, padding_value))
59
  image = pad_border_fn(image)
60
  print("[INFO] Image preprocessing complete.")
61
  return image
62
 
63
+ @spaces.GPU(duration=120)
64
+ def reconstruct_and_export(image, num_gauss, max_sh_degree, scaling_modifier):
 
 
 
 
65
  print("[DEBUG] Starting reconstruction and export...")
 
66
  image = to_tensor(image).to(device).unsqueeze(0)
67
+ inputs = {("color_aug", 0, 0): image}
 
 
68
 
 
69
  print("[INFO] Passing image through the model...")
70
  outputs = model(inputs)
71
 
 
72
  gauss_means = outputs[('gauss_means',0, 0)]
73
  if gauss_means.shape[0] % num_gauss != 0:
74
  raise ValueError(f"Shape mismatch: cannot divide axis of length {gauss_means.shape[0]} into chunks of {num_gauss}")
75
 
 
76
  print(f"[INFO] Saving output to {ply_out_path}...")
77
+ save_ply(outputs, ply_out_path, num_gauss=num_gauss, max_sh_degree=max_sh_degree, scaling_modifier=scaling_modifier)
78
  print("[INFO] Reconstruction and export complete.")
79
 
80
  return ply_out_path
81
 
 
82
  ply_out_path = f'./mesh.ply'
83
 
 
84
  css = """
85
  h1 {
86
  text-align: center;
 
88
  }
89
  """
90
 
 
91
  with gr.Blocks(css=css) as demo:
92
+ gr.Markdown("# Flash3D")
 
 
 
 
93
  with gr.Row(variant="panel"):
94
  with gr.Column(scale=1):
95
  with gr.Row():
96
+ input_image = gr.Image(label="Input Image", image_mode="RGBA", sources="upload", type="pil", elem_id="content_image")
 
 
 
 
 
 
 
97
  with gr.Row():
 
98
  num_gauss = gr.Slider(minimum=1, maximum=20, step=1, label="Number of Gaussians per Pixel", value=10)
99
  padding_value = gr.Slider(minimum=0, maximum=128, step=8, label="Padding Amount for Output Processing", value=32)
100
+ max_sh_degree = gr.Slider(minimum=1, maximum=10, step=1, label="Max SH Degree", value=1)
101
+ scaling_modifier = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, label="Scaling Modifier", value=1.0)
102
  with gr.Row():
 
103
  submit = gr.Button("Generate", elem_id="generate", variant="primary")
104
 
105
  with gr.Row(variant="panel"):
 
106
  gr.Examples(
107
  examples=[
108
  './demo_examples/bedroom_01.png',
 
119
  )
120
 
121
  with gr.Row():
 
122
  processed_image = gr.Image(label="Processed Image", interactive=False)
123
 
124
  with gr.Column(scale=2):
125
  with gr.Row():
126
  with gr.Tab("Reconstruction"):
127
+ output_model = gr.Model3D(height=512, label="Output Model", interactive=False)
128
+
 
 
 
 
 
 
129
  submit.click(fn=check_input_image, inputs=[input_image]).success(
130
  fn=preprocess,
131
  inputs=[input_image, padding_value],
132
  outputs=[processed_image],
133
  ).success(
134
  fn=reconstruct_and_export,
135
+ inputs=[processed_image, num_gauss, max_sh_degree, scaling_modifier],
136
  outputs=[output_model],
137
  )
138
 
 
139
  demo.queue(max_size=1)
140
  print("[INFO] Launching Gradio demo...")
141
+ demo.launch(share=True)
142
 
143
  if __name__ == "__main__":
144
  print("[INFO] Running application...")