Narenameme commited on
Commit
7e1884b
·
verified ·
1 Parent(s): dc315ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -78
app.py CHANGED
@@ -1,80 +1,35 @@
1
  import gradio as gr
2
- from transformers import XLNetTokenizer, XLNetModel
3
  import torch
4
  import torch.nn as nn
 
5
  import numpy as np
6
 
7
- # TextEncoder class
8
  class TextEncoder(nn.Module):
9
  def __init__(self):
10
  super().__init__()
11
  self.transformer = XLNetModel.from_pretrained("xlnet-base-cased")
12
-
13
  def forward(self, input_ids, token_type_ids, attention_mask):
14
  hidden = self.transformer(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask).last_hidden_state
15
  context = hidden.mean(dim=1)
16
  context = context.view(*context.shape, 1, 1)
17
  return context
18
 
19
- # Generator class
20
  class Generator(nn.Module):
21
  def __init__(self, nz=100, ngf=64, nt=768, nc=3):
22
  super().__init__()
23
  self.layer1 = nn.Sequential(
24
- nn.ConvTranspose2d(nz+nt, ngf*8, 4, 1, 0, bias=False),
25
- nn.BatchNorm2d(ngf*8)
26
  )
27
  self.layer2 = nn.Sequential(
28
- nn.Conv2d(ngf*8, ngf*2, 1, 1),
29
- nn.Dropout2d(inplace=True),
30
- nn.BatchNorm2d(ngf*2),
31
- nn.ReLU(True)
32
- )
33
- self.layer3 = nn.Sequential(
34
- nn.Conv2d(ngf*2, ngf*2, 3, 1, 1),
35
  nn.Dropout2d(inplace=True),
36
- nn.BatchNorm2d(ngf*2),
37
- nn.ReLU(True)
38
- )
39
- self.layer4 = nn.Sequential(
40
- nn.Conv2d(ngf*2, ngf*8, 3, 1, 1),
41
- nn.Dropout2d(inplace=True),
42
- nn.BatchNorm2d(ngf*8),
43
- nn.ReLU(True)
44
- )
45
- self.layer5 = nn.Sequential(
46
- nn.ConvTranspose2d(ngf*8, ngf*4, 4, 2, 1, bias=False),
47
- nn.BatchNorm2d(ngf*4),
48
- nn.ReLU(True)
49
- )
50
- self.layer6 = nn.Sequential(
51
- nn.Conv2d(ngf*4, ngf, 1, 1),
52
- nn.Dropout2d(inplace=True),
53
- nn.BatchNorm2d(ngf),
54
- nn.ReLU(True)
55
- )
56
- self.layer7 = nn.Sequential(
57
- nn.Conv2d(ngf, ngf, 3, 1, 1),
58
- nn.Dropout2d(inplace=True),
59
- nn.BatchNorm2d(ngf),
60
- nn.ReLU(True)
61
- )
62
- self.layer8 = nn.Sequential(
63
- nn.Conv2d(ngf, ngf*4, 3, 1, 1),
64
- nn.Dropout2d(inplace=True),
65
- nn.BatchNorm2d(ngf*4),
66
- nn.ReLU(True)
67
- )
68
- self.layer9 = nn.Sequential(
69
- nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1, bias=False),
70
- nn.BatchNorm2d(ngf*2),
71
- nn.ReLU(True)
72
- )
73
- self.layer10 = nn.Sequential(
74
- nn.ConvTranspose2d(ngf*2, ngf, 4, 2, 1, bias=False),
75
- nn.BatchNorm2d(ngf),
76
  nn.ReLU(True)
77
  )
 
 
