File size: 3,168 Bytes
aa7cb02
 
7a3f742
81b0f36
 
 
aa7cb02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e69b04
aa7cb02
5466d29
add0ca0
e1c3caf
 
5466d29
 
aa7cb02
 
3f89b0f
aa7cb02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch
import os
# Disable numba caching
os.environ["NUMBA_CACHE_DIR"] = "/tmp/numba_cache"
os.environ["NUMBA_DISABLE_JIT"] = "1"
def nllb():
    """
    Load and return the NLLB (No Language Left Behind) model and tokenizer.

    This function loads the NLLB-200-distilled-1.3B model and tokenizer from Hugging Face's Transformers library.
    The model is configured to use a GPU if available, otherwise it defaults to CPU.

    Returns:
        tuple: A tuple containing the loaded model and tokenizer.
            - model (transformers.AutoModelForSeq2SeqLM): The loaded NLLB model.
            - tokenizer (transformers.AutoTokenizer): The loaded tokenizer.
            
    Example usage:
        model, tokenizer = nllb()
    """
    #device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # Load the tokenizer and model
    # Set Hugging Face cache directory
    # Ensure the cache directory exists and has the correct permissions
    os.environ['HF_HOME'] = '/app/cache/huggingface'
    os.environ['TRANSFORMERS_CACHE'] = '/app/cache/huggingface'

    # Load models
    tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-1.3B")
    model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-1.3B").to(device)
    
    return model, tokenizer

def nllb_translate(model, tokenizer, article, language):
    """
    Translate an article using the NLLB model and tokenizer.

    Args:
        model (transformers.AutoModelForSeq2SeqLM): The NLLB model to use for translation.
            Example: model, tokenizer = nllb()
        tokenizer (transformers.AutoTokenizer): The tokenizer to use with the NLLB model.
            Example: model, tokenizer = nllb()
        article (str): The article text to be translated.
            Example: "This is a sample article."
        language (str): The target language for translation. Must be either 'spanish' or 'english'.
            Example: "spanish"

    Returns:
        str: The translated text.
            Example: "Este es un artículo de muestra."
    """
    try:
        # Tokenize the text
        inputs = tokenizer(article, return_tensors="pt")

        # Move the tokenized inputs to the same device as the model
        inputs = {k: v.to(model.device) for k, v in inputs.items()}

        if language == "es":
            translated_tokens = model.generate(
                **inputs, forced_bos_token_id=tokenizer.lang_code_to_id["spa_Latn"], max_length=30
            )
        elif language == "en":
            translated_tokens = model.generate(
                **inputs, forced_bos_token_id=tokenizer.lang_code_to_id["eng_Latn"], max_length=30
            )
        else:
            raise ValueError("Unsupported language. Use 'es' or 'en'.")

        # Decode the translation
        text = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
        return text
    
    except Exception as e:
        print(f"Error during translation: {e}")
        return "Translation failed"