Sheemz commited on
Commit
3913596
1 Parent(s): 631c1d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -35
app.py CHANGED
@@ -1,14 +1,21 @@
1
  import streamlit as st
2
  import torch
 
3
  import os
4
  import requests
5
 
6
- # Define the URL of your model file
7
- model_url = "https://huggingface.co/SLPG/English_to_Urdu_Unsupervised_MT/resolve/main/checkpoint_8_96000.pt"
8
- model_path = "checkpoint_8_96000.pt"
 
9
 
10
- # Define a function to download the model file
11
- def download_model(url, file_path):
 
 
 
 
 
12
  if not os.path.exists(file_path):
13
  with requests.get(url, stream=True) as r:
14
  r.raise_for_status()
@@ -17,35 +24,17 @@ def download_model(url, file_path):
17
  f.write(chunk)
18
  return file_path
19
 
20
- # Load the model checkpoint
21
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
-
23
- # Define a function to load the model
24
- def load_model(model_path):
25
- model = torch.load(model_path, map_location=device)
26
- model.eval()
27
- return model
28
-
29
- # Load the dictionaries
30
- def load_dictionary(dict_path):
31
- with open(dict_path, 'r') as file:
32
- dictionary = {line.split()[0]: i for i, line in enumerate(file.readlines())}
33
- return dictionary
34
-
35
- # Translation function
36
- def translate(model, input_text, src_dict, tgt_dict):
37
- # Implement the logic to translate using your model
38
- # This is a placeholder, modify according to your model's requirements
39
- translated_text = "Translated text here"
40
- return translated_text
41
-
42
- # Download the model file
43
- download_model(model_url, model_path)
44
 
45
- # Load model and dictionaries
46
- model = load_model(model_path)
47
- src_dict = load_dictionary("path/to/dict.en.txt")
48
- tgt_dict = load_dictionary("path/to/dict.ur.txt")
 
 
49
 
50
  # Streamlit interface
51
  st.title("Translation Model Inference")
@@ -53,7 +42,7 @@ input_text = st.text_area("Enter text to translate", "")
53
 
54
  if st.button("Translate"):
55
  if input_text:
56
- translated_text = translate(model, input_text, src_dict, tgt_dict)
57
- st.write(f"Translated Text: {translated_text}")
58
  else:
59
  st.write("Please enter text to translate.")
 
1
  import streamlit as st
2
  import torch
3
+ from fairseq.models.transformer import TransformerModel
4
  import os
5
  import requests
6
 
7
+ # Define the URLs of your model and dictionary files
8
+ model_url = "https://huggingface.co/SLPG/English_to_Urdu_Unsupervised_MT/resolve/main/sent_iwslt-bt-enur_42.pt"
9
+ dict_en_url = "https://huggingface.co/SLPG/English_to_Urdu_Unsupervised_MT/resolve/main/dict.en.txt"
10
+ dict_ur_url = "https://huggingface.co/SLPG/English_to_Urdu_Unsupervised_MT/resolve/main/dict.ur.txt"
11
 
12
+ # Define the paths to save the downloaded files
13
+ model_path = "sent_iwslt-bt-enur_42.pt"
14
+ dict_en_path = "dict.en.txt"
15
+ dict_ur_path = "dict.ur.txt"
16
+
17
+ # Define a function to download files
18
+ def download_file(url, file_path):
19
  if not os.path.exists(file_path):
20
  with requests.get(url, stream=True) as r:
21
  r.raise_for_status()
 
24
  f.write(chunk)
25
  return file_path
26
 
27
+ # Download the model and dictionary files
28
+ download_file(model_url, model_path)
29
+ download_file(dict_en_url, dict_en_path)
30
+ download_file(dict_ur_url, dict_ur_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
+ # Load the model
33
+ en_ur_model = TransformerModel.from_pretrained(
34
+ '.',
35
+ checkpoint_file=model_path,
36
+ data_name_or_path='.'
37
+ )
38
 
39
  # Streamlit interface
40
  st.title("Translation Model Inference")
 
42
 
43
  if st.button("Translate"):
44
  if input_text:
45
+ output_text = en_ur_model.translate(input_text)
46
+ st.write(f"Translated Text: {output_text}")
47
  else:
48
  st.write("Please enter text to translate.")