File size: 3,791 Bytes
6227608
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import itertools
import os
from pathlib import Path
import yaml
from download_utils import download_dataset
import utils
from formality_transformer import FormalityTransformer
from hazm import SentenceTokenizer



def translate_short_sent(model, sent):
    out_dict = {}
    txt = utils.cleanify(sent)
    is_valid = lambda w: model.oneshot_transformer.transform(w, None)
    cnd_tokens = model.informal_tokenizer.tokenize(txt, is_valid)
    for tokens in cnd_tokens:
        tokens = [t for t in tokens if t != '']
        new_tokens = []
        for t in tokens:
            new_tokens.extend(t.split())
        txt = ' '.join(new_tokens)
        tokens = txt.split()
        candidates = []
        for index in range(len(tokens)):
            tok = tokens[index]
            cnd = set()
            pos = None
            if model.verb_handler.informal_to_formal(tok):
                pos = 'VERB'
            f_words_lemma = model.oneshot_transformer.transform(tok, pos)
            f_words_lemma = list(f_words_lemma)
            for index, (word, lemma) in enumerate(f_words_lemma):
                if pos != 'VERB' and tok not in model.mapper and model.should_filtered_by_one_bigram(lemma, word, tok):
                    f_words_lemma[index] = (tok, tok)
                else:
                    word_toks = word.split()
                    word_repr = ''
                    for t in word_toks:
                        word_repr += ' ' + t
                    word_repr = word_repr.strip()
                    word_repr = model.repalce_for_gpt2(word_repr)
                    f_words_lemma[index] = (word, word_repr)
            if f_words_lemma:
                cnd.update(f_words_lemma)
            else:
                cnd = {(tok, tok)}
            candidates.append(cnd)
        all_combinations = itertools.product(*candidates)
        all_combinations_list = list(all_combinations)
        for id, cnd in enumerate(all_combinations_list):
            normal_seq = ' '.join([c[0] for c in cnd])
            lemma_seq = ' '.join([c[1] for c in cnd])
            lemma_seq = utils.clean_text_for_lm(lemma_seq)
            out_dict[id] = (normal_seq, lemma_seq)
        candidates = [[item[0] for item in candidate_phrases] for candidate_phrases in candidates]
        return model.lm_obj.get_best(candidates)


def translate(model, sentence_tokenizer, txt):
    sents = sentence_tokenizer.tokenize(txt)
    formal_output = ''
    for sentence in sents:
        formal_sentence = translate_short_sent(model, sentence)
        formal_output += ' ' + formal_sentence
    return formal_output

def load_config(config_file):
    with open(config_file, "r") as file:
        config = yaml.safe_load(file)
    return config




if __name__ == '__main__':
    
    #download or load files
    DEFAULT_CACHE_DIR = os.path.join(str(Path.home()), '.dadmatools', 'informal2formal')
    config = load_config('config.yml')
    file_urls = config['files'].values()
    download_dataset(file_urls, DEFAULT_CACHE_DIR, filename=None)
    
    # set assets files address
    verbs_csv_addr = os.path.join(DEFAULT_CACHE_DIR, 'verbs.csv')
    irregular_verbs_mapper = os.path.join(DEFAULT_CACHE_DIR, 'irregular_verb_mapper.csv')
    lm_addr = os.path.join(DEFAULT_CACHE_DIR,'3gram.bin')
    assets_file_addr = os.path.join(DEFAULT_CACHE_DIR,'assets.pkl')
    
    #test on a sample
    sentence_tokenizer = SentenceTokenizer()
    model = FormalityTransformer(asset_file_addr=assets_file_addr, 
                                 irregular_verbs_mapper_addr=irregular_verbs_mapper, verbs_csv_addr=verbs_csv_addr, lm_addr=lm_addr)
    print(translate(model, sentence_tokenizer, 'اینو میشه واسه تبدیل تموم جملات محاوره استفاده کرد اگه خواستین'))