File size: 3,964 Bytes
1774ce2
 
75a1da3
 
1219780
 
 
75a1da3
1774ce2
 
 
 
 
1219780
 
 
 
 
 
 
75a1da3
1219780
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75a1da3
 
1219780
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1774ce2
75a1da3
 
1774ce2
 
1219780
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1774ce2
 
1219780
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import shlex
import subprocess
import os
import sys
import logging
from pathlib import Path
from huggingface_hub import snapshot_download, HfHubError
import torch
import fire
import gradio as gr
from gradio_app.gradio_3dgen import create_ui as create_3d_ui
from gradio_app.all_models import model_zoo

# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

def setup_dependencies():
    """Install required packages with error handling"""
    try:
        logger.info("Installing dependencies...")
        subprocess.run(shlex.split("pip install pip==24.0"), check=True)
        
        # Define package paths
        packages = [
            "package/onnxruntime_gpu-1.17.0-cp310-cp310-manylinux_2_28_x86_64.whl",
            "package/nvdiffrast-0.3.1.torch-cp310-cp310-linux_x86_64.whl"
        ]
        
        # Check if package files exist
        for package in packages:
            if not Path(package).exists():
                raise FileNotFoundError(f"Package file not found: {package}")
            
            logger.info(f"Installing {package}")
            subprocess.run(
                shlex.split(f"pip install {package} --force-reinstall --no-deps"),
                check=True
            )
        
        logger.info("Dependencies installed successfully")
    except subprocess.CalledProcessError as e:
        logger.error(f"Failed to install dependencies: {str(e)}")
        raise
    except FileNotFoundError as e:
        logger.error(str(e))
        raise

def setup_model():
    """Download and set up model with error handling"""
    try:
        logger.info("Downloading model checkpoints...")
        
        # Create checkpoint directory if it doesn't exist
        ckpt_dir = Path("./ckpt")
        ckpt_dir.mkdir(parents=True, exist_ok=True)
        
        # Download model with retry mechanism
        max_retries = 3
        for attempt in range(max_retries):
            try:
                snapshot_download(
                    "public-data/Unique3D",
                    repo_type="model",
                    local_dir=str(ckpt_dir),
                    token=os.getenv("HF_TOKEN")  # Add token if needed
                )
                break
            except HfHubError as e:
                if attempt == max_retries - 1:
                    logger.error(f"Failed to download model after {max_retries} attempts: {str(e)}")
                    raise
                logger.warning(f"Download attempt {attempt + 1} failed, retrying...")
                continue
        
        logger.info("Model checkpoints downloaded successfully")
        
        # Configure PyTorch settings
        torch.set_float32_matmul_precision('medium')
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.set_grad_enabled(False)
        
        logger.info("PyTorch configured successfully")
    except Exception as e:
        logger.error(f"Error during model setup: {str(e)}")
        raise

# Application title
_TITLE = 'Text to 3D'

def launch():
    """Launch the Gradio interface"""
    try:
        logger.info("Initializing models...")
        model_zoo.init_models()
        
        logger.info("Creating Gradio interface...")
        with gr.Blocks(title=_TITLE) as demo:
            with gr.Row():
                with gr.Column(scale=1):
                    gr.Markdown('# ' + _TITLE)
                    create_3d_ui("wkl")
            
            demo.queue().launch(share=True)
    except Exception as e:
        logger.error(f"Error launching application: {str(e)}")
        raise

if __name__ == '__main__':
    try:
        logger.info("Starting application setup...")
        setup_dependencies()
        sys.path.append(os.curdir)
        setup_model()
        fire.Fire(launch)
    except Exception as e:
        logger.error(f"Application startup failed: {str(e)}")
        sys.exit(1)