huzefa11 commited on
Commit
637bebb
·
verified ·
1 Parent(s): 8740caa

Update load_models_utils.py

Browse files
Files changed (1) hide show
  1. load_models_utils.py +61 -32
load_models_utils.py CHANGED
@@ -1,52 +1,81 @@
1
  import yaml
2
  import torch
 
3
  from diffusers import StableDiffusionXLPipeline
4
  from utils import PhotoMakerStableDiffusionXLPipeline
5
- import os
6
 
7
- def get_models_dict():
8
- with open('config/models.yaml', 'r') as stream:
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  try:
10
  data = yaml.safe_load(stream)
11
- print(data)
 
12
  return data
13
-
14
  except yaml.YAMLError as exc:
15
- print(exc)
16
 
17
- def load_models(model_info,device,photomaker_path):
18
- path = model_info["path"]
19
- single_files = model_info["single_files"]
20
- use_safetensors = model_info["use_safetensors"]
21
- model_type = model_info["model_type"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  if model_type == "original":
24
- if single_files:
25
- pipe = StableDiffusionXLPipeline.from_single_file(
26
- path,
27
- torch_dtype=torch.float16
28
- )
29
- else:
30
- pipe = StableDiffusionXLPipeline.from_pretrained(path, torch_dtype=torch.float16, use_safetensors=use_safetensors)
31
- pipe = pipe.to(device)
32
  elif model_type == "Photomaker":
33
- if single_files:
34
- print("loading from a single_files")
35
- pipe = PhotoMakerStableDiffusionXLPipeline.from_single_file(
36
- path,
37
- torch_dtype=torch.float16
38
- )
39
- else:
40
- pipe = PhotoMakerStableDiffusionXLPipeline.from_pretrained(
41
- path, torch_dtype=torch.float16, use_safetensors=use_safetensors)
42
- pipe = pipe.to(device)
 
 
 
 
 
 
 
 
 
 
43
  pipe.load_photomaker_adapter(
44
  os.path.dirname(photomaker_path),
45
  subfolder="",
46
  weight_name=os.path.basename(photomaker_path),
47
- trigger_word="img" # define the trigger word
48
  )
49
  pipe.fuse_lora()
50
- else:
51
- raise NotImplementedError("You should choice between original and Photomaker!",f"But you choice {model_type}")
52
  return pipe
 
1
  import yaml
2
  import torch
3
+ import os
4
  from diffusers import StableDiffusionXLPipeline
5
  from utils import PhotoMakerStableDiffusionXLPipeline
 
6
 
7
+ def get_models_dict(config_path='config/models.yaml', verbose=False):
8
+ """
9
+ Loads model configuration from a YAML file.
10
+
11
+ Args:
12
+ config_path (str): Path to the YAML configuration file.
13
+ verbose (bool): If True, prints the loaded configuration.
14
+
15
+ Returns:
16
+ dict: Parsed YAML data.
17
+ """
18
+ if not os.path.exists(config_path):
19
+ raise FileNotFoundError(f"Config file '{config_path}' not found.")
20
+
21
+ with open(config_path, 'r') as stream:
22
  try:
23
  data = yaml.safe_load(stream)
24
+ if verbose:
25
+ print("Loaded model configuration:", data)
26
  return data
 
27
  except yaml.YAMLError as exc:
28
+ raise RuntimeError(f"Error parsing YAML file: {exc}")
29
 
30
+ def load_models(model_info, device="cuda", photomaker_path=None):
31
+ """
32
+ Loads a Stable Diffusion XL model or a PhotoMaker variant based on the provided info.
33
+
34
+ Args:
35
+ model_info (dict): Model configuration dictionary.
36
+ device (str): Target device ('cuda' or 'cpu').
37
+ photomaker_path (str, optional): Path to PhotoMaker adapter weights if using Photomaker.
38
+
39
+ Returns:
40
+ DiffusionPipeline: Loaded diffusion pipeline.
41
+ """
42
+ path = model_info.get("path")
43
+ single_file = model_info.get("single_files", False)
44
+ use_safetensors = model_info.get("use_safetensors", True)
45
+ model_type = model_info.get("model_type", "original")
46
+
47
+ if not path:
48
+ raise ValueError("Model path must be specified in the model_info.")
49
 
50
  if model_type == "original":
51
+ pipeline_cls = StableDiffusionXLPipeline
 
 
 
 
 
 
 
52
  elif model_type == "Photomaker":
53
+ pipeline_cls = PhotoMakerStableDiffusionXLPipeline
54
+ else:
55
+ raise NotImplementedError(
56
+ f"Unsupported model type '{model_type}'. Choose either 'original' or 'Photomaker'."
57
+ )
58
+
59
+ # Load model
60
+ if single_file:
61
+ print(f"Loading model from a single file: {path}")
62
+ pipe = pipeline_cls.from_single_file(path, torch_dtype=torch.float16)
63
+ else:
64
+ print(f"Loading model from a directory: {path}")
65
+ pipe = pipeline_cls.from_pretrained(path, torch_dtype=torch.float16, use_safetensors=use_safetensors)
66
+
67
+ pipe = pipe.to(device)
68
+
69
+ # Load PhotoMaker adapter if needed
70
+ if model_type == "Photomaker":
71
+ if not photomaker_path:
72
+ raise ValueError("Photomaker model type requires a valid 'photomaker_path'.")
73
  pipe.load_photomaker_adapter(
74
  os.path.dirname(photomaker_path),
75
  subfolder="",
76
  weight_name=os.path.basename(photomaker_path),
77
+ trigger_word="img"
78
  )
79
  pipe.fuse_lora()
80
+
 
81
  return pipe