nsfwalex commited on
Commit
eb385d9
·
verified ·
1 Parent(s): aada3b3

Update inference_manager.py

Browse files
Files changed (1) hide show
  1. inference_manager.py +1 -0
inference_manager.py CHANGED
@@ -514,6 +514,7 @@ class ModelManager:
514
  faceid_all_embeds.append(faceid_embed)
515
 
516
  average_embedding = torch.mean(torch.stack(faceid_all_embeds, dim=0), dim=0)
 
517
 
518
  print("start inference...")
519
  style_selection = ""
 
514
  faceid_all_embeds.append(faceid_embed)
515
 
516
  average_embedding = torch.mean(torch.stack(faceid_all_embeds, dim=0), dim=0)
517
+ average_embedding = average_embedding.to("cuda")
518
 
519
  print("start inference...")
520
  style_selection = ""