jiten6555 commited on
Commit
3bfc3b0
·
verified ·
1 Parent(s): 347327e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +135 -91
app.py CHANGED
@@ -1,144 +1,188 @@
1
  import torch
 
2
  import gradio as gr
3
  import numpy as np
4
  import open3d as o3d
5
  from PIL import Image
6
  import cv2
7
 
8
- class CPUFriendlyAIDepthTo3DConverter:
9
  def __init__(self):
10
- # Load MiDaS depth estimation model with explicit CPU configuration
11
- self.model = torch.hub.load('intel-isl/MiDaS', 'MiDaS_small', force_reload=False)
12
- self.model.to('cpu') # Ensure model runs on CPU
 
13
  self.model.eval()
14
 
15
- # Preprocessing transforms
16
- self.transform = torch.hub.load('intel-isl/MiDaS', 'transforms').small_transform
 
 
 
 
 
 
17
 
18
- def estimate_depth(self, input_image):
19
  """
20
- CPU-optimized depth estimation
21
  """
22
- # Convert PIL Image to numpy
23
- img = np.array(input_image)
24
-
25
- # Ensure image is in RGB
26
- if img.shape[-1] == 4: # If RGBA, convert to RGB
27
- img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
28
-
29
- # Preprocess image
30
- input_batch = self.transform(img).unsqueeze(0).to('cpu')
31
 
32
- # Estimate depth with minimal memory usage
33
- with torch.no_grad():
34
- prediction = self.model(input_batch)
35
- depth = prediction.squeeze().cpu().numpy()
36
 
37
- # Free up memory
38
- torch.cuda.empty_cache()
 
39
 
40
- # Normalize depth
41
- depth_normalized = cv2.normalize(depth, None, 0, 255,
42
- norm_type=cv2.NORM_MINMAX,
43
- dtype=cv2.CV_8U)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- return depth_normalized
 
 
46
 
47
  def create_point_cloud(self, image, depth_map):
48
  """
49
- Efficient point cloud creation
50
  """
51
- img_array = np.array(image)
52
- height, width = img_array.shape[:2]
53
 
54
- # More aggressive downsampling for memory efficiency
55
- step = 4
56
- points = []
57
- colors = []
58
-
59
- for y in range(0, height, step):
60
- for x in range(0, width, step):
61
- # Use depth as Z coordinate
62
- z = depth_map[y, x] / 255.0 * 5 # Scaled depth
63
- points.append([x, y, z])
64
-
65
- # Safely get color
66
- try:
67
- color = img_array[y, x] / 255.0
 
 
 
68
  colors.append(color)
69
- except IndexError:
70
- colors.append([0.5, 0.5, 0.5]) # Default color if out of bounds
71
-
72
- pcd = o3d.geometry.PointCloud()
73
- pcd.points = o3d.utility.Vector3dVector(points)
74
- pcd.colors = o3d.utility.Vector3dVector(colors)
75
 
76
- return pcd
 
 
77
 
78
  def convert_to_mesh(self, point_cloud):
79
  """
80
- Memory-efficient mesh conversion
81
  """
82
- # Estimate and orient normals
83
- point_cloud.estimate_normals()
84
- point_cloud.orient_normals_consistent_tangent_plane(100)
85
-
86
- # Lower depth for less memory consumption
87
- mesh, _ = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(
88
- point_cloud, depth=7 # Reduced from previous version
89
- )
90
 
91
- # Simplified color handling
92
- mesh.vertex_colors = point_cloud.colors
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
- return mesh
 
 
95
 
96
  def process_image(self, input_image):
97
  """
98
- CPU-friendly full pipeline
99
  """
