jiten6555 commited on
Commit
e486d98
·
verified ·
1 Parent(s): 9d38c97

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +177 -122
app.py CHANGED
@@ -1,161 +1,216 @@
1
  import torch
2
- import torchvision.transforms as transforms
3
  import numpy as np
4
- import open3d as o3d
5
  from PIL import Image
 
6
  import cv2
7
- import gc
8
- from transformers import DPTForDepthEstimation, DPTFeatureExtractor
9
- from torchvision.models.segmentation import deeplabv3_resnet50 # Optional for segmentation
 
 
 
10
 
11
- class MultiModel3DReconstruction:
12
  def __init__(self):
13
- self.device = torch.device("cpu")
 
14
 
 
15
  try:
16
- # Load DPT model for depth estimation
17
- self.dpt_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(self.device)
18
- self.dpt_feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
19
- self.dpt_model.eval()
20
-
21
- print("DPT model successfully loaded.")
22
  except Exception as e:
23
- print(f"Error loading DPT model: {e}")
24
- self.dpt_model = None
25
 
 
26
  try:
27
- # Optional fallback: Load ZoeDepth model
28
- self.zoe_model = torch.hub.load("isl-org/ZoeDepth", "ZoeD_N", pretrained=True).to(self.device)
29
- self.zoe_model.eval()
 
 
30
 
31
- print("ZoeDepth model successfully loaded.")
 
 
 
 
32
  except Exception as e:
33
- print(f"Error loading ZoeDepth model: {e}")
34
- self.zoe_model = None
35
-
36
- def preprocess_image(self, input_image):
37
- """
38
- Preprocess input image for models.
39
- """
40
- if not isinstance(input_image, Image.Image):
41
- input_image = Image.fromarray(input_image)
42
-
43
- if input_image.mode != "RGB":
44
- input_image = input_image.convert("RGB")
45
-
46
- return input_image
47
-
48
- def estimate_depth(self, input_image):
49
  """
50
- Estimate depth using the best available model.
51
  """
52
- input_image = self.preprocess_image(input_image)
 
53
 
54
- try:
55
- # Use DPT for depth estimation
56
- inputs = self.dpt_feature_extractor(images=input_image, return_tensors="pt").to(self.device)
57
- with torch.no_grad():
58
- outputs = self.dpt_model(**inputs)
59
- depth = outputs.predicted_depth.squeeze().cpu().numpy()
60
- return depth
61
 
62
- except Exception as dpt_error:
63
- print(f"DPT model error: {dpt_error}")
 
64
 
65
- try:
66
- # Fallback: Use ZoeDepth for depth estimation
67
- zoe_input = transforms.ToTensor()(input_image).unsqueeze(0).to(self.device)
68
- with torch.no_grad():
69
- depth = self.zoe_model.infer(zoe_input).squeeze().cpu().numpy()
70
- return depth
71
-
72
- except Exception as zoe_error:
73
- print(f"ZoeDepth fallback error: {zoe_error}")
74
-
75
- return None
76
-
77
- def refine_depth(self, depth_map):
78
  """
79
- Smooth and refine the depth map.
80
  """
81
- if depth_map is None:
82
- return None
83
-
84
- depth_map_normalized = cv2.normalize(depth_map, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)
85
- refined = cv2.bilateralFilter(depth_map_normalized, d=9, sigmaColor=75, sigmaSpace=75)
86
- return refined
87
-
88
- def create_point_cloud(self, image, depth_map):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  """
90
- Generate a point cloud from the depth map.
91
  """
92
- if depth_map is None:
93
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
- img_array = np.array(image)
96
- depth_map_resized = cv2.resize(depth_map, (img_array.shape[1], img_array.shape[0]), interpolation=cv2.INTER_LINEAR)
97
- height, width = img_array.shape[:2]
98
 
