|
import torch |
|
import pandas as pd |
|
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast |
|
|
|
class EndpointHandler: |
|
def __init__(self, model_dir): |
|
self.tokenizer = MBart50TokenizerFast.from_pretrained(model_dir) |
|
self.model = MBartForConditionalGeneration.from_pretrained(model_dir) |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
self.model.to(self.device) |
|
|
|
|
|
self.tags = [ |
|
"[AFRICA]", "[EASTERN_EUROPE]", "[TRANSLATION]", "[FAR_RIGHT_EUROPE]", |
|
"[CONSPIRACY_LIKELIHOOD]", "[UNITED_STATES]", "[UKRAINE]", |
|
"[RUSSIAN_OPPOSITION]", "[OCCIDENTAL_VALUES]", "[ORGANIZATION]", |
|
"[EUROPEAN_UNION]", "[PROPAGANDA]", "[CENTRAL_EUROPE]", "[COUNTRY]", |
|
"[NATO]", "[HISTORICAL_REVISIONISM]", "[BRICS]", "[TOPIC_LIST]", |
|
"[TOPIC_DETERMINISTIC]", "[BALTIC_STATES]", "[RUSSIAN_PARAMILITARY]", |
|
"[ANTI_GLOBALISM]", "[MIDDLE_EAST]", "[NER]", "[SUMMARY]", |
|
"[DEHUMANIZATION]" |
|
] |
|
|
|
def query_with_tags(self, text, tags): |
|
|
|
temp_df = pd.DataFrame([{"text": text}]) |
|
|
|
try: |
|
|
|
result_df = self.init.query_with_df(df=temp_df, tags=tags) |
|
return result_df |
|
except Exception as e: |
|
return pd.DataFrame([{"error": str(e)}]) |
|
|
|
def process_single_text(self, text, tags=None): |
|
try: |
|
|
|
if tags: |
|
tagged_result = self.query_with_tags(text, tags) |
|
if "error" in tagged_result: |
|
return [{"error": tagged_result["error"].iloc[0]}] |
|
|
|
text = tagged_result["text"].iloc[0] |
|
|
|
|
|
self.tokenizer.src_lang = "ru_RU" |
|
|
|
|
|
inputs = self.tokenizer( |
|
text, |
|
return_tensors="pt", |
|
padding=True, |
|
truncation=True, |
|
max_length=1024 |
|
).to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
generated_tokens = self.model.generate( |
|
**inputs, |
|
forced_bos_token_id=self.tokenizer.lang_code_to_id["en_XX"], |
|
max_length=1024, |
|
num_beams=4, |
|
length_penalty=1.0, |
|
do_sample=False |
|
) |
|
|
|
|
|
translation = self.tokenizer.decode(generated_tokens[0], skip_special_tokens=True) |
|
|
|
return [{"output": translation}] |
|
|
|
except Exception as e: |
|
return [{"error": str(e)}] |
|
|
|
def __call__(self, data): |
|
try: |
|
if "inputs" not in data: |
|
return [{"error": "Request must contain 'inputs' field"}] |
|
|
|
inputs = data["inputs"] |
|
tags = data.get("tags", None) |
|
|
|
|
|
if tags: |
|
invalid_tags = [tag for tag in tags if tag not in self.tags] |
|
if invalid_tags: |
|
return [{"error": f"Invalid tags: {invalid_tags}"}] |
|
|
|
|
|
if isinstance(inputs, str): |
|
return self.process_single_text(inputs, tags) |
|
elif isinstance(inputs, list) and len(inputs) > 0: |
|
if isinstance(inputs[0], dict) and "input" in inputs[0]: |
|
return self.process_single_text(inputs[0]["input"], tags) |
|
else: |
|
return [{"error": "Invalid input format"}] |
|
else: |
|
return [{"error": "Invalid input format"}] |
|
|
|
except Exception as e: |
|
return [{"error": str(e)}] |
|
|
|
|
|
def init_handler(model_dir): |
|
handler = EndpointHandler(model_dir) |
|
return handler |