Ryukijano commited on
Commit
8403619
·
verified ·
1 Parent(s): f2e0064

Demake app.py

Browse files

We wait for now.

Files changed (1) hide show
  1. app.py +71 -21
app.py CHANGED
@@ -13,9 +13,9 @@ import numpy as np
13
  from networks.gaussian_predictor import GaussianPredictor
14
  from util.vis3d import save_ply
15
 
16
-
17
  def main():
18
  print("[INFO] Starting main function...")
 
19
  if torch.cuda.is_available():
20
  device = "cuda:0"
21
  print("[INFO] CUDA is available. Using GPU device.")
@@ -23,25 +23,33 @@ def main():
23
  device = "cpu"
24
  print("[INFO] CUDA is not available. Using CPU device.")
25
 
 
26
  print("[INFO] Downloading model configuration...")
27
- model_cfg_path = hf_hub_download(repo_id="einsafutdinov/flash3d", filename="config_re10k_v1.yaml")
 
28
  print("[INFO] Downloading model weights...")
29
- model_path = hf_hub_download(repo_id="einsafutdinov/flash3d", filename="model_re10k_v1.pth")
 
30
 
 
31
  print("[INFO] Loading model configuration...")
32
  cfg = OmegaConf.load(model_cfg_path)
33
 
 
34
  print("[INFO] Initializing GaussianPredictor model...")
35
  model = GaussianPredictor(cfg)
36
  device = torch.device(device)
37
- model.to(device)
38
 
 
39
  print("[INFO] Loading model weights...")
40
  model.load_model(model_path)
41
 
42
- pad_border_fn = TT.Pad((cfg.dataset.pad_border_aug, cfg.dataset.pad_border_aug))
43
- to_tensor = TT.ToTensor()
 
44
 
 
45
  def check_input_image(input_image):
46
  print("[DEBUG] Checking input image...")
47
  if input_image is None:
@@ -49,27 +57,47 @@ def main():
49
  raise gr.Error("No image uploaded!")
50
  print("[INFO] Input image is valid.")
51
 
52
- def preprocess(image, resolution):
 
53
  print("[DEBUG] Preprocessing image...")
54
- image = TTF.resize(image, (resolution, resolution), interpolation=TT.InterpolationMode.BICUBIC)
 
 
 
 
 
55
  image = pad_border_fn(image)
56
  print("[INFO] Image preprocessing complete.")
57
  return image
58
 
59
- @spaces.GPU(duration=120)
60
- def reconstruct_and_export(image, num_gauss):
 
 
 
 
61
  print("[DEBUG] Starting reconstruction and export...")
 
62
  image = to_tensor(image).to(device).unsqueeze(0)
63
- inputs = {("color_aug", 0, 0): image}
 
 
 
 
64
  print("[INFO] Passing image through the model...")
65
  outputs = model(inputs)
 
 
66
  print(f"[INFO] Saving output to {ply_out_path}...")
67
- save_ply(outputs, ply_out_path, num_gauss=num_gauss)
68
  print("[INFO] Reconstruction and export complete.")
 
69
  return ply_out_path
70
 
 
71
  ply_out_path = f'./mesh.ply'
72
 
 
73
  css = """
74
  h1 {
75
  text-align: center;
@@ -77,15 +105,30 @@ def main():
77
  }
78
  """
79
 
 
80
  with gr.Blocks(css=css) as demo:
81
- gr.Markdown("# Flash3D")
 
 
 
 
82
  with gr.Row(variant="panel"):
83
  with gr.Column(scale=1):
84
  with gr.Row():
85
- input_image = gr.Image(label="Input Image", image_mode="RGBA", sources="upload", type="pil", elem_id="content_image")
 
 
 
 
 
 
 
86
  with gr.Row():
 
87
  submit = gr.Button("Generate", elem_id="generate", variant="primary")
 
88
  with gr.Row(variant="panel"):
 
