Spaces:
Running
on
Zero
Running
on
Zero
Update style_transfer.py
Browse files- style_transfer.py +6 -1
style_transfer.py
CHANGED
@@ -24,7 +24,7 @@ class DogStyleTransfer:
|
|
24 |
"""
|
25 |
def __init__(self):
|
26 |
self.models = {}
|
27 |
-
self.device =
|
28 |
|
29 |
# Check xformers availability
|
30 |
self.xformers_available = False
|
@@ -120,6 +120,11 @@ class DogStyleTransfer:
|
|
120 |
|
121 |
def load_model(self, style_name):
|
122 |
"""Load the appropriate model based on style, handling xformers compatibility"""
|
|
|
|
|
|
|
|
|
|
|
123 |
# Get model ID for the style
|
124 |
model_id = self.style_model_mapping.get(style_name, "runwayml/stable-diffusion-v1-5")
|
125 |
|
|
|
24 |
"""
|
25 |
def __init__(self):
|
26 |
self.models = {}
|
27 |
+
self.device = cpu
|
28 |
|
29 |
# Check xformers availability
|
30 |
self.xformers_available = False
|
|
|
120 |
|
121 |
def load_model(self, style_name):
|
122 |
"""Load the appropriate model based on style, handling xformers compatibility"""
|
123 |
+
|
124 |
+
if not hasattr(self, '_cuda_initialized'):
|
125 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
126 |
+
self._cuda_initialized = True
|
127 |
+
|
128 |
# Get model ID for the style
|
129 |
model_id = self.style_model_mapping.get(style_name, "runwayml/stable-diffusion-v1-5")
|
130 |
|