data-silence commited on
Commit
1ce344c
·
verified ·
1 Parent(s): 78e4c04

Upload 2 files

Browse files
Files changed (2) hide show
  1. inference.py +63 -0
  2. requirements.txt +4 -0
inference.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ from transformers import BertModel
5
+ from transformers import AutoTokenizer
6
+ from huggingface_hub import hf_hub_download
7
+
8
+ class BiLSTMClassifier(nn.Module):
9
+ def __init__(self, hidden_dim, output_dim, n_layers, dropout):
10
+ super(BiLSTMClassifier, self).__init__()
11
+ self.bert = BertModel.from_pretrained("bert-base-multilingual-cased")
12
+ self.lstm = nn.LSTM(self.bert.config.hidden_size, hidden_dim, num_layers=n_layers,
13
+ bidirectional=True, dropout=dropout, batch_first=True)
14
+ self.fc = nn.Linear(hidden_dim * 2, output_dim)
15
+ self.dropout = nn.Dropout(dropout)
16
+
17
+ def forward(self, input_ids, attention_mask, labels=None):
18
+ with torch.no_grad():
19
+ embedded = self.bert(input_ids=input_ids, attention_mask=attention_mask)[0]
20
+ lstm_out, _ = self.lstm(embedded)
21
+ pooled = torch.mean(lstm_out, dim=1)
22
+ logits = self.fc(self.dropout(pooled))
23
+
24
+ if labels is not None:
25
+ loss_fn = nn.CrossEntropyLoss()
26
+ loss = loss_fn(logits, labels)
27
+ return {"loss": loss, "logits": logits} # Возвращаем словарь
28
+ return logits # Возвращаем логиты, если метки не переданы
29
+
30
+
31
+ categories = ['climate', 'conflicts', 'culture', 'economy', 'gloss', 'health',
32
+ 'politics', 'science', 'society', 'sports', 'travel']
33
+
34
+ repo_id = "data-silence/lstm-news-classifier"
35
+ tokenizer = AutoTokenizer.from_pretrained(repo_id)
36
+ model_path = hf_hub_download(repo_id=repo_id, filename="model.pth")
37
+
38
+ model = torch.load(model_path)
39
+
40
+
41
+ def predict(news: str) -> str:
42
+ with torch.no_grad():
43
+ inputs = tokenizer(news, return_tensors="pt")
44
+ del inputs['token_type_ids']
45
+ output = model.forward(**inputs)
46
+ id_best_label = torch.argmax(output[0, :], dim=-1).detach().cpu().numpy()
47
+ prediction = categories[id_best_label]
48
+ return prediction
49
+
50
+
51
+ # Создание интерфейса Gradio
52
+ iface = gr.Interface(
53
+ fn=predict,
54
+ inputs=gr.Textbox(lines=5, label="Enter news text | Введите текст новости"),
55
+ outputs=[
56
+ gr.Label(label="Predicted category | Предсказанная категория"),
57
+ gr.Label(label="Category probabilities | Вероятности категорий")
58
+ ],
59
+ title="News Classifier | Классификатор новостей",
60
+ description="Enter the news text in any language and the model will predict its category. | Введите текст новости на любом языке, и модель предскажет её категорию"
61
+ )
62
+
63
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ transformers
4
+ huggingface_hub