OVAWARE commited on
Commit
6bdcf63
·
verified ·
1 Parent(s): f9aac6f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -56
app.py CHANGED
@@ -1,13 +1,14 @@
 
1
  import torch
2
  import torch.nn as nn
3
  from torchvision import transforms
4
  from PIL import Image
5
  from transformers import BertTokenizer, BertModel
6
- import gradio as gr
7
  import numpy as np
8
  import os
9
  import time
10
 
 
11
  from train import CVAE, TextEncoder, LATENT_DIM, HIDDEN_DIM
12
 
13
  # Initialize the BERT tokenizer
@@ -17,90 +18,81 @@ def clean_image(image, threshold=0.75):
17
  np_image = np.array(image)
18
  alpha_channel = np_image[:, :, 3]
19
  alpha_channel[alpha_channel <= int(threshold * 255)] = 0
20
- alpha_channel[alpha_channel > int(threshold * 255)] = 255 # Set to 100% visibility
21
  return Image.fromarray(np_image)
22
 
23
  def generate_image(model, text_prompt, device, input_image=None, img_control=0.5):
24
  encoded_input = tokenizer(text_prompt, padding=True, truncation=True, return_tensors="pt")
25
  input_ids = encoded_input['input_ids'].to(device)
26
  attention_mask = encoded_input['attention_mask'].to(device)
27
-
28
  with torch.no_grad():
29
  text_encoding = model.text_encoder(input_ids, attention_mask)
30
-
31
- z = torch.randn(1, LATENT_DIM).to(device)
32
-
33
- with torch.no_grad():
34
  generated_image = model.decode(z, text_encoding)
35
-
36
  if input_image is not None:
37
  input_image = input_image.convert("RGBA").resize((16, 16), resample=Image.NEAREST)
38
  input_image = transforms.ToTensor()(input_image).unsqueeze(0).to(device)
39
  generated_image = img_control * input_image + (1 - img_control) * generated_image
40
-
41
  generated_image = generated_image.squeeze(0).cpu()
42
- generated_image = (generated_image + 1) / 2 # Rescale from [-1, 1] to [0, 1]
43
  generated_image = generated_image.clamp(0, 1)
44
  generated_image = transforms.ToPILImage()(generated_image)
45
-
46
  return generated_image
47
 
48
- def process(prompt, model_path, clean, size, input_image, img_control, output_dir):
49
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50
-
51
- # Load input image if provided
52
- input_image = Image.open(input_image).convert("RGBA") if input_image else None
53
-
54
- # Initialize model
55
  text_encoder = TextEncoder(hidden_size=HIDDEN_DIM, output_size=HIDDEN_DIM)
56
  model = CVAE(text_encoder).to(device)
57
-
58
- # Load the trained model
59
  model.load_state_dict(torch.load(model_path, map_location=device))
60
  model.eval()
 
 
 
 
 
61
 
62
  start_time = time.time()
63
-
64
- # Generate image from prompt
65
  generated_image = generate_image(model, prompt, device, input_image, img_control)
66
-
67
  end_time = time.time()
68
  generation_time = end_time - start_time
69
-
70
- # Clean up the image if the flag is set
71
- if clean:
72
  generated_image = clean_image(generated_image)
73
-
74
- # Resize the generated image
75
  generated_image = generated_image.resize((size, size), resample=Image.NEAREST)
76
-
77
- # Save the generated image to the specified directory
78
- model_name = os.path.splitext(os.path.basename(model_path))[0]
79
- output_file = os.path.join(output_dir, f"{model_name}_{prompt}.png")
80
- os.makedirs(output_dir, exist_ok=True)
81
- generated_image.save(output_file)
82
-
83
- print(f"Generated image saved as {output_file}")
84
- print(f"Generation time: {generation_time:.10f} seconds")
85
-
86
- return generated_image
87
 
88
- # Gradio Interface
89
- interface = gr.Interface(
90
- fn=process,
91
- inputs=[
92
- gr.Textbox(label="Text Prompt"),
93
- gr.File(label="Model Path (.pth file)", file_types=['.pth']),
94
- gr.Checkbox(label="Clean Image (Remove Low Opacity Pixels)", default=False),
95
- gr.Slider(label="Image Size", minimum=16, maximum=512, step=16, default=16),
96
- gr.File(label="Input Image (Optional)", file_types=["image"]),
97
- gr.Slider(label="Image Control (0-1)", minimum=0.0, maximum=1.0, step=0.01, default=0.5),
98
- gr.Textbox(label="Output Directory", value="generated_images")
99
- ],
100
- outputs=gr.Image(label="Generated Image"),
101
- title="Text-to-Image Generator",
102
- description="Generate an image from a text prompt using a trained CVAE model."
103
- )
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  if __name__ == "__main__":
106
- interface.launch()
 
 
1
+ import gradio as gr
2
  import torch
