charleselena commited on
Commit
212c508
·
verified ·
1 Parent(s): 186873f

gpu for cuda

Browse files
Files changed (1) hide show
  1. handler.py +3 -4
handler.py CHANGED
@@ -20,9 +20,7 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20
  if device.type != 'cuda':
21
  raise ValueError("need to run on GPU")
22
  # set mixed precision dtype
23
- #dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
24
-
25
- dtype = torch.float32
26
 
27
  # controlnet mapping for controlnet id and control hinter
28
  CONTROLNET_MAPPING = {
@@ -87,7 +85,8 @@ class EndpointHandler():
87
 
88
 
89
  # Define Generator with seed
90
- self.generator = torch.Generator(device="cpu").manual_seed(3)
 
91
 
92
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
93
  """
 
20
  if device.type != 'cuda':
21
  raise ValueError("need to run on GPU")
22
  # set mixed precision dtype
23
+ dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
 
 
24
 
25
  # controlnet mapping for controlnet id and control hinter
26
  CONTROLNET_MAPPING = {
 
85
 
86
 
87
  # Define Generator with seed
88
+ # self.generator = torch.Generator(device="cpu").manual_seed(3)
89
+ self.generator = torch.Generator(device="cuda").manual_seed(3)
90
 
91
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
92
  """