basharatwali commited on
Commit
b98188b
·
1 Parent(s): 0645da3

Add application and model

Browse files
Files changed (3) hide show
  1. app.py +69 -0
  2. generator_final.pth +3 -0
  3. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import torchvision.transforms as transforms
3
+ import gradio as gr
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ latent_dim = 100
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+
10
+ class Generator(nn.Module):
11
+ def __init__(self, latent_dim=100, img_channels=3, feature_map_size=32):
12
+ super(Generator, self).__init__()
13
+ self.net = nn.Sequential(
14
+ nn.ConvTranspose2d(latent_dim, feature_map_size * 8, 4, 1, 0, bias=False),
15
+ nn.BatchNorm2d(feature_map_size * 8),
16
+ nn.ReLU(True),
17
+ nn.ConvTranspose2d(feature_map_size * 8, feature_map_size * 4, 4, 2, 1, bias=False),
18
+ nn.BatchNorm2d(feature_map_size * 4),
19
+ nn.ReLU(True),
20
+ nn.ConvTranspose2d(feature_map_size * 4, feature_map_size * 2, 4, 2, 1, bias=False),
21
+ nn.BatchNorm2d(feature_map_size * 2),
22
+ nn.ReLU(True),
23
+ nn.ConvTranspose2d(feature_map_size * 2, feature_map_size, 4, 2, 1, bias=False),
24
+ nn.BatchNorm2d(feature_map_size),
25
+ nn.ReLU(True),
26
+ nn.ConvTranspose2d(feature_map_size, img_channels, 4, 2, 1, bias=False),
27
+ nn.Tanh()
28
+ )
29
+
30
+ def forward(self, x):
31
+ return self.net(x)
32
+
33
+ def generate_artwork(generator, latent_dim=latent_dim, device=device, num_images=1):
34
+ generator.eval()
35
+ with torch.no_grad():
36
+ noise = torch.randn(num_images, latent_dim, 1, 1, device=device)
37
+ fake_images = generator(noise)
38
+ fake_images = fake_images * 0.5 + 0.5
39
+ return fake_images.detach().cpu()
40
+
41
+ def inference_interface(latent_dim=latent_dim, device=device):
42
+ # Create model and load weights
43
+ generator = Generator(latent_dim=latent_dim)
44
+ generator = nn.DataParallel(generator)
45
+ generator.load_state_dict(torch.load("generator_final.pth", map_location=device))
46
+
47
+ if isinstance(generator, nn.DataParallel):
48
+ generator = generator.module
49
+ generator.to(device)
50
+
51
+ def generate(num_images):
52
+ fake_images = generate_artwork(generator, latent_dim=latent_dim, device=device, num_images=num_images)
53
+ images = [transforms.ToPILImage()(img) for img in fake_images]
54
+ upscaled_images = [img.resize((256, 256), resample=Image.LANCZOS) for img in images]
55
+ return upscaled_images
56
+
57
+ demo = gr.Interface(
58
+ fn=generate,
59
+ inputs=gr.Slider(minimum=1, maximum=9, step=1, default=1, label="Number of Images"),
60
+ outputs=gr.Gallery(label="Generated Artwork").style(grid=[3], height="auto"),
61
+ title="Art Generation with GAN",
62
+ description="Generate artwork using a trained GAN model."
63
+ )
64
+ return demo
65
+
66
+ # The key part: launch the Gradio interface when app.py is run
67
+ if __name__ == "__main__":
68
+ demo = inference_interface()
69
+ demo.launch()
generator_final.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8dc949f7130cc293d5c6a2b37b2a838aa82cc5848bbea014596b768829271b63
3
+ size 4413491
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ gradio
4
+ Pillow