Imagemodel / app.py
jiten6555's picture
Update app.py
e486d98 verified
import torch
import gradio as gr
import numpy as np
from PIL import Image
import trimesh
import cv2
import open3d as o3d
# Critical Model Imports
from transformers import pipeline, AutoFeatureExtractor, AutoModelForImageToImage
from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel
from huggingface_hub import hf_hub_download
class CompleteMeshGenerator:
def __init__(self):
# Critical Model Configuration
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Depth Estimation Model
try:
self.depth_estimator = pipeline(
"depth-estimation",
model="Intel/dpt-large",
device=self.device
)
except Exception as e:
print(f"Depth Estimation Model Load Error: {e}")
self.depth_estimator = None
# Multi-View Generation Setup
try:
# Load ControlNet for multi-view generation
self.controlnet = ControlNetModel.from_pretrained(
"lllyasviel/control_v11f1p_sd15_depth",
torch_dtype=torch.float32
)
self.multi_view_pipeline = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
controlnet=self.controlnet,
torch_dtype=torch.float32
).to(self.device)
except Exception as e:
print(f"Multi-View Generation Model Load Error: {e}")
self.multi_view_pipeline = None
def generate_depth_map(self, image):
"""
Advanced Depth Map Generation
"""
if self.depth_estimator is None:
raise ValueError("Depth estimation model not loaded")
# Ensure image is in correct format
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
# Estimate depth
depth_result = self.depth_estimator(image)
depth_map = np.array(depth_result['depth'])
return depth_map
def generate_multi_view_images(self, input_image, num_views=4):
"""
Generate Multiple View Images
"""
if self.multi_view_pipeline is None:
raise ValueError("Multi-view generation pipeline not loaded")
# Estimate initial depth map
depth_map = self.generate_depth_map(input_image)
# Convert depth map to PIL Image
depth_image = Image.fromarray((depth_map * 255).astype(np.uint8))
# View generation parameters
view_angles = [
(30, "Side view"),
(150, "Opposite side"),
(90, "Top view"),
(270, "Bottom view")
]
multi_view_images = []
for angle, description in view_angles[:num_views]:
try:
generated_image = self.multi_view_pipeline(
prompt=f"3D object view from {description}",
image=input_image,
control_image=depth_image,
controlnet_conditioning_scale=1.0,
rotation=angle,
guidance_scale=7.5
).images[0]
multi_view_images.append(generated_image)
except Exception as e:
print(f"View generation error for angle {angle}: {e}")
return multi_view_images
def advanced_point_cloud_reconstruction(self, depth_maps):
"""
Advanced Point Cloud Reconstruction
"""
point_clouds = []
for depth_map in depth_maps:
# Create point cloud from depth map
height, width = depth_map.shape
x = np.linspace(0, width-1, width)
y = np.linspace(0, height-1, height)
xx, yy = np.meshgrid(x, y)
# Convert depth to 3D points
points_3d = np.column_stack([
xx.ravel(),
yy.ravel(),
depth_map.ravel()
])
# Create Open3D point cloud
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(points_3d)
point_clouds.append(pcd)
# Merge point clouds
merged_pcd = point_clouds[0]
for pcd in point_clouds[1:]:
merged_pcd += pcd
return merged_pcd
def mesh_reconstruction(self, point_cloud):
"""
Advanced Mesh Reconstruction
"""
# Poisson surface reconstruction
mesh, _ = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(point_cloud, depth=9)
# Clean and smooth mesh
mesh.compute_vertex_normals()
mesh = mesh.filter_smooth_laplacian(number_of_iterations=10)
return mesh
def create_3d_model(self, input_image):
"""
Comprehensive 3D Model Generation Pipeline
"""
try:
# Generate multi-view images
multi_view_images = self.generate_multi_view_images(input_image)
# Extract depth maps
depth_maps = [np.array(self.generate_depth_map(img)) for img in multi_view_images]
# Advanced point cloud reconstruction
point_cloud = self.advanced_point_cloud_reconstruction(depth_maps)
# Mesh generation
mesh = self.mesh_reconstruction(point_cloud)
# Save mesh in multiple formats
output_path = "reconstructed_3d_model"
o3d.io.write_triangle_mesh(f"{output_path}.ply", mesh)
# Convert to trimesh for additional formats
trimesh_mesh = trimesh.Trimesh(
vertices=np.asarray(mesh.vertices),
faces=np.asarray(mesh.triangles)
)
trimesh_mesh.export(f"{output_path}.obj")
trimesh_mesh.export(f"{output_path}.stl")
return (
"3D Model Generated Successfully!",
multi_view_images,
[f"{output_path}.ply", f"{output_path}.obj", f"{output_path}.stl"]
)
except Exception as e:
return f"3D Model Generation Error: {str(e)}", None, None
def create_gradio_interface(self):
interface = gr.Interface(
fn=self.create_3d_model,
inputs=gr.Image(type="pil", label="Upload Image for 3D Reconstruction"),
outputs=[
gr.Textbox(label="Generation Status"),
gr.Gallery(label="Generated Multi-View Images"),
gr.File(label="Reconstructed 3D Model Files")
],
title="Advanced 3D Model Generator",
description="""
Generate comprehensive 3D models from single images using:
- Multi-view image generation
- Advanced depth estimation
- Point cloud reconstruction
- Mesh generation
""",
allow_flagging="never"
)
return interface
def main():
mesh_generator = CompleteMeshGenerator()
interface = mesh_generator.create_gradio_interface()
interface.launch(debug=True)
if __name__ == "__main__":
main()