ruslanmv commited on
Commit
e18c93b
·
verified ·
1 Parent(s): dabcfca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -34
app.py CHANGED
@@ -19,6 +19,7 @@ import matplotlib.pyplot as plt
19
  import gc # Import the garbage collector
20
  from audio import *
21
  import os
 
22
  # Define a fallback for environments without GPU
23
  if os.environ.get("SPACES_ZERO_GPU") is not None:
24
  import spaces
@@ -29,8 +30,6 @@ else:
29
  def wrapper(*args, **kwargs):
30
  return func(*args, **kwargs)
31
  return wrapper
32
-
33
-
34
  # Download necessary NLTK data
35
  try:
36
  nltk.data.find('tokenizers/punkt')
@@ -57,41 +56,31 @@ def log_gpu_memory():
57
  print("CUDA is not available. Cannot log GPU memory.")
58
 
59
  # --------- MinDalle Image Generation Functions ---------
60
- # Load MinDalle model once
61
- # Dynamically determine device and precision
62
- @spaces.GPU(duration=60 * 3)
63
- def load_min_dalle_model(models_root: str = 'pretrained'):
64
- """
65
- Load the MinDalle model, automatically selecting device and precision.
66
-
67
- Args:
68
- models_root: Path to the directory containing MinDalle models.
69
-
70
- Returns:
71
- An instance of the MinDalle model.
72
- """
73
- print("DEBUG: Loading MinDalle model...")
74
-
75
  if torch.cuda.is_available():
76
- device = 'cuda'
77
- dtype = torch.float16
78
- print("DEBUG: Using GPU with float16 precision.")
79
  else:
80
- device = 'cpu'
81
- dtype = torch.float32
82
- print("DEBUG: Using CPU with float32 precision.")
83
-
84
- return MinDalle(
85
- is_mega=True,
86
- models_root=models_root,
87
- is_reusable=False,
88
- is_verbose=True,
89
- dtype=dtype,
90
- device=device
91
- )
 
 
 
92
 
93
- # Initialize the MinDalle model (will now automatically use GPU if available)
94
- min_dalle_model = load_min_dalle_model()
95
 
96
  def generate_image_with_min_dalle(
97
  model: MinDalle,
 
19
  import gc # Import the garbage collector
20
  from audio import *
21
  import os
22
+ multiprocessing.set_start_method("spawn")
23
  # Define a fallback for environments without GPU
24
  if os.environ.get("SPACES_ZERO_GPU") is not None:
25
  import spaces
 
30
  def wrapper(*args, **kwargs):
31
  return func(*args, **kwargs)
32
  return wrapper
 
 
33
  # Download necessary NLTK data
34
  try:
35
  nltk.data.find('tokenizers/punkt')
 
56
  print("CUDA is not available. Cannot log GPU memory.")
57
 
58
  # --------- MinDalle Image Generation Functions ---------
59
+ # Check for GPU availability
60
+ def check_gpu_availability():
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  if torch.cuda.is_available():
62
+ print(f"CUDA devices: {torch.cuda.device_count()}")
63
+ print(f"Current device: {torch.cuda.current_device()}")
64
+ print(torch.cuda.get_device_properties(torch.cuda.current_device()))
65
  else:
66
+ print("CUDA is not available. Running on CPU.")
67
+ check_gpu_availability()
68
+ # GPU-safe model loading
69
+ def initialize_min_dalle_with_gpu():
70
+ @spaces.GPU(duration=60 * 3)
71
+ def load_model():
72
+ return MinDalle(
73
+ is_mega=True,
74
+ models_root='pretrained',
75
+ is_reusable=False,
76
+ is_verbose=True,
77
+ dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
78
+ device='cuda' if torch.cuda.is_available() else 'cpu'
79
+ )
80
+ return load_model()
81
 
82
+ # Initialize MinDalle model
83
+ min_dalle_model = initialize_min_dalle_with_gpu()
84
 
85
  def generate_image_with_min_dalle(
86
  model: MinDalle,