Spaces:
Sleeping
Sleeping
from transformers import T5ForConditionalGeneration, T5TokenizerFast | |
from torch.utils.data import DataLoader | |
import streamlit as st | |
import pandas as pd | |
import torch | |
import os | |
# Let us define the main page | |
st.markdown("Translation page 🔠") | |
# Dropdown for the translation type | |
translation_type = st.sidebar.selectbox("Translation Type", options=["French ➡️ Wolof", "Wolof ➡️ French"]) | |
# define a dictionary of versions | |
models = { | |
"Version ✌️": { | |
"French ➡️ Wolof": { | |
"checkpoints": "wolof-translate/wolof_translate/checkpoints/t5_small_custom_train_results_fw_v4", | |
"tokenizer": "wolof-translate/wolof_translate/tokenizers/t5_tokenizers/tokenizer_v4.json", | |
"max_len": None | |
} | |
}, | |
"Version ☝️": { | |
"French ➡️ Wolof": { | |
"checkpoints": "wolof-translate/wolof_translate/checkpoints/t5_small_custom_train_results_fw_v3", | |
"tokenizer": "wolof-translate/wolof_translate/tokenizers/t5_tokenizers/tokenizer_v3.json", | |
"max_len": 51 | |
}, | |
"Wolof ➡️ French": { | |
"checkpoints": "wolof-translate/wolof_translate/checkpoints/t5_small_custom_train_results_wf_v3", | |
"tokenizer": "wolof-translate/wolof_translate/trokenizers/t5_tokenizers/tokenizer_v3.json", | |
"max_len": 51 | |
} | |
} | |
} | |
# add special characters from Wolof | |
sp_wolof_chars = pd.read_csv('wolof-translate/wolof_translate/data/wolof_writing/wolof_special_chars.csv') | |
# add definitions | |
sp_wolof_words = pd.read_csv('wolof-translate/wolof_translate/data/wolof_writing/definitions.csv') | |
# let us add a callback functions to change the input text | |
def add_symbol_to_text(): | |
st.session_state.input_text += st.session_state.symbol | |
def add_word_to_text(): | |
word = st.session_state.word.split('/')[0].strip() | |
st.session_state.input_text += word | |
# Dropdown for introducing wolof special characters | |
if translation_type == "Wolof ➡️ French": | |
symbol = st.sidebar.selectbox("Wolof characters", key="symbol", options = sp_wolof_chars['wolof_special_chars'], on_change=add_symbol_to_text) | |
word = st.sidebar.selectbox("Wolof words/Definitions", key="word", options = [sp_wolof_words.loc[i, 'wolof']+" / "+sp_wolof_words.loc[i, 'french'] for i in range(sp_wolof_words.shape[0])], on_change=add_word_to_text) | |
# Dropdown for the model version | |
version = st.sidebar.selectbox("Model version", options=["Version ☝️", "Version ✌️"]) | |
# Recuperate the number of sentences to provide | |
temperature = st.sidebar.slider("How randomly need you the translated sentences to be from 0% to 100%", min_value = 0, | |
max_value = 100) | |
# make the process | |
try: | |
# recuperate the max length | |
max_len = models[version][translation_type]['max_len'] | |
# let us get the best model | |
def get_modelfw_v3(): | |
# recuperate checkpoints | |
checkpoints = torch.load(os.path.join('wolof-translate/wolof_translate/checkpoints/t5_small_custom_train_results_fw_v3', "best_checkpoints.pth"), map_location=torch.device('cpu')) | |
# recuperate the tokenizer | |
tokenizer_file = "wolof-translate/wolof_translate/tokenizers/t5_tokenizers/tokenizer_v3.json" | |
# initialize the tokenizer | |
tokenizer = T5TokenizerFast(tokenizer_file=tokenizer_file) | |
model = T5ForConditionalGeneration.from_pretrained('t5-small') | |
# resize the token embeddings | |
model.resize_token_embeddings(len(tokenizer)) | |
model.load_state_dict(checkpoints['model_state_dict']) | |
return model, tokenizer | |
def get_modelwf_v3(): | |
# recuperate checkpoints | |
checkpoints = torch.load(os.path.join('wolof-translate/wolof_translate/checkpoints/t5_small_custom_train_results_wf_v3', "best_checkpoints.pth"), map_location=torch.device('cpu')) | |
# recuperate the tokenizer | |
tokenizer_file = "wolof-translate/wolof_translate/tokenizers/t5_tokenizers/tokenizer_v3.json" | |
# initialize the tokenizer | |
tokenizer = T5TokenizerFast(tokenizer_file=tokenizer_file) | |
model = T5ForConditionalGeneration.from_pretrained('t5-small') | |
# resize the token embeddings | |
model.resize_token_embeddings(len(tokenizer)) | |
model.load_state_dict(checkpoints['model_state_dict']) | |
return model, tokenizer | |
if version == "Version ☝️": | |
if translation_type == "French ➡️ Wolof": | |
model, tokenizer = get_modelfw_v3() | |
elif translation_type == "Wolof ➡️ French": | |
model, tokenizer = get_modelwf_v3() | |
# set the model to eval mode | |
_ = model.eval() | |
language = "Wolof" if translation_type == "French ➡️ Wolof" else "French" | |
# Add a title | |
st.header(f"Translate French sentences to {language} 👌") | |
# Recuperate two columns | |
left, right = st.columns(2) | |
if translation_type == "French ➡️ Wolof": | |
# recuperate sentences | |
left.subheader('Give me some sentences in French: ') | |
else: | |
# recuperate sentences | |
left.subheader('Give me some sentences in Wolof: ') | |
# for i in range(number): | |
left.text_input(f"- Sentence", key = f"input_text") | |
# run model inference on all test data | |
original_translations, predicted_translations, original_texts, scores = [], [], [], {} | |
if translation_type == "French ➡️ Wolof": | |
# print a sentence recuperated from the session | |
right.subheader("Translation to Wolof:") | |
else: | |
# print a sentence recuperated from the session | |
right.subheader("Translation to French:") | |
# for i in range(number): | |
sentence = st.session_state[f"input_text"] + tokenizer.eos_token | |
if not sentence == tokenizer.eos_token: | |
# Let us encode the sentences | |
encoding = tokenizer([sentence], return_tensors='pt', max_length=max_len, padding='max_length', truncation=True) | |
# Let us recuperate the input ids | |
input_ids = encoding.input_ids | |
# Let us recuperate the mask | |
mask = encoding.attention_mask | |
# Let us recuperate the pad token id | |
pad_token_id = tokenizer.pad_token_id | |
# perform prediction | |
predictions = model.generate(input_ids, do_sample = False, top_k = 50, max_length = max_len, top_p = 0.90, | |
temperature = temperature/100, num_return_sequences = 0, attention_mask = mask, pad_token_id = pad_token_id) | |
# decode the predictions | |
predicted_sentence = tokenizer.batch_decode(predictions, skip_special_tokens = True) | |
# provide the prediction | |
right.write(f"Translation: {predicted_sentence[0]}") | |
else: | |
# provide the prediction | |
right.write(f"Translation: ") | |
except Exception as e: | |
st.warning("The chosen model is not available yet !", icon = "⚠️") | |
st.write(e) | |