ginipick's picture
Update app.py
9bc6ef8 verified
raw
history blame
5.74 kB
import shlex
import subprocess
import os
import sys
import logging
from pathlib import Path
from huggingface_hub import snapshot_download
from huggingface_hub.utils import RepositoryNotFoundError
import torch
import fire
import gradio as gr
from gradio_app.gradio_3dgen import create_ui as create_3d_ui
from gradio_app.all_models import model_zoo
# Set up logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
def install_requirements():
"""Install requirements from requirements.txt"""
requirements = """
pytorch3d @ https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt221/pytorch3d-0.7.6-cp310-cp310-linux_x86_64.whl
ort_nightly_gpu @ https://aiinfra.pkgs.visualstudio.com/2692857e-05ef-43b4-ba9c-ccf1c22c437c/_packaging/d3daa2b0-aa56-45ac-8145-2c3dc0661c87/pypi/download/ort-nightly-gpu/1.17.dev20240118002/ort_nightly_gpu-1.17.0.dev20240118002-cp310-cp310-manylinux_2_28_x86_64.whl
onnxruntime_gpu @ https://pkgs.dev.azure.com/onnxruntime/2a773b67-e88b-4c7f-9fc0-87d31fea8ef2/_packaging/7fa31e42-5da1-4e84-a664-f2b4129c7d45/pypi/download/onnxruntime-gpu/1.17/onnxruntime_gpu-1.17.0-cp310-cp310-manylinux_2_28_x86_64.whl
torch==2.2.0
accelerate
datasets
diffusers>=0.26.3
fire
gradio
jax
typing
numba
numpy<2
omegaconf>=2.3.0
opencv_python
opencv_python_headless
peft
Pillow
pygltflib
pymeshlab>=2023.12
rembg[gpu]
torch_scatter @ https://data.pyg.org/whl/torch-2.2.0%2Bcu121/torch_scatter-2.1.2%2Bpt22cu121-cp310-cp310-linux_x86_64.whl
tqdm
transformers
trimesh
typeguard
wandb
xformers
ninja
""".strip()
# Write requirements to file
with open("requirements.txt", "w") as f:
f.write(requirements)
try:
logger.info("Installing requirements...")
subprocess.run(
[sys.executable, "-m", "pip", "install", "-r", "requirements.txt"],
check=True
)
logger.info("Requirements installed successfully")
except subprocess.CalledProcessError as e:
logger.error(f"Failed to install requirements: {str(e)}")
raise
def setup_dependencies():
"""Install required packages with error handling"""
try:
logger.info("Installing dependencies...")
install_requirements()
# Define package paths
packages = [
"package/onnxruntime_gpu-1.17.0-cp310-cp310-manylinux_2_28_x86_64.whl",
"package/nvdiffrast-0.3.1.torch-cp310-cp310-linux_x86_64.whl"
]
# Check if package files exist
for package in packages:
if not Path(package).exists():
raise FileNotFoundError(f"Package file not found: {package}")
logger.info(f"Installing {package}")
subprocess.run(
shlex.split(f"pip install {package} --force-reinstall --no-deps"),
check=True
)
logger.info("Dependencies installed successfully")
except subprocess.CalledProcessError as e:
logger.error(f"Failed to install dependencies: {str(e)}")
raise
except FileNotFoundError as e:
logger.error(str(e))
raise
def setup_model():
"""Download and set up model with offline support"""
try:
logger.info("Setting up model checkpoints...")
# Create checkpoint directory if it doesn't exist
ckpt_dir = Path("./ckpt")
ckpt_dir.mkdir(parents=True, exist_ok=True)
# Check for offline mode first
model_path = ckpt_dir / "img2mvimg"
if not model_path.exists():
logger.info("Model not found locally, attempting download...")
try:
snapshot_download(
"public-data/Unique3D",
repo_type="model",
local_dir=str(ckpt_dir),
token=os.getenv("HF_TOKEN"),
local_files_only=False
)
except Exception as e:
logger.error(f"Failed to download model: {str(e)}")
logger.info("Checking for local model files...")
if not (model_path / "config.json").exists():
raise RuntimeError(
"Model not found locally and download failed. "
"Please ensure you have the model files in ./ckpt/img2mvimg"
)
# Configure PyTorch settings
torch.set_float32_matmul_precision('medium')
torch.backends.cuda.matmul.allow_tf32 = True
torch.set_grad_enabled(False)
logger.info("Model setup completed successfully")
except Exception as e:
logger.error(f"Error during model setup: {str(e)}")
raise
# Application title
_TITLE = 'Text to 3D'
def launch():
"""Launch the Gradio interface"""
try:
logger.info("Initializing models...")
model_zoo.init_models()
logger.info("Creating Gradio interface...")
with gr.Blocks(title=_TITLE) as demo:
with gr.Row():
with gr.Column(scale=1):
gr.Markdown('# ' + _TITLE)
create_3d_ui("wkl")
demo.queue().launch(share=True)
except Exception as e:
logger.error(f"Error launching application: {str(e)}")
raise
if __name__ == '__main__':
try:
logger.info("Starting application setup...")
setup_dependencies()
sys.path.append(os.curdir)
setup_model()
fire.Fire(launch)
except Exception as e:
logger.error(f"Application startup failed: {str(e)}")
sys.exit(1)