charleselena
commited on
gpu for cuda
Browse files- handler.py +3 -4
handler.py
CHANGED
@@ -20,9 +20,7 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
20 |
if device.type != 'cuda':
|
21 |
raise ValueError("need to run on GPU")
|
22 |
# set mixed precision dtype
|
23 |
-
|
24 |
-
|
25 |
-
dtype = torch.float32
|
26 |
|
27 |
# controlnet mapping for controlnet id and control hinter
|
28 |
CONTROLNET_MAPPING = {
|
@@ -87,7 +85,8 @@ class EndpointHandler():
|
|
87 |
|
88 |
|
89 |
# Define Generator with seed
|
90 |
-
self.generator = torch.Generator(device="cpu").manual_seed(3)
|
|
|
91 |
|
92 |
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
|
93 |
"""
|
|
|
20 |
if device.type != 'cuda':
|
21 |
raise ValueError("need to run on GPU")
|
22 |
# set mixed precision dtype
|
23 |
+
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
|
|
|
|
|
24 |
|
25 |
# controlnet mapping for controlnet id and control hinter
|
26 |
CONTROLNET_MAPPING = {
|
|
|
85 |
|
86 |
|
87 |
# Define Generator with seed
|
88 |
+
# self.generator = torch.Generator(device="cpu").manual_seed(3)
|
89 |
+
self.generator = torch.Generator(device="cuda").manual_seed(3)
|
90 |
|
91 |
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
|
92 |
"""
|