100
- # Resize image to reduce memory usage
101
- max_size = (800, 800)
102
- input_image.thumbnail(max_size, Image.LANCZOS)
103
-
104
- # Estimate depth using AI model
105
- depth_map = self.estimate_depth(input_image)
106
-
107
- # Create point cloud
108
- point_cloud = self.create_point_cloud(np.array(input_image), depth_map)
109
-
110
- # Convert to mesh
111
- mesh = self.convert_to_mesh(point_cloud)
112
-
113
- # Save mesh
114
- output_path = "/tmp/cpu_optimized_3d_model.obj"
115
- o3d.io.write_triangle_mesh(output_path, mesh)
 
 
 
 
 
 
 
 
116
 
117
- return output_path
 
 
118
 
119
  def create_huggingface_space():
120
  # Initialize converter
121
- converter = CPUFriendlyAIDepthTo3DConverter()
122
 
123
  def convert_image(input_image):
124
  try:
125
- # Ensure image is in PIL format
126
- if not isinstance(input_image, Image.Image):
127
- input_image = Image.fromarray(input_image)
128
-
129
- # Process image
130
  output_model = converter.process_image(input_image)
131
  return output_model
132
  except Exception as e:
133
- return f"Error during conversion: {str(e)}"
134
 
135
  # Gradio Interface
136
  iface = gr.Interface(
137
  fn=convert_image,
138
  inputs=gr.Image(type="pil", label="Input Image"),
139
  outputs=gr.File(label="3D Model (OBJ)"),
140
- title="CPU-Friendly AI Image to 3D Converter",
141
- description="Convert images to 3D models using lightweight AI depth estimation."
142
  )
143
 
144
  return iface
 
1
  import torch
2
+ import torchvision.transforms as transforms
3
  import gradio as gr
4
  import numpy as np
5
  import open3d as o3d
6
  from PIL import Image
7
  import cv2
8
 
9
+ class RobustDepthTo3DConverter:
10
  def __init__(self):
11
+ # Load MiDaS model with explicit configuration
12
+ self.device = torch.device("cpu")
13
+ self.model = torch.hub.load("intel-isl/MiDaS", "MiDaS_small", pretrained=True)
14
+ self.model.to(self.device)
15
  self.model.eval()
16
 
17
+ # Create transformation pipeline
18
+ self.transform = transforms.Compose([
19
+ transforms.ToTensor(),
20
+ transforms.Normalize(
21
+ mean=[0.485, 0.456, 0.406],
22
+ std=[0.229, 0.224, 0.225]
23
+ )
24
+ ])
25
 
26
+ def preprocess_image(self, input_image):
27
  """
28
+ Standardize image input
29
  """
30
+ # Ensure input is PIL Image
31
+ if not isinstance(input_image, Image.Image):
32
+ input_image = Image.fromarray(input_image)
 
 
 
 
 
 
33
 
34
+ # Resize image
35
+ max_size = (800, 800)
36
+ input_image.thumbnail(max_size, Image.LANCZOS)
 
37
 
38
+ # Convert to RGB if needed
39
+ if input_image.mode != 'RGB':
40
+ input_image = input_image.convert('RGB')
41
 
42
+ return input_image
43
+
44
+ def estimate_depth(self, input_image):
45
+ """
46
+ Robust depth estimation
47
+ """
48
+ try:
49
+ # Preprocess image
50
+ img = self.preprocess_image(input_image)
51
+
52
+ # Convert to tensor
53
+ img_tensor = self.transform(img).unsqueeze(0).to(self.device)
54
+
55
+ # Estimate depth
56
+ with torch.no_grad():
57
+ prediction = self.model(img_tensor)
58
+ depth = prediction.squeeze().cpu().numpy()
59
+
60
+ # Normalize depth
61
+ depth_normalized = cv2.normalize(
62
+ depth, None, 0, 255,
63
+ norm_type=cv2.NORM_MINMAX,
64
+ dtype=cv2.CV_8U
65
+ )
66
+
67
+ return depth_normalized
68
 
69
+ except Exception as e:
70
+ print(f"Depth estimation error: {e}")
71
+ return None
72
 
73
  def create_point_cloud(self, image, depth_map):
