Update models/pairwise_model.py
Browse files- models/pairwise_model.py +4 -2
models/pairwise_model.py
CHANGED
@@ -4,9 +4,11 @@ import torch.nn as nn
|
|
4 |
from torch.utils.data import Dataset, DataLoader
|
5 |
from transformers import AutoModel, AutoConfig
|
6 |
from transformers import AutoTokenizer
|
|
|
7 |
import pandas as pd
|
|
|
8 |
|
9 |
-
AUTH_TOKEN = "
|
10 |
|
11 |
tokenizer = AutoTokenizer.from_pretrained('nguyenvulebinh/vi-mrc-base',
|
12 |
use_auth_token=AUTH_TOKEN)
|
@@ -19,7 +21,7 @@ class PairwiseModel(nn.Module):
|
|
19 |
self.max_length = max_length
|
20 |
self.batch_size = batch_size
|
21 |
self.device = device
|
22 |
-
self.model =
|
23 |
self.model.to(self.device)
|
24 |
self.model.eval()
|
25 |
self.config = AutoConfig.from_pretrained(model_name, use_auth_token=AUTH_TOKEN)
|
|
|
4 |
from torch.utils.data import Dataset, DataLoader
|
5 |
from transformers import AutoModel, AutoConfig
|
6 |
from transformers import AutoTokenizer
|
7 |
+
from optimum.onnxruntime import ORTModelForQuestionAnswering
|
8 |
import pandas as pd
|
9 |
+
import os
|
10 |
|
11 |
+
AUTH_TOKEN = os.getenv("AUTH_TOKEN")
|
12 |
|
13 |
tokenizer = AutoTokenizer.from_pretrained('nguyenvulebinh/vi-mrc-base',
|
14 |
use_auth_token=AUTH_TOKEN)
|
|
|
21 |
self.max_length = max_length
|
22 |
self.batch_size = batch_size
|
23 |
self.device = device
|
24 |
+
self.model = ORTModelForQuestionAnswering.from_pretrained(model_name, use_auth_token=AUTH_TOKEN, from_transformers=True)
|
25 |
self.model.to(self.device)
|
26 |
self.model.eval()
|
27 |
self.config = AutoConfig.from_pretrained(model_name, use_auth_token=AUTH_TOKEN)
|