Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -53,7 +53,13 @@ class DistillBERTClass(torch.nn.Module):
|
|
53 |
pooler = self.dropout(pooler)
|
54 |
output = self.classifier(pooler)
|
55 |
return output
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
# print(model)
|
59 |
# model = DistilBertForSequenceClassification.from_pretrained("model/distilbert-model1.pt", local_files_only=True)
|
|
|
53 |
pooler = self.dropout(pooler)
|
54 |
output = self.classifier(pooler)
|
55 |
return output
|
56 |
+
|
57 |
+
|
58 |
+
model = DistillBERTClass()
|
59 |
+
optimizer = torch.optim.Adam(params = model.parameters(), lr=LEARNING_RATE)
|
60 |
+
checkpoint = torch.load("pytorch_distilbert_news.bin")
|
61 |
+
model.load_state_dict(checkpoint['model'])
|
62 |
+
optimizer.load_state_dict(checkpoint['opt'])
|
63 |
|
64 |
# print(model)
|
65 |
# model = DistilBertForSequenceClassification.from_pretrained("model/distilbert-model1.pt", local_files_only=True)
|