99
- step = max(1, min(height, width) // 100) # Adjustable step size
100
- points, colors = [], []
101
-
102
- for y in range(0, height, step):
103
- for x in range(0, width, step):
104
- z = depth_map_resized[y, x] / 255.0 * 3 # Adjust depth scaling
105
- points.append([x, y, z])
106
- color = img_array[y, x][:3] / 255.0 if len(img_array[y, x]) >= 3 else [0.5, 0.5, 0.5]
107
- colors.append(color)
108
-
109
- pcd = o3d.geometry.PointCloud()
110
- pcd.points = o3d.utility.Vector3dVector(points)
111
- pcd.colors = o3d.utility.Vector3dVector(colors)
112
-
113
- return pcd
114
-
115
- def convert_to_mesh(self, point_cloud):
116
  """
117
- Convert a point cloud to a 3D mesh.
118
  """
119
- if point_cloud is None:
120
- return None
121
 
122
- try:
123
- point_cloud.estimate_normals()
124
- mesh, _ = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(point_cloud, depth=8)
125
- mesh = mesh.simplify_quadric_decimation(target_number_of_triangles=10000)
126
- return mesh
127
 
128
- except Exception as e:
129
- print(f"Mesh conversion error: {e}")
130
- return None
131
-
132
- def process_image(self, input_image):
133
  """
134
- Complete pipeline: Depth estimation -> Point Cloud -> Mesh.
135
  """
136
  try:
137
- depth_map = self.estimate_depth(input_image)
138
- if depth_map is None:
139
- raise ValueError("Depth estimation failed.")
140
 
141
- refined_depth = self.refine_depth(depth_map)
142
- point_cloud = self.create_point_cloud(input_image, refined_depth)
143
 
144
- if point_cloud is None:
145
- raise ValueError("Point cloud generation failed.")
146
 
147
- mesh = self.convert_to_mesh(point_cloud)
148
- if mesh is None:
149
- raise ValueError("Mesh conversion failed.")
150
 
151
- output_file = f"/tmp/3d_model_{uuid.uuid4()}.obj"
152
- o3d.io.write_triangle_mesh(output_file, mesh)
 
153
 
154
- return output_file
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
  except Exception as e:
157
- print(f"Pipeline error: {e}")
158
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
- # Instantiate and test the pipeline
161
- converter = MultiModel3DReconstruction()
 
1
  import torch
2
+ import gradio as gr
3
  import numpy as np
 
4
  from PIL import Image
5
+ import trimesh
6
  import cv2
7
+ import open3d as o3d
8
+
9
+ # Critical Model Imports
10
+ from transformers import pipeline, AutoFeatureExtractor, AutoModelForImageToImage
11
+ from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel
12
+ from huggingface_hub import hf_hub_download
13
 
14
+ class CompleteMeshGenerator:
15
  def __init__(self):
16
+ # Critical Model Configuration
17
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
 
19
+ # Depth Estimation Model
20
  try:
21
+ self.depth_estimator = pipeline(
22
+ "depth-estimation",
23
+ model="Intel/dpt-large",
24
+ device=self.device
25
+ )
 
26
  except Exception as e:
27
+ print(f"Depth Estimation Model Load Error: {e}")
28
+ self.depth_estimator = None
29
 
30
+ # Multi-View Generation Setup
31
  try:
32
+ # Load ControlNet for multi-view generation
33
+ self.controlnet = ControlNetModel.from_pretrained(
34
+ "lllyasviel/control_v11f1p_sd15_depth",
35
+ torch_dtype=torch.float32
36
+ )
37
 
38
+ self.multi_view_pipeline = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
39
+ "runwayml/stable-diffusion-v1-5",
40
+ controlnet=self.controlnet,
41
+ torch_dtype=torch.float32
42
+ ).to(self.device)
43
  except Exception as e:
44
+ print(f"Multi-View Generation Model Load Error: {e}")
45
+ self.multi_view_pipeline = None
46
+
47
+ def generate_depth_map(self, image):
 
 
 
 
 
 
 
 
 
 
 
 
48
  """
49
+ Advanced Depth Map Generation
50
  """
51
+ if self.depth_estimator is None:
52
+ raise ValueError("Depth estimation model not loaded")
53
 
54
+ # Ensure image is in correct format
55
+ if isinstance(image, np.ndarray):
56
+ image = Image.fromarray(image)
 
 
 
 
57
 
58
+ # Estimate depth
59
+ depth_result = self.depth_estimator(image)
60
+ depth_map = np.array(depth_result['depth'])
61
 
62
+ return depth_map
63
+
64
+ def generate_multi_view_images(self, input_image, num_views=4):
 
 
 
 
 
 
 
 
 
 
65
  """
66
+ Generate Multiple View Images
67
  """
68
+ if self.multi_view_pipeline is None:
69
+ raise ValueError("Multi-view generation pipeline not loaded")
70
+
71
+ # Estimate initial depth map
72
+ depth_map = self.generate_depth_map(input_image)
73
+
74
+ # Convert depth map to PIL Image
75
+ depth_image = Image.fromarray((depth_map * 255).astype(np.uint8))
76
+
77
+ # View generation parameters
78
+ view_angles = [
79
+ (30, "Side view"),
80
+ (150, "Opposite side"),
81
+ (90, "Top view"),
82
+ (270, "Bottom view")
83
+ ]
84
+
85
+ multi_view_images = []
86
+
87
+ for angle, description in view_angles[:num_views]:
88
+ try:
89
+ generated_image = self.multi_view_pipeline(
90
+ prompt=f"3D object view from {description}",
91
+ image=input_image,
92
+ control_image=depth_image,
93
+ controlnet_conditioning_scale=1.0,
94
+ rotation=angle,
95
+ guidance_scale=7.5
96
+ ).images[0]
97
+
98
+ multi_view_images.append(generated_image)
99
+ except Exception as e:
100
+ print(f"View generation error for angle {angle}: {e}")
101
+
102
+ return multi_view_images
103
+
104
+ def advanced_point_cloud_reconstruction(self, depth_maps):
105
  """
106
+ Advanced Point Cloud Reconstruction
107
  """
