osanseviero commited on
Commit
560a4cb
1 Parent(s): 4912c62

Create pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +59 -0
pipeline.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from huggingface_hub import PyTorchModelHubMixin
4
+
5
+ class Generator(nn.Module, PyTorchModelHubMixin):
6
+ def __init__(self, num_channels=3, latent_dim=100, hidden_size=64):
7
+ super(Generator, self).__init__()
8
+ self.model = nn.Sequential(
9
+ # input is Z, going into a convolution
10
+ nn.ConvTranspose2d(latent_dim, hidden_size * 8, 4, 1, 0, bias=False),
11
+ nn.BatchNorm2d(hidden_size * 8),
12
+ nn.ReLU(True),
13
+ # state size. (hidden_size*8) x 4 x 4
14
+ nn.ConvTranspose2d(hidden_size * 8, hidden_size * 4, 4, 2, 1, bias=False),
15
+ nn.BatchNorm2d(hidden_size * 4),
16
+ nn.ReLU(True),
17
+ # state size. (hidden_size*4) x 8 x 8
18
+ nn.ConvTranspose2d(hidden_size * 4, hidden_size * 2, 4, 2, 1, bias=False),
19
+ nn.BatchNorm2d(hidden_size * 2),
20
+ nn.ReLU(True),
21
+ # state size. (hidden_size*2) x 16 x 16
22
+ nn.ConvTranspose2d(hidden_size * 2, hidden_size, 4, 2, 1, bias=False),
23
+ nn.BatchNorm2d(hidden_size),
24
+ nn.ReLU(True),
25
+ # state size. (hidden_size) x 32 x 32
26
+ nn.ConvTranspose2d(hidden_size, num_channels, 4, 2, 1, bias=False),
27
+ nn.Tanh()
28
+ # state size. (num_channels) x 64 x 64
29
+ )
30
+
31
+ def forward(self, noise):
32
+ pixel_values = self.model(noise)
33
+
34
+ return pixel_values
35
+
36
+
37
+
38
+ class PreTrainedPipeline():
39
+ def __init__(self, path=""):
40
+ """
41
+ Initialize model
42
+ """
43
+ self.model = model = Generator.from_pretrained("huggan/dcgan-mnist")
44
+
45
+ def __call__(self, inputs: str):
46
+ """
47
+ Args:
48
+ inputs (:obj:`str`):
49
+ a string containing some text
50
+ Return:
51
+ A :obj:`PIL.Image` with the raw image representation as PIL.
52
+ """
53
+ noise = torch.randn(1, 100, 1, 1, device=device)
54
+ with torch.no_grad():
55
+ output = self.model(noise)
56
+
57
+ # Scale image
58
+ img = transforms.ToPILImage()(output)
59
+ return img