edemana commited on
Commit
8ff7c61
·
verified ·
1 Parent(s): 91acb3d

create app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -0
app.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as transforms
3
+ from PIL import Image
4
+ import gradio as gr
5
+ from tqdm import tqdm
6
+ def optimize_latent_vector(G, target_image, num_iterations=1000):
7
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
8
+ target_image = transforms.Resize((G.img_resolution, G.img_resolution))(target_image)
9
+ target_tensor = transforms.ToTensor()(target_image).unsqueeze(0).to(device)
10
+ target_tensor = (target_tensor * 2) - 1 # Normalize to [-1, 1]
11
+
12
+ latent_vector = torch.randn((1, G.z_dim), device=device, requires_grad=True)
13
+ optimizer = torch.optim.Adam([latent_vector], lr=0.1)
14
+
15
+ for i in tqdm(range(num_iterations), desc="Optimizing latent vector"):
16
+ optimizer.zero_grad()
17
+
18
+ generated_image = G(latent_vector, None)
19
+ loss = torch.nn.functional.mse_loss(generated_image, target_tensor)
20
+
21
+ loss.backward()
22
+ optimizer.step()
23
+
24
+ if (i + 1) % 100 == 0:
25
+ print(f'Iteration {i+1}/{num_iterations}, Loss: {loss.item()}')
26
+
27
+ return latent_vector.detach()
28
+
29
+ def generate_from_upload(uploaded_image):
30
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
31
+
32
+ # Optimize latent vector for the uploaded image
33
+ optimized_z = optimize_latent_vector(G, uploaded_image)
34
+
35
+ # Generate variations
36
+ num_variations = 4
37
+ variation_strength = 0.1
38
+ varied_z = optimized_z + torch.randn((num_variations, G.z_dim), device=device) * variation_strength
39
+
40
+ # Generate the variations
41
+ with torch.no_grad():
42
+ imgs = G(varied_z, c=None, truncation_psi=0.7, noise_mode='const')
43
+
44
+ imgs = (imgs * 127.5 + 128).clamp(0, 255).to(torch.uint8)
45
+ imgs = imgs.permute(0, 2, 3, 1).cpu().numpy()
46
+
47
+ # Convert the generated image tensors to PIL Images
48
+ generated_images = [Image.fromarray(img) for img in imgs]
49
+
50
+ # Return the images separately
51
+ return generated_images[0], generated_images[1], generated_images[2], generated_images[3]
52
+
53
+ # Create the Gradio interface
54
+ iface = gr.Interface(
55
+ fn=generate_from_upload,
56
+ inputs=gr.Image(type="pil"),
57
+ outputs=[gr.Image(type="pil") for _ in range(4)],
58
+ title="StyleGAN Image Variation Generator"
59
+ )
60
+
61
+ # Launch the Gradio interface
62
+ iface.launch(share=True, debug=True)
63
+
64
+ # If you want to test it without the Gradio interface:
65
+ """
66
+ # Load an image explicitly
67
+ image_path = "path/to/your/image.jpg"
68
+ image = Image.open(image_path)
69
+
70
+ # Call the generate method explicitly
71
+ generated_images = generate_from_upload(image)
72
+
73
+ # Display the generated images
74
+ for img in generated_images:
75
+ img.show()
76
+ """