Valeriy Sinyukov commited on
Commit
5c5407c
·
1 Parent(s): bf2b17b

Add translation models

Browse files
category_classification/models/models.py CHANGED
@@ -6,6 +6,8 @@ import warnings
6
  from pathlib import Path
7
 
8
  from . import pipeline
 
 
9
 
10
  def import_model_module(file_path: os.PathLike):
11
  module_name = str(Path(file_path).relative_to(os.getcwd())).replace(
@@ -76,6 +78,10 @@ for path in file_dir.glob("*"):
76
  language_to_models.setdefault(lang, {})
77
  language_to_models[lang][name] = get_model
78
 
 
 
 
 
79
 
80
  def get_model(name: str):
81
  if name not in models:
@@ -86,6 +92,7 @@ def get_model(name: str):
86
  def get_all_model_names():
87
  return list(models.keys())
88
 
 
89
  def get_model_names_by_lang(lang):
90
  if lang not in language_to_models:
91
  return []
 
6
  from pathlib import Path
7
 
8
  from . import pipeline
9
+ from .translation import create_translation_models
10
+
11
 
12
  def import_model_module(file_path: os.PathLike):
13
  module_name = str(Path(file_path).relative_to(os.getcwd())).replace(
 
78
  language_to_models.setdefault(lang, {})
79
  language_to_models[lang][name] = get_model
80
 
81
+ translation_models = create_translation_models(language_to_models["en"])
82
+ language_to_models.setdefault("ru", {}).update(translation_models)
83
+ models.update(translation_models)
84
+
85
 
86
  def get_model(name: str):
87
  if name not in models:
 
92
  def get_all_model_names():
93
  return list(models.keys())
94
 
95
+
96
  def get_model_names_by_lang(lang):
97
  if lang not in language_to_models:
98
  return []
category_classification/models/pipeline.py CHANGED
@@ -1,4 +1,5 @@
1
  import typing as tp
 
2
 
3
  import torch
4
 
 
1
  import typing as tp
2
+ from collections import namedtuple
3
 
4
  import torch
5
 
category_classification/models/translation.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+ from functools import partial
3
+
4
+ import torch
5
+ from transformers import pipeline
6
+
7
+
8
+ def get_translator():
9
+ return pipeline(
10
+ "translation_en_to_ru",
11
+ model="Helsinki-NLP/opus-mt-ru-en",
12
+ device="cuda" if torch.cuda.is_available() else "cpu",
13
+ torch_dtype="auto",
14
+ )
15
+
16
+ class Input:
17
+ def __init__(self, title, abstract, authors):
18
+ self.title = title
19
+ self.abstract = abstract
20
+ self.authors = authors
21
+
22
+ class TranslationModel:
23
+ def __init__(self, get_model):
24
+ self.translator = get_translator()
25
+ self.model = get_model()
26
+
27
+ def __call__(self, input):
28
+ def translate(text):
29
+ if text is None or text.strip() == "":
30
+ return ""
31
+ text = str(text).strip()
32
+ translated = self.translator(text)[0]['translation_text']
33
+ return translated
34
+ title = translate(input.title)
35
+ abstract = translate(input.abstract)
36
+ authors = translate(input.authors)
37
+ out = self.model(Input(title, abstract, authors))
38
+ return out
39
+
40
+
41
+ def create_translation_models(models):
42
+ return {
43
+ f"{name} (С помощью перевода)": partial(TranslationModel, get_model=get_model)
44
+ for name, get_model in models.items()
45
+ }
46
+
requirements.txt CHANGED
@@ -130,6 +130,7 @@ safetensors==0.5.3
130
  scikit-learn==1.6.1
131
  scipy==1.15.2
132
  Send2Trash==1.8.3
 
133
  six==1.17.0
134
  smmap==5.0.2
135
  sniffio==1.3.1
 
130
  scikit-learn==1.6.1
131
  scipy==1.15.2
132
  Send2Trash==1.8.3
133
+ sentencepiece==0.2.0
134
  six==1.17.0
135
  smmap==5.0.2
136
  sniffio==1.3.1