File size: 5,535 Bytes
b1b52ab
16dfcc8
3057b36
cd41f5f
 
 
 
 
 
 
 
b1b52ab
 
db6a3b7
b1b52ab
db6a3b7
cd41f5f
b1b52ab
258ea5a
b1b52ab
c260ece
a481d7a
b1b52ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a481d7a
b1b52ab
 
a481d7a
b1b52ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db6a3b7
b1b52ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db6a3b7
b1b52ab
 
 
 
 
 
 
 
 
 
 
 
a481d7a
b1b52ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db6a3b7
b1b52ab
db6a3b7
b1b52ab
db6a3b7
b1b52ab
 
 
 
 
db6a3b7
 
b1b52ab
c666caf
b1b52ab
 
 
 
 
258ea5a
 
b1b52ab
258ea5a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
# ... (previous imports remain the same) ...

@spaces.GPU
def image_to_3d(
    image: Image.Image,
    seed: int,
    ss_guidance_strength: float,
    ss_sampling_steps: int,
    slat_guidance_strength: float,
    slat_sampling_steps: int,
    req: gr.Request,
    progress: gr.Progress = gr.Progress()
) -> Tuple[dict, str, str, str]:
    """
    Convert an image to a 3D model with improved memory management and progress tracking.
    """
    user_dir = os.path.join(TMP_DIR, str(req.session_hash))
    progress(0, desc="Initializing...")
    
    # Clear CUDA cache before starting
    torch.cuda.empty_cache()
    
    try:
        # Generate 3D model with progress updates
        progress(0.1, desc="Running 3D generation pipeline...")
        outputs = pipeline.run(
            image,
            seed=seed,
            formats=["gaussian", "mesh"],
            preprocess_image=False,
            sparse_structure_sampler_params={
                "steps": ss_sampling_steps,
                "cfg_strength": ss_guidance_strength,
            },
            slat_sampler_params={
                "steps": slat_sampling_steps,
                "cfg_strength": slat_guidance_strength,
            },
        )
        
        progress(0.4, desc="Generating video preview...")
        # Generate video frames in batches to manage memory
        batch_size = 30  # Process 30 frames at a time
        num_frames = 120
        video = []
        video_geo = []
        
        for i in range(0, num_frames, batch_size):
            end_idx = min(i + batch_size, num_frames)
            batch_frames = render_utils.render_video(
                outputs['gaussian'][0], 
                num_frames=end_idx - i, 
                start_frame=i
            )['color']
            batch_geo = render_utils.render_video(
                outputs['mesh'][0], 
                num_frames=end_idx - i,
                start_frame=i
            )['normal']
            
            video.extend(batch_frames)
            video_geo.extend(batch_geo)
            
            # Clear cache after each batch
            torch.cuda.empty_cache()
            progress(0.4 + (0.3 * i / num_frames), desc=f"Rendering frames {i} to {end_idx}...")
        
        # Combine video frames
        video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
        
        # Generate unique ID and save video
        trial_id = str(uuid.uuid4())
        video_path = os.path.join(user_dir, f"{trial_id}.mp4")
        progress(0.7, desc="Saving video...")
        imageio.mimsave(video_path, video, fps=15)
        
        # Clear video data from memory
        del video
        del video_geo
        torch.cuda.empty_cache()
        
        # Generate and save full-quality GLB
        progress(0.8, desc="Generating full-quality GLB...")
        glb = postprocessing_utils.to_glb(
            outputs['gaussian'][0], 
            outputs['mesh'][0],
            simplify=0.0,
            texture_size=2048,
            verbose=False
        )
        glb_path = os.path.join(user_dir, f"{trial_id}_full.glb")
        progress(0.9, desc="Saving GLB file...")
        glb.export(glb_path)
        
        # Pack state for reduced version
        progress(0.95, desc="Finalizing...")
        state = pack_state(outputs['gaussian'][0], outputs['mesh'][0], trial_id)
        
        # Final cleanup
        torch.cuda.empty_cache()
        progress(1.0, desc="Complete!")
        
        return state, video_path, glb_path, glb_path
        
    except Exception as e:
        # Clean up on error
        torch.cuda.empty_cache()
        raise gr.Error(f"Processing failed: {str(e)}")

@spaces.GPU
def extract_reduced_glb(
    state: dict,
    mesh_simplify: float,
    texture_size: int,
    req: gr.Request,
    progress: gr.Progress = gr.Progress()
) -> Tuple[str, str]:
    """
    Extract a reduced-quality GLB file with progress tracking.
    """
    user_dir = os.path.join(TMP_DIR, str(req.session_hash))
    
    try:
        progress(0.1, desc="Unpacking model state...")
        gs, mesh, trial_id = unpack_state(state)
        
        progress(0.3, desc="Generating reduced GLB...")
        glb = postprocessing_utils.to_glb(
            gs, mesh,
            simplify=mesh_simplify,
            texture_size=texture_size,
            verbose=False
        )
        
        progress(0.8, desc="Saving reduced GLB...")
        glb_path = os.path.join(user_dir, f"{trial_id}_reduced.glb")
        glb.export(glb_path)
        
        progress(0.9, desc="Cleaning up...")
        torch.cuda.empty_cache()
        
        progress(1.0, desc="Complete!")
        return glb_path, glb_path
        
    except Exception as e:
        torch.cuda.empty_cache()
        raise gr.Error(f"GLB reduction failed: {str(e)}")

# ... (rest of the UI code remains the same) ...

# Add some memory optimization settings at startup
if __name__ == "__main__":
    # Set some CUDA memory management options
    torch.cuda.empty_cache()
    torch.backends.cudnn.benchmark = True
    
    # Initialize pipeline
    pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
    pipeline.cuda()
    
    try:
        # Preload rembg with minimal memory usage
        test_img = np.zeros((256, 256, 3), dtype=np.uint8)  # Smaller test image
        pipeline.preprocess_image(Image.fromarray(test_img))
        del test_img
        torch.cuda.empty_cache()
    except:
        pass
    
    demo.launch()