File size: 4,326 Bytes
6249ec3
5e7f6d8
6249ec3
 
 
 
 
 
 
 
 
5e7f6d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6249ec3
5e7f6d8
 
 
 
 
 
 
 
b6ff77b
79e850c
6249ec3
 
 
b6ff77b
6249ec3
 
 
9d470b0
6249ec3
 
b6ff77b
6249ec3
b6ff77b
6249ec3
79e850c
9d470b0
b6ff77b
 
6249ec3
 
 
b6ff77b
 
5e7f6d8
 
6249ec3
 
2f8ca30
6249ec3
 
b6ff77b
 
2f8ca30
b6ff77b
 
5e7f6d8
 
 
 
 
 
 
b6ff77b
5e7f6d8
b6ff77b
5e7f6d8
2f8ca30
 
5e7f6d8
2f8ca30
 
b6ff77b
2f8ca30
b6ff77b
 
5e7f6d8
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import torch
import pandas as pd
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast

class EndpointHandler:
    def __init__(self, model_dir):
        self.tokenizer = MBart50TokenizerFast.from_pretrained(model_dir)
        self.model = MBartForConditionalGeneration.from_pretrained(model_dir)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        
        # Définition des tags disponibles
        self.tags = [
            "[AFRICA]", "[EASTERN_EUROPE]", "[TRANSLATION]", "[FAR_RIGHT_EUROPE]",
            "[CONSPIRACY_LIKELIHOOD]", "[UNITED_STATES]", "[UKRAINE]",
            "[RUSSIAN_OPPOSITION]", "[OCCIDENTAL_VALUES]", "[ORGANIZATION]",
            "[EUROPEAN_UNION]", "[PROPAGANDA]", "[CENTRAL_EUROPE]", "[COUNTRY]",
            "[NATO]", "[HISTORICAL_REVISIONISM]", "[BRICS]", "[TOPIC_LIST]",
            "[TOPIC_DETERMINISTIC]", "[BALTIC_STATES]", "[RUSSIAN_PARAMILITARY]",
            "[ANTI_GLOBALISM]", "[MIDDLE_EAST]", "[NER]", "[SUMMARY]",
            "[DEHUMANIZATION]"
        ]

    def query_with_tags(self, text, tags):
        # Créer un DataFrame temporaire avec le texte
        temp_df = pd.DataFrame([{"text": text}])
        
        try:
            # Appeler la fonction init.query_with_df (à implémenter selon votre init)
            result_df = self.init.query_with_df(df=temp_df, tags=tags)
            return result_df
        except Exception as e:
            return pd.DataFrame([{"error": str(e)}])

    def process_single_text(self, text, tags=None):
        try:
            # Si des tags sont fournis, traiter d'abord avec les tags
            if tags:
                tagged_result = self.query_with_tags(text, tags)
                if "error" in tagged_result:
                    return [{"error": tagged_result["error"].iloc[0]}]
                # Utiliser le résultat du traitement des tags comme entrée pour la traduction
                text = tagged_result["text"].iloc[0]

            # Configuration de la langue source
            self.tokenizer.src_lang = "ru_RU"
            
            # Tokenization
            inputs = self.tokenizer(
                text,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=1024
            ).to(self.device)
            
            # Génération
            with torch.no_grad():
                generated_tokens = self.model.generate(
                    **inputs,
                    forced_bos_token_id=self.tokenizer.lang_code_to_id["en_XX"],
                    max_length=1024,
                    num_beams=4,
                    length_penalty=1.0,
                    do_sample=False
                )
            
            # Décodage
            translation = self.tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
            
            return [{"output": translation}]
            
        except Exception as e:
            return [{"error": str(e)}]

    def __call__(self, data):
        try:
            if "inputs" not in data:
                return [{"error": "Request must contain 'inputs' field"}]
            
            inputs = data["inputs"]
            tags = data.get("tags", None)  # Récupérer les tags s'ils sont fournis
            
            # Validation des tags
            if tags:
                invalid_tags = [tag for tag in tags if tag not in self.tags]
                if invalid_tags:
                    return [{"error": f"Invalid tags: {invalid_tags}"}]
            
            # Traitement de l'entrée
            if isinstance(inputs, str):
                return self.process_single_text(inputs, tags)
            elif isinstance(inputs, list) and len(inputs) > 0:
                if isinstance(inputs[0], dict) and "input" in inputs[0]:
                    return self.process_single_text(inputs[0]["input"], tags)
                else:
                    return [{"error": "Invalid input format"}]
            else:
                return [{"error": "Invalid input format"}]
                
        except Exception as e:
            return [{"error": str(e)}]

# Exemple d'utilisation
def init_handler(model_dir):
    handler = EndpointHandler(model_dir)
    return handler