89
  gr.Examples(
90
  examples=[
91
  './demo_examples/bedroom_01.png',
@@ -100,29 +143,36 @@ def main():
100
  label="Examples",
101
  examples_per_page=20,
102
  )
 
103
  with gr.Row():
 
104
  processed_image = gr.Image(label="Processed Image", interactive=False)
 
105
  with gr.Column(scale=2):
106
  with gr.Row():
107
  with gr.Tab("Reconstruction"):
108
- output_model = gr.Model3D(height=512, label="Output Model", interactive=False)
109
- with gr.Row():
110
- resolution = gr.Slider(minimum=256, maximum=1024, step=64, label="Image Resolution", value=cfg.dataset.height)
111
- num_gauss = gr.Slider(minimum=1, maximum=10, step=1, label="Number of Gaussian Components", value=2)
 
 
112
 
 
113
  submit.click(fn=check_input_image, inputs=[input_image]).success(
114
  fn=preprocess,
115
- inputs=[input_image, resolution],
116
  outputs=[processed_image],
117
  ).success(
118
  fn=reconstruct_and_export,
119
- inputs=[processed_image, num_gauss],
120
  outputs=[output_model],
121
  )
122
 
 
123
  demo.queue(max_size=1)
124
  print("[INFO] Launching Gradio demo...")
125
- demo.launch(share=True)
126
 
127
  if __name__ == "__main__":
128
  print("[INFO] Running application...")
 
13
  from networks.gaussian_predictor import GaussianPredictor
14
  from util.vis3d import save_ply
15
 
 
16
  def main():
17
  print("[INFO] Starting main function...")
18
+ # Determine if CUDA (GPU) is available and set the device accordingly
19
  if torch.cuda.is_available():
20
  device = "cuda:0"
21
  print("[INFO] CUDA is available. Using GPU device.")
 
23
  device = "cpu"
24
  print("[INFO] CUDA is not available. Using CPU device.")
25
 
26
+ # Download model configuration and weights from Hugging Face Hub
27
  print("[INFO] Downloading model configuration...")
28
+ model_cfg_path = hf_hub_download(repo_id="einsafutdinov/flash3d",
29
+ filename="config_re10k_v1.yaml")
30
  print("[INFO] Downloading model weights...")
31
+ model_path = hf_hub_download(repo_id="einsafutdinov/flash3d",
32
+ filename="model_re10k_v1.pth")
33
 
34
+ # Load model configuration using OmegaConf
35
  print("[INFO] Loading model configuration...")
36
  cfg = OmegaConf.load(model_cfg_path)
37
 
38
+ # Initialize the GaussianPredictor model with the loaded configuration
39
  print("[INFO] Initializing GaussianPredictor model...")
40
  model = GaussianPredictor(cfg)
41
  device = torch.device(device)
42
+ model.to(device) # Move the model to the specified device (CPU or GPU)
43
 
44
+ # Load the pre-trained model weights
45
  print("[INFO] Loading model weights...")
46
  model.load_model(model_path)
47
 
48
+ # Define transformation functions for image preprocessing
49
+ pad_border_fn = TT.Pad((cfg.dataset.pad_border_aug, cfg.dataset.pad_border_aug)) # Padding to augment the image borders
50
+ to_tensor = TT.ToTensor() # Convert image to tensor
51
 
52
+ # Function to check if an image is uploaded by the user
53
  def check_input_image(input_image):
54
  print("[DEBUG] Checking input image...")
55
  if input_image is None:
 
57
  raise gr.Error("No image uploaded!")
58
  print("[INFO] Input image is valid.")
59
 
60
+ # Function to preprocess the input image before passing it to the model
61
+ def preprocess(image):
62
  print("[DEBUG] Preprocessing image...")
63
+ # Resize the image to the desired height and width specified in the configuration
64
+ image = TTF.resize(
65
+ image, (cfg.dataset.height, cfg.dataset.width),
66
+ interpolation=TT.InterpolationMode.BICUBIC
67
+ )
68
+ # Apply padding to the image
69
  image = pad_border_fn(image)
70
  print("[INFO] Image preprocessing complete.")
71
  return image
72
 
73
+ # Function to reconstruct the 3D model from the input image and export it as a PLY file
74
+ @spaces.GPU(duration=120) # Decorator to allocate a GPU for this function during execution
75
+ def reconstruct_and_export(image):
76
+ """
77
+ Passes image through model, outputs reconstruction in form of a dict of tensors.
78
+ """
79
  print("[DEBUG] Starting reconstruction and export...")
80
+ # Convert the preprocessed image to a tensor and move it to the specified device
81
  image = to_tensor(image).to(device).unsqueeze(0)
82
+ inputs = {
83
+ ("color_aug", 0, 0): image,
84
+ }
85
+
86
+ # Pass the image through the model to get the output
87
  print("[INFO] Passing image through the model...")
88
  outputs = model(inputs)
89
+
90
+ # Export the reconstruction to a PLY file
91
  print(f"[INFO] Saving output to {ply_out_path}...")
92
+ save_ply(outputs, ply_out_path, num_gauss=2)
93
  print("[INFO] Reconstruction and export complete.")
94
+
95
  return ply_out_path
96
 
97
+ # Path to save the output PLY file
98
  ply_out_path = f'./mesh.ply'
99
 
100
+ # CSS styling for the Gradio interface
101
  css = """
102
  h1 {
103
  text-align: center;
 
105
  }
106
  """
107
 
108
+ # Create the Gradio user interface
109
  with gr.Blocks(css=css) as demo:
110
+ gr.Markdown(
111
+ """
112
+ # Flash3D
113
+ """
114
+ )
115
  with gr.Row(variant="panel"):
116
  with gr.Column(scale=1):
117
  with gr.Row():
118
+ # Input image component for the user to upload an image
119
+ input_image = gr.Image(
120
+ label="Input Image",
121
+ image_mode="RGBA",
122
+ sources="upload",
123
+ type="pil",
124
+ elem_id="content_image",
125
+ )
126
  with gr.Row():
127
+ # Button to trigger the generation process
128
  submit = gr.Button("Generate", elem_id="generate", variant="primary")
129
+
130
  with gr.Row(variant="panel"):
131
+ # Examples panel to provide sample images for users
132
  gr.Examples(
133
  examples=[
134
  './demo_examples/bedroom_01.png',
 
143
  label="Examples",
144
  examples_per_page=20,
145
  )
146
+
147
  with gr.Row():
148
+ # Display the preprocessed image (after resizing and padding)
149
  processed_image = gr.Image(label="Processed Image", interactive=False)
150
+
151
  with gr.Column(scale=2):
152
  with gr.Row():
153
  with gr.Tab("Reconstruction"):
154
+ # 3D model viewer to display the reconstructed model
155
+ output_model = gr.Model3D(
156
+ height=512,
157
+ label="Output Model",
158
+ interactive=False
159
+ )
160
 
161
+ # Define the workflow for the Generate button
162
  submit.click(fn=check_input_image, inputs=[input_image]).success(
163
  fn=preprocess,
164
+ inputs=[input_image],
165
  outputs=[processed_image],
166
  ).success(
167
  fn=reconstruct_and_export,
168
+ inputs=[processed_image],
169
  outputs=[output_model],
170
  )
171
 
172
+ # Queue the requests to handle them sequentially (to avoid GPU resource conflicts)
173
  demo.queue(max_size=1)
174
  print("[INFO] Launching Gradio demo...")
175
+ demo.launch(share=True) # Launch the Gradio interface and allow public sharing
176
 
177
  if __name__ == "__main__":
178
  print("[INFO] Running application...")