jiten6555 commited on
Commit
9d38c97
·
verified ·
1 Parent(s): 57cf4d8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -143
app.py CHANGED
@@ -1,142 +1,128 @@
1
  import torch
2
  import torchvision.transforms as transforms
3
- import gradio as gr
4
  import numpy as np
5
  import open3d as o3d
6
  from PIL import Image
7
  import cv2
8
- import uuid
9
- import gc # Garbage collection
10
  from transformers import DPTForDepthEstimation, DPTFeatureExtractor
 
11
 
12
- class RobustDepthTo3DConverter:
13
  def __init__(self):
14
- # Explicitly use CPU
15
  self.device = torch.device("cpu")
16
 
17
  try:
18
- # Load Hugging Face DPT model and feature extractor
19
- self.model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(self.device)
20
- self.feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
21
- self.model.eval()
22
 
23
- print("DPT model successfully initialized")
 
 
 
24
 
 
 
 
 
 
 
25
  except Exception as e:
26
- print(f"Critical model initialization error: {e}")
27
- self.model = None
28
 
29
  def preprocess_image(self, input_image):
30
  """
31
- Preprocess image using Hugging Face's feature extractor
32
  """
33
- # Ensure input is PIL Image
34
  if not isinstance(input_image, Image.Image):
35
  input_image = Image.fromarray(input_image)
36
 
37
- # Convert to RGB if needed
38
- if input_image.mode != 'RGB':
39
- input_image = input_image.convert('RGB')
40
 
41
  return input_image
42
 
43
  def estimate_depth(self, input_image):
44
  """
45
- Estimate depth using Hugging Face DPT model
46
  """
47
- if self.model is None:
48
- raise ValueError("DPT model not properly initialized. Check model loading.")
49
 
50
  try:
51
- # Preprocess the image
52
- img = self.preprocess_image(input_image)
53
- inputs = self.feature_extractor(images=img, return_tensors="pt").to(self.device)
54
-
55
- # Estimate depth
56
  with torch.no_grad():
57
- outputs = self.model(**inputs)
58
  depth = outputs.predicted_depth.squeeze().cpu().numpy()
59
-
60
- # Normalize depth
61
- depth_normalized = cv2.normalize(
62
- depth, None, 0, 255,
63
- norm_type=cv2.NORM_MINMAX,
64
- dtype=cv2.CV_8U
65
- )
66
-
67
- # Manual memory cleanup
68
- torch.cuda.empty_cache()
69
- gc.collect()
70
-
71
- return depth_normalized
72
 
