edemana commited on
Commit
dc4ac4a
·
verified ·
1 Parent(s): 73e4fea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -20
app.py CHANGED
@@ -1,8 +1,26 @@
1
  import torch
2
- import torchvision.transforms as transforms
3
  from PIL import Image
4
  import gradio as gr
 
 
5
  from tqdm import tqdm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  def optimize_latent_vector(G, target_image, num_iterations=1000):
7
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
8
  target_image = transforms.Resize((G.img_resolution, G.img_resolution))(target_image)
@@ -29,25 +47,20 @@ def optimize_latent_vector(G, target_image, num_iterations=1000):
29
  def generate_from_upload(uploaded_image):
30
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
31
 
32
- # Optimize latent vector for the uploaded image
33
  optimized_z = optimize_latent_vector(G, uploaded_image)
34
 
35
- # Generate variations
36
  num_variations = 4
37
  variation_strength = 0.1
38
  varied_z = optimized_z + torch.randn((num_variations, G.z_dim), device=device) * variation_strength
39
 
40
- # Generate the variations
41
  with torch.no_grad():
42
  imgs = G(varied_z, c=None, truncation_psi=0.7, noise_mode='const')
43
 
44
  imgs = (imgs * 127.5 + 128).clamp(0, 255).to(torch.uint8)
45
  imgs = imgs.permute(0, 2, 3, 1).cpu().numpy()
46
 
47
- # Convert the generated image tensors to PIL Images
48
  generated_images = [Image.fromarray(img) for img in imgs]
49
 
50
- # Return the images separately
51
  return generated_images[0], generated_images[1], generated_images[2], generated_images[3]
52
 
53
  # Create the Gradio interface
@@ -60,17 +73,3 @@ iface = gr.Interface(
60
 
61
  # Launch the Gradio interface
62
  iface.launch(share=True, debug=True)
63
-
64
- # If you want to test it without the Gradio interface:
65
- """
66
- # Load an image explicitly
67
- image_path = "path/to/your/image.jpg"
68
- image = Image.open(image_path)
69
-
70
- # Call the generate method explicitly
71
- generated_images = generate_from_upload(image)
72
-
73
- # Display the generated images
74
- for img in generated_images:
75
- img.show()
76
- """
 
1
  import torch
 
2
  from PIL import Image
3
  import gradio as gr
4
+ import pickle
5
+ from torchvision import transforms
6
  from tqdm import tqdm
7
+
8
+ # Load the fine-tuned model
9
+ def load_model(model_path='fine_tuned_stylegan.pth'):
10
+ with open('ffhq.pkl', 'rb') as f:
11
+ data = pickle.load(f)
12
+ G = data['G_ema']
13
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14
+ G = G.to(device)
15
+
16
+ # Load the fine-tuned weights
17
+ G.load_state_dict(torch.load(model_path))
18
+ G.eval()
19
+
20
+ return G
21
+
22
+ G = load_model('fine_tuned_stylegan.pth')
23
+
24
  def optimize_latent_vector(G, target_image, num_iterations=1000):
25
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
26
  target_image = transforms.Resize((G.img_resolution, G.img_resolution))(target_image)
 
47
  def generate_from_upload(uploaded_image):
48
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
49
 
 
50
  optimized_z = optimize_latent_vector(G, uploaded_image)
51
 
 
52
  num_variations = 4
53
  variation_strength = 0.1
54
  varied_z = optimized_z + torch.randn((num_variations, G.z_dim), device=device) * variation_strength
55
 
 
56
  with torch.no_grad():
57
  imgs = G(varied_z, c=None, truncation_psi=0.7, noise_mode='const')
58
 
59
  imgs = (imgs * 127.5 + 128).clamp(0, 255).to(torch.uint8)
60
  imgs = imgs.permute(0, 2, 3, 1).cpu().numpy()
61
 
 
62
  generated_images = [Image.fromarray(img) for img in imgs]
63
 
 
64
  return generated_images[0], generated_images[1], generated_images[2], generated_images[3]
65
 
66
  # Create the Gradio interface
 
73
 
74
  # Launch the Gradio interface
75
  iface.launch(share=True, debug=True)