Spaces:
Sleeping
Sleeping
Added fa-tj model and language detection
Browse files
app.py
CHANGED
@@ -1,23 +1,35 @@
|
|
1 |
import torch
|
2 |
import streamlit as st
|
3 |
from model import init_model, predict
|
4 |
-
from data import Tokenizer, load_config
|
5 |
|
|
|
|
|
6 |
|
7 |
-
|
|
|
|
|
8 |
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
|
13 |
-
# Load the model
|
14 |
-
model = init_model(MODEL_PATH)
|
15 |
-
|
16 |
-
# Create a text area box where the user can enter their text
|
17 |
user_input = st.text_area("Enter some text here", value="Он ҷо, ки висоли дӯстон аст,\nВ-оллоҳ, ки миёни хона саҳрост.")
|
18 |
|
19 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
# Run the model on the user's text and store the output
|
22 |
model_output = predict(model, tokenizer, user_input, device)
|
23 |
|
|
|
1 |
import torch
|
2 |
import streamlit as st
|
3 |
from model import init_model, predict
|
4 |
+
from data import Tokenizer, load_config, language_detect
|
5 |
|
6 |
+
MODEL_PATH_TJ_FA = 'tj-fa.pt'
|
7 |
+
MODEL_PATH_FA_TJ = 'fa-tj.pt'
|
8 |
|
9 |
+
config_tj_fa = load_config(MODEL_PATH_TJ_FA)
|
10 |
+
tokenizer_tj_fa = Tokenizer(config_tj_fa)
|
11 |
+
model_tj_fa = init_model(MODEL_PATH_TJ_FA)
|
12 |
|
13 |
+
config_fa_tj = load_config(MODEL_PATH_FA_TJ)
|
14 |
+
tokenizer_fa_tj = Tokenizer(config_fa_tj)
|
15 |
+
model_fa_tj = init_model(MODEL_PATH_FA_TJ)
|
16 |
|
|
|
|
|
|
|
|
|
17 |
user_input = st.text_area("Enter some text here", value="Он ҷо, ки висоли дӯстон аст,\nВ-оллоҳ, ки миёни хона саҳрост.")
|
18 |
|
19 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
20 |
|
21 |
+
# Detect language
|
22 |
+
detected_language = language_detect(user_input, tokenizer_tj_fa, tokenizer_fa_tj)
|
23 |
+
|
24 |
+
if detected_language == 'tj':
|
25 |
+
model = model_tj_fa
|
26 |
+
tokenizer = tokenizer_tj_fa
|
27 |
+
st.text('Detected language: Tajik (TJ) -> Transliterating to Persian (FA)')
|
28 |
+
else:
|
29 |
+
model = model_fa_tj
|
30 |
+
tokenizer = tokenizer_fa_tj
|
31 |
+
st.text('Detected language: Persian (FA) -> Transliterating to Tajik (TJ)')
|
32 |
+
|
33 |
# Run the model on the user's text and store the output
|
34 |
model_output = predict(model, tokenizer, user_input, device)
|
35 |
|
data.py
CHANGED
@@ -21,7 +21,6 @@ class Tokenizer:
|
|
21 |
self.trg_pad_idx = self.trg_char_index['<PAD>']
|
22 |
self.trg_unk_idx = self.trg_char_index['<UNK>']
|
23 |
self.src_unk_idx = self.src_char_index['<UNK>']
|
24 |
-
|
25 |
|
26 |
def encode_src(self, text: str):
|
27 |
src = [self.src_char_index.get(src_char, self.src_unk_idx) for src_char in text]
|
@@ -43,3 +42,17 @@ class Tokenizer:
|
|
43 |
src_padded = pad_sequence(src, batch_first=True, padding_value=self.src_pad_idx)
|
44 |
trg_padded = pad_sequence(trg, batch_first=True, padding_value=self.trg_pad_idx)
|
45 |
return src_padded, trg_padded
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
self.trg_pad_idx = self.trg_char_index['<PAD>']
|
22 |
self.trg_unk_idx = self.trg_char_index['<UNK>']
|
23 |
self.src_unk_idx = self.src_char_index['<UNK>']
|
|
|
24 |
|
25 |
def encode_src(self, text: str):
|
26 |
src = [self.src_char_index.get(src_char, self.src_unk_idx) for src_char in text]
|
|
|
42 |
src_padded = pad_sequence(src, batch_first=True, padding_value=self.src_pad_idx)
|
43 |
trg_padded = pad_sequence(trg, batch_first=True, padding_value=self.trg_pad_idx)
|
44 |
return src_padded, trg_padded
|
45 |
+
|
46 |
+
|
47 |
+
def language_detect(text, tokenizer_tj_fa: "Tokenizer", tokenizer_fa_tj: "Tokenizer"):
|
48 |
+
# Calculate the percentage of characters in text that are present in the source vocabulary of tokenizer_tj_fa
|
49 |
+
percentage_tj_fa = sum(char in tokenizer_tj_fa.src_vocab for char in text) / len(text)
|
50 |
+
|
51 |
+
# Calculate the percentage of characters in text that are present in the source vocabulary of tokenizer_fa_tj
|
52 |
+
percentage_fa_tj = sum(char in tokenizer_fa_tj.src_vocab for char in text) / len(text)
|
53 |
+
|
54 |
+
# Return the language code of the tokenizer with the higher percentage
|
55 |
+
if percentage_tj_fa > percentage_fa_tj:
|
56 |
+
return 'tj'
|
57 |
+
else:
|
58 |
+
return 'fa'
|
fa-tj.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bcdbc01b0630e0a01e42f5a06e1002a7cf0089ee0f32d3c21d9b359f47846aa7
|
3 |
+
size 22892367
|