73
- except Exception as e:
74
- print(f"Depth estimation error: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  return None
 
 
 
 
76
 
77
  def create_point_cloud(self, image, depth_map):
78
  """
79
- Create point cloud with reduced resolution
80
  """
81
  if depth_map is None:
82
  return None
 
 
 
 
83
 
84
- try:
85
- img_array = np.array(image)
86
- depth_map_resized = cv2.resize(depth_map, (img_array.shape[1], img_array.shape[0]), interpolation=cv2.INTER_LINEAR)
87
- height, width = img_array.shape[:2]
88
-
89
- # Increase step size to reduce point cloud density
90
- step = max(1, min(height, width) // 100)
91
- points, colors = [], []
92
-
93
- for y in range(0, height, step):
94
- for x in range(0, width, step):
95
- z = depth_map_resized[y, x] / 255.0 * 3 # Reduced depth scale
96
- points.append([x, y, z])
97
- color = img_array[y, x][:3] / 255.0 if len(img_array[y, x]) >= 3 else [0.5, 0.5, 0.5]
98
- colors.append(color)
99
-
100
- pcd = o3d.geometry.PointCloud()
101
- pcd.points = o3d.utility.Vector3dVector(points)
102
- pcd.colors = o3d.utility.Vector3dVector(colors)
103
-
104
- return pcd
105
-
106
- except Exception as e:
107
- print(f"Point cloud creation error: {e}")
108
- return None
109
 
110
  def convert_to_mesh(self, point_cloud):
111
  """
112
- More robust mesh conversion with reduced complexity
113
  """
114
  if point_cloud is None:
115
  return None
116
 
117
  try:
118
- # Estimate normals with error handling
119
  point_cloud.estimate_normals()
120
- point_cloud.orient_normals_consistent_tangent_plane(100)
121
-
122
- # More flexible mesh reconstruction with lower depth
123
- try:
124
- mesh, _ = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(
125
- point_cloud, depth=6 # Reduced depth for less memory usage
126
- )
127
- except RuntimeError:
128
- # Fallback method if Poisson reconstruction fails
129
- mesh = o3d.geometry.TriangleMesh.create_from_point_cloud_ball_pivoting(
130
- point_cloud, o3d.utility.DoubleVector([0.02, 0.04, 0.08])
131
- )
132
-
133
- # Simplify mesh to reduce memory and file size
134
  mesh = mesh.simplify_quadric_decimation(target_number_of_triangles=10000)
135
-
136
- # Color the mesh safely
137
- if hasattr(point_cloud, 'colors'):
138
- mesh.vertex_colors = point_cloud.colors
139
-
140
  return mesh
141
 
142
  except Exception as e:
@@ -145,72 +131,31 @@ class RobustDepthTo3DConverter:
145
 
146
  def process_image(self, input_image):
147
  """
148
- Enhanced full pipeline with comprehensive error handling and memory optimization
149
  """
150
- # First, check if model is initialized
151
- if self.model is None:
152
- raise ValueError("DPT model initialization failed. Cannot process image.")
153
-
154
  try:
155
- # Preprocess and validate input
156
- input_image = self.preprocess_image(input_image)
157
-
158
- # Estimate depth
159
  depth_map = self.estimate_depth(input_image)
160
  if depth_map is None:
161
- raise ValueError("Depth estimation failed")
 
 
 
162
 
163
- # Create point cloud
164
- point_cloud = self.create_point_cloud(input_image, depth_map)
165
  if point_cloud is None:
166
- raise ValueError("Point cloud creation failed")
167
 
168
- # Convert to mesh
169
  mesh = self.convert_to_mesh(point_cloud)
170
  if mesh is None:
171
- raise ValueError("Mesh conversion failed")
172
-
173
- # Save mesh with unique filename
174
- output_filename = f"/tmp/3d_model_{uuid.uuid4()}.obj"
175
- o3d.io.write_triangle_mesh(output_filename, mesh)
176
 
177
- # Manual cleanup
178
- del depth_map, point_cloud, mesh
179
- gc.collect()
180
 
181
- return output_filename
182
 
183
  except Exception as e:
184
- print(f"Full pipeline error: {e}")
185
- raise # Re-raise the exception to be caught in the Gradio interface
186
-
187
- def create_huggingface_space():
188
- # Initialize converter
189
- converter = RobustDepthTo3DConverter()
190
-
191
- def convert_image(input_image):
192
- try:
193
- # Check model initialization before processing
194
- if converter.model is None:
195
- raise ValueError("DPT model failed to initialize. Cannot process image.")
196
-
197
- output_model = converter.process_image(input_image)
198
- return output_model
199
- except Exception as e:
200
- print(f"Conversion error: {e}")
201
- raise gr.Error(f"Conversion failed: {str(e)}")
202
-
203
- # Gradio Interface
204
- iface = gr.Interface(
205
- fn=convert_image,
206
- inputs=gr.Image(type="pil", label="Input Image"),
207
- outputs=gr.File(label="3D Model (OBJ)"),
208
- title="Optimized AI Image to 3D Converter",
209
- description="Convert images to 3D models with CPU optimization and reduced memory usage."
210
- )
211
-
212
- return iface
213
 
214
- # Launch the Gradio interface
215
- demo = create_huggingface_space()
216
- demo.launch(debug=True)
 
1
  import torch
2
  import torchvision.transforms as transforms
 
3
  import numpy as np
4
  import open3d as o3d
5
  from PIL import Image
6
  import cv2
7
+ import gc
 
8
  from transformers import DPTForDepthEstimation, DPTFeatureExtractor
9
+ from torchvision.models.segmentation import deeplabv3_resnet50 # Optional for segmentation
10
 
11
+ class MultiModel3DReconstruction:
12
  def __init__(self):
 
13
  self.device = torch.device("cpu")
14
 
15
  try:
16
+ # Load DPT model for depth estimation
17
+ self.dpt_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(self.device)
18
+ self.dpt_feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
19
+ self.dpt_model.eval()
20
 
21
+ print("DPT model successfully loaded.")
22
+ except Exception as e:
23
+ print(f"Error loading DPT model: {e}")
24
+ self.dpt_model = None
25
 
26
+ try:
27
+ # Optional fallback: Load ZoeDepth model
28
+ self.zoe_model = torch.hub.load("isl-org/ZoeDepth", "ZoeD_N", pretrained=True).to(self.device)
29
+ self.zoe_model.eval()
30
+
31
+ print("ZoeDepth model successfully loaded.")
32
  except Exception as e:
33
+ print(f"Error loading ZoeDepth model: {e}")
34
+ self.zoe_model = None
35
 
36
  def preprocess_image(self, input_image):
37
  """
38
+ Preprocess input image for models.
39
  """
 
40
  if not isinstance(input_image, Image.Image):
41
  input_image = Image.fromarray(input_image)
42
 
43
+ if input_image.mode != "RGB":
44
+ input_image = input_image.convert("RGB")
 
45
 
46
  return input_image
47
 
48
  def estimate_depth(self, input_image):
49
  """
50
+ Estimate depth using the best available model.
51
  """
52
+ input_image = self.preprocess_image(input_image)
 
53
 
54
  try:
55
+ # Use DPT for depth estimation
56
+ inputs = self.dpt_feature_extractor(images=input_image, return_tensors="pt").to(self.device)
 
 
 
57
  with torch.no_grad():
58
+ outputs = self.dpt_model(**inputs)
59
  depth = outputs.predicted_depth.squeeze().cpu().numpy()
60
+ return depth
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
+ except Exception as dpt_error:
63
+ print(f"DPT model error: {dpt_error}")
64
+
65
+ try:
66
+ # Fallback: Use ZoeDepth for depth estimation
67
+ zoe_input = transforms.ToTensor()(input_image).unsqueeze(0).to(self.device)
68
+ with torch.no_grad():
69
+ depth = self.zoe_model.infer(zoe_input).squeeze().cpu().numpy()
70
+ return depth
71
+
72
+ except Exception as zoe_error:
73
+ print(f"ZoeDepth fallback error: {zoe_error}")
74
+
75
+ return None
76
+
77
+ def refine_depth(self, depth_map):
78
+ """
79
+ Smooth and refine the depth map.
80
+ """
81
+ if depth_map is None:
82
  return None
83
+
84
+ depth_map_normalized = cv2.normalize(depth_map, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)
85
+ refined = cv2.bilateralFilter(depth_map_normalized, d=9, sigmaColor=75, sigmaSpace=75)
86
+ return refined
87
 
88
  def create_point_cloud(self, image, depth_map):
89
  """
90
+ Generate a point cloud from the depth map.
91
  """
92
  if depth_map is None:
93
  return None
94
+
95
+ img_array = np.array(image)
96
+ depth_map_resized = cv2.resize(depth_map, (img_array.shape[1], img_array.shape[0]), interpolation=cv2.INTER_LINEAR)
97
+ height, width = img_array.shape[:2]
98
 
99
+ step = max(1, min(height, width) // 100) # Adjustable step size
100
+ points, colors = [], []
101
+
102
+ for y in range(0, height, step):
103
+ for x in range(0, width, step):
104
+ z = depth_map_resized[y, x] / 255.0 * 3 # Adjust depth scaling
105
+ points.append([x, y, z])
106
+ color = img_array[y, x][:3] / 255.0 if len(img_array[y, x]) >= 3 else [0.5, 0.5, 0.5]
107
+ colors.append(color)
108
+
109
+ pcd = o3d.geometry.PointCloud()
110
+ pcd.points = o3d.utility.Vector3dVector(points)
111
+ pcd.colors = o3d.utility.Vector3dVector(colors)
112
+
113
+ return pcd
 
 
 
 
 
 
 
 
 
 
114
 
115
  def convert_to_mesh(self, point_cloud):
116
  """
117
+ Convert a point cloud to a 3D mesh.
118
  """
119
  if point_cloud is None:
120
  return None
121
 
122
  try:
 
123
  point_cloud.estimate_normals()
124
+ mesh, _ = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(point_cloud, depth=8)
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  mesh = mesh.simplify_quadric_decimation(target_number_of_triangles=10000)
 
 
 
 
 
126
  return mesh
127
 
128
  except Exception as e:
 
131
 
132
  def process_image(self, input_image):
133
  """
134
+ Complete pipeline: Depth estimation -> Point Cloud -> Mesh.
135
  """
 
 
 
 
136
  try:
 
 
 
 
137
  depth_map = self.estimate_depth(input_image)
138
  if depth_map is None:
139
+ raise ValueError("Depth estimation failed.")
140
+
141
+ refined_depth = self.refine_depth(depth_map)
142
+ point_cloud = self.create_point_cloud(input_image, refined_depth)
143
 
 
 
144
  if point_cloud is None:
145
+ raise ValueError("Point cloud generation failed.")
146
 
 
147
  mesh = self.convert_to_mesh(point_cloud)
148
  if mesh is None:
149
+ raise ValueError("Mesh conversion failed.")
 
 
 
 
150
 
151
+ output_file = f"/tmp/3d_model_{uuid.uuid4()}.obj"
152
+ o3d.io.write_triangle_mesh(output_file, mesh)
 
153
 
154
+ return output_file
155
 
156
  except Exception as e:
157
+ print(f"Pipeline error: {e}")
158
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
+ # Instantiate and test the pipeline
161
+ converter = MultiModel3DReconstruction()