Spaces:
Runtime error
Runtime error
Added type hinting and config file
Browse files- .gitignore +1 -1
- app.py +9 -4
- examples/examples_en.py +32 -1
- interfaces/{interface_sesgoEnFrases.py → interface_biasPhrase.py} +8 -2
- interfaces/interface_crowsPairs.py +8 -2
- language/.gitignore +1 -1
- language/{english.json → en.json} +50 -1
- modules/module_connection.py +8 -15
- modules/module_crowsPairs.py +6 -18
- modules/module_languageModel.py +6 -5
- modules/module_rankSents.py +7 -7
- tool.cfg +11 -0
.gitignore
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
__pycache__/
|
2 |
.env
|
3 |
-
|
|
|
1 |
__pycache__/
|
2 |
.env
|
3 |
+
logs_edia_lmodels_en/
|
app.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
# --- Imports libs ---
|
2 |
import gradio as gr
|
3 |
import pandas as pd
|
|
|
4 |
|
5 |
|
6 |
# --- Imports modules ---
|
@@ -8,14 +9,18 @@ from modules.module_languageModel import LanguageModel
|
|
8 |
|
9 |
|
10 |
# --- Imports interfaces ---
|
11 |
-
from interfaces.
|
12 |
from interfaces.interface_crowsPairs import interface as interface_crowsPairs
|
13 |
|
14 |
|
15 |
# --- Tool config ---
|
16 |
-
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
|
|
19 |
|
20 |
|
21 |
# --- Init classes ---
|
|
|
1 |
# --- Imports libs ---
|
2 |
import gradio as gr
|
3 |
import pandas as pd
|
4 |
+
import configparser
|
5 |
|
6 |
|
7 |
# --- Imports modules ---
|
|
|
9 |
|
10 |
|
11 |
# --- Imports interfaces ---
|
12 |
+
from interfaces.interface_biasPhrase import interface as interface_sesgoEnFrases
|
13 |
from interfaces.interface_crowsPairs import interface as interface_crowsPairs
|
14 |
|
15 |
|
16 |
# --- Tool config ---
|
17 |
+
cfg = configparser.ConfigParser()
|
18 |
+
cfg.read('tool.cfg')
|
19 |
+
|
20 |
+
LANGUAGE = cfg['INTERFACE']['language']
|
21 |
+
LANGUAGE_MODEL = cfg['LMODEL']['language_model']
|
22 |
+
AVAILABLE_LOGS = cfg['LOGS'].getboolean('available_logs')
|
23 |
+
|
24 |
|
25 |
|
26 |
# --- Init classes ---
|
examples/examples_en.py
CHANGED
@@ -1,4 +1,35 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
examples_sesgos_frases = [
|
3 |
["The * people are better basketball players.", "black, white"],
|
4 |
["Beethoven is a *.", "musician, dog"],
|
|
|
1 |
+
fem_words = 'woman,girl,her,mother,daughter,feminine'
|
2 |
+
masc_words = 'man,boy,him,father,son,masculine'
|
3 |
+
young_words = 'young,immature,kid,child'
|
4 |
+
old_words = 'old,mature,elderly,adult'
|
5 |
+
diagnose_words = 'education,cook,wash,sports,clothes,fight,nurse,win,leader,saw,nurse,surgeon,king,queen'
|
6 |
+
|
7 |
+
positive_money_words = 'save,economize,administer,manage,business,benefits'
|
8 |
+
negative_money_words = 'waste,squander'
|
9 |
+
diagnose_money = 'german,australian,argentinian,millionaire,rich,poor'
|
10 |
+
|
11 |
+
lazy_words = 'lazy, sluggish, slacker'
|
12 |
+
active_words = 'active, effort , worker'
|
13 |
+
|
14 |
+
examples1_explorar_sesgo_en_palabras = [
|
15 |
+
[fem_words, masc_words, diagnose_words],
|
16 |
+
[old_words, young_words, diagnose_words],
|
17 |
+
[positive_money_words, negative_money_words, diagnose_money],
|
18 |
+
[lazy_words, active_words, diagnose_money]
|
19 |
+
]
|
20 |
+
|
21 |
+
examples2_explorar_sesgo_en_palabras = [
|
22 |
+
[fem_words, masc_words, young_words, old_words, diagnose_words],
|
23 |
+
[lazy_words, active_words, positive_money_words, negative_money_words,diagnose_money],
|
24 |
+
]
|
25 |
+
|
26 |
+
|
27 |
+
examples_explorar_relaciones_entre_palabras = [
|
28 |
+
[diagnose_words, fem_words, masc_words, young_words, old_words],
|
29 |
+
[diagnose_money, lazy_words, active_words, positive_money_words, negative_money_words],
|
30 |
+
]
|
31 |
+
|
32 |
+
# Examples bias phrase
|
33 |
examples_sesgos_frases = [
|
34 |
["The * people are better basketball players.", "black, white"],
|
35 |
["Beethoven is a *.", "musician, dog"],
|
interfaces/{interface_sesgoEnFrases.py → interface_biasPhrase.py}
RENAMED
@@ -3,15 +3,21 @@ import pandas as pd
|
|
3 |
from tool_info import TOOL_INFO
|
4 |
from modules.module_logsManager import HuggingFaceDatasetSaver
|
5 |
from modules.module_connection import PhraseBiasExplorerConnector
|
6 |
-
|
7 |
|
8 |
|
9 |
def interface(
|
10 |
language_model: str,
|
11 |
available_logs: bool,
|
12 |
-
lang: str="
|
13 |
) -> gr.Blocks:
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
# --- Init logs ---
|
16 |
log_callback = HuggingFaceDatasetSaver(
|
17 |
available_logs=available_logs,
|
|
|
3 |
from tool_info import TOOL_INFO
|
4 |
from modules.module_logsManager import HuggingFaceDatasetSaver
|
5 |
from modules.module_connection import PhraseBiasExplorerConnector
|
6 |
+
|
7 |
|
8 |
|
9 |
def interface(
|
10 |
language_model: str,
|
11 |
available_logs: bool,
|
12 |
+
lang: str="es"
|
13 |
) -> gr.Blocks:
|
14 |
|
15 |
+
# -- Load examples --
|
16 |
+
if lang == 'es':
|
17 |
+
from examples.examples_es import examples_sesgos_frases
|
18 |
+
elif lang == 'en':
|
19 |
+
from examples.examples_en import examples_sesgos_frases
|
20 |
+
|
21 |
# --- Init logs ---
|
22 |
log_callback = HuggingFaceDatasetSaver(
|
23 |
available_logs=available_logs,
|
interfaces/interface_crowsPairs.py
CHANGED
@@ -3,15 +3,21 @@ import pandas as pd
|
|
3 |
from tool_info import TOOL_INFO
|
4 |
from modules.module_logsManager import HuggingFaceDatasetSaver
|
5 |
from modules.module_connection import CrowsPairsExplorerConnector
|
6 |
-
|
7 |
|
8 |
|
9 |
def interface(
|
10 |
language_model: str,
|
11 |
available_logs: bool,
|
12 |
-
lang: str="
|
13 |
) -> gr.Blocks:
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
# --- Init logs ---
|
16 |
log_callback = HuggingFaceDatasetSaver(
|
17 |
available_logs=available_logs,
|
|
|
3 |
from tool_info import TOOL_INFO
|
4 |
from modules.module_logsManager import HuggingFaceDatasetSaver
|
5 |
from modules.module_connection import CrowsPairsExplorerConnector
|
6 |
+
|
7 |
|
8 |
|
9 |
def interface(
|
10 |
language_model: str,
|
11 |
available_logs: bool,
|
12 |
+
lang: str="es"
|
13 |
) -> gr.Blocks:
|
14 |
|
15 |
+
# -- Load examples --
|
16 |
+
if lang == 'es':
|
17 |
+
from examples.examples_es import examples_crows_pairs
|
18 |
+
elif lang == 'en':
|
19 |
+
from examples.examples_en import examples_crows_pairs
|
20 |
+
|
21 |
# --- Init logs ---
|
22 |
log_callback = HuggingFaceDatasetSaver(
|
23 |
available_logs=available_logs,
|
language/.gitignore
CHANGED
@@ -1,2 +1,2 @@
|
|
1 |
__pycache__
|
2 |
-
|
|
|
1 |
__pycache__
|
2 |
+
es.json
|
language/{english.json → en.json}
RENAMED
@@ -1,8 +1,43 @@
|
|
1 |
{
|
2 |
"app": {
|
|
|
|
|
|
|
3 |
"phraseExplorer": "Phrase bias",
|
4 |
"crowsPairsExplorer": "Crows-Pairs"
|
5 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
"PhraseExplorer_interface": {
|
7 |
"step1": "1. Enter a sentence",
|
8 |
"step2": "2. Enter words of interest (Optional)",
|
@@ -12,7 +47,7 @@
|
|
12 |
"placeholder": "Use * to mask the word of interest."
|
13 |
},
|
14 |
"wordList": {
|
15 |
-
"title": "
|
16 |
"placeholder": "The words in the list must be comma separated"
|
17 |
},
|
18 |
"bannedWordList": {
|
@@ -26,6 +61,20 @@
|
|
26 |
"plot": "Display of proportions",
|
27 |
"examples": "Examples"
|
28 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
"CrowsPairs_interface": {
|
30 |
"title": "1. Enter sentences to compare",
|
31 |
"sent0": "Sentence Nº 1 (*)",
|
|
|
1 |
{
|
2 |
"app": {
|
3 |
+
"wordExplorer": "Word explorer",
|
4 |
+
"biasWordExplorer": "Word bias",
|
5 |
+
"dataExplorer": "Data",
|
6 |
"phraseExplorer": "Phrase bias",
|
7 |
"crowsPairsExplorer": "Crows-Pairs"
|
8 |
},
|
9 |
+
"WordExplorer_interface": {
|
10 |
+
"title": "Write some words to visualize their related ones",
|
11 |
+
"wordList1": "Word list 1",
|
12 |
+
"wordList2": "Word list 2",
|
13 |
+
"wordList3": "Word list 3",
|
14 |
+
"wordList4": "Word list 4",
|
15 |
+
"wordListToDiagnose": "List of words to be diagnosed",
|
16 |
+
"plotNeighbours": {
|
17 |
+
"title": "Plot neighbours words",
|
18 |
+
"quantity": "Quantity"
|
19 |
+
},
|
20 |
+
"options": {
|
21 |
+
"font-size": "Font size",
|
22 |
+
"transparency": "Transparency"
|
23 |
+
},
|
24 |
+
"plot_button": "Plot in the space!",
|
25 |
+
"examples": "Examples"
|
26 |
+
},
|
27 |
+
"BiasWordExplorer_interface": {
|
28 |
+
"step1": "1. Write comma separated words to be diagnosed",
|
29 |
+
"step2&2Spaces": "2. For plotting 2 spaces, fill in the following lists:",
|
30 |
+
"step2&4Spaces": "2. For plotting 4 spaces, also fill in the following lists:",
|
31 |
+
"plot2SpacesButton": "Plot 2 stereotypes!",
|
32 |
+
"plot4SpacesButton": "Plot 4 stereotypes!",
|
33 |
+
"wordList1": "Word list 1",
|
34 |
+
"wordList2": "Word list 2",
|
35 |
+
"wordList3": "Word list 3",
|
36 |
+
"wordList4": "Word list 4",
|
37 |
+
"wordListToDiagnose": "List of words to be diagnosed",
|
38 |
+
"examples2Spaces": "Examples in 2 spaces",
|
39 |
+
"examples4Spaces": "Examples in 4 spaces"
|
40 |
+
},
|
41 |
"PhraseExplorer_interface": {
|
42 |
"step1": "1. Enter a sentence",
|
43 |
"step2": "2. Enter words of interest (Optional)",
|
|
|
47 |
"placeholder": "Use * to mask the word of interest."
|
48 |
},
|
49 |
"wordList": {
|
50 |
+
"title": "Word List",
|
51 |
"placeholder": "The words in the list must be comma separated"
|
52 |
},
|
53 |
"bannedWordList": {
|
|
|
61 |
"plot": "Display of proportions",
|
62 |
"examples": "Examples"
|
63 |
},
|
64 |
+
"DataExplorer_interface": {
|
65 |
+
"step1": "1. Enter a word of interest",
|
66 |
+
"step2": "2. Select maximum number of contexts to retrieve",
|
67 |
+
"step3": "3. Select sets of interest",
|
68 |
+
"inputWord": {
|
69 |
+
"title": "Word",
|
70 |
+
"placeholder": "Enter the word ..."
|
71 |
+
},
|
72 |
+
"wordInfoButton": "Get word information",
|
73 |
+
"wordContextButton": "Search contexts",
|
74 |
+
"wordDistributionTitle": "Word distribution in vocabulary",
|
75 |
+
"frequencyPerSetTitle": "Frequencies of occurrence per set",
|
76 |
+
"contextList": "Context list"
|
77 |
+
},
|
78 |
"CrowsPairs_interface": {
|
79 |
"title": "1. Enter sentences to compare",
|
80 |
"sent0": "Sentence Nº 1 (*)",
|
modules/module_connection.py
CHANGED
@@ -1,15 +1,14 @@
|
|
|
|
1 |
from modules.module_rankSents import RankSents
|
2 |
from modules.module_crowsPairs import CrowsPairs
|
3 |
from typing import List, Tuple
|
4 |
-
from abc import ABC
|
5 |
-
|
6 |
|
7 |
class Connector(ABC):
|
8 |
def parse_word(
|
9 |
self,
|
10 |
word: str
|
11 |
) -> str:
|
12 |
-
|
13 |
return word.lower().strip()
|
14 |
|
15 |
def parse_words(
|
@@ -20,6 +19,7 @@ class Connector(ABC):
|
|
20 |
words = array_in_string.strip()
|
21 |
if not words:
|
22 |
return []
|
|
|
23 |
words = [
|
24 |
self.parse_word(word)
|
25 |
for word in words.split(',') if word.strip() != ''
|
@@ -31,11 +31,9 @@ class Connector(ABC):
|
|
31 |
err: str
|
32 |
) -> str:
|
33 |
|
34 |
-
# Mod
|
35 |
if err:
|
36 |
err = "<center><h3>" + err + "</h3></center>"
|
37 |
-
return err
|
38 |
-
|
39 |
|
40 |
class PhraseBiasExplorerConnector(Connector):
|
41 |
def __init__(
|
@@ -43,13 +41,8 @@ class PhraseBiasExplorerConnector(Connector):
|
|
43 |
**kwargs
|
44 |
) -> None:
|
45 |
|
46 |
-
|
47 |
-
if 'language_model' in kwargs:
|
48 |
language_model = kwargs.get('language_model')
|
49 |
-
else:
|
50 |
-
raise KeyError
|
51 |
-
|
52 |
-
if 'lang' in kwargs:
|
53 |
lang = kwargs.get('lang')
|
54 |
else:
|
55 |
raise KeyError
|
@@ -90,7 +83,6 @@ class PhraseBiasExplorerConnector(Connector):
|
|
90 |
all_plls_scores = self.phrase_bias_explorer.Label.compute(all_plls_scores)
|
91 |
return self.process_error(err), all_plls_scores, ""
|
92 |
|
93 |
-
|
94 |
class CrowsPairsExplorerConnector(Connector):
|
95 |
def __init__(
|
96 |
self,
|
@@ -116,15 +108,16 @@ class CrowsPairsExplorerConnector(Connector):
|
|
116 |
sent5: str
|
117 |
) -> Tuple:
|
118 |
|
|
|
119 |
err = self.crows_pairs_explorer.errorChecking(
|
120 |
-
|
121 |
)
|
122 |
|
123 |
if err:
|
124 |
return self.process_error(err), "", ""
|
125 |
|
126 |
all_plls_scores = self.crows_pairs_explorer.rank(
|
127 |
-
|
128 |
)
|
129 |
|
130 |
all_plls_scores = self.crows_pairs_explorer.Label.compute(all_plls_scores)
|
|
|
1 |
+
from abc import ABC
|
2 |
from modules.module_rankSents import RankSents
|
3 |
from modules.module_crowsPairs import CrowsPairs
|
4 |
from typing import List, Tuple
|
|
|
|
|
5 |
|
6 |
class Connector(ABC):
|
7 |
def parse_word(
|
8 |
self,
|
9 |
word: str
|
10 |
) -> str:
|
11 |
+
|
12 |
return word.lower().strip()
|
13 |
|
14 |
def parse_words(
|
|
|
19 |
words = array_in_string.strip()
|
20 |
if not words:
|
21 |
return []
|
22 |
+
|
23 |
words = [
|
24 |
self.parse_word(word)
|
25 |
for word in words.split(',') if word.strip() != ''
|
|
|
31 |
err: str
|
32 |
) -> str:
|
33 |
|
|
|
34 |
if err:
|
35 |
err = "<center><h3>" + err + "</h3></center>"
|
36 |
+
return err
|
|
|
37 |
|
38 |
class PhraseBiasExplorerConnector(Connector):
|
39 |
def __init__(
|
|
|
41 |
**kwargs
|
42 |
) -> None:
|
43 |
|
44 |
+
if 'language_model' in kwargs and 'lang' in kwargs:
|
|
|
45 |
language_model = kwargs.get('language_model')
|
|
|
|
|
|
|
|
|
46 |
lang = kwargs.get('lang')
|
47 |
else:
|
48 |
raise KeyError
|
|
|
83 |
all_plls_scores = self.phrase_bias_explorer.Label.compute(all_plls_scores)
|
84 |
return self.process_error(err), all_plls_scores, ""
|
85 |
|
|
|
86 |
class CrowsPairsExplorerConnector(Connector):
|
87 |
def __init__(
|
88 |
self,
|
|
|
108 |
sent5: str
|
109 |
) -> Tuple:
|
110 |
|
111 |
+
sent_list = [sent0, sent1, sent2, sent3, sent4, sent5]
|
112 |
err = self.crows_pairs_explorer.errorChecking(
|
113 |
+
sent_list
|
114 |
)
|
115 |
|
116 |
if err:
|
117 |
return self.process_error(err), "", ""
|
118 |
|
119 |
all_plls_scores = self.crows_pairs_explorer.rank(
|
120 |
+
sent_list
|
121 |
)
|
122 |
|
123 |
all_plls_scores = self.crows_pairs_explorer.Label.compute(all_plls_scores)
|
modules/module_crowsPairs.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
from modules.module_customPllLabel import CustomPllLabel
|
2 |
from modules.module_pllScore import PllScore
|
3 |
-
from typing import Dict
|
4 |
|
5 |
class CrowsPairs:
|
6 |
def __init__(
|
@@ -15,19 +15,13 @@ class CrowsPairs:
|
|
15 |
|
16 |
def errorChecking(
|
17 |
self,
|
18 |
-
|
19 |
-
sent1: str,
|
20 |
-
sent2: str,
|
21 |
-
sent3: str,
|
22 |
-
sent4: str,
|
23 |
-
sent5: str
|
24 |
) -> str:
|
25 |
|
26 |
out_msj = ""
|
27 |
-
all_sents = [sent0, sent1, sent2, sent3, sent4, sent5]
|
28 |
|
29 |
mandatory_sents = [0,1]
|
30 |
-
for sent_id, sent in enumerate(
|
31 |
c_sent = sent.strip()
|
32 |
if c_sent:
|
33 |
if not self.pllScore.sentIsCorrect(c_sent):
|
@@ -42,21 +36,15 @@ class CrowsPairs:
|
|
42 |
|
43 |
def rank(
|
44 |
self,
|
45 |
-
|
46 |
-
sent1: str,
|
47 |
-
sent2: str,
|
48 |
-
sent3: str,
|
49 |
-
sent4: str,
|
50 |
-
sent5: str
|
51 |
) -> Dict[str, float]:
|
52 |
|
53 |
-
err = self.errorChecking(
|
54 |
if err:
|
55 |
raise Exception(err)
|
56 |
|
57 |
-
all_sents = [sent0, sent1, sent2, sent3, sent4, sent5]
|
58 |
all_plls_scores = {}
|
59 |
-
for sent in
|
60 |
if sent:
|
61 |
all_plls_scores[sent] = self.pllScore.compute(sent)
|
62 |
|
|
|
1 |
from modules.module_customPllLabel import CustomPllLabel
|
2 |
from modules.module_pllScore import PllScore
|
3 |
+
from typing import Dict, List
|
4 |
|
5 |
class CrowsPairs:
|
6 |
def __init__(
|
|
|
15 |
|
16 |
def errorChecking(
|
17 |
self,
|
18 |
+
sent_list: List[str],
|
|
|
|
|
|
|
|
|
|
|
19 |
) -> str:
|
20 |
|
21 |
out_msj = ""
|
|
|
22 |
|
23 |
mandatory_sents = [0,1]
|
24 |
+
for sent_id, sent in enumerate(sent_list):
|
25 |
c_sent = sent.strip()
|
26 |
if c_sent:
|
27 |
if not self.pllScore.sentIsCorrect(c_sent):
|
|
|
36 |
|
37 |
def rank(
|
38 |
self,
|
39 |
+
sent_list: List[str],
|
|
|
|
|
|
|
|
|
|
|
40 |
) -> Dict[str, float]:
|
41 |
|
42 |
+
err = self.errorChecking(sent_list)
|
43 |
if err:
|
44 |
raise Exception(err)
|
45 |
|
|
|
46 |
all_plls_scores = {}
|
47 |
+
for sent in sent_list:
|
48 |
if sent:
|
49 |
all_plls_scores[sent] = self.pllScore.compute(sent)
|
50 |
|
modules/module_languageModel.py
CHANGED
@@ -1,22 +1,23 @@
|
|
1 |
-
# --- Imports libs ---
|
2 |
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
3 |
|
4 |
class LanguageModel:
|
5 |
def __init__(
|
6 |
self,
|
7 |
-
model_name
|
8 |
) -> None:
|
9 |
-
|
10 |
print("Downloading language model...")
|
11 |
self.__tokenizer = AutoTokenizer.from_pretrained(model_name)
|
12 |
self.__model = AutoModelForMaskedLM.from_pretrained(model_name)
|
13 |
|
14 |
def initTokenizer(
|
15 |
self
|
16 |
-
):
|
|
|
17 |
return self.__tokenizer
|
18 |
|
19 |
def initModel(
|
20 |
self
|
21 |
-
):
|
|
|
22 |
return self.__model
|
|
|
|
|
1 |
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
2 |
|
3 |
class LanguageModel:
|
4 |
def __init__(
|
5 |
self,
|
6 |
+
model_name
|
7 |
) -> None:
|
8 |
+
|
9 |
print("Downloading language model...")
|
10 |
self.__tokenizer = AutoTokenizer.from_pretrained(model_name)
|
11 |
self.__model = AutoModelForMaskedLM.from_pretrained(model_name)
|
12 |
|
13 |
def initTokenizer(
|
14 |
self
|
15 |
+
) -> AutoTokenizer:
|
16 |
+
|
17 |
return self.__tokenizer
|
18 |
|
19 |
def initModel(
|
20 |
self
|
21 |
+
) -> AutoModelForMaskedLM:
|
22 |
+
|
23 |
return self.__model
|
modules/module_rankSents.py
CHANGED
@@ -21,7 +21,7 @@ class RankSents:
|
|
21 |
)
|
22 |
self.softmax = torch.nn.Softmax(dim=-1)
|
23 |
|
24 |
-
if lang == "
|
25 |
self.articles = [
|
26 |
'un','una','unos','unas','el','los','la','las','lo'
|
27 |
]
|
@@ -32,7 +32,7 @@ class RankSents:
|
|
32 |
'y','o','ni','que','pero','si'
|
33 |
]
|
34 |
|
35 |
-
elif lang == "
|
36 |
self.articles = [
|
37 |
'a','an', 'the'
|
38 |
]
|
@@ -135,11 +135,11 @@ class RankSents:
|
|
135 |
|
136 |
def rank(self,
|
137 |
sent: str,
|
138 |
-
word_list: List[str],
|
139 |
-
banned_word_list: List[str],
|
140 |
-
articles: bool,
|
141 |
-
prepositions: bool,
|
142 |
-
conjunctions: bool
|
143 |
) -> Dict[str, float]:
|
144 |
|
145 |
err = self.errorChecking(sent)
|
|
|
21 |
)
|
22 |
self.softmax = torch.nn.Softmax(dim=-1)
|
23 |
|
24 |
+
if lang == "es":
|
25 |
self.articles = [
|
26 |
'un','una','unos','unas','el','los','la','las','lo'
|
27 |
]
|
|
|
32 |
'y','o','ni','que','pero','si'
|
33 |
]
|
34 |
|
35 |
+
elif lang == "en":
|
36 |
self.articles = [
|
37 |
'a','an', 'the'
|
38 |
]
|
|
|
135 |
|
136 |
def rank(self,
|
137 |
sent: str,
|
138 |
+
word_list: List[str]=[],
|
139 |
+
banned_word_list: List[str]=[],
|
140 |
+
articles: bool=False,
|
141 |
+
prepositions: bool=False,
|
142 |
+
conjunctions: bool=False
|
143 |
) -> Dict[str, float]:
|
144 |
|
145 |
err = self.errorChecking(sent)
|
tool.cfg
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[INTERFACE]
|
2 |
+
# ['es' | 'en']
|
3 |
+
language = en
|
4 |
+
|
5 |
+
[LMODEL]
|
6 |
+
# [bert-base-uncased | dccuchile/bert-base-spanish-wwm-uncased]
|
7 |
+
language_model = bert-base-uncased
|
8 |
+
|
9 |
+
[LOGS]
|
10 |
+
# [True | False]
|
11 |
+
available_logs = True
|