Spaces:
Running
on
Zero
Running
on
Zero
import yaml | |
import torch | |
import os | |
from diffusers import StableDiffusionXLPipeline | |
from utils import PhotoMakerStableDiffusionXLPipeline | |
def get_models_dict(config_path='config/models.yaml', verbose=False): | |
""" | |
Loads model configuration from a YAML file. | |
Args: | |
config_path (str): Path to the YAML configuration file. | |
verbose (bool): If True, prints the loaded configuration. | |
Returns: | |
dict: Parsed YAML data. | |
""" | |
if not os.path.exists(config_path): | |
raise FileNotFoundError(f"Config file '{config_path}' not found.") | |
with open(config_path, 'r') as stream: | |
try: | |
data = yaml.safe_load(stream) | |
if verbose: | |
print("Loaded model configuration:", data) | |
return data | |
except yaml.YAMLError as exc: | |
raise RuntimeError(f"Error parsing YAML file: {exc}") | |
def load_models(model_info, device="cuda", photomaker_path=None): | |
""" | |
Loads a Stable Diffusion XL model or a PhotoMaker variant based on the provided info. | |
Args: | |
model_info (dict): Model configuration dictionary. | |
device (str): Target device ('cuda' or 'cpu'). | |
photomaker_path (str, optional): Path to PhotoMaker adapter weights if using Photomaker. | |
Returns: | |
DiffusionPipeline: Loaded diffusion pipeline. | |
""" | |
path = model_info.get("path") | |
single_file = model_info.get("single_files", False) | |
use_safetensors = model_info.get("use_safetensors", True) | |
model_type = model_info.get("model_type", "original") | |
if not path: | |
raise ValueError("Model path must be specified in the model_info.") | |
if model_type == "original": | |
pipeline_cls = StableDiffusionXLPipeline | |
elif model_type == "Photomaker": | |
pipeline_cls = PhotoMakerStableDiffusionXLPipeline | |
else: | |
raise NotImplementedError( | |
f"Unsupported model type '{model_type}'. Choose either 'original' or 'Photomaker'." | |
) | |
# Load model | |
if single_file: | |
print(f"Loading model from a single file: {path}") | |
pipe = pipeline_cls.from_single_file(path, torch_dtype=torch.float16) | |
else: | |
print(f"Loading model from a directory: {path}") | |
pipe = pipeline_cls.from_pretrained(path, torch_dtype=torch.float16, use_safetensors=use_safetensors) | |
pipe = pipe.to(device) | |
# Load PhotoMaker adapter if needed | |
if model_type == "Photomaker": | |
if not photomaker_path: | |
raise ValueError("Photomaker model type requires a valid 'photomaker_path'.") | |
pipe.load_photomaker_adapter( | |
os.path.dirname(photomaker_path), | |
subfolder="", | |
weight_name=os.path.basename(photomaker_path), | |
trigger_word="img" | |
) | |
pipe.fuse_lora() | |
return pipe |