74
  """
75
+ Create point cloud with error handling
76
  """
77
+ if depth_map is None:
78
+ return None
79
 
80
+ try:
81
+ img_array = np.array(image)
82
+ height, width = img_array.shape[:2]
83
+
84
+ # Adaptive sampling based on image size
85
+ step = max(1, min(height, width) // 200)
86
+
87
+ points = []
88
+ colors = []
89
+
90
+ for y in range(0, height, step):
91
+ for x in range(0, width, step):
92
+ z = depth_map[y, x] / 255.0 * 5 # Scaled depth
93
+ points.append([x, y, z])
94
+
95
+ # Safe color extraction
96
+ color = img_array[y, x][:3] / 255.0 if len(img_array[y, x]) >= 3 else [0.5, 0.5, 0.5]
97
  colors.append(color)
98
+
99
+ pcd = o3d.geometry.PointCloud()
100
+ pcd.points = o3d.utility.Vector3dVector(points)
101
+ pcd.colors = o3d.utility.Vector3dVector(colors)
102
+
103
+ return pcd
104
 
105
+ except Exception as e:
106
+ print(f"Point cloud creation error: {e}")
107
+ return None
108
 
109
  def convert_to_mesh(self, point_cloud):
110
  """
111
+ Mesh conversion with error handling
112
  """
113
+ if point_cloud is None:
114
+ return None
 
 
 
 
 
 
115
 
116
+ try:
117
+ # Estimate normals
118
+ point_cloud.estimate_normals()
119
+ point_cloud.orient_normals_consistent_tangent_plane(100)
120
+
121
+ # Mesh reconstruction
122
+ mesh, _ = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(
123
+ point_cloud, depth=7
124
+ )
125
+
126
+ # Color the mesh
127
+ mesh.vertex_colors = point_cloud.colors
128
+
129
+ return mesh
130
 
131
+ except Exception as e:
132
+ print(f"Mesh conversion error: {e}")
133
+ return None
134
 
135
  def process_image(self, input_image):
136
  """
137
+ Full pipeline with comprehensive error handling
138
  """
139
+ try:
140
+ # Preprocess and validate input
141
+ input_image = self.preprocess_image(input_image)
142
+
143
+ # Estimate depth
144
+ depth_map = self.estimate_depth(input_image)
145
+ if depth_map is None:
146
+ raise ValueError("Depth estimation failed")
147
+
148
+ # Create point cloud
149
+ point_cloud = self.create_point_cloud(input_image, depth_map)
150
+ if point_cloud is None:
151
+ raise ValueError("Point cloud creation failed")
152
+
153
+ # Convert to mesh
154
+ mesh = self.convert_to_mesh(point_cloud)
155
+ if mesh is None:
156
+ raise ValueError("Mesh conversion failed")
157
+
158
+ # Save mesh
159
+ output_path = "/tmp/robust_3d_model.obj"
160
+ o3d.io.write_triangle_mesh(output_path, mesh)
161
+
162
+ return output_path
163
 
164
+ except Exception as e:
165
+ print(f"Full pipeline error: {e}")
166
+ return f"Error during conversion: {str(e)}"
167
 
168
  def create_huggingface_space():
169
  # Initialize converter
170
+ converter = RobustDepthTo3DConverter()
171
 
172
  def convert_image(input_image):
173
  try:
 
 
 
 
 
174
  output_model = converter.process_image(input_image)
175
  return output_model
176
  except Exception as e:
177
+ return f"Conversion failed: {str(e)}"
178
 
179
  # Gradio Interface
180
  iface = gr.Interface(
181
  fn=convert_image,
182
  inputs=gr.Image(type="pil", label="Input Image"),
183
  outputs=gr.File(label="3D Model (OBJ)"),
184
+ title="Robust AI Image to 3D Converter",
185
+ description="Convert images to 3D models with advanced error handling."
186
  )
187
 
188
  return iface