niknikita commited on
Commit
a2a33af
·
1 Parent(s): a331b68

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -1
app.py CHANGED
@@ -53,7 +53,13 @@ class DistillBERTClass(torch.nn.Module):
53
  pooler = self.dropout(pooler)
54
  output = self.classifier(pooler)
55
  return output
56
- model = torch.load("model.pt", map_location='cpu').eval()
 
 
 
 
 
 
57
 
58
  # print(model)
59
  # model = DistilBertForSequenceClassification.from_pretrained("model/distilbert-model1.pt", local_files_only=True)
 
53
  pooler = self.dropout(pooler)
54
  output = self.classifier(pooler)
55
  return output
56
+
57
+
58
+ model = DistillBERTClass()
59
+ optimizer = torch.optim.Adam(params = model.parameters(), lr=LEARNING_RATE)
60
+ checkpoint = torch.load("pytorch_distilbert_news.bin")
61
+ model.load_state_dict(checkpoint['model'])
62
+ optimizer.load_state_dict(checkpoint['opt'])
63
 
64
  # print(model)
65
  # model = DistilBertForSequenceClassification.from_pretrained("model/distilbert-model1.pt", local_files_only=True)