Spaces:
Runtime error
Runtime error
Improvement in the display of the graph axes labels. Generalization of rankSent class. Minor fixes.
Browse files- modules/module_BiasExplorer.py +22 -10
- modules/module_connection.py +12 -10
- modules/module_rankSents.py +29 -25
- modules/utils.py +64 -3
modules/module_BiasExplorer.py
CHANGED
@@ -5,7 +5,7 @@ import seaborn as sns
|
|
5 |
import matplotlib.pyplot as plt
|
6 |
from sklearn.decomposition import PCA
|
7 |
from typing import List, Dict, Tuple, Optional, Any
|
8 |
-
from modules.utils import normalize, cosine_similarity, project_params, take_two_sides_extreme_sorted
|
9 |
|
10 |
__all__ = ['WordBiasExplorer', 'WEBiasExplorer2Spaces', 'WEBiasExplorer4Spaces']
|
11 |
|
@@ -371,9 +371,14 @@ class WEBiasExplorer2Spaces(WordBiasExplorer):
|
|
371 |
plt.xticks(np.arange(-most_extream_projection,
|
372 |
most_extream_projection + axis_projection_step,
|
373 |
axis_projection_step))
|
374 |
-
|
375 |
-
|
376 |
-
|
|
|
|
|
|
|
|
|
|
|
377 |
|
378 |
plt.xlabel(xlabel)
|
379 |
plt.ylabel('Words')
|
@@ -515,13 +520,20 @@ class WEBiasExplorer4Spaces(WordBiasExplorer):
|
|
515 |
for _, row in (projections_df.iterrows()):
|
516 |
ax.annotate(
|
517 |
row['word'], (row['projection_x'], row['projection_y']))
|
518 |
-
x_label = 'β {} {} {} β'.format(name_left,
|
519 |
-
' ' * 20,
|
520 |
-
name_right)
|
521 |
|
522 |
-
|
523 |
-
|
524 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
525 |
|
526 |
plt.xlabel(x_label)
|
527 |
ax.xaxis.set_label_position('bottom')
|
|
|
5 |
import matplotlib.pyplot as plt
|
6 |
from sklearn.decomposition import PCA
|
7 |
from typing import List, Dict, Tuple, Optional, Any
|
8 |
+
from modules.utils import normalize, cosine_similarity, project_params, take_two_sides_extreme_sorted, axes_labels_format
|
9 |
|
10 |
__all__ = ['WordBiasExplorer', 'WEBiasExplorer2Spaces', 'WEBiasExplorer4Spaces']
|
11 |
|
|
|
371 |
plt.xticks(np.arange(-most_extream_projection,
|
372 |
most_extream_projection + axis_projection_step,
|
373 |
axis_projection_step))
|
374 |
+
|
375 |
+
|
376 |
+
xlabel = axes_labels_format(
|
377 |
+
left=self.negative_end,
|
378 |
+
right=self.positive_end,
|
379 |
+
sep=' ' * 20,
|
380 |
+
word_wrap=3
|
381 |
+
)
|
382 |
|
383 |
plt.xlabel(xlabel)
|
384 |
plt.ylabel('Words')
|
|
|
520 |
for _, row in (projections_df.iterrows()):
|
521 |
ax.annotate(
|
522 |
row['word'], (row['projection_x'], row['projection_y']))
|
|
|
|
|
|
|
523 |
|
524 |
+
|
525 |
+
x_label = axes_labels_format(
|
526 |
+
left=name_left,
|
527 |
+
right=name_right,
|
528 |
+
sep=' ' * 20,
|
529 |
+
word_wrap=3
|
530 |
+
)
|
531 |
+
y_label = axes_labels_format(
|
532 |
+
left=name_top,
|
533 |
+
right=name_bottom,
|
534 |
+
sep=' ' * 20,
|
535 |
+
word_wrap=3
|
536 |
+
)
|
537 |
|
538 |
plt.xlabel(x_label)
|
539 |
ax.xaxis.set_label_position('bottom')
|
modules/module_connection.py
CHANGED
@@ -422,11 +422,12 @@ class PhraseBiasExplorerConnector(Connector):
|
|
422 |
def rank_sentence_options(
|
423 |
self,
|
424 |
sent: str,
|
425 |
-
|
426 |
banned_word_list: str,
|
427 |
-
|
428 |
-
|
429 |
-
|
|
|
430 |
) -> Tuple:
|
431 |
|
432 |
sent = " ".join(sent.strip().replace("*"," * ").split())
|
@@ -435,7 +436,7 @@ class PhraseBiasExplorerConnector(Connector):
|
|
435 |
if err:
|
436 |
return err, "", ""
|
437 |
|
438 |
-
|
439 |
banned_word_list = self.parse_words(banned_word_list)
|
440 |
|
441 |
# Save inputs in logs file
|
@@ -443,16 +444,17 @@ class PhraseBiasExplorerConnector(Connector):
|
|
443 |
self.logs_file_name,
|
444 |
self.headers,
|
445 |
sent,
|
446 |
-
|
447 |
)
|
448 |
|
449 |
all_plls_scores = self.phrase_bias_explorer.rank(
|
450 |
sent,
|
451 |
-
|
452 |
banned_word_list,
|
453 |
-
|
454 |
-
|
455 |
-
|
|
|
456 |
)
|
457 |
|
458 |
all_plls_scores = self.phrase_bias_explorer.Label.compute(all_plls_scores)
|
|
|
422 |
def rank_sentence_options(
|
423 |
self,
|
424 |
sent: str,
|
425 |
+
interest_word_list: str,
|
426 |
banned_word_list: str,
|
427 |
+
exclude_articles: bool,
|
428 |
+
exclude_prepositions: bool,
|
429 |
+
exclude_conjunctions: bool,
|
430 |
+
n_predictions: int=5
|
431 |
) -> Tuple:
|
432 |
|
433 |
sent = " ".join(sent.strip().replace("*"," * ").split())
|
|
|
436 |
if err:
|
437 |
return err, "", ""
|
438 |
|
439 |
+
interest_word_list = self.parse_words(interest_word_list)
|
440 |
banned_word_list = self.parse_words(banned_word_list)
|
441 |
|
442 |
# Save inputs in logs file
|
|
|
444 |
self.logs_file_name,
|
445 |
self.headers,
|
446 |
sent,
|
447 |
+
interest_word_list
|
448 |
)
|
449 |
|
450 |
all_plls_scores = self.phrase_bias_explorer.rank(
|
451 |
sent,
|
452 |
+
interest_word_list,
|
453 |
banned_word_list,
|
454 |
+
exclude_articles,
|
455 |
+
exclude_prepositions,
|
456 |
+
exclude_conjunctions,
|
457 |
+
n_predictions
|
458 |
)
|
459 |
|
460 |
all_plls_scores = self.phrase_bias_explorer.Label.compute(all_plls_scores)
|
modules/module_rankSents.py
CHANGED
@@ -66,13 +66,14 @@ class RankSents:
|
|
66 |
|
67 |
return self.errorManager.process(out_msj)
|
68 |
|
69 |
-
def
|
70 |
self,
|
|
|
71 |
sent: str,
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
) -> List[str]:
|
77 |
|
78 |
sent_masked = sent.replace("*", self.tokenizer.mask_token)
|
@@ -80,7 +81,8 @@ class RankSents:
|
|
80 |
sent_masked,
|
81 |
add_special_tokens=True,
|
82 |
return_tensors='pt',
|
83 |
-
return_attention_mask=True,
|
|
|
84 |
)
|
85 |
|
86 |
tk_position_mask = torch.where(inputs['input_ids'][0] == self.tokenizer.mask_token_id)[0].item()
|
@@ -94,26 +96,26 @@ class RankSents:
|
|
94 |
probabilities = outputs[tk_position_mask]
|
95 |
first_tk_id = torch.argsort(probabilities, descending=True)
|
96 |
|
97 |
-
|
98 |
for tk_id in first_tk_id:
|
99 |
tk_string = self.tokenizer.decode([tk_id])
|
100 |
|
101 |
-
tk_is_banned = tk_string in
|
102 |
tk_is_punctuation = not tk_string.isalnum()
|
103 |
tk_is_substring = tk_string.startswith("##")
|
104 |
tk_is_special = (tk_string in self.tokenizer.all_special_tokens)
|
105 |
|
106 |
-
if
|
107 |
tk_is_article = tk_string in self.articles
|
108 |
else:
|
109 |
tk_is_article = False
|
110 |
|
111 |
-
if
|
112 |
tk_is_prepositions = tk_string in self.prepositions
|
113 |
else:
|
114 |
tk_is_prepositions = False
|
115 |
|
116 |
-
if
|
117 |
tk_is_conjunctions = tk_string in self.conjunctions
|
118 |
else:
|
119 |
tk_is_conjunctions = False
|
@@ -128,39 +130,41 @@ class RankSents:
|
|
128 |
tk_is_conjunctions
|
129 |
])
|
130 |
|
131 |
-
if predictions_is_dessire and len(
|
132 |
-
|
133 |
|
134 |
-
elif len(
|
135 |
break
|
136 |
|
137 |
-
return
|
138 |
|
139 |
def rank(self,
|
140 |
sent: str,
|
141 |
-
|
142 |
banned_word_list: List[str]=[],
|
143 |
-
|
144 |
-
|
145 |
-
|
|
|
146 |
) -> Dict[str, float]:
|
147 |
|
148 |
err = self.errorChecking(sent)
|
149 |
if err:
|
150 |
raise Exception(err)
|
151 |
|
152 |
-
if not
|
153 |
-
|
|
|
154 |
sent,
|
155 |
banned_word_list,
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
)
|
160 |
|
161 |
sent_list = []
|
162 |
sent_list2print = []
|
163 |
-
for word in
|
164 |
sent_list.append(sent.replace("*", "<"+word+">"))
|
165 |
sent_list2print.append(sent.replace("*", "<"+word+">"))
|
166 |
|
|
|
66 |
|
67 |
return self.errorManager.process(out_msj)
|
68 |
|
69 |
+
def getTopPredictions(
|
70 |
self,
|
71 |
+
n: int,
|
72 |
sent: str,
|
73 |
+
banned_word_list: List[str],
|
74 |
+
exclude_articles: bool,
|
75 |
+
exclude_prepositions: bool,
|
76 |
+
exclude_conjunctions: bool,
|
77 |
) -> List[str]:
|
78 |
|
79 |
sent_masked = sent.replace("*", self.tokenizer.mask_token)
|
|
|
81 |
sent_masked,
|
82 |
add_special_tokens=True,
|
83 |
return_tensors='pt',
|
84 |
+
return_attention_mask=True,
|
85 |
+
truncation=True
|
86 |
)
|
87 |
|
88 |
tk_position_mask = torch.where(inputs['input_ids'][0] == self.tokenizer.mask_token_id)[0].item()
|
|
|
96 |
probabilities = outputs[tk_position_mask]
|
97 |
first_tk_id = torch.argsort(probabilities, descending=True)
|
98 |
|
99 |
+
top_tks_pred = []
|
100 |
for tk_id in first_tk_id:
|
101 |
tk_string = self.tokenizer.decode([tk_id])
|
102 |
|
103 |
+
tk_is_banned = tk_string in banned_word_list
|
104 |
tk_is_punctuation = not tk_string.isalnum()
|
105 |
tk_is_substring = tk_string.startswith("##")
|
106 |
tk_is_special = (tk_string in self.tokenizer.all_special_tokens)
|
107 |
|
108 |
+
if exclude_articles:
|
109 |
tk_is_article = tk_string in self.articles
|
110 |
else:
|
111 |
tk_is_article = False
|
112 |
|
113 |
+
if exclude_prepositions:
|
114 |
tk_is_prepositions = tk_string in self.prepositions
|
115 |
else:
|
116 |
tk_is_prepositions = False
|
117 |
|
118 |
+
if exclude_conjunctions:
|
119 |
tk_is_conjunctions = tk_string in self.conjunctions
|
120 |
else:
|
121 |
tk_is_conjunctions = False
|
|
|
130 |
tk_is_conjunctions
|
131 |
])
|
132 |
|
133 |
+
if predictions_is_dessire and len(top_tks_pred) < n:
|
134 |
+
top_tks_pred.append(tk_string)
|
135 |
|
136 |
+
elif len(top_tks_pred) >= n:
|
137 |
break
|
138 |
|
139 |
+
return top_tks_pred
|
140 |
|
141 |
def rank(self,
|
142 |
sent: str,
|
143 |
+
interest_word_list: List[str]=[],
|
144 |
banned_word_list: List[str]=[],
|
145 |
+
exclude_articles: bool=False,
|
146 |
+
exclude_prepositions: bool=False,
|
147 |
+
exclude_conjunctions: bool=False,
|
148 |
+
n_predictions: int=5
|
149 |
) -> Dict[str, float]:
|
150 |
|
151 |
err = self.errorChecking(sent)
|
152 |
if err:
|
153 |
raise Exception(err)
|
154 |
|
155 |
+
if not interest_word_list:
|
156 |
+
interest_word_list = self.getTopPredictions(
|
157 |
+
n_predictions,
|
158 |
sent,
|
159 |
banned_word_list,
|
160 |
+
exclude_articles,
|
161 |
+
exclude_prepositions,
|
162 |
+
exclude_conjunctions
|
163 |
)
|
164 |
|
165 |
sent_list = []
|
166 |
sent_list2print = []
|
167 |
+
for word in interest_word_list:
|
168 |
sent_list.append(sent.replace("*", "<"+word+">"))
|
169 |
sent_list2print.append(sent.replace("*", "<"+word+">"))
|
170 |
|
modules/utils.py
CHANGED
@@ -1,13 +1,15 @@
|
|
1 |
import numpy as np
|
2 |
import pandas as pd
|
3 |
-
from datetime import datetime
|
4 |
import pytz
|
|
|
|
|
|
|
5 |
|
6 |
|
7 |
class DateLogs:
|
8 |
def __init__(
|
9 |
self,
|
10 |
-
zone: str="America/Argentina/Cordoba"
|
11 |
) -> None:
|
12 |
|
13 |
self.time_zone = pytz.timezone(zone)
|
@@ -80,4 +82,63 @@ def cosine_similarity(
|
|
80 |
v_norm = np.linalg.norm(v)
|
81 |
u_norm = np.linalg.norm(u)
|
82 |
similarity = v @ u / (v_norm * u_norm)
|
83 |
-
return similarity
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import numpy as np
|
2 |
import pandas as pd
|
|
|
3 |
import pytz
|
4 |
+
from datetime import datetime
|
5 |
+
from typing import List
|
6 |
+
|
7 |
|
8 |
|
9 |
class DateLogs:
|
10 |
def __init__(
|
11 |
self,
|
12 |
+
zone: str = "America/Argentina/Cordoba"
|
13 |
) -> None:
|
14 |
|
15 |
self.time_zone = pytz.timezone(zone)
|
|
|
82 |
v_norm = np.linalg.norm(v)
|
83 |
u_norm = np.linalg.norm(u)
|
84 |
similarity = v @ u / (v_norm * u_norm)
|
85 |
+
return similarity
|
86 |
+
|
87 |
+
|
88 |
+
def axes_labels_format(
|
89 |
+
left: str,
|
90 |
+
right: str,
|
91 |
+
sep: str,
|
92 |
+
word_wrap: int = 4
|
93 |
+
) -> str:
|
94 |
+
|
95 |
+
def sparse(
|
96 |
+
word: str,
|
97 |
+
max_len: int
|
98 |
+
) -> str:
|
99 |
+
|
100 |
+
diff = max_len-len(word)
|
101 |
+
rest = diff if diff > 0 else 0
|
102 |
+
return word+" "*rest
|
103 |
+
|
104 |
+
def gen_block(
|
105 |
+
list_: List[str],
|
106 |
+
n_rows:int,
|
107 |
+
n_cols:int
|
108 |
+
) -> List[str]:
|
109 |
+
|
110 |
+
block = []
|
111 |
+
block_row = []
|
112 |
+
for r in range(n_rows):
|
113 |
+
for c in range(n_cols):
|
114 |
+
i = r * n_cols + c
|
115 |
+
w = list_[i] if i <= len(list_) - 1 else ""
|
116 |
+
block_row.append(w)
|
117 |
+
if (i+1) % n_cols == 0:
|
118 |
+
block.append(block_row)
|
119 |
+
block_row = []
|
120 |
+
return block
|
121 |
+
|
122 |
+
# Transform 'string' to list of string
|
123 |
+
l_list = [word.strip() for word in left.split(",") if word.strip() != ""]
|
124 |
+
r_list = [word.strip() for word in right.split(",") if word.strip() != ""]
|
125 |
+
|
126 |
+
# Get longest word, and longest_list
|
127 |
+
longest_list = max(len(l_list), len(r_list))
|
128 |
+
longest_word = len(max( max(l_list, key=len), max(r_list, key=len)))
|
129 |
+
|
130 |
+
# Creation of word blocks for each list
|
131 |
+
n_rows = (longest_list // word_wrap) if longest_list % word_wrap == 0 else (longest_list // word_wrap) + 1
|
132 |
+
n_cols = word_wrap
|
133 |
+
|
134 |
+
l_block = gen_block(l_list, n_rows, n_cols)
|
135 |
+
r_block = gen_block(r_list, n_rows, n_cols)
|
136 |
+
|
137 |
+
# Transform list of list to sparse string
|
138 |
+
labels = ""
|
139 |
+
for i,(l,r) in enumerate(zip(l_block, r_block)):
|
140 |
+
line = ' '.join([sparse(w, longest_word) for w in l]) + sep + \
|
141 |
+
' '.join([sparse(w, longest_word) for w in r])
|
142 |
+
labels += f"β {line} β\n" if i==0 else f" {line} \n"
|
143 |
+
|
144 |
+
return labels
|