78
  self.layer11 = nn.Sequential(
79
  nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
80
  nn.Tanh()
@@ -84,23 +39,16 @@ class Generator(nn.Module):
84
  x = torch.cat([noise, encoded_text], dim=1)
85
  x = self.layer1(x)
86
  x = self.layer2(x)
87
- x = self.layer3(x)
88
- x = self.layer4(x)
89
- x = self.layer5(x)
90
- x = self.layer6(x)
91
- x = self.layer7(x)
92
- x = self.layer8(x)
93
- x = self.layer9(x)
94
- x = self.layer10(x)
95
  x = self.layer11(x)
96
  return x
97
 
98
-
99
  # Load the model and tokenizer
100
- model_path = "./checkpoint.pth"
101
  tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')
102
- text_encoder = XLNetModel.from_pretrained('xlnet-base-cased')
103
  model = Generator()
 
104
  model_state_dict = torch.load(model_path, map_location="cpu")
105
  generator = model_state_dict['models']['generator']
106
  model.load_state_dict(generator)
@@ -109,26 +57,37 @@ text_encoder.to("cpu")
109
  model.to("cpu")
110
  model.eval()
111
 
112
- # Functions to encode text and generate image
113
- def encode_text(text):
114
- text_encoder_model = TextEncoder()
115
- inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
116
- encoded_text = text_encoder_model(**inputs)
117
- return encoded_text
118
-
119
- def generate_image(text):
120
- encoded_text = encode_text(text)
121
  noise = torch.randn((1, 100, 1, 1), device="cpu")
122
  with torch.no_grad():
123
- generated_image = model(noise, encoded_text).detach().squeeze().cpu()
 
124
  gen_image_np = generated_image.numpy()
125
  gen_image_np = np.transpose(gen_image_np, (1, 2, 0)) # Change from CHW to HWC
126
  gen_image_np = (gen_image_np - gen_image_np.min()) / (gen_image_np.max() - gen_image_np.min()) # Normalize to [0, 1]
127
  gen_image_np = (gen_image_np * 255).astype(np.uint8)
128
  return gen_image_np
129
 
130
- # Gradio interface
131
- inputs = gr.inputs.Textbox(label="Enter a flower-related description", default="A beautiful red rose")
132
- outputs = gr.outputs.Image(type="numpy", label="Generated Flower Image")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
- gr.Interface(fn=generate_image, inputs=inputs, outputs=outputs, title="Flower Image Generator", description="Enter a description of a flower to generate an image.").launch()
 
 
1
  import gradio as gr
 
2
  import torch
3
  import torch.nn as nn
4
+ from transformers import XLNetTokenizer, XLNetModel
5
  import numpy as np
6
 
 
7
  class TextEncoder(nn.Module):
8
  def __init__(self):
9
  super().__init__()
10
  self.transformer = XLNetModel.from_pretrained("xlnet-base-cased")
11
+
12
  def forward(self, input_ids, token_type_ids, attention_mask):
13
  hidden = self.transformer(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask).last_hidden_state
14
  context = hidden.mean(dim=1)
15
  context = context.view(*context.shape, 1, 1)
16
  return context
17
 
 
18
  class Generator(nn.Module):
19
  def __init__(self, nz=100, ngf=64, nt=768, nc=3):
20
  super().__init__()
21
  self.layer1 = nn.Sequential(
22
+ nn.ConvTranspose2d(nz + nt, ngf * 8, 4, 1, 0, bias=False),
23
+ nn.BatchNorm2d(ngf * 8)
24
  )
25
  self.layer2 = nn.Sequential(
26
+ nn.Conv2d(ngf * 8, ngf * 2, 1, 1),
 
 
 
 
 
 
27
  nn.Dropout2d(inplace=True),
28
+ nn.BatchNorm2d(ngf * 2),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  nn.ReLU(True)
30
  )
31
+ # Add other layers as needed...
32
+
33
  self.layer11 = nn.Sequential(
34
  nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
35
  nn.Tanh()
 
39
  x = torch.cat([noise, encoded_text], dim=1)
40
  x = self.layer1(x)
41
  x = self.layer2(x)
42
+ # Pass through other layers...
 
 
 
 
 
 
 
43
  x = self.layer11(x)
44
  return x
45
 
 
46
  # Load the model and tokenizer
47
+ model_path = "checkpoint.pth" # Adjust as necessary
48
  tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')
49
+ text_encoder = TextEncoder()
50
  model = Generator()
51
+
52
  model_state_dict = torch.load(model_path, map_location="cpu")
53
  generator = model_state_dict['models']['generator']
54
  model.load_state_dict(generator)
 
57
  model.to("cpu")
58
  model.eval()
59
 
60
+ def generate_image(enc_text):
 
 
 
 
 
 
 
 
61
  noise = torch.randn((1, 100, 1, 1), device="cpu")
62
  with torch.no_grad():
63
+ generated_image = model(noise, enc_text).detach().squeeze().cpu()
64
+
65
  gen_image_np = generated_image.numpy()
66
  gen_image_np = np.transpose(gen_image_np, (1, 2, 0)) # Change from CHW to HWC
67
  gen_image_np = (gen_image_np - gen_image_np.min()) / (gen_image_np.max() - gen_image_np.min()) # Normalize to [0, 1]
68
  gen_image_np = (gen_image_np * 255).astype(np.uint8)
69
  return gen_image_np
70
 
71
+ def encode_text(text):
72
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
73
+ encoded_text = text_encoder(**inputs)
74
+ return encoded_text
75
+
76
+ def on_generate_button_click(text_input):
77
+ if text_input:
78
+ encoded_text = encode_text(text_input)
79
+ generated_image = generate_image(encoded_text)
80
+ return generated_image
81
+ return None
82
+
83
+ # Create the Gradio interface
84
+ with gr.Blocks() as demo:
85
+ gr.Markdown("## Flower Image Generator")
86
+ text_input = gr.Textbox(label="Enter a flower-related description", value="A beautiful red rose")
87
+ generate_button = gr.Button("Generate Image")
88
+ output_image = gr.Image(type="numpy") # Ensure output type is correct
89
+
90
+ generate_button.click(on_generate_button_click, inputs=text_input, outputs=output_image)
91
 
92
+ # Launch the Gradio app
93
+ demo.launch()