|
import gradio as gr |
|
from transformers import ( |
|
MBart50TokenizerFast, |
|
MBartForConditionalGeneration, |
|
AutoTokenizer, |
|
AutoModelForSequenceClassification, |
|
) |
|
import torch |
|
|
|
|
|
lang_detector_name = "Aesopskenya/LanguageDetector" |
|
lang_tokenizer = AutoTokenizer.from_pretrained(lang_detector_name) |
|
lang_model = AutoModelForSequenceClassification.from_pretrained(lang_detector_name) |
|
|
|
|
|
lang_to_model = { |
|
"Gikuyu": "Aesopskenya/translator", |
|
"Kalenjin": "Aesopskenya/KalenjinTranslator", |
|
"Kamba": "Aesopskenya/KambaTranslation", |
|
"Luo": "Aesopskenya/LuoTranslator", |
|
"Sheng": "Aesopskenya/ShengTranslation", |
|
} |
|
|
|
|
|
reverse_mapper = { |
|
0: "English", |
|
1: "Sheng", |
|
2: "Other", |
|
3: "Luhya", |
|
4: "Kamba", |
|
5: "Gikuyu", |
|
6: "Kalenjin", |
|
7: "Luo", |
|
} |
|
|
|
|
|
def detect_language(text): |
|
inputs = lang_tokenizer( |
|
text, |
|
max_length=128, |
|
padding=True, |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
with torch.no_grad(): |
|
outputs = lang_model(**inputs) |
|
logits = outputs.logits |
|
prediction = torch.argmax(logits, dim=-1).item() |
|
return reverse_mapper[prediction] |
|
|
|
|
|
def load_model_and_tokenizer(language): |
|
model_name = lang_to_model.get(language) |
|
if model_name: |
|
tokenizer = MBart50TokenizerFast.from_pretrained(model_name) |
|
model = MBartForConditionalGeneration.from_pretrained(model_name) |
|
return tokenizer, model |
|
return None, None |
|
|
|
|
|
def translate_text(text): |
|
|
|
detected_language = detect_language(text) |
|
print(f"Detected Language: {detected_language}") |
|
if detected_language not in lang_to_model: |
|
return f"Detected Language: {detected_language}. Language not supported for translation." |
|
|
|
|
|
tokenizer, model = load_model_and_tokenizer(detected_language) |
|
if not tokenizer or not model: |
|
return "Error loading the translation model." |
|
|
|
|
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128) |
|
|
|
|
|
outputs = model.generate(inputs.input_ids, max_length=128) |
|
|
|
|
|
translation = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
return f"Detected Language: {detected_language}\nTranslation: {translation}" |
|
|
|
|
|
|
|
iface = gr.Interface( |
|
fn=translate_text, |
|
inputs="text", |
|
outputs="text", |
|
title="Multi-Language Translator", |
|
description="Enter a sentence, and the model will detect its language and translate it into English.", |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
iface.launch(server_name="0.0.0.0", server_port=7860) |
|
|