3
  import torch.nn as nn
4
  from torchvision import transforms
5
  from PIL import Image
6
  from transformers import BertTokenizer, BertModel
 
7
  import numpy as np
8
  import os
9
  import time
10
 
11
+ # Import the model architecture from train.py
12
  from train import CVAE, TextEncoder, LATENT_DIM, HIDDEN_DIM
13
 
14
  # Initialize the BERT tokenizer
 
18
  np_image = np.array(image)
19
  alpha_channel = np_image[:, :, 3]
20
  alpha_channel[alpha_channel <= int(threshold * 255)] = 0
21
+ alpha_channel[alpha_channel > int(threshold * 255)] = 255
22
  return Image.fromarray(np_image)
23
 
24
  def generate_image(model, text_prompt, device, input_image=None, img_control=0.5):
25
  encoded_input = tokenizer(text_prompt, padding=True, truncation=True, return_tensors="pt")
26
  input_ids = encoded_input['input_ids'].to(device)
27
  attention_mask = encoded_input['attention_mask'].to(device)
28
+
29
  with torch.no_grad():
30
  text_encoding = model.text_encoder(input_ids, attention_mask)
31
+ z = torch.randn(1, LATENT_DIM).to(device)
 
 
 
32
  generated_image = model.decode(z, text_encoding)
33
+
34
  if input_image is not None:
35
  input_image = input_image.convert("RGBA").resize((16, 16), resample=Image.NEAREST)
36
  input_image = transforms.ToTensor()(input_image).unsqueeze(0).to(device)
37
  generated_image = img_control * input_image + (1 - img_control) * generated_image
38
+
39
  generated_image = generated_image.squeeze(0).cpu()
40
+ generated_image = (generated_image + 1) / 2
41
  generated_image = generated_image.clamp(0, 1)
42
  generated_image = transforms.ToPILImage()(generated_image)
43
+
44
  return generated_image
45
 
46
+ def load_model(model_path, device):
 
 
 
 
 
 
47
  text_encoder = TextEncoder(hidden_size=HIDDEN_DIM, output_size=HIDDEN_DIM)
48
  model = CVAE(text_encoder).to(device)
 
 
49
  model.load_state_dict(torch.load(model_path, map_location=device))
50
  model.eval()
51
+ return model
52
+
53
+ def generate_image_gradio(prompt, model_path, clean_image_flag, size, input_image=None, img_control=0.5):
54
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
55
+ model = load_model(model_path, device)
56
 
57
  start_time = time.time()
 
 
58
  generated_image = generate_image(model, prompt, device, input_image, img_control)
 
59
  end_time = time.time()
60
  generation_time = end_time - start_time
61
+
62
+ if clean_image_flag:
 
63
  generated_image = clean_image(generated_image)
64
+
 
65
  generated_image = generated_image.resize((size, size), resample=Image.NEAREST)
 
 
 
 
 
 
 
 
 
 
 
66
 
67
+ return generated_image, f"Generation time: {generation_time:.4f} seconds"
68
+
69
+ # Gradio interface
70
+ def gradio_interface():
71
+ with gr.Blocks() as demo:
72
+ gr.Markdown("# Image Generator from Text Prompt")
73
+
74
+ with gr.Row():
75
+ with gr.Column():
76
+ prompt = gr.Textbox(label="Text Prompt")
77
+ model_path = gr.Textbox(label="Model Path", value="path/to/your/model.pth")
78
+ clean_image_flag = gr.Checkbox(label="Clean Image", value=False)
79
+ size = gr.Slider(minimum=16, maximum=512, step=16, label="Image Size", value=16)
80
+ img_control = gr.Slider(minimum=0, maximum=1, step=0.1, label="Image Control", value=0.5)
81
+ input_image = gr.Image(label="Input Image (optional)", type="pil")
82
+ generate_button = gr.Button("Generate Image")
83
+
84
+ with gr.Column():
85
+ output_image = gr.Image(label="Generated Image")
86
+ generation_time = gr.Textbox(label="Generation Time")
87
+
88
+ generate_button.click(
89
+ generate_image_gradio,
90
+ inputs=[prompt, model_path, clean_image_flag, size, input_image, img_control],
91
+ outputs=[output_image, generation_time]
92
+ )
93
+
94
+ return demo
95
 
96
  if __name__ == "__main__":
97
+ demo = gradio_interface()
98
+ demo.launch()