File size: 5,036 Bytes
51201d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55a51c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8824190
 
55a51c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
---
license: mit
tags:
- pytorch
- diffusers
- class-conditional-image-generation
- diffusion-models-class
---

# Overview

This model is a diffusion model for conditional image generation of  clothes from the FashionMNIST dataset. The model is a class-conditioned UNet that generates images of clothes conditioned on the class label.
The code for this model can be found in this [GitHub repository](https://github.com/Huertas97/GenAI-FashionMNIST)

## Usage

As it is a Custom Class Model of the Diffusers library, it can be used as follows:

Setup
```python	
import json
import torch
import torchvision
from matplotlib import pyplot as plt
from tqdm.auto import tqdm
from torch import nn
from diffusers import UNet2DModel, DDPMScheduler
import safetensors
from huggingface_hub import hf_hub_download
device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'
```
Load the ClassConditionedUnet model safetensor:
```python
# Custom Class
class ClassConditionedUnet(nn.Module):
  def __init__(self, num_classes=10, class_emb_size=10):
    super().__init__()

    # The embedding layer will map the class label to a vector of size class_emb_size
    self.class_emb = nn.Embedding(num_classes, class_emb_size)

    # Self.model is an unconditional UNet with extra input channels
    # to accept the conditioning information (the class embedding)
    self.model = UNet2DModel(
        sample_size=28,           # output image resolution. Equal to input resolution
        in_channels=1 + class_emb_size, # Additional input channels for class cond
        out_channels=1,           # the number of output channels. Equal to input
        layers_per_block=3,       # three residual connections (ResNet) per block
        block_out_channels=(128, 256, 512), # N of output channels for each block. Inverse for upsampling
        down_block_types=(
            "DownBlock2D",  # a regular ResNet downsampling block
            "AttnDownBlock2D",
            "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
        ),
        up_block_types=(
            "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
            "AttnUpBlock2D",
            "UpBlock2D",  # a regular ResNet upsampling block
        ),
        dropout = 0.1,  # Dropout prob between Conv1 and Conv2 in a block. From Improved DDPM paper
    )

  # Forward method takes the class labels as an additional argument
  def forward(self, x, t, class_labels):
    bs, ch, w, h = x.shape # x is shape (bs, 1, 28, 28)

    # class conditioning embedding to add as additional input channels
    class_cond = self.class_emb(class_labels) # Map to embedding dimension
    class_cond = class_cond.view(bs, class_cond.shape[1], 1, 1).expand(bs, class_cond.shape[1], w, h)
    # class_cond final shape (bs, 4, 28, 28)

    # Model input is now x and class cond concatenated together along dimension 1
    # We need provide additional information (the class label)
    # to every spatial location (pixel) in the image. Not changing the original
    # pixels of the images, but adding new channels.
    net_input = torch.cat((x, class_cond), 1) # (bs, 5, 28, 28)

    # Feed this to the UNet alongside the timestep and return the prediction
    # with image output size
    return self.model(net_input, t).sample # (bs, 1, 28, 28)

# Define paths to download the model and scheduler
repo_name = "Huertas97/conditioned-unet-fashion-mnist-non-ema"

# Download the safetensors model file
model_file_path = hf_hub_download(repo_id=repo_name, filename="fashion_class_cond_unet_model_best.safetensors")

# # Load the Class Conditioned UNet model state dictionary
state_dict = safetensors.torch.load_file(model_file_path)
model_classcond_native =  ClassConditionedUnet()
model_classcond_native.load_state_dict(state_dict).to(device)
```

Load the DDPMScheduler:
```python
# Download and load the scheduler configuration file
scheduler_file_path = hf_hub_download(repo_id=repo_name, filename="scheduler_config.json")

with open(scheduler_file_path, 'r') as f:
    scheduler_config = json.load(f)

noise_scheduler = DDPMScheduler.from_config(scheduler_config)
```


Use the model to generate images:

```python
desired_class = [7] # desired class from 0 -> 9
num_samples = 2  # num of images to generate per class

# Prepare random x to start from
x = torch.randn(num_samples*len(desired_class), 1, 28, 28).to(device)

# Prepare the desired classes
y = torch.tensor([[i]*num_samples for i in desired_class]).flatten().to(device)

model_classcond_native = model_classcond_native.to(device)

# Sampling loop
for i, t in tqdm(enumerate(noise_scheduler.timesteps)):
  # Get model pred
  with torch.no_grad():
      residual = model_classcond_native(x, t, y)

  # Update sample with step
  x = noise_scheduler.step(residual, t, x).prev_sample

# Show the results
fig, ax = plt.subplots(1, 1, figsize=(12, 12))
ax.imshow(torchvision.utils.make_grid(x.detach().cpu().clip(-1, 1), nrow=8)[0], cmap='Greys')
```