seddiktrk commited on
Commit
4a09c08
·
verified ·
1 Parent(s): 4970c30

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -10,6 +10,7 @@ print(torch.__version__)
10
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
  print(device)
12
 
 
13
  from transformers import GPT2Tokenizer,GPT2LMHeadModel,DataCollatorWithPadding
14
 
15
  tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
@@ -132,12 +133,9 @@ class ClipCaptionModel(nn.Module):
132
 
133
  # prepare mask
134
  if mask.shape[1] != embedding_cat.shape[1]:
135
- dummy_mask = torch.ones(tokens.shape[0],self.prefix_length, dtype=torch.int64, device=self.gpt.device)
136
  mask = torch.cat([dummy_mask,mask],dim=1)
137
 
138
- if labels is not None:
139
- dummy_token = torch.zeros(tokens.shape[0],self.prefix_length, dtype=torch.int64, device=device)
140
- labels = torch.cat((dummy_token, tokens), dim=1)
141
 
142
  return self.gpt(inputs_embeds=embedding_cat,
143
  labels=labels,
@@ -167,6 +165,7 @@ class ClipCaptionModel(nn.Module):
167
  dropout_rate = dropout_rate)
168
 
169
 
 
170
  ## Prepare Model
171
  CliPGPT = ClipCaptionModel()
172
  path = "model_epoch_1.pt"
@@ -176,6 +175,7 @@ state_dict = torch.load(path)
176
  CliPGPT.load_state_dict(state_dict)
177
  CliPGPT.to(device)
178
 
 
179
  from transformers import CLIPProcessor, CLIPModel
180
 
181
  model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
@@ -247,6 +247,7 @@ def generate(image,
247
  return tokens[0].replace('#','').strip()
248
 
249
 
 
250
  st.title("CLIP GPT2 Image Captionning")
251
  st.write("This is a web app for generating captions for images using a model built with CLIP & GPT2.")
252
 
 
10
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
  print(device)
12
 
13
+ print('importing tokenizer')
14
  from transformers import GPT2Tokenizer,GPT2LMHeadModel,DataCollatorWithPadding
15
 
16
  tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
 
133
 
134
  # prepare mask
135
  if mask.shape[1] != embedding_cat.shape[1]:
136
+ dummy_mask = torch.ones(tokens.shape[0],self.prefix_length, dtype=torch.int64, device=mask.device)
137
  mask = torch.cat([dummy_mask,mask],dim=1)
138
 
 
 
 
139
 
140
  return self.gpt(inputs_embeds=embedding_cat,
141
  labels=labels,
 
165
  dropout_rate = dropout_rate)
166
 
167
 
168
+ print('loading model')
169
  ## Prepare Model
170
  CliPGPT = ClipCaptionModel()
171
  path = "model_epoch_1.pt"
 
175
  CliPGPT.load_state_dict(state_dict)
176
  CliPGPT.to(device)
177
 
178
+ print('importing CLIP')
179
  from transformers import CLIPProcessor, CLIPModel
180
 
181
  model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
 
247
  return tokens[0].replace('#','').strip()
248
 
249
 
250
+ print('app starts')
251
  st.title("CLIP GPT2 Image Captionning")
252
  st.write("This is a web app for generating captions for images using a model built with CLIP & GPT2.")
253