Translation_API / app.py
Sheemz's picture
Update app.py
631c1d5 verified
raw
history blame
1.94 kB
import streamlit as st
import torch
import os
import requests
# Define the URL of your model file
model_url = "https://huggingface.co/SLPG/English_to_Urdu_Unsupervised_MT/resolve/main/checkpoint_8_96000.pt"
model_path = "checkpoint_8_96000.pt"
# Define a function to download the model file
def download_model(url, file_path):
if not os.path.exists(file_path):
with requests.get(url, stream=True) as r:
r.raise_for_status()
with open(file_path, 'wb') as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
return file_path
# Load the model checkpoint
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Define a function to load the model
def load_model(model_path):
model = torch.load(model_path, map_location=device)
model.eval()
return model
# Load the dictionaries
def load_dictionary(dict_path):
with open(dict_path, 'r') as file:
dictionary = {line.split()[0]: i for i, line in enumerate(file.readlines())}
return dictionary
# Translation function
def translate(model, input_text, src_dict, tgt_dict):
# Implement the logic to translate using your model
# This is a placeholder, modify according to your model's requirements
translated_text = "Translated text here"
return translated_text
# Download the model file
download_model(model_url, model_path)
# Load model and dictionaries
model = load_model(model_path)
src_dict = load_dictionary("path/to/dict.en.txt")
tgt_dict = load_dictionary("path/to/dict.ur.txt")
# Streamlit interface
st.title("Translation Model Inference")
input_text = st.text_area("Enter text to translate", "")
if st.button("Translate"):
if input_text:
translated_text = translate(model, input_text, src_dict, tgt_dict)
st.write(f"Translated Text: {translated_text}")
else:
st.write("Please enter text to translate.")