File size: 5,738 Bytes
1774ce2
 
75a1da3
 
1219780
 
d821d06
dc9efd8
75a1da3
1774ce2
 
 
 
 
1219780
 
 
 
 
 
 
9bc6ef8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75a1da3
1219780
 
 
9bc6ef8
1219780
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75a1da3
 
9bc6ef8
1219780
9bc6ef8
1219780
 
 
 
 
9bc6ef8
 
 
 
1219780
 
 
 
 
9bc6ef8
 
1219780
d821d06
9bc6ef8
 
 
 
 
 
 
1219780
 
 
 
 
 
9bc6ef8
1219780
 
 
1774ce2
75a1da3
 
1774ce2
 
1219780
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1774ce2
 
1219780
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import shlex
import subprocess
import os
import sys
import logging
from pathlib import Path
from huggingface_hub import snapshot_download
from huggingface_hub.utils import RepositoryNotFoundError
import torch
import fire
import gradio as gr
from gradio_app.gradio_3dgen import create_ui as create_3d_ui
from gradio_app.all_models import model_zoo

# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

def install_requirements():
    """Install requirements from requirements.txt"""
    requirements = """
pytorch3d @ https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt221/pytorch3d-0.7.6-cp310-cp310-linux_x86_64.whl
ort_nightly_gpu @ https://aiinfra.pkgs.visualstudio.com/2692857e-05ef-43b4-ba9c-ccf1c22c437c/_packaging/d3daa2b0-aa56-45ac-8145-2c3dc0661c87/pypi/download/ort-nightly-gpu/1.17.dev20240118002/ort_nightly_gpu-1.17.0.dev20240118002-cp310-cp310-manylinux_2_28_x86_64.whl
onnxruntime_gpu @ https://pkgs.dev.azure.com/onnxruntime/2a773b67-e88b-4c7f-9fc0-87d31fea8ef2/_packaging/7fa31e42-5da1-4e84-a664-f2b4129c7d45/pypi/download/onnxruntime-gpu/1.17/onnxruntime_gpu-1.17.0-cp310-cp310-manylinux_2_28_x86_64.whl
torch==2.2.0
accelerate
datasets
diffusers>=0.26.3
fire
gradio
jax
typing
numba
numpy<2
omegaconf>=2.3.0
opencv_python
opencv_python_headless
peft
Pillow
pygltflib
pymeshlab>=2023.12
rembg[gpu]
torch_scatter @ https://data.pyg.org/whl/torch-2.2.0%2Bcu121/torch_scatter-2.1.2%2Bpt22cu121-cp310-cp310-linux_x86_64.whl
tqdm
transformers
trimesh
typeguard
wandb
xformers
ninja
    """.strip()
    
    # Write requirements to file
    with open("requirements.txt", "w") as f:
        f.write(requirements)
    
    try:
        logger.info("Installing requirements...")
        subprocess.run(
            [sys.executable, "-m", "pip", "install", "-r", "requirements.txt"],
            check=True
        )
        logger.info("Requirements installed successfully")
    except subprocess.CalledProcessError as e:
        logger.error(f"Failed to install requirements: {str(e)}")
        raise

def setup_dependencies():
    """Install required packages with error handling"""
    try:
        logger.info("Installing dependencies...")
        install_requirements()
        
        # Define package paths
        packages = [
            "package/onnxruntime_gpu-1.17.0-cp310-cp310-manylinux_2_28_x86_64.whl",
            "package/nvdiffrast-0.3.1.torch-cp310-cp310-linux_x86_64.whl"
        ]
        
        # Check if package files exist
        for package in packages:
            if not Path(package).exists():
                raise FileNotFoundError(f"Package file not found: {package}")
            
            logger.info(f"Installing {package}")
            subprocess.run(
                shlex.split(f"pip install {package} --force-reinstall --no-deps"),
                check=True
            )
        
        logger.info("Dependencies installed successfully")
    except subprocess.CalledProcessError as e:
        logger.error(f"Failed to install dependencies: {str(e)}")
        raise
    except FileNotFoundError as e:
        logger.error(str(e))
        raise

def setup_model():
    """Download and set up model with offline support"""
    try:
        logger.info("Setting up model checkpoints...")
        
        # Create checkpoint directory if it doesn't exist
        ckpt_dir = Path("./ckpt")
        ckpt_dir.mkdir(parents=True, exist_ok=True)
        
        # Check for offline mode first
        model_path = ckpt_dir / "img2mvimg"
        if not model_path.exists():
            logger.info("Model not found locally, attempting download...")
            try:
                snapshot_download(
                    "public-data/Unique3D",
                    repo_type="model",
                    local_dir=str(ckpt_dir),
                    token=os.getenv("HF_TOKEN"),
                    local_files_only=False
                )
            except Exception as e:
                logger.error(f"Failed to download model: {str(e)}")
                logger.info("Checking for local model files...")
                if not (model_path / "config.json").exists():
                    raise RuntimeError(
                        "Model not found locally and download failed. "
                        "Please ensure you have the model files in ./ckpt/img2mvimg"
                    )
        
        # Configure PyTorch settings
        torch.set_float32_matmul_precision('medium')
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.set_grad_enabled(False)
        
        logger.info("Model setup completed successfully")
    except Exception as e:
        logger.error(f"Error during model setup: {str(e)}")
        raise

# Application title
_TITLE = 'Text to 3D'

def launch():
    """Launch the Gradio interface"""
    try:
        logger.info("Initializing models...")
        model_zoo.init_models()
        
        logger.info("Creating Gradio interface...")
        with gr.Blocks(title=_TITLE) as demo:
            with gr.Row():
                with gr.Column(scale=1):
                    gr.Markdown('# ' + _TITLE)
                    create_3d_ui("wkl")
            
            demo.queue().launch(share=True)
    except Exception as e:
        logger.error(f"Error launching application: {str(e)}")
        raise

if __name__ == '__main__':
    try:
        logger.info("Starting application setup...")
        setup_dependencies()
        sys.path.append(os.curdir)
        setup_model()
        fire.Fire(launch)
    except Exception as e:
        logger.error(f"Application startup failed: {str(e)}")
        sys.exit(1)