File size: 712 Bytes
9d9968c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file

def load_unet_model(base, repo, ckpt, device="cpu"):
    """

    Load the UNet model from Hugging Face Hub.



    Args:

        base (str): Base model name.

        repo (str): Repository name.

        ckpt (str): Checkpoint filename.

        device (str): Device to load the model on.



    Returns:

        UNet2DConditionModel: Loaded UNet model.

    """
    from diffusers import UNet2DConditionModel
    unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(device, torch.float16)
    unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device))
    return unet