ginipick commited on
Commit
9bc6ef8
·
verified ·
1 Parent(s): dc9efd8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -19
app.py CHANGED
@@ -19,11 +19,60 @@ logging.basicConfig(
19
  )
20
  logger = logging.getLogger(__name__)
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  def setup_dependencies():
23
  """Install required packages with error handling"""
24
  try:
25
  logger.info("Installing dependencies...")
26
- subprocess.run(shlex.split("pip install pip==24.0"), check=True)
27
 
28
  # Define package paths
29
  packages = [
@@ -51,43 +100,41 @@ def setup_dependencies():
51
  raise
52
 
53
  def setup_model():
54
- """Download and set up model with error handling"""
55
  try:
56
- logger.info("Downloading model checkpoints...")
57
 
58
  # Create checkpoint directory if it doesn't exist
59
  ckpt_dir = Path("./ckpt")
60
  ckpt_dir.mkdir(parents=True, exist_ok=True)
61
 
62
- # Download model with retry mechanism
63
- max_retries = 3
64
- for attempt in range(max_retries):
 
65
  try:
66
  snapshot_download(
67
  "public-data/Unique3D",
68
  repo_type="model",
69
  local_dir=str(ckpt_dir),
70
- token=os.getenv("HF_TOKEN") # Add token if needed
 
71
  )
72
- break
73
- except RepositoryNotFoundError as e:
74
- logger.error(f"Repository not found: {str(e)}")
75
- raise
76
  except Exception as e:
77
- if attempt == max_retries - 1:
78
- logger.error(f"Failed to download model after {max_retries} attempts: {str(e)}")
79
- raise
80
- logger.warning(f"Download attempt {attempt + 1} failed, retrying...")
81
- continue
82
-
83
- logger.info("Model checkpoints downloaded successfully")
84
 
85
  # Configure PyTorch settings
86
  torch.set_float32_matmul_precision('medium')
87
  torch.backends.cuda.matmul.allow_tf32 = True
88
  torch.set_grad_enabled(False)
89
 
90
- logger.info("PyTorch configured successfully")
91
  except Exception as e:
92
  logger.error(f"Error during model setup: {str(e)}")
93
  raise
 
19
  )
20
  logger = logging.getLogger(__name__)
21
 
22
+ def install_requirements():
23
+ """Install requirements from requirements.txt"""
24
+ requirements = """
25
+ pytorch3d @ https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt221/pytorch3d-0.7.6-cp310-cp310-linux_x86_64.whl
26
+ ort_nightly_gpu @ https://aiinfra.pkgs.visualstudio.com/2692857e-05ef-43b4-ba9c-ccf1c22c437c/_packaging/d3daa2b0-aa56-45ac-8145-2c3dc0661c87/pypi/download/ort-nightly-gpu/1.17.dev20240118002/ort_nightly_gpu-1.17.0.dev20240118002-cp310-cp310-manylinux_2_28_x86_64.whl
27
+ onnxruntime_gpu @ https://pkgs.dev.azure.com/onnxruntime/2a773b67-e88b-4c7f-9fc0-87d31fea8ef2/_packaging/7fa31e42-5da1-4e84-a664-f2b4129c7d45/pypi/download/onnxruntime-gpu/1.17/onnxruntime_gpu-1.17.0-cp310-cp310-manylinux_2_28_x86_64.whl
28
+ torch==2.2.0
29
+ accelerate
30
+ datasets
31
+ diffusers>=0.26.3
32
+ fire
33
+ gradio
34
+ jax
35
+ typing
36
+ numba
37
+ numpy<2
38
+ omegaconf>=2.3.0
39
+ opencv_python
40
+ opencv_python_headless
41
+ peft
42
+ Pillow
43
+ pygltflib
44
+ pymeshlab>=2023.12
45
+ rembg[gpu]
46
+ torch_scatter @ https://data.pyg.org/whl/torch-2.2.0%2Bcu121/torch_scatter-2.1.2%2Bpt22cu121-cp310-cp310-linux_x86_64.whl
47
+ tqdm
48
+ transformers
49
+ trimesh
50
+ typeguard
51
+ wandb
52
+ xformers
53
+ ninja
54
+ """.strip()
55
+
56
+ # Write requirements to file
57
+ with open("requirements.txt", "w") as f:
58
+ f.write(requirements)
59
+
60
+ try:
61
+ logger.info("Installing requirements...")
62
+ subprocess.run(
63
+ [sys.executable, "-m", "pip", "install", "-r", "requirements.txt"],
64
+ check=True
65
+ )
66
+ logger.info("Requirements installed successfully")
67
+ except subprocess.CalledProcessError as e:
68
+ logger.error(f"Failed to install requirements: {str(e)}")
69
+ raise
70
+
71
  def setup_dependencies():
72
  """Install required packages with error handling"""
73
  try:
74
  logger.info("Installing dependencies...")
75
+ install_requirements()
76
 
77
  # Define package paths
78
  packages = [
 
100
  raise
101
 
102
  def setup_model():
103
+ """Download and set up model with offline support"""
104
  try:
105
+ logger.info("Setting up model checkpoints...")
106
 
107
  # Create checkpoint directory if it doesn't exist
108
  ckpt_dir = Path("./ckpt")
109
  ckpt_dir.mkdir(parents=True, exist_ok=True)
110
 
111
+ # Check for offline mode first
112
+ model_path = ckpt_dir / "img2mvimg"
113
+ if not model_path.exists():
114
+ logger.info("Model not found locally, attempting download...")
115
  try:
116
  snapshot_download(
117
  "public-data/Unique3D",
118
  repo_type="model",
119
  local_dir=str(ckpt_dir),
120
+ token=os.getenv("HF_TOKEN"),
121
+ local_files_only=False
122
  )
 
 
 
 
123
  except Exception as e:
124
+ logger.error(f"Failed to download model: {str(e)}")
125
+ logger.info("Checking for local model files...")
126
+ if not (model_path / "config.json").exists():
127
+ raise RuntimeError(
128
+ "Model not found locally and download failed. "
129
+ "Please ensure you have the model files in ./ckpt/img2mvimg"
130
+ )
131
 
132
  # Configure PyTorch settings
133
  torch.set_float32_matmul_precision('medium')
134
  torch.backends.cuda.matmul.allow_tf32 = True
135
  torch.set_grad_enabled(False)
136
 
137
+ logger.info("Model setup completed successfully")
138
  except Exception as e:
139
  logger.error(f"Error during model setup: {str(e)}")
140
  raise