Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -10,22 +10,23 @@ MODELS = ["facebook/nllb-200-distilled-600M","DigitalUmuganda/Finetuned-NLLB"]
|
|
10 |
# tokenizer = AutoTokenizer.from_pretrained(CKPT)
|
11 |
|
12 |
device = 0 if torch.cuda.is_available() else -1
|
|
|
|
|
|
|
13 |
|
14 |
-
|
15 |
# def translate(text, src_lang, tgt_lang, max_length=400):
|
16 |
def translate(text, src_lang, tgt_lang, CKPT, max_length=400):
|
17 |
|
18 |
"""
|
19 |
Translate the text from source lang to target lang
|
20 |
"""
|
21 |
-
|
22 |
-
tokenizer = AutoTokenizer.from_pretrained(CKPT)
|
23 |
-
|
24 |
translation_pipeline = pipeline(TASK,
|
25 |
tokenizer=tokenizer,
|
26 |
src_lang=src_lang,
|
27 |
tgt_lang=tgt_lang,
|
28 |
-
model =
|
29 |
max_length=max_length,
|
30 |
device=device)
|
31 |
|
|
|
10 |
# tokenizer = AutoTokenizer.from_pretrained(CKPT)
|
11 |
|
12 |
device = 0 if torch.cuda.is_available() else -1
|
13 |
+
fb_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
|
14 |
+
du_model = AutoModelForSeq2SeqLM.from_pretrained("DigitalUmuganda/Finetuned-NLLB")
|
15 |
+
tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
|
16 |
|
17 |
+
models = {"facebook/nllb-200-distilled-600M":fb_model,"DigitalUmuganda/Finetuned-NLLB":du_model}
|
18 |
# def translate(text, src_lang, tgt_lang, max_length=400):
|
19 |
def translate(text, src_lang, tgt_lang, CKPT, max_length=400):
|
20 |
|
21 |
"""
|
22 |
Translate the text from source lang to target lang
|
23 |
"""
|
24 |
+
|
|
|
|
|
25 |
translation_pipeline = pipeline(TASK,
|
26 |
tokenizer=tokenizer,
|
27 |
src_lang=src_lang,
|
28 |
tgt_lang=tgt_lang,
|
29 |
+
model = models[CKPT],
|
30 |
max_length=max_length,
|
31 |
device=device)
|
32 |
|