SatAT commited on
Commit
96f8331
·
1 Parent(s): 1409048

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -21
app.py CHANGED
@@ -7,11 +7,12 @@ from sklearn.preprocessing import LabelEncoder
7
  from keras.utils import pad_sequences
8
  from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
9
 
10
- st.markdown("### Hello, world!")
11
  st.markdown("<img width=200px src='https://rozetked.me/images/uploads/dwoilp3BVjlE.jpg'>", unsafe_allow_html=True)
12
  # ^-- можно показывать пользователю текст, картинки, ограниченное подмножество html - всё как в jupyter
13
 
14
- text = st.text_area("TEXT HERE")
 
15
  # ^-- показать текстовое поле. В поле text лежит строка, которая находится там в данный момент
16
 
17
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
@@ -20,63 +21,84 @@ model = BertForSequenceClassification.from_pretrained(
20
  num_labels = 44,)
21
  model.load_state_dict(torch.load("model_last_version.pt", map_location=torch.device('cpu')))
22
  MAX_LEN = 64
23
- tokens = tokenizer.encode_plus(text, add_special_tokens=True, max_length=MAX_LEN, truncation=True, padding='max_length')
24
- input_ids = torch.tensor(tokens['input_ids']).unsqueeze(0)
25
- attention_mask = torch.tensor(tokens['attention_mask']).unsqueeze(0)
26
 
 
 
 
27
  logits = model(input_ids, attention_mask)[0]
28
- probs = torch.softmax(logits, dim=1)
29
-
30
- predicted_category = torch.argmax(probs).item()
31
 
32
  tags_names = ['acc-phys',
33
  'adap-org',
34
- "adap-org'",
35
  'alg-geom',
36
  'astro-ph',
37
- "astro-ph'",
38
  'chao-dyn',
39
  'chem-ph',
40
  'cmp-lg',
41
- "cmp-lg'",
42
  'comp-gas',
43
  'cond-mat',
44
- "cond-mat'",
45
  'cs',
46
  'dg-ga',
47
  'econ',
48
  'eess',
49
  'funct-an',
50
  'gr-qc',
51
- "gr-qc'",
52
  'hep-ex',
53
- "hep-ex'",
54
  'hep-lat',
55
- "hep-lat'",
56
  'hep-ph',
57
- "hep-ph'",
58
  'hep-th',
59
- "hep-th'",
60
  'math',
61
  'math-ph',
62
  'mtrl-th',
63
  'nlin',
64
  'nucl-ex',
65
  'nucl-th',
66
- "nucl-th'",
67
  'patt-sol',
68
  'physics',
69
  'q-alg',
70
  'q-bio',
71
  'q-fin',
72
  'quant-ph',
73
- "quant-ph'",
74
  'solv-int',
75
  'stat']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  # from transformers import pipeline
77
  # pipe = pipeline("ner", "Davlan/distilbert-base-multilingual-cased-ner-hrl")
78
- raw_predictions = tags_names[predicted_category]#le.inverse_transform(prediction)#pipe(text)
79
  # тут уже знакомый вам код с huggingface.transformers -- его можно заменить на что угодно от fairseq до catboost
80
 
81
- st.markdown(f"{raw_predictions}")
82
  # выводим результаты модели в текстовое поле, на потеху пользователю
 
7
  from keras.utils import pad_sequences
8
  from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
9
 
10
+ st.markdown("### Paper category classification")
11
  st.markdown("<img width=200px src='https://rozetked.me/images/uploads/dwoilp3BVjlE.jpg'>", unsafe_allow_html=True)
12
  # ^-- можно показывать пользователю текст, картинки, ограниченное подмножество html - всё как в jupyter
13
 
14
+ title = st.text_area("TITLE HERE")
15
+ abstract = st.text_area("ABSTRACT HERE")
16
  # ^-- показать текстовое поле. В поле text лежит строка, которая находится там в данный момент
17
 
18
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
 
21
  num_labels = 44,)
22
  model.load_state_dict(torch.load("model_last_version.pt", map_location=torch.device('cpu')))
23
  MAX_LEN = 64
24
+ # Преобразуем название статьи в токены
25
+ tokens = tokenizer(title, padding=True, truncation=True, return_tensors="pt")
 
26
 
27
+ # Получаем предсказание модели для названия статьи и абстракта (если есть)
28
+ input_ids = tokens['input_ids']
29
+ attention_mask = tokens['attention_mask']
30
  logits = model(input_ids, attention_mask)[0]
 
 
 
31
 
32
  tags_names = ['acc-phys',
33
  'adap-org',
34
+ "adap-org",
35
  'alg-geom',
36
  'astro-ph',
37
+ "astro-ph",
38
  'chao-dyn',
39
  'chem-ph',
40
  'cmp-lg',
41
+ "cmp-lg",
42
  'comp-gas',
43
  'cond-mat',
44
+ "cond-mat",
45
  'cs',
46
  'dg-ga',
47
  'econ',
48
  'eess',
49
  'funct-an',
50
  'gr-qc',
51
+ "gr-qc",
52
  'hep-ex',
53
+ "hep-ex",
54
  'hep-lat',
55
+ "hep-lat",
56
  'hep-ph',
57
+ "hep-ph",
58
  'hep-th',
59
+ "hep-th",
60
  'math',
61
  'math-ph',
62
  'mtrl-th',
63
  'nlin',
64
  'nucl-ex',
65
  'nucl-th',
66
+ "nucl-th",
67
  'patt-sol',
68
  'physics',
69
  'q-alg',
70
  'q-bio',
71
  'q-fin',
72
  'quant-ph',
73
+ "quant-ph",
74
  'solv-int',
75
  'stat']
76
+
77
+ if abstract:
78
+ abstract_tokens = tokenizer(abstract, padding=True, truncation=True, return_tensors="pt")
79
+ abstract_input_ids = abstract_tokens['input_ids']
80
+ abstract_attention_mask = abstract_tokens['attention_mask']
81
+ abstract_logits = model(abstract_input_ids, abstract_attention_mask)[0]
82
+ logits += abstract_logits
83
+
84
+ # Получаем вероятности и сортируем их в порядке убывания
85
+ probs = torch.softmax(logits, dim=-1).squeeze()
86
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True)
87
+
88
+ # Считаем сумму вероятностей
89
+ sum_probs = 0.0
90
+ top_classes = []
91
+ for i in range(len(sorted_probs)):
92
+ sum_probs += sorted_probs[i]
93
+ if sum_probs > 0.95 or sorted_probs[i] < 0.001:
94
+ break
95
+ top_classes.append((tag_names[sorted_indices[i].item()], sorted_probs[i].item()))
96
+
97
+ # Выводим список тем с их вероятностями
98
  # from transformers import pipeline
99
  # pipe = pipeline("ner", "Davlan/distilbert-base-multilingual-cased-ner-hrl")
100
+ raw_predictions = top_classes#le.inverse_transform(prediction)#pipe(text)
101
  # тут уже знакомый вам код с huggingface.transformers -- его можно заменить на что угодно от fairseq до catboost
102
 
103
+ st.markdown(f"Possible categories for this article: {raw_predictions}")
104
  # выводим результаты модели в текстовое поле, на потеху пользователю