OVAWARE commited on
Commit
cf238b7
·
verified ·
1 Parent(s): d47023e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -89
app.py CHANGED
@@ -3,22 +3,17 @@ import torch.nn as nn
3
  from torchvision import transforms
4
  from PIL import Image
5
  from transformers import BertTokenizer, BertModel
6
- import argparse
7
  import numpy as np
8
  import os
9
- import time # Import the time module
10
 
11
- # Import the model architecture from train.py
12
  from train import CVAE, TextEncoder, LATENT_DIM, HIDDEN_DIM
13
 
14
  # Initialize the BERT tokenizer
15
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
16
 
17
  def clean_image(image, threshold=0.75):
18
- """
19
- Clean up the image by setting pixels with opacity <= threshold to 0% opacity
20
- and pixels above the threshold to 100% visibility.
21
- """
22
  np_image = np.array(image)
23
  alpha_channel = np_image[:, :, 3]
24
  alpha_channel[alpha_channel <= int(threshold * 255)] = 0
@@ -26,19 +21,15 @@ def clean_image(image, threshold=0.75):
26
  return Image.fromarray(np_image)
27
 
28
  def generate_image(model, text_prompt, device, input_image=None, img_control=0.5):
29
- # Encode text prompt using BERT tokenizer
30
  encoded_input = tokenizer(text_prompt, padding=True, truncation=True, return_tensors="pt")
31
  input_ids = encoded_input['input_ids'].to(device)
32
  attention_mask = encoded_input['attention_mask'].to(device)
33
 
34
- # Generate text encoding
35
  with torch.no_grad():
36
  text_encoding = model.text_encoder(input_ids, attention_mask)
37
 
38
- # Sample from the latent space
39
  z = torch.randn(1, LATENT_DIM).to(device)
40
 
41
- # Generate image
42
  with torch.no_grad():
43
  generated_image = model.decode(z, text_encoding)
44
 
