File size: 1,473 Bytes
6db5958
 
 
a95bd26
6db5958
 
 
a95bd26
6db5958
 
 
 
 
a95bd26
6db5958
 
 
 
 
a95bd26
6db5958
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import streamlit as st
import torch
import os

# Load the model checkpoint
model_path = "https://huggingface.co/SLPG/English_to_Urdu_Unsupervised_MT/tree/main/checkpoint_8_96000.pt"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define a function to load the model
def load_model(model_path):
    model = torch.load(model_path, map_location=device)
    model.eval()
    return model

# Load the dictionaries
def load_dictionary(dict_path):
    with open(dict_path, 'r') as file:
        dictionary = {line.split()[0]: i for i, line in enumerate(file.readlines())}
    return dictionary

# Translation function
def translate(model, input_text, src_dict, tgt_dict):
    # Implement the logic to translate using your model
    # This is a placeholder, modify according to your model's requirements
    translated_text = "Translated text here"
    return translated_text

# Load model and dictionaries
model = load_model(model_path)
src_dict = load_dictionary("SLPG/English_to_Urdu_Unsupervised_MT/dict.en.txt")
tgt_dict = load_dictionary("SLPG/English_to_Urdu_Unsupervised_MT/dict.ur.txt")

# Streamlit interface
st.title("Translation Model Inference")
input_text = st.text_area("Enter text to translate", "")

if st.button("Translate"):
    if input_text:
        translated_text = translate(model, input_text, src_dict, tgt_dict)
        st.write(f"Translated Text: {translated_text}")
    else:
        st.write("Please enter text to translate.")