igorithm commited on
Commit
0b8334d
·
verified ·
1 Parent(s): 30fc256

Add attribute "name" to model.py

Browse files
category_classification/models/allenai__scibert_sci_vocab_uncased/model.py CHANGED
@@ -1,45 +1,46 @@
1
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
2
- import torch
3
-
4
-
5
- class SciBertPaperClassifier:
6
- def __init__(self, model_path="trained_model"):
7
- self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
8
- self.tokenizer = AutoTokenizer.from_pretrained(model_path)
9
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
- self.model.to(self.device)
11
- self.model.eval()
12
-
13
- def __call__(self, inputs):
14
- texts = [
15
- f"AUTHORS: {' '.join(authors) if isinstance(authors, list) else authors} "
16
- f"TITLE: {paper['title']} ABSTRACT: {paper['abstract']}"
17
- for paper in inputs
18
- for authors in [paper.get("authors", "")]
19
- ]
20
-
21
- inputs = self.tokenizer(
22
- texts, truncation=True, padding=True, max_length=256, return_tensors="pt"
23
- ).to(self.device)
24
-
25
- with torch.no_grad():
26
- outputs = self.model(**inputs)
27
-
28
- probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
29
- scores, labels = torch.max(probs, dim=1)
30
-
31
- return [
32
- [{"label": self.model.config.id2label[label.item()], "score": score.item()}]
33
- for label, score in zip(labels, scores)
34
- ]
35
-
36
- def __getstate__(self):
37
- return self.__dict__
38
-
39
- def __setstate__(self, state):
40
- self.__dict__ = state
41
- self.model.to(self.device)
42
-
43
-
44
- def get_model():
45
- return SciBertPaperClassifier()
 
 
1
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
2
+ import torch
3
+
4
+ name = "allenai/scibert_scivocab_uncased"
5
+
6
+ class SciBertPaperClassifier:
7
+ def __init__(self, model_path="trained_model"):
8
+ self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
9
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
10
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ self.model.to(self.device)
12
+ self.model.eval()
13
+
14
+ def __call__(self, inputs):
15
+ texts = [
16
+ f"AUTHORS: {' '.join(authors) if isinstance(authors, list) else authors} "
17
+ f"TITLE: {paper['title']} ABSTRACT: {paper['abstract']}"
18
+ for paper in inputs
19
+ for authors in [paper.get("authors", "")]
20
+ ]
21
+
22
+ inputs = self.tokenizer(
23
+ texts, truncation=True, padding=True, max_length=256, return_tensors="pt"
24
+ ).to(self.device)
25
+
26
+ with torch.no_grad():
27
+ outputs = self.model(**inputs)
28
+
29
+ probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
30
+ scores, labels = torch.max(probs, dim=1)
31
+
32
+ return [
33
+ [{"label": self.model.config.id2label[label.item()], "score": score.item()}]
34
+ for label, score in zip(labels, scores)
35
+ ]
36
+
37
+ def __getstate__(self):
38
+ return self.__dict__
39
+
40
+ def __setstate__(self, state):
41
+ self.__dict__ = state
42
+ self.model.to(self.device)
43
+
44
+
45
+ def get_model():
46
+ return SciBertPaperClassifier()