@@ -47,7 +38,6 @@ def generate_image(model, text_prompt, device, input_image=None, img_control=0.5
47
  input_image = transforms.ToTensor()(input_image).unsqueeze(0).to(device)
48
  generated_image = img_control * input_image + (1 - img_control) * generated_image
49
 
50
- # Convert the generated tensor to a PIL Image
51
  generated_image = generated_image.squeeze(0).cpu()
52
  generated_image = (generated_image + 1) / 2 # Rescale from [-1, 1] to [0, 1]
53
  generated_image = generated_image.clamp(0, 1)
@@ -55,86 +45,62 @@ def generate_image(model, text_prompt, device, input_image=None, img_control=0.5
55
 
56
  return generated_image
57
 
58
- def main():
59
- parser = argparse.ArgumentParser(description="Generate an image from a text prompt using the trained CVAE model(s).")
60
- parser.add_argument("--prompt", type=str, help="Text prompt for image generation")
61
- parser.add_argument("--prompt_file", type=str, help="File containing prompts, one per line")
62
- parser.add_argument("--output", type=str, default="generated_images", help="Output directory or file for generated images")
63
- parser.add_argument("--model_paths", type=str, nargs='*', help="Paths to the trained model(s)")
64
- parser.add_argument("--model_path", type=str, help="Path to a single trained model")
65
- parser.add_argument("--clean", action="store_true", help="Clean up the image by removing low opacity pixels")
66
- parser.add_argument("--size", type=int, default=16, help="Size of the generated image")
67
- parser.add_argument("--input_image", type=str, help="Path to the input image for img2img generation")
68
- parser.add_argument("--img_control", type=float, default=0.5, help="Control how much the input image influences the output (0 to 1)")
69
- args = parser.parse_args()
70
-
71
- if not args.prompt and not args.prompt_file:
72
- parser.error("Either --prompt or --prompt_file must be provided")
73
-
74
- if args.model_paths and args.model_path:
75
- parser.error("Specify either --model_paths or --model_path, not both")
76
-
77
- model_paths = args.model_paths if args.model_paths else [args.model_path]
78
-
79
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
80
-
81
- # Check if --output is a file or directory
82
- is_folder_output = os.path.isdir(args.output)
83
-
84
- if is_folder_output:
85
- # Ensure output directory exists if it's not a file
86
- os.makedirs(args.output, exist_ok=True)
87
-
88
  # Load input image if provided
89
- input_image = None
90
- if args.input_image:
91
- input_image = Image.open(args.input_image).convert("RGBA")
92
-
93
- # Process single prompt or batch of prompts
94
- if args.prompt:
95
- prompts = [args.prompt]
96
- else:
97
- with open(args.prompt_file, 'r') as f:
98
- prompts = [line.strip() for line in f if line.strip()]
99
-
100
- for model_path in model_paths:
101
- # Initialize model
102
- text_encoder = TextEncoder(hidden_size=HIDDEN_DIM, output_size=HIDDEN_DIM)
103
- model = CVAE(text_encoder).to(device)
104
-
105
- # Load the trained model
106
- model.load_state_dict(torch.load(model_path, map_location=device))
107
- model.eval()
108
-
109
- model_name = os.path.splitext(os.path.basename(model_path))[0]
110
 
111
- for i, prompt in enumerate(prompts):
112
- start_time = time.time() # Start timing the generation
113
-
114
- # Generate image from prompt
115
- generated_image = generate_image(model, prompt, device, input_image, args.img_control)
116
-
117
- # End timing the generation
118
- end_time = time.time()
119
- generation_time = end_time - start_time # Calculate the generation time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
- # Clean up the image if the flag is set
122
- if args.clean:
123
- generated_image = clean_image(generated_image)
124
-
125
- # Resize the generated image
126
- generated_image = generated_image.resize((args.size, args.size), resample=Image.NEAREST)
127
-
128
- if not is_folder_output:
129
- # Save the generated image to the specified file
130
- output_file = args.output
131
- else:
132
- # Save the generated image to the output directory
133
- output_file = os.path.join(args.output, f"{model_name}_{prompt}_{i:03d}.png")
134
-
135
- generated_image.save(output_file)
136
- print(f"Generated image for prompt '{prompt}' using model '{model_name}' saved as {output_file}")
137
- print(f"Generation time: {generation_time:.10f} seconds") # Print the generation time
138
 
139
  if __name__ == "__main__":
140
- main()
 
3
  from torchvision import transforms
4
  from PIL import Image
5
  from transformers import BertTokenizer, BertModel
6
+ import gradio as gr
7
  import numpy as np
8
  import os
9
+ import time
10
 
 
11
  from train import CVAE, TextEncoder, LATENT_DIM, HIDDEN_DIM
12
 
13
  # Initialize the BERT tokenizer
14
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
15
 
16
  def clean_image(image, threshold=0.75):
 
 
 
 
17
  np_image = np.array(image)
18
  alpha_channel = np_image[:, :, 3]
19
  alpha_channel[alpha_channel <= int(threshold * 255)] = 0
 
21
  return Image.fromarray(np_image)
22
 
23
  def generate_image(model, text_prompt, device, input_image=None, img_control=0.5):
 
24
  encoded_input = tokenizer(text_prompt, padding=True, truncation=True, return_tensors="pt")
25
  input_ids = encoded_input['input_ids'].to(device)
26
  attention_mask = encoded_input['attention_mask'].to(device)
27
 
 
28
  with torch.no_grad():
29
  text_encoding = model.text_encoder(input_ids, attention_mask)
30
 
 
31
  z = torch.randn(1, LATENT_DIM).to(device)
32
 
 
33
  with torch.no_grad():
34
  generated_image = model.decode(z, text_encoding)
35
 
 
38
  input_image = transforms.ToTensor()(input_image).unsqueeze(0).to(device)
39
  generated_image = img_control * input_image + (1 - img_control) * generated_image
40
 
 
41
  generated_image = generated_image.squeeze(0).cpu()
42
  generated_image = (generated_image + 1) / 2 # Rescale from [-1, 1] to [0, 1]
43
  generated_image = generated_image.clamp(0, 1)
 
45
 
46
  return generated_image
47
 
48
+ def process(prompt, model_path, clean, size, input_image, img_control, output_dir):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50
+
 
 
 
 
 
 
 
51
  # Load input image if provided
52
+ input_image = Image.open(input_image).convert("RGBA") if input_image else None
53
+
54
+ # Initialize model
55
+ text_encoder = TextEncoder(hidden_size=HIDDEN_DIM, output_size=HIDDEN_DIM)
56
+ model = CVAE(text_encoder).to(device)
57
+
58
+ # Load the trained model
59
+ model.load_state_dict(torch.load(model_path, map_location=device))
60
+ model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
+ start_time = time.time()
63
+
64
+ # Generate image from prompt
65
+ generated_image = generate_image(model, prompt, device, input_image, img_control)
66
+
67
+ end_time = time.time()
68
+ generation_time = end_time - start_time
69
+
70
+ # Clean up the image if the flag is set
71
+ if clean:
72
+ generated_image = clean_image(generated_image)
73
+
74
+ # Resize the generated image
75
+ generated_image = generated_image.resize((size, size), resample=Image.NEAREST)
76
+
77
+ # Save the generated image to the specified directory
78
+ model_name = os.path.splitext(os.path.basename(model_path))[0]
79
+ output_file = os.path.join(output_dir, f"{model_name}_{prompt}.png")
80
+ os.makedirs(output_dir, exist_ok=True)
81
+ generated_image.save(output_file)
82
+
83
+ print(f"Generated image saved as {output_file}")
84
+ print(f"Generation time: {generation_time:.10f} seconds")
85
+
86
+ return generated_image
87
 
88
+ # Gradio Interface
89
+ interface = gr.Interface(
90
+ fn=process,
91
+ inputs=[
92
+ gr.Textbox(label="Text Prompt"),
93
+ gr.File(label="Model Path (.pth file)", file_types=['.pth']),
94
+ gr.Checkbox(label="Clean Image (Remove Low Opacity Pixels)", default=False),
95
+ gr.Slider(label="Image Size", minimum=16, maximum=512, step=16, default=16),
96
+ gr.File(label="Input Image (Optional)", file_types=["image"]),
97
+ gr.Slider(label="Image Control (0-1)", minimum=0.0, maximum=1.0, step=0.01, default=0.5),
98
+ gr.Textbox(label="Output Directory", value="generated_images")
99
+ ],
100
+ outputs=gr.Image(label="Generated Image"),
101
+ title="Text-to-Image Generator",
102
+ description="Generate an image from a text prompt using a trained CVAE model."
103
+ )
 
104
 
105
  if __name__ == "__main__":
106
+ interface.launch()