File size: 2,814 Bytes
e21ad99
 
637bebb
e21ad99
 
 
637bebb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e21ad99
 
637bebb
 
e21ad99
 
637bebb
e21ad99
637bebb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e21ad99
 
637bebb
e21ad99
637bebb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e21ad99
 
 
 
637bebb
e21ad99
 
637bebb
e21ad99
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import yaml
import torch
import os
from diffusers import StableDiffusionXLPipeline
from utils import PhotoMakerStableDiffusionXLPipeline

def get_models_dict(config_path='config/models.yaml', verbose=False):
    """
    Loads model configuration from a YAML file.
    
    Args:
        config_path (str): Path to the YAML configuration file.
        verbose (bool): If True, prints the loaded configuration.

    Returns:
        dict: Parsed YAML data.
    """
    if not os.path.exists(config_path):
        raise FileNotFoundError(f"Config file '{config_path}' not found.")

    with open(config_path, 'r') as stream:
        try:
            data = yaml.safe_load(stream)
            if verbose:
                print("Loaded model configuration:", data)
            return data
        except yaml.YAMLError as exc:
            raise RuntimeError(f"Error parsing YAML file: {exc}")

def load_models(model_info, device="cuda", photomaker_path=None):
    """
    Loads a Stable Diffusion XL model or a PhotoMaker variant based on the provided info.

    Args:
        model_info (dict): Model configuration dictionary.
        device (str): Target device ('cuda' or 'cpu').
        photomaker_path (str, optional): Path to PhotoMaker adapter weights if using Photomaker.

    Returns:
        DiffusionPipeline: Loaded diffusion pipeline.
    """
    path = model_info.get("path")
    single_file = model_info.get("single_files", False)
    use_safetensors = model_info.get("use_safetensors", True)
    model_type = model_info.get("model_type", "original")

    if not path:
        raise ValueError("Model path must be specified in the model_info.")

    if model_type == "original":
        pipeline_cls = StableDiffusionXLPipeline
    elif model_type == "Photomaker":
        pipeline_cls = PhotoMakerStableDiffusionXLPipeline
    else:
        raise NotImplementedError(
            f"Unsupported model type '{model_type}'. Choose either 'original' or 'Photomaker'."
        )

    # Load model
    if single_file:
        print(f"Loading model from a single file: {path}")
        pipe = pipeline_cls.from_single_file(path, torch_dtype=torch.float16)
    else:
        print(f"Loading model from a directory: {path}")
        pipe = pipeline_cls.from_pretrained(path, torch_dtype=torch.float16, use_safetensors=use_safetensors)

    pipe = pipe.to(device)

    # Load PhotoMaker adapter if needed
    if model_type == "Photomaker":
        if not photomaker_path:
            raise ValueError("Photomaker model type requires a valid 'photomaker_path'.")
        pipe.load_photomaker_adapter(
            os.path.dirname(photomaker_path),
            subfolder="",
            weight_name=os.path.basename(photomaker_path),
            trigger_word="img"
        )
        pipe.fuse_lora()

    return pipe