dcgan-mnist / pipeline.py
osanseviero's picture
Update pipeline.py
0fb3f0b
raw
history blame
2.14 kB
import torch
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin
from PIL import Image
from torchvision import transforms
class Generator(nn.Module, PyTorchModelHubMixin):
def __init__(self, num_channels=3, latent_dim=100, hidden_size=64):
super(Generator, self).__init__()
self.model = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d(latent_dim, hidden_size * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(hidden_size * 8),
nn.ReLU(True),
# state size. (hidden_size*8) x 4 x 4
nn.ConvTranspose2d(hidden_size * 8, hidden_size * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(hidden_size * 4),
nn.ReLU(True),
# state size. (hidden_size*4) x 8 x 8
nn.ConvTranspose2d(hidden_size * 4, hidden_size * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(hidden_size * 2),
nn.ReLU(True),
# state size. (hidden_size*2) x 16 x 16
nn.ConvTranspose2d(hidden_size * 2, hidden_size, 4, 2, 1, bias=False),
nn.BatchNorm2d(hidden_size),
nn.ReLU(True),
# state size. (hidden_size) x 32 x 32
nn.ConvTranspose2d(hidden_size, num_channels, 4, 2, 1, bias=False),
nn.Tanh()
# state size. (num_channels) x 64 x 64
)
def forward(self, noise):
pixel_values = self.model(noise)
return pixel_values
class PreTrainedPipeline():
def __init__(self, path=""):
"""
Initialize model
"""
self.model = model = Generator.from_pretrained("huggan/dcgan-mnist")
def __call__(self, inputs: str):
"""
Args:
inputs (:obj:`str`):
a string containing some text
Return:
A :obj:`PIL.Image` with the raw image representation as PIL.
"""
noise = torch.randn(1, 100, 1, 1)
with torch.no_grad():
output = self.model(noise)
# Scale image
img = output[0]
img = (img + 1) /2
return transforms.ToPILImage()(img)