File size: 1,393 Bytes
b546526
 
0a7346e
b546526
 
 
 
 
 
 
 
 
 
 
 
 
 
0a7346e
b546526
 
 
 
 
0a7346e
b546526
 
 
 
 
 
 
 
 
 
 
 
 
 
0a7346e
b546526
 
0a7346e
b546526
 
 
 
0a7346e
b546526
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from modules.module_customPllLabel import CustomPllLabel
from modules.module_pllScore import PllScore
from typing import Dict, List

class CrowsPairs:
    def __init__(
        self, 
        language_model # LanguageModel class instance
    ) -> None:

        self.Label = CustomPllLabel()
        self.pllScore = PllScore(
            language_model=language_model
        )

    def errorChecking(
        self, 
        sent_list: List[str],
    ) -> str:

        out_msj = ""

        mandatory_sents = [0,1]
        for sent_id, sent in enumerate(sent_list):
            c_sent = sent.strip()
            if c_sent:
                if not self.pllScore.sentIsCorrect(c_sent):
                    out_msj = f"Error: The sentence Nº {sent_id+1} does not have the correct format!."
                    break
            else:
                if sent_id in mandatory_sents:
                    out_msj = f"Error: The sentence Nº{sent_id+1} can not be empty!"
                    break
        
        return out_msj

    def rank(
        self, 
        sent_list: List[str],
    ) -> Dict[str, float]:

        err = self.errorChecking(sent_list)
        if err:
            raise Exception(err)
        
        all_plls_scores = {}
        for sent in sent_list:
            if sent:
                all_plls_scores[sent] = self.pllScore.compute(sent)

        return all_plls_scores