data-silence commited on
Commit
bb37dc6
·
verified ·
1 Parent(s): 8ec3572

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +61 -62
inference.py CHANGED
@@ -1,63 +1,62 @@
1
- import gradio as gr
2
- import torch
3
- import torch.nn as nn
4
- from transformers import BertModel
5
- from transformers import AutoTokenizer
6
- from huggingface_hub import hf_hub_download
7
-
8
- class BiLSTMClassifier(nn.Module):
9
- def __init__(self, hidden_dim, output_dim, n_layers, dropout):
10
- super(BiLSTMClassifier, self).__init__()
11
- self.bert = BertModel.from_pretrained("bert-base-multilingual-cased")
12
- self.lstm = nn.LSTM(self.bert.config.hidden_size, hidden_dim, num_layers=n_layers,
13
- bidirectional=True, dropout=dropout, batch_first=True)
14
- self.fc = nn.Linear(hidden_dim * 2, output_dim)
15
- self.dropout = nn.Dropout(dropout)
16
-
17
- def forward(self, input_ids, attention_mask, labels=None):
18
- with torch.no_grad():
19
- embedded = self.bert(input_ids=input_ids, attention_mask=attention_mask)[0]
20
- lstm_out, _ = self.lstm(embedded)
21
- pooled = torch.mean(lstm_out, dim=1)
22
- logits = self.fc(self.dropout(pooled))
23
-
24
- if labels is not None:
25
- loss_fn = nn.CrossEntropyLoss()
26
- loss = loss_fn(logits, labels)
27
- return {"loss": loss, "logits": logits} # Возвращаем словарь
28
- return logits # Возвращаем логиты, если метки не переданы
29
-
30
-
31
- categories = ['climate', 'conflicts', 'culture', 'economy', 'gloss', 'health',
32
- 'politics', 'science', 'society', 'sports', 'travel']
33
-
34
- repo_id = "data-silence/lstm-news-classifier"
35
- tokenizer = AutoTokenizer.from_pretrained(repo_id)
36
- model_path = hf_hub_download(repo_id=repo_id, filename="model.pth")
37
-
38
- model = torch.load(model_path)
39
-
40
-
41
- def predict(news: str) -> str:
42
- with torch.no_grad():
43
- inputs = tokenizer(news, return_tensors="pt")
44
- del inputs['token_type_ids']
45
- output = model.forward(**inputs)
46
- id_best_label = torch.argmax(output[0, :], dim=-1).detach().cpu().numpy()
47
- prediction = categories[id_best_label]
48
- return prediction
49
-
50
-
51
- # Создание интерфейса Gradio
52
- iface = gr.Interface(
53
- fn=predict,
54
- inputs=gr.Textbox(lines=5, label="Enter news text | Введите текст новости"),
55
- outputs=[
56
- gr.Label(label="Predicted category | Предсказанная категория"),
57
- gr.Label(label="Category probabilities | Вероятности категорий")
58
- ],
59
- title="News Classifier | Классификатор новостей",
60
- description="Enter the news text in any language and the model will predict its category. | Введите текст новости на любом языке, и модель предскажет её категорию"
61
- )
62
-
63
  iface.launch()
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ from transformers import BertModel
5
+ from transformers import AutoTokenizer
6
+ from huggingface_hub import hf_hub_download
7
+
8
+ class BiLSTMClassifier(nn.Module):
9
+ def __init__(self, hidden_dim, output_dim, n_layers, dropout):
10
+ super(BiLSTMClassifier, self).__init__()
11
+ self.bert = BertModel.from_pretrained("bert-base-multilingual-cased")
12
+ self.lstm = nn.LSTM(self.bert.config.hidden_size, hidden_dim, num_layers=n_layers,
13
+ bidirectional=True, dropout=dropout, batch_first=True)
14
+ self.fc = nn.Linear(hidden_dim * 2, output_dim)
15
+ self.dropout = nn.Dropout(dropout)
16
+
17
+ def forward(self, input_ids, attention_mask, labels=None):
18
+ with torch.no_grad():
19
+ embedded = self.bert(input_ids=input_ids, attention_mask=attention_mask)[0]
20
+ lstm_out, _ = self.lstm(embedded)
21
+ pooled = torch.mean(lstm_out, dim=1)
22
+ logits = self.fc(self.dropout(pooled))
23
+
24
+ if labels is not None:
25
+ loss_fn = nn.CrossEntropyLoss()
26
+ loss = loss_fn(logits, labels)
27
+ return {"loss": loss, "logits": logits} # Возвращаем словарь
28
+ return logits # Возвращаем логиты, если метки не переданы
29
+
30
+
31
+ categories = ['climate', 'conflicts', 'culture', 'economy', 'gloss', 'health',
32
+ 'politics', 'science', 'society', 'sports', 'travel']
33
+
34
+ repo_id = "data-silence/lstm-news-classifier"
35
+ tokenizer = AutoTokenizer.from_pretrained(repo_id)
36
+ model_path = hf_hub_download(repo_id=repo_id, filename="model.pth")
37
+
38
+ model = torch.load(model_path)
39
+
40
+
41
+ def predict(news: str) -> str:
42
+ with torch.no_grad():
43
+ inputs = tokenizer(news, return_tensors="pt")
44
+ del inputs['token_type_ids']
45
+ output = model.forward(**inputs)
46
+ id_best_label = torch.argmax(output[0, :], dim=-1).detach().cpu().numpy()
47
+ prediction = categories[id_best_label]
48
+ return prediction
49
+
50
+
51
+ # Создание интерфейса Gradio
52
+ iface = gr.Interface(
53
+ fn=predict,
54
+ inputs=gr.Textbox(lines=5, label="Enter news text | Введите текст новости"),
55
+ outputs=[
56
+ gr.Label(label="Predicted category | Предсказанная категория")
57
+ ],
58
+ title="LSTM News Classifier | LSTM Классификатор новостей",
59
+ description="Enter the news text in russian and the model will predict its category. | Введите текст русскоязычной новости, и модель предскажет её категорию"
60
+ )
61
+
 
62
  iface.launch()