sobir-hf commited on
Commit
f2dafec
·
1 Parent(s): b1dc1be

Added fa-tj model and language detection

Browse files
Files changed (3) hide show
  1. app.py +21 -9
  2. data.py +14 -1
  3. fa-tj.pt +3 -0
app.py CHANGED
@@ -1,23 +1,35 @@
1
  import torch
2
  import streamlit as st
3
  from model import init_model, predict
4
- from data import Tokenizer, load_config
5
 
 
 
6
 
7
- MODEL_PATH = 'tj-fa.pt'
 
 
8
 
9
- config = load_config(MODEL_PATH)
10
- print('Config:', config)
11
- tokenizer = Tokenizer(config)
12
 
13
- # Load the model
14
- model = init_model(MODEL_PATH)
15
-
16
- # Create a text area box where the user can enter their text
17
  user_input = st.text_area("Enter some text here", value="Он ҷо, ки висоли дӯстон аст,\nВ-оллоҳ, ки миёни хона саҳрост.")
18
 
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  # Run the model on the user's text and store the output
22
  model_output = predict(model, tokenizer, user_input, device)
23
 
 
1
  import torch
2
  import streamlit as st
3
  from model import init_model, predict
4
+ from data import Tokenizer, load_config, language_detect
5
 
6
+ MODEL_PATH_TJ_FA = 'tj-fa.pt'
7
+ MODEL_PATH_FA_TJ = 'fa-tj.pt'
8
 
9
+ config_tj_fa = load_config(MODEL_PATH_TJ_FA)
10
+ tokenizer_tj_fa = Tokenizer(config_tj_fa)
11
+ model_tj_fa = init_model(MODEL_PATH_TJ_FA)
12
 
13
+ config_fa_tj = load_config(MODEL_PATH_FA_TJ)
14
+ tokenizer_fa_tj = Tokenizer(config_fa_tj)
15
+ model_fa_tj = init_model(MODEL_PATH_FA_TJ)
16
 
 
 
 
 
17
  user_input = st.text_area("Enter some text here", value="Он ҷо, ки висоли дӯстон аст,\nВ-оллоҳ, ки миёни хона саҳрост.")
18
 
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
 
21
+ # Detect language
22
+ detected_language = language_detect(user_input, tokenizer_tj_fa, tokenizer_fa_tj)
23
+
24
+ if detected_language == 'tj':
25
+ model = model_tj_fa
26
+ tokenizer = tokenizer_tj_fa
27
+ st.text('Detected language: Tajik (TJ) -> Transliterating to Persian (FA)')
28
+ else:
29
+ model = model_fa_tj
30
+ tokenizer = tokenizer_fa_tj
31
+ st.text('Detected language: Persian (FA) -> Transliterating to Tajik (TJ)')
32
+
33
  # Run the model on the user's text and store the output
34
  model_output = predict(model, tokenizer, user_input, device)
35
 
data.py CHANGED
@@ -21,7 +21,6 @@ class Tokenizer:
21
  self.trg_pad_idx = self.trg_char_index['<PAD>']
22
  self.trg_unk_idx = self.trg_char_index['<UNK>']
23
  self.src_unk_idx = self.src_char_index['<UNK>']
24
-
25
 
26
  def encode_src(self, text: str):
27
  src = [self.src_char_index.get(src_char, self.src_unk_idx) for src_char in text]
@@ -43,3 +42,17 @@ class Tokenizer:
43
  src_padded = pad_sequence(src, batch_first=True, padding_value=self.src_pad_idx)
44
  trg_padded = pad_sequence(trg, batch_first=True, padding_value=self.trg_pad_idx)
45
  return src_padded, trg_padded
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  self.trg_pad_idx = self.trg_char_index['<PAD>']
22
  self.trg_unk_idx = self.trg_char_index['<UNK>']
23
  self.src_unk_idx = self.src_char_index['<UNK>']
 
24
 
25
  def encode_src(self, text: str):
26
  src = [self.src_char_index.get(src_char, self.src_unk_idx) for src_char in text]
 
42
  src_padded = pad_sequence(src, batch_first=True, padding_value=self.src_pad_idx)
43
  trg_padded = pad_sequence(trg, batch_first=True, padding_value=self.trg_pad_idx)
44
  return src_padded, trg_padded
45
+
46
+
47
+ def language_detect(text, tokenizer_tj_fa: "Tokenizer", tokenizer_fa_tj: "Tokenizer"):
48
+ # Calculate the percentage of characters in text that are present in the source vocabulary of tokenizer_tj_fa
49
+ percentage_tj_fa = sum(char in tokenizer_tj_fa.src_vocab for char in text) / len(text)
50
+
51
+ # Calculate the percentage of characters in text that are present in the source vocabulary of tokenizer_fa_tj
52
+ percentage_fa_tj = sum(char in tokenizer_fa_tj.src_vocab for char in text) / len(text)
53
+
54
+ # Return the language code of the tokenizer with the higher percentage
55
+ if percentage_tj_fa > percentage_fa_tj:
56
+ return 'tj'
57
+ else:
58
+ return 'fa'
fa-tj.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bcdbc01b0630e0a01e42f5a06e1002a7cf0089ee0f32d3c21d9b359f47846aa7
3
+ size 22892367