Update app.py
Browse files
app.py
CHANGED
@@ -10,6 +10,7 @@ print(torch.__version__)
|
|
10 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
11 |
print(device)
|
12 |
|
|
|
13 |
from transformers import GPT2Tokenizer,GPT2LMHeadModel,DataCollatorWithPadding
|
14 |
|
15 |
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
@@ -132,12 +133,9 @@ class ClipCaptionModel(nn.Module):
|
|
132 |
|
133 |
# prepare mask
|
134 |
if mask.shape[1] != embedding_cat.shape[1]:
|
135 |
-
dummy_mask = torch.ones(tokens.shape[0],self.prefix_length, dtype=torch.int64, device=
|
136 |
mask = torch.cat([dummy_mask,mask],dim=1)
|
137 |
|
138 |
-
if labels is not None:
|
139 |
-
dummy_token = torch.zeros(tokens.shape[0],self.prefix_length, dtype=torch.int64, device=device)
|
140 |
-
labels = torch.cat((dummy_token, tokens), dim=1)
|
141 |
|
142 |
return self.gpt(inputs_embeds=embedding_cat,
|
143 |
labels=labels,
|
@@ -167,6 +165,7 @@ class ClipCaptionModel(nn.Module):
|
|
167 |
dropout_rate = dropout_rate)
|
168 |
|
169 |
|
|
|
170 |
## Prepare Model
|
171 |
CliPGPT = ClipCaptionModel()
|
172 |
path = "model_epoch_1.pt"
|
@@ -176,6 +175,7 @@ state_dict = torch.load(path)
|
|
176 |
CliPGPT.load_state_dict(state_dict)
|
177 |
CliPGPT.to(device)
|
178 |
|
|
|
179 |
from transformers import CLIPProcessor, CLIPModel
|
180 |
|
181 |
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
|
@@ -247,6 +247,7 @@ def generate(image,
|
|
247 |
return tokens[0].replace('#','').strip()
|
248 |
|
249 |
|
|
|
250 |
st.title("CLIP GPT2 Image Captionning")
|
251 |
st.write("This is a web app for generating captions for images using a model built with CLIP & GPT2.")
|
252 |
|
|
|
10 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
11 |
print(device)
|
12 |
|
13 |
+
print('importing tokenizer')
|
14 |
from transformers import GPT2Tokenizer,GPT2LMHeadModel,DataCollatorWithPadding
|
15 |
|
16 |
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
|
|
133 |
|
134 |
# prepare mask
|
135 |
if mask.shape[1] != embedding_cat.shape[1]:
|
136 |
+
dummy_mask = torch.ones(tokens.shape[0],self.prefix_length, dtype=torch.int64, device=mask.device)
|
137 |
mask = torch.cat([dummy_mask,mask],dim=1)
|
138 |
|
|
|
|
|
|
|
139 |
|
140 |
return self.gpt(inputs_embeds=embedding_cat,
|
141 |
labels=labels,
|
|
|
165 |
dropout_rate = dropout_rate)
|
166 |
|
167 |
|
168 |
+
print('loading model')
|
169 |
## Prepare Model
|
170 |
CliPGPT = ClipCaptionModel()
|
171 |
path = "model_epoch_1.pt"
|
|
|
175 |
CliPGPT.load_state_dict(state_dict)
|
176 |
CliPGPT.to(device)
|
177 |
|
178 |
+
print('importing CLIP')
|
179 |
from transformers import CLIPProcessor, CLIPModel
|
180 |
|
181 |
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
|
|
|
247 |
return tokens[0].replace('#','').strip()
|
248 |
|
249 |
|
250 |
+
print('app starts')
|
251 |
st.title("CLIP GPT2 Image Captionning")
|
252 |
st.write("This is a web app for generating captions for images using a model built with CLIP & GPT2.")
|
253 |
|