Spaces:
Runtime error
Runtime error
import os | |
# hide TF warnings | |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' | |
import os | |
import tensorflow as tf | |
import numpy as np | |
import pickle | |
from tensorflow.keras.models import load_model | |
from tensorflow.keras.preprocessing.sequence import pad_sequences | |
import logging | |
class Translator: | |
def __init__(self, model_path): | |
logging.info("Translator class initialized") | |
self.model = load_model(model_path) | |
logging.info("Model is loaded!") | |
# Load tokenizers from pickle files | |
with open('tokenizer_eng.pkl', 'rb') as file: | |
self.tokenizer_eng = pickle.load(file) | |
with open('tokenizer_fr.pkl', 'rb') as file: | |
self.tokenizer_fr = pickle.load(file) | |
# Load max_eng from configuration file | |
with open('config.pkl', 'rb') as file: | |
config = pickle.load(file) | |
self.max_eng = config['max_eng'] | |
print("Max English Length: ", self.max_eng) | |
def predict(self, input_sentence): | |
# process the eng text | |
padded_sequence = self.process_text(input_sentence) | |
predicted_sequence = self.model.predict(padded_sequence) | |
# return french text | |
translated_sequence = [] | |
for token_index in predicted_sequence[0]: | |
# Get the index of the token with maximum probability | |
predicted_token_index = np.argmax(token_index) | |
# Convert the index to word | |
predicted_word = self.tokenizer_fr.index_word.get(predicted_token_index, '<OOV>') | |
# If the word is a padding token, ignore it | |
if predicted_word != '<OOV>': | |
translated_sequence.append(predicted_word) | |
# Join the words to form the translated sentence | |
translated_sentence = ' '.join(translated_sequence) | |
return translated_sentence | |
def process_text(self, input_sentence): | |
# Tokenize the input sentence | |
input_sequence = self.tokenizer_eng.texts_to_sequences([input_sentence]) | |
# Pad the sequence | |
padded_sequence = pad_sequences(input_sequence, maxlen=self.max_eng, padding='post') | |
return padded_sequence | |
def main(): | |
model = Translator('model.h5') | |
translated_sentence = model.predict("she is driving the truck") | |
logging.info("The translation is {}".format(translated_sentence)) | |
if __name__ == "__main__": | |
logging.basicConfig(level=logging.INFO) | |
main() | |