Update app.py
Browse files
app.py
CHANGED
@@ -7,7 +7,7 @@ import numpy as np
|
|
7 |
import torch
|
8 |
from torch import nn
|
9 |
print(torch.__version__)
|
10 |
-
device = torch.device('
|
11 |
print(device)
|
12 |
|
13 |
print('importing tokenizer')
|
@@ -166,13 +166,14 @@ class ClipCaptionModel(nn.Module):
|
|
166 |
|
167 |
|
168 |
print('loading model')
|
|
|
169 |
## Prepare Model
|
170 |
CliPGPT = ClipCaptionModel()
|
171 |
path = "model_epoch_1.pt"
|
172 |
-
state_dict = torch.load(path)
|
173 |
|
174 |
# Apply the weights to the model
|
175 |
-
CliPGPT.load_state_dict(state_dict
|
176 |
CliPGPT.to(device)
|
177 |
|
178 |
print('importing CLIP')
|
|
|
7 |
import torch
|
8 |
from torch import nn
|
9 |
print(torch.__version__)
|
10 |
+
device = torch.device('cpu')
|
11 |
print(device)
|
12 |
|
13 |
print('importing tokenizer')
|
|
|
166 |
|
167 |
|
168 |
print('loading model')
|
169 |
+
print()
|
170 |
## Prepare Model
|
171 |
CliPGPT = ClipCaptionModel()
|
172 |
path = "model_epoch_1.pt"
|
173 |
+
state_dict = torch.load(path,map_location=torch.device('cpu')))
|
174 |
|
175 |
# Apply the weights to the model
|
176 |
+
CliPGPT.load_state_dict(state_dict)
|
177 |
CliPGPT.to(device)
|
178 |
|
179 |
print('importing CLIP')
|