Translation_API / app.py
Sheemz's picture
Update app.py
6db5958 verified
raw
history blame
1.47 kB
import streamlit as st
import torch
import os
# Load the model checkpoint
model_path = "https://huggingface.co/SLPG/English_to_Urdu_Unsupervised_MT/tree/main/checkpoint_8_96000.pt"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Define a function to load the model
def load_model(model_path):
model = torch.load(model_path, map_location=device)
model.eval()
return model
# Load the dictionaries
def load_dictionary(dict_path):
with open(dict_path, 'r') as file:
dictionary = {line.split()[0]: i for i, line in enumerate(file.readlines())}
return dictionary
# Translation function
def translate(model, input_text, src_dict, tgt_dict):
# Implement the logic to translate using your model
# This is a placeholder, modify according to your model's requirements
translated_text = "Translated text here"
return translated_text
# Load model and dictionaries
model = load_model(model_path)
src_dict = load_dictionary("SLPG/English_to_Urdu_Unsupervised_MT/dict.en.txt")
tgt_dict = load_dictionary("SLPG/English_to_Urdu_Unsupervised_MT/dict.ur.txt")
# Streamlit interface
st.title("Translation Model Inference")
input_text = st.text_area("Enter text to translate", "")
if st.button("Translate"):
if input_text:
translated_text = translate(model, input_text, src_dict, tgt_dict)
st.write(f"Translated Text: {translated_text}")
else:
st.write("Please enter text to translate.")