3D-LLAMA / app.py
ginipick's picture
Update app.py
1219780 verified
raw
history blame
3.96 kB
import shlex
import subprocess
import os
import sys
import logging
from pathlib import Path
from huggingface_hub import snapshot_download, HfHubError
import torch
import fire
import gradio as gr
from gradio_app.gradio_3dgen import create_ui as create_3d_ui
from gradio_app.all_models import model_zoo
# Set up logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
def setup_dependencies():
"""Install required packages with error handling"""
try:
logger.info("Installing dependencies...")
subprocess.run(shlex.split("pip install pip==24.0"), check=True)
# Define package paths
packages = [
"package/onnxruntime_gpu-1.17.0-cp310-cp310-manylinux_2_28_x86_64.whl",
"package/nvdiffrast-0.3.1.torch-cp310-cp310-linux_x86_64.whl"
]
# Check if package files exist
for package in packages:
if not Path(package).exists():
raise FileNotFoundError(f"Package file not found: {package}")
logger.info(f"Installing {package}")
subprocess.run(
shlex.split(f"pip install {package} --force-reinstall --no-deps"),
check=True
)
logger.info("Dependencies installed successfully")
except subprocess.CalledProcessError as e:
logger.error(f"Failed to install dependencies: {str(e)}")
raise
except FileNotFoundError as e:
logger.error(str(e))
raise
def setup_model():
"""Download and set up model with error handling"""
try:
logger.info("Downloading model checkpoints...")
# Create checkpoint directory if it doesn't exist
ckpt_dir = Path("./ckpt")
ckpt_dir.mkdir(parents=True, exist_ok=True)
# Download model with retry mechanism
max_retries = 3
for attempt in range(max_retries):
try:
snapshot_download(
"public-data/Unique3D",
repo_type="model",
local_dir=str(ckpt_dir),
token=os.getenv("HF_TOKEN") # Add token if needed
)
break
except HfHubError as e:
if attempt == max_retries - 1:
logger.error(f"Failed to download model after {max_retries} attempts: {str(e)}")
raise
logger.warning(f"Download attempt {attempt + 1} failed, retrying...")
continue
logger.info("Model checkpoints downloaded successfully")
# Configure PyTorch settings
torch.set_float32_matmul_precision('medium')
torch.backends.cuda.matmul.allow_tf32 = True
torch.set_grad_enabled(False)
logger.info("PyTorch configured successfully")
except Exception as e:
logger.error(f"Error during model setup: {str(e)}")
raise
# Application title
_TITLE = 'Text to 3D'
def launch():
"""Launch the Gradio interface"""
try:
logger.info("Initializing models...")
model_zoo.init_models()
logger.info("Creating Gradio interface...")
with gr.Blocks(title=_TITLE) as demo:
with gr.Row():
with gr.Column(scale=1):
gr.Markdown('# ' + _TITLE)
create_3d_ui("wkl")
demo.queue().launch(share=True)
except Exception as e:
logger.error(f"Error launching application: {str(e)}")
raise
if __name__ == '__main__':
try:
logger.info("Starting application setup...")
setup_dependencies()
sys.path.append(os.curdir)
setup_model()
fire.Fire(launch)
except Exception as e:
logger.error(f"Application startup failed: {str(e)}")
sys.exit(1)