Translator / model.py
Taghrid's picture
Upload folder using huggingface_hub
870b562 verified
import os
# hide TF warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import os
import tensorflow as tf
import numpy as np
import pickle
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.sequence import pad_sequences
import logging
class Translator:
def __init__(self, model_path):
logging.info("Translator class initialized")
self.model = load_model(model_path)
logging.info("Model is loaded!")
# Load tokenizers from pickle files
with open('tokenizer_eng.pkl', 'rb') as file:
self.tokenizer_eng = pickle.load(file)
with open('tokenizer_fr.pkl', 'rb') as file:
self.tokenizer_fr = pickle.load(file)
# Load max_eng from configuration file
with open('config.pkl', 'rb') as file:
config = pickle.load(file)
self.max_eng = config['max_eng']
print("Max English Length: ", self.max_eng)
def predict(self, input_sentence):
# process the eng text
padded_sequence = self.process_text(input_sentence)
predicted_sequence = self.model.predict(padded_sequence)
# return french text
translated_sequence = []
for token_index in predicted_sequence[0]:
# Get the index of the token with maximum probability
predicted_token_index = np.argmax(token_index)
# Convert the index to word
predicted_word = self.tokenizer_fr.index_word.get(predicted_token_index, '<OOV>')
# If the word is a padding token, ignore it
if predicted_word != '<OOV>':
translated_sequence.append(predicted_word)
# Join the words to form the translated sentence
translated_sentence = ' '.join(translated_sequence)
return translated_sentence
def process_text(self, input_sentence):
# Tokenize the input sentence
input_sequence = self.tokenizer_eng.texts_to_sequences([input_sentence])
# Pad the sequence
padded_sequence = pad_sequences(input_sequence, maxlen=self.max_eng, padding='post')
return padded_sequence
def main():
model = Translator('model.h5')
translated_sentence = model.predict("she is driving the truck")
logging.info("The translation is {}".format(translated_sentence))
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
main()