108
+ point_clouds = []
109
+
110
+ for depth_map in depth_maps:
111
+ # Create point cloud from depth map
112
+ height, width = depth_map.shape
113
+ x = np.linspace(0, width-1, width)
114
+ y = np.linspace(0, height-1, height)
115
+ xx, yy = np.meshgrid(x, y)
116
+
117
+ # Convert depth to 3D points
118
+ points_3d = np.column_stack([
119
+ xx.ravel(),
120
+ yy.ravel(),
121
+ depth_map.ravel()
122
+ ])
123
+
124
+ # Create Open3D point cloud
125
+ pcd = o3d.geometry.PointCloud()
126
+ pcd.points = o3d.utility.Vector3dVector(points_3d)
127
+
128
+ point_clouds.append(pcd)
129
+
130
+ # Merge point clouds
131
+ merged_pcd = point_clouds[0]
132
+ for pcd in point_clouds[1:]:
133
+ merged_pcd += pcd
134
 
135
+ return merged_pcd
 
 
136
 
137
+ def mesh_reconstruction(self, point_cloud):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  """
139
+ Advanced Mesh Reconstruction
140
  """
141
+ # Poisson surface reconstruction
142
+ mesh, _ = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(point_cloud, depth=9)
143
 
144
+ # Clean and smooth mesh
145
+ mesh.compute_vertex_normals()
146
+ mesh = mesh.filter_smooth_laplacian(number_of_iterations=10)
 
 
147
 
148
+ return mesh
149
+
150
+ def create_3d_model(self, input_image):
 
 
151
  """
152
+ Comprehensive 3D Model Generation Pipeline
153
  """
154
  try:
155
+ # Generate multi-view images
156
+ multi_view_images = self.generate_multi_view_images(input_image)
 
157
 
158
+ # Extract depth maps
159
+ depth_maps = [np.array(self.generate_depth_map(img)) for img in multi_view_images]
160
 
161
+ # Advanced point cloud reconstruction
162
+ point_cloud = self.advanced_point_cloud_reconstruction(depth_maps)
163
 
164
+ # Mesh generation
165
+ mesh = self.mesh_reconstruction(point_cloud)
 
166
 
167
+ # Save mesh in multiple formats
168
+ output_path = "reconstructed_3d_model"
169
+ o3d.io.write_triangle_mesh(f"{output_path}.ply", mesh)
170
 
171
+ # Convert to trimesh for additional formats
172
+ trimesh_mesh = trimesh.Trimesh(
173
+ vertices=np.asarray(mesh.vertices),
174
+ faces=np.asarray(mesh.triangles)
175
+ )
176
+ trimesh_mesh.export(f"{output_path}.obj")
177
+ trimesh_mesh.export(f"{output_path}.stl")
178
+
179
+ return (
180
+ "3D Model Generated Successfully!",
181
+ multi_view_images,
182
+ [f"{output_path}.ply", f"{output_path}.obj", f"{output_path}.stl"]
183
+ )
184
 
185
  except Exception as e:
186
+ return f"3D Model Generation Error: {str(e)}", None, None
187
+
188
+ def create_gradio_interface(self):
189
+ interface = gr.Interface(
190
+ fn=self.create_3d_model,
191
+ inputs=gr.Image(type="pil", label="Upload Image for 3D Reconstruction"),
192
+ outputs=[
193
+ gr.Textbox(label="Generation Status"),
194
+ gr.Gallery(label="Generated Multi-View Images"),
195
+ gr.File(label="Reconstructed 3D Model Files")
196
+ ],
197
+ title="Advanced 3D Model Generator",
198
+ description="""
199
+ Generate comprehensive 3D models from single images using:
200
+ - Multi-view image generation
201
+ - Advanced depth estimation
202
+ - Point cloud reconstruction
203
+ - Mesh generation
204
+ """,
205
+ allow_flagging="never"
206
+ )
207
+
208
+ return interface
209
+
210
+ def main():
211
+ mesh_generator = CompleteMeshGenerator()
212
+ interface = mesh_generator.create_gradio_interface()
213
+ interface.launch(debug=True)
214
 
215
+ if __name__ == "__main__":
216
+ main()