seddiktrk commited on
Commit
308ef29
·
verified ·
1 Parent(s): b18ff3d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +269 -0
app.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ import torch
4
+ import time
5
+
6
+ from tqdm.auto import tqdm
7
+ import numpy as np
8
+ from torch import nn
9
+ print(torch.__version__)
10
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
+ print(device)
12
+
13
+ from transformers import GPT2Tokenizer,GPT2LMHeadModel,DataCollatorWithPadding
14
+
15
+ tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
16
+ tokenizer.pad_token_id = 0
17
+ collator = DataCollatorWithPadding(tokenizer = tokenizer)
18
+
19
+ class EncoderAttention(nn.Module):
20
+ def __init__(self,embed_dim=768, num_heads=8, dropout=0.1):
21
+ super().__init__()
22
+ self.mha = nn.MultiheadAttention(embed_dim, num_heads,batch_first=True, dropout=dropout)
23
+ self.layernorm = nn.LayerNorm(embed_dim)
24
+
25
+ def forward(self,x):
26
+
27
+ attn, _ = self.mha(query=x,
28
+ value=x,
29
+ key=x,
30
+ need_weights=False,
31
+ )
32
+ x = x + attn
33
+ return self.layernorm(x)
34
+
35
+
36
+ class FeedForward(nn.Module):
37
+ def __init__(self, embed_dim=768, dropout_rate=0.1):
38
+ super().__init__()
39
+ self.seq = nn.Sequential(
40
+ nn.Linear(embed_dim, embed_dim*2),
41
+ nn.ReLU(),
42
+ nn.Linear(embed_dim*2, embed_dim),
43
+ nn.Dropout(dropout_rate)
44
+ )
45
+
46
+ self.layernorm = nn.LayerNorm(embed_dim)
47
+
48
+ def forward(self, x):
49
+ x = x + self.seq(x)
50
+ return self.layernorm(x)
51
+
52
+
53
+ class MapperLayer(nn.Module):
54
+ def __init__(self, embed_dim=768, num_heads=8, dropout_rate=0.1):
55
+ super().__init__()
56
+
57
+ self.attn = EncoderAttention( num_heads=num_heads,
58
+ embed_dim=embed_dim,
59
+ dropout=dropout_rate)
60
+ self.ff = FeedForward(embed_dim=embed_dim,
61
+ dropout_rate=dropout_rate)
62
+
63
+ def forward(self, x):
64
+ x = self.attn(x)
65
+ x = self.ff(x)
66
+ return x
67
+
68
+
69
+ class Transformer(nn.Module):
70
+ def __init__(self,
71
+ num_layers=8,
72
+ num_heads=8,
73
+ embed_dim=768,
74
+ dropout_rate=0.1
75
+ ):
76
+ super().__init__()
77
+
78
+ layers = [MapperLayer(embed_dim=embed_dim,
79
+ num_heads=num_heads,
80
+ dropout_rate=dropout_rate) for i in range(num_layers)]
81
+ self.layers = nn.ModuleList(layers)
82
+
83
+
84
+ def forward(self,x):
85
+ for layer in self.layers:
86
+ x = layer(x)
87
+ return x
88
+
89
+ class TransformerMapper(nn.Module):
90
+
91
+ def forward(self, x):
92
+ x = self.linear(x).view(x.shape[0], self.clip_length, -1)
93
+ prefix = self.prefix_const.unsqueeze(0).expand(x.shape[0], *self.prefix_const.shape) # (B,prefix_len,embed_dim)
94
+ prefix = torch.cat((x, prefix), dim=1)
95
+ return self.transformer(prefix)[:, self.clip_length:]
96
+
97
+
98
+ def __init__(self,
99
+ dim_clip = 768,
100
+ embed_dim = 768,
101
+ prefix_length = 16,
102
+ clip_length = 10,
103
+ num_layers = 8,
104
+ num_heads = 8,
105
+ dropout_rate = 0.1
106
+ ):
107
+ super().__init__()
108
+ self.clip_length = clip_length
109
+ self.transformer = Transformer(
110
+ num_layers=num_layers,
111
+ num_heads=num_heads,
112
+ embed_dim=embed_dim,
113
+ dropout_rate=dropout_rate
114
+ )
115
+ self.linear = nn.Linear(dim_clip, self.clip_length * embed_dim) # CLIP prefixes (clip_length prefixes) (B,clip_len*768)
116
+ self.prefix_const = nn.Parameter(torch.randn(prefix_length, embed_dim), requires_grad=True)
117
+
118
+ class ClipCaptionModel(nn.Module):
119
+
120
+ def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor:
121
+ return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)
122
+
123
+ def forward(self,
124
+ tokens: torch.Tensor,
125
+ prefix: torch.Tensor,
126
+ mask: torch.Tensor,
127
+ labels=None):
128
+ # create embeddings for the gpt model
129
+ embedding_text = self.gpt.transformer.wte(tokens)
130
+ prefix_projections = self.clip_project(prefix)
131
+ embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
132
+
133
+ # prepare mask
134
+ if mask.shape[1] != embedding_cat.shape[1]:
135
+ dummy_mask = torch.ones(tokens.shape[0],self.prefix_length, dtype=torch.int64, device=self.gpt.device)
136
+ mask = torch.cat([dummy_mask,mask],dim=1)
137
+
138
+ if labels is not None:
139
+ dummy_token = torch.zeros(tokens.shape[0],self.prefix_length, dtype=torch.int64, device=device)
140
+ labels = torch.cat((dummy_token, tokens), dim=1)
141
+
142
+ return self.gpt(inputs_embeds=embedding_cat,
143
+ labels=labels,
144
+ attention_mask=mask)
145
+
146
+
147
+ def __init__(self,
148
+ dim_clip = 768,
149
+ embed_dim = 768,
150
+ prefix_length = 16,
151
+ clip_length = 10,
152
+ num_layers = 8,
153
+ num_heads = 8,
154
+ dropout_rate = 0.1,
155
+ ):
156
+ super().__init__()
157
+ self.prefix_length = prefix_length
158
+ self.gpt = GPT2LMHeadModel.from_pretrained('gpt2')
159
+ self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
160
+ self.clip_project = TransformerMapper(
161
+ dim_clip = dim_clip,
162
+ embed_dim = self.gpt_embedding_size,
163
+ prefix_length = prefix_length,
164
+ clip_length = clip_length,
165
+ num_layers = num_layers,
166
+ num_heads = num_heads,
167
+ dropout_rate = dropout_rate)
168
+
169
+
170
+ ## Prepare Model
171
+ CliPGPT = ClipCaptionModel()
172
+ path = "files/model_epoch_1.pt"
173
+ state_dict = torch.load(path)
174
+
175
+ # Apply the weights to the model
176
+ CliPGPT.load_state_dict(state_dict)
177
+ CliPGPT.to(device)
178
+
179
+ from transformers import CLIPProcessor, CLIPModel
180
+
181
+ model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
182
+ model.eval()
183
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
184
+
185
+ def sample_from_logits(logits, temperature=0.3):
186
+ logits = logits / temperature
187
+ probabilities = torch.softmax(logits, dim=-1)
188
+ return torch.multinomial(probabilities, 1).squeeze()
189
+
190
+ def generate(image,
191
+ device=device,
192
+ max_tokens=48,
193
+ temperature=0.3,
194
+ verbose=True,
195
+ sample=True,
196
+ ):
197
+ model.to(device)
198
+ CliPGPT.to(device)
199
+ # encode image
200
+ with torch.inference_mode():
201
+ input = torch.tensor(np.stack(processor.image_processor(image).pixel_values,axis=0)).to(device)
202
+ embeds = model.vision_model(input)
203
+ embeds = embeds.pooler_output
204
+
205
+ CliPGPT.eval()
206
+ prefix_length = CliPGPT.prefix_length
207
+
208
+ # prepare initial token '#' used as token to begin generation of caption
209
+ tokens = ['#']
210
+ input_ids,attention_mask = collator(tokenizer(tokens)).values()
211
+
212
+ # forward pass
213
+ for i in tqdm(range(max_tokens),desc='generating... '):
214
+
215
+ input_ids = input_ids.to(device)
216
+ embeds = embeds.to(device)
217
+ attention_mask = attention_mask.to(device)
218
+
219
+ with torch.inference_mode():
220
+ out = CliPGPT(
221
+ tokens= input_ids,
222
+ prefix= embeds,
223
+ mask= attention_mask,
224
+ )
225
+ logits = out.logits
226
+ logits = logits[:,prefix_length:,:]
227
+
228
+ # Sampling Technique
229
+ if sample:
230
+ next_token = sample_from_logits(logits[:, -1, :],
231
+ temperature=temperature)
232
+ else:
233
+ next_token = torch.argmax(logits[:,-1,:],dim=-1).squeeze()
234
+ token = next_token.item()
235
+
236
+ if token == tokenizer.eos_token_id:
237
+ break
238
+ # update string
239
+ tokens = [tokens[0] + tokenizer.decode(next_token)]
240
+ # update tokens
241
+ input_ids,attention_mask = collator(tokenizer(tokens)).values()
242
+
243
+ if verbose:
244
+ print(token)
245
+ print(tokens[0])
246
+ print()
247
+ return tokens[0].replace('#','').strip()
248
+
249
+
250
+ st.title("CLIP GPT2 Image Captionning")
251
+ st.write("This is a web app for generating captions for images using a model built with CLIP & GPT2.")
252
+
253
+ # Image upload section
254
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"])
255
+
256
+ if uploaded_file is not None:
257
+ # Display the uploaded image
258
+ image = Image.open(uploaded_file)
259
+ st.image(image, caption='Uploaded Image', use_column_width=True)
260
+
261
+ # Generate caption button
262
+ if st.button('Submit'):
263
+ with st.spinner('Generating caption...'):
264
+ start_time = time.time()
265
+ caption = generate(image)
266
+ end_time = time.time()
267
+
268
+ st.text_area('Output', caption)
269
+ st.write(f"Inference time: {end_time - start_time:.2f} seconds")