ironjr commited on
Commit
59ea7db
·
verified ·
1 Parent(s): bea5c3f

Update util.py

Browse files
Files changed (1) hide show
  1. util.py +27 -1
util.py CHANGED
@@ -20,7 +20,7 @@
20
 
21
  import concurrent.futures
22
  import time
23
- from typing import Any, Callable, List, Tuple, Union
24
 
25
  from PIL import Image
26
  import numpy as np
@@ -30,6 +30,12 @@ import torch.nn.functional as F
30
  import torchvision.transforms as T
31
  import torchvision.transforms.functional as TF
32
 
 
 
 
 
 
 
33
 
34
  def seed_everything(seed: int) -> None:
35
  torch.manual_seed(seed)
@@ -38,6 +44,26 @@ def seed_everything(seed: int) -> None:
38
  torch.backends.cudnn.benchmark = True
39
 
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  def get_cutoff(cutoff: float = None, scale: float = None) -> float:
42
  if cutoff is not None:
43
  return cutoff
 
20
 
21
  import concurrent.futures
22
  import time
23
+ from typing import Any, Callable, List, Literal, Tuple, Union
24
 
25
  from PIL import Image
26
  import numpy as np
 
30
  import torchvision.transforms as T
31
  import torchvision.transforms.functional as TF
32
 
33
+ from diffusers import (
34
+ DiffusionPipeline,
35
+ StableDiffusionPipeline,
36
+ StableDiffusionXLPipeline,
37
+ )
38
+
39
 
40
  def seed_everything(seed: int) -> None:
41
  torch.manual_seed(seed)
 
44
  torch.backends.cudnn.benchmark = True
45
 
46
 
47
+ def load_model(
48
+ model_key: str,
49
+ sd_version: Literal['1.5', 'xl'],
50
+ device: torch.device,
51
+ dtype: torch.dtype,
52
+ ) -> torch.nn.Module:
53
+ if model_key.endswith('.safetensors'):
54
+ if sd_version == '1.5':
55
+ pipeline = StableDiffusionPipeline
56
+ elif sd_version == 'xl':
57
+ pipeline = StableDiffusionXLPipeline
58
+ else:
59
+ raise ValueError(f'Stable Diffusion version {sd_version} not supported.')
60
+ return pipeline.from_single_file(model_key, torch_dtype=dtype).to(device)
61
+ try:
62
+ return DiffusionPipeline.from_pretrained(model_key, variant='fp16', torch_dtype=dtype).to(device)
63
+ except:
64
+ return DiffusionPipeline.from_pretrained(model_key, variant=None, torch_dtype=dtype).to(device)
65
+
66
+
67
  def get_cutoff(cutoff: float = None, scale: float = None) -> float:
68
  if cutoff is not None:
69
  return cutoff