import gradio as gr import yaml import torch from mmtafrica import load_params, translate from huggingface_hub import hf_hub_download language_map = {'English':'en','Swahili':'sw','Fon':'fon','Igbo':'ig', 'Kinyarwanda':'rw','Xhosa':'xh','Yoruba':'yo','French':'fr'} available_languages = list(language_map.keys()) # Load parameters and model from checkpoint checkpoint = hf_hub_download(repo_id="chrisjay/mmtafrica", filename="mmt_translation.pt") device = 'gpu' if torch.cuda.is_available() else 'cpu' params = load_params({'checkpoint':checkpoint,'device':device}) def get_translation(source_language,target_language,source_sentence=None): ''' This takes a sentence and gets the translation. ''' source_language_ = language_map[source_language] target_language_ = language_map[target_language] try: pred = translate(params,source_sentence,source_lang=source_language_,target_lang=target_language_) if pred=='': return f"Could not find translation" else: return pred except Exception as error: return f"Issue with translation: \n {error}" title = "MMTAfrica: Multilingual Machine Translation" description = "Enjoy our MMT model that allows you to translate among 6 African languages, English and French!\n\nProfitez de notre modèle MMT qui vous permet de traduire parmi 6 langues africaines, anglais et français!" iface = gr.Interface(fn=get_translation, inputs=[gr.inputs.Dropdown(choices = available_languages,default='Igbo'), gr.inputs.Dropdown(choices = available_languages,default='Fon'), gr.inputs.Textbox(label="Input")], outputs=gr.outputs.Textbox(type="auto", label='Translation'), title=title, description=description, enable_queue=True, theme='huggingface') iface.launch()