import torch
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file

def load_unet_model(base, repo, ckpt, device="cpu"):
    """
    Load the UNet model from Hugging Face Hub.

    Args:
        base (str): Base model name.
        repo (str): Repository name.
        ckpt (str): Checkpoint filename.
        device (str): Device to load the model on.

    Returns:
        UNet2DConditionModel: Loaded UNet model.
    """
    from diffusers import UNet2DConditionModel
    unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(device, torch.float16)
    unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device))
    return unet