Spaces:
Runtime error
Runtime error
create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision.transforms as transforms
|
3 |
+
from PIL import Image
|
4 |
+
import gradio as gr
|
5 |
+
from tqdm import tqdm
|
6 |
+
def optimize_latent_vector(G, target_image, num_iterations=1000):
|
7 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
8 |
+
target_image = transforms.Resize((G.img_resolution, G.img_resolution))(target_image)
|
9 |
+
target_tensor = transforms.ToTensor()(target_image).unsqueeze(0).to(device)
|
10 |
+
target_tensor = (target_tensor * 2) - 1 # Normalize to [-1, 1]
|
11 |
+
|
12 |
+
latent_vector = torch.randn((1, G.z_dim), device=device, requires_grad=True)
|
13 |
+
optimizer = torch.optim.Adam([latent_vector], lr=0.1)
|
14 |
+
|
15 |
+
for i in tqdm(range(num_iterations), desc="Optimizing latent vector"):
|
16 |
+
optimizer.zero_grad()
|
17 |
+
|
18 |
+
generated_image = G(latent_vector, None)
|
19 |
+
loss = torch.nn.functional.mse_loss(generated_image, target_tensor)
|
20 |
+
|
21 |
+
loss.backward()
|
22 |
+
optimizer.step()
|
23 |
+
|
24 |
+
if (i + 1) % 100 == 0:
|
25 |
+
print(f'Iteration {i+1}/{num_iterations}, Loss: {loss.item()}')
|
26 |
+
|
27 |
+
return latent_vector.detach()
|
28 |
+
|
29 |
+
def generate_from_upload(uploaded_image):
|
30 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
31 |
+
|
32 |
+
# Optimize latent vector for the uploaded image
|
33 |
+
optimized_z = optimize_latent_vector(G, uploaded_image)
|
34 |
+
|
35 |
+
# Generate variations
|
36 |
+
num_variations = 4
|
37 |
+
variation_strength = 0.1
|
38 |
+
varied_z = optimized_z + torch.randn((num_variations, G.z_dim), device=device) * variation_strength
|
39 |
+
|
40 |
+
# Generate the variations
|
41 |
+
with torch.no_grad():
|
42 |
+
imgs = G(varied_z, c=None, truncation_psi=0.7, noise_mode='const')
|
43 |
+
|
44 |
+
imgs = (imgs * 127.5 + 128).clamp(0, 255).to(torch.uint8)
|
45 |
+
imgs = imgs.permute(0, 2, 3, 1).cpu().numpy()
|
46 |
+
|
47 |
+
# Convert the generated image tensors to PIL Images
|
48 |
+
generated_images = [Image.fromarray(img) for img in imgs]
|
49 |
+
|
50 |
+
# Return the images separately
|
51 |
+
return generated_images[0], generated_images[1], generated_images[2], generated_images[3]
|
52 |
+
|
53 |
+
# Create the Gradio interface
|
54 |
+
iface = gr.Interface(
|
55 |
+
fn=generate_from_upload,
|
56 |
+
inputs=gr.Image(type="pil"),
|
57 |
+
outputs=[gr.Image(type="pil") for _ in range(4)],
|
58 |
+
title="StyleGAN Image Variation Generator"
|
59 |
+
)
|
60 |
+
|
61 |
+
# Launch the Gradio interface
|
62 |
+
iface.launch(share=True, debug=True)
|
63 |
+
|
64 |
+
# If you want to test it without the Gradio interface:
|
65 |
+
"""
|
66 |
+
# Load an image explicitly
|
67 |
+
image_path = "path/to/your/image.jpg"
|
68 |
+
image = Image.open(image_path)
|
69 |
+
|
70 |
+
# Call the generate method explicitly
|
71 |
+
generated_images = generate_from_upload(image)
|
72 |
+
|
73 |
+
# Display the generated images
|
74 |
+
for img in generated_images:
|
75 |
+
img.show()
|
76 |
+
"""
|