DawnC commited on
Commit
7540236
·
verified ·
1 Parent(s): 97f6d6e

Update style_transfer.py

Browse files
Files changed (1) hide show
  1. style_transfer.py +6 -1
style_transfer.py CHANGED
@@ -24,7 +24,7 @@ class DogStyleTransfer:
24
  """
25
  def __init__(self):
26
  self.models = {}
27
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
28
 
29
  # Check xformers availability
30
  self.xformers_available = False
@@ -120,6 +120,11 @@ class DogStyleTransfer:
120
 
121
  def load_model(self, style_name):
122
  """Load the appropriate model based on style, handling xformers compatibility"""
 
 
 
 
 
123
  # Get model ID for the style
124
  model_id = self.style_model_mapping.get(style_name, "runwayml/stable-diffusion-v1-5")
125
 
 
24
  """
25
  def __init__(self):
26
  self.models = {}
27
+ self.device = cpu
28
 
29
  # Check xformers availability
30
  self.xformers_available = False
 
120
 
121
  def load_model(self, style_name):
122
  """Load the appropriate model based on style, handling xformers compatibility"""
123
+
124
+ if not hasattr(self, '_cuda_initialized'):
125
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
126
+ self._cuda_initialized = True
127
+
128
  # Get model ID for the style
129
  model_id = self.style_model_mapping.get(style_name, "runwayml/stable-diffusion-v1-5")
130