Sheemz commited on
Commit
6db5958
1 Parent(s): c105317

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -19
app.py CHANGED
@@ -1,23 +1,42 @@
1
- import gradio as gr
2
- from transformers import pipeline
 
3
 
4
- # Load your model from the Hugging Face repository
5
- model_name = "SLPG/English_to_Urdu_Unsupervised_MT"
6
- translator = pipeline("translation_en_to_ur", model=model_name)
7
 
8
- # Define the translation function
9
- def translate_text(input_text):
10
- result = translator(input_text, max_length=400)
11
- return result[0]['translation_text']
 
12
 
13
- # Create the Gradio interface
14
- interface = gr.Interface(
15
- fn=translate_text,
16
- inputs=gr.inputs.Textbox(lines=2, placeholder="Enter text to translate"),
17
- outputs="text",
18
- title="Translation Model Inference",
19
- description="Translate text using the Hugging Face translation model."
20
- )
21
 
22
- # Launch the interface
23
- interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import os
4
 
5
+ # Load the model checkpoint
6
+ model_path = "https://huggingface.co/SLPG/English_to_Urdu_Unsupervised_MT/tree/main/checkpoint_8_96000.pt"
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
 
9
+ # Define a function to load the model
10
+ def load_model(model_path):
11
+ model = torch.load(model_path, map_location=device)
12
+ model.eval()
13
+ return model
14
 
15
+ # Load the dictionaries
16
+ def load_dictionary(dict_path):
17
+ with open(dict_path, 'r') as file:
18
+ dictionary = {line.split()[0]: i for i, line in enumerate(file.readlines())}
19
+ return dictionary
 
 
 
20
 
21
+ # Translation function
22
+ def translate(model, input_text, src_dict, tgt_dict):
23
+ # Implement the logic to translate using your model
24
+ # This is a placeholder, modify according to your model's requirements
25
+ translated_text = "Translated text here"
26
+ return translated_text
27
+
28
+ # Load model and dictionaries
29
+ model = load_model(model_path)
30
+ src_dict = load_dictionary("SLPG/English_to_Urdu_Unsupervised_MT/dict.en.txt")
31
+ tgt_dict = load_dictionary("SLPG/English_to_Urdu_Unsupervised_MT/dict.ur.txt")
32
+
33
+ # Streamlit interface
34
+ st.title("Translation Model Inference")
35
+ input_text = st.text_area("Enter text to translate", "")
36
+
37
+ if st.button("Translate"):
38
+ if input_text:
39
+ translated_text = translate(model, input_text, src_dict, tgt_dict)
40
+ st.write(f"Translated Text: {translated_text}")
41
+ else:
42
+ st.write("Please enter text to translate.")