File size: 3,074 Bytes
0a7346e
b546526
 
 
 
 
 
 
 
 
0a7346e
b546526
 
 
 
 
 
 
 
 
 
0a7346e
b546526
 
 
 
 
 
 
 
 
 
 
 
 
0a7346e
b546526
 
 
 
 
 
 
524b9ae
 
 
b546526
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
524b9ae
 
b546526
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a7346e
b546526
0a7346e
b546526
 
 
 
 
 
0a7346e
b546526
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from abc import ABC
from modules.module_rankSents import RankSents
from modules.module_crowsPairs import CrowsPairs
from typing import List, Tuple

class Connector(ABC):
    def parse_word(
        self, 
        word: str
    ) -> str:
        
        return word.lower().strip()

    def parse_words(
        self, 
        array_in_string: str
    ) -> List[str]:

        words = array_in_string.strip()
        if not words:
            return []

        words = [
            self.parse_word(word) 
            for word in words.split(',') if word.strip() != ''
        ]
        return words

    def process_error(
        self, 
        err: str
    ) -> str:

        if err:
            err = "<center><h3>" + err + "</h3></center>"
        return err    

class PhraseBiasExplorerConnector(Connector):
    def __init__(
        self, 
        **kwargs
    ) -> None:

        language_model = kwargs.get('language_model', None)
        lang =  kwargs.get('lang', None)
        if language_model is None or lang is None:
            raise KeyError

        self.phrase_bias_explorer = RankSents(
            language_model=language_model,
            lang=lang
        )

    def rank_sentence_options(
        self,
        sent: str,
        word_list: str,
        banned_word_list: str,
        useArticles: bool,
        usePrepositions: bool,
        useConjunctions: bool
    ) -> Tuple:

        sent = " ".join(sent.strip().replace("*"," * ").split())

        err = self.phrase_bias_explorer.errorChecking(sent)
        if err:
            return self.process_error(err), "", ""

        word_list = self.parse_words(word_list)
        banned_word_list = self.parse_words(banned_word_list)

        all_plls_scores = self.phrase_bias_explorer.rank(
            sent, 
            word_list, 
            banned_word_list, 
            useArticles, 
            usePrepositions, 
            useConjunctions
        )
        
        all_plls_scores = self.phrase_bias_explorer.Label.compute(all_plls_scores)
        return self.process_error(err), all_plls_scores, ""

class CrowsPairsExplorerConnector(Connector):
    def __init__(
        self, 
        **kwargs
    ) -> None:

        language_model = kwargs.get('language_model', None)
        if language_model is None:
            raise KeyError
        
        self.crows_pairs_explorer = CrowsPairs(
            language_model=language_model
        )

    def compare_sentences(
        self,
        sent0: str,
        sent1: str,
        sent2: str,
        sent3: str,
        sent4: str,
        sent5: str
    ) -> Tuple:

        sent_list = [sent0, sent1, sent2, sent3, sent4, sent5]
        err = self.crows_pairs_explorer.errorChecking(
            sent_list
        )

        if err:
            return self.process_error(err), "", ""

        all_plls_scores = self.crows_pairs_explorer.rank(
            sent_list
        )
        
        all_plls_scores = self.crows_pairs_explorer.Label.compute(all_plls_scores)
        return self.process_error(err), all_plls_scores, ""