Spaces:
Runtime error
Runtime error
from datasets import load_dataset, interleave_datasets | |
from modules.module_segmentedWordCloud import SegmentedWordCloud | |
from modules.module_customSubsetsLabel import CustomSubsetsLabel | |
from random import sample as random_sample | |
from typing import Tuple, List, Dict | |
import re | |
import matplotlib as mpl | |
mpl.use('Agg') | |
import matplotlib.pyplot as plt | |
class Word2Context: | |
def __init__( | |
self, | |
context_ds_name: str, # Context dataset HF name | path | |
vocabulary, # Vocabulary class instance | |
errorManager # ErrorManager class instance | |
) -> None: | |
self.context_ds_name = context_ds_name | |
# Vocabulary class | |
self.vocab = vocabulary | |
# Custom Label component | |
self.Label = CustomSubsetsLabel() | |
self.errorManager = errorManager | |
def errorChecking( | |
self, | |
word: str | |
) -> str: | |
out_msj = "" | |
if not word: | |
out_msj = ['EMBEDDING_NO_WORD_PROVIDED'] | |
else: | |
if word not in self.vocab: | |
out_msj = ['EMBEDDING_WORD_OOV', word] | |
return self.errorManager.process(out_msj) | |
def genWebLink( | |
self, | |
text: str | |
) -> str: | |
text = text.replace("\"", "'") | |
text = text.replace("<u><b>", "") | |
text = text.replace("</b></u>", "") | |
url = "https://www.google.com.tr/search?q={}".format(text) | |
return '<a href="{}" rel="noopener noreferrer" target="_blank"><center>ππ</center></a>'.format(url) | |
def genWordCloudPlot( | |
self, | |
word: str, | |
figsize: Tuple[int,int]=(9,3) | |
) -> plt.Figure: | |
err = self.errorChecking(word) | |
if err: | |
raise Exception(err) | |
freq_dic, l_group, g_group = self.vocab.getWordNeighbors(word, n_neighbors=10) | |
wc = SegmentedWordCloud(freq_dic, l_group, g_group) | |
return wc.plot(figsize) | |
def genDistributionPlot( | |
self, | |
word: str, | |
figsize: Tuple[int,int]=(6,1) | |
) -> plt.Figure: | |
err = self.errorChecking(word) | |
if err: | |
raise Exception(err) | |
x_values, y_values = self.vocab.distribution() | |
w_percentile = self.vocab.getPercentile(word) | |
w_freq = self.vocab.getFreq(word) | |
fig, ax = plt.subplots(figsize=figsize) | |
ax.plot(x_values, y_values, color='green') | |
ax.fill_between(x_values, y_values, color='lightgreen',) | |
ax.axvline(x=max(0,w_percentile-.01), | |
color='blue', | |
linewidth=7, | |
alpha=.1, | |
linestyle='-' | |
) | |
ax.axvline(x=min(100,w_percentile+.01), | |
color='black', | |
linewidth=7, | |
alpha=.1, | |
linestyle='-' | |
) | |
ax.axvline(x=w_percentile, | |
color='#d35400', | |
linewidth=2, | |
linestyle='--', | |
label=f'{w_freq}\n(frecuencia total)' | |
) | |
ax.axis('off') | |
plt.legend(loc='upper left', prop={'size': 7}) | |
return fig | |
def findSplits( | |
self, | |
word: str, | |
subsets_list: List[str] | |
): | |
err = self.errorChecking(word) | |
if err: | |
raise Exception(err) | |
w_splits = self.vocab.getSplits(word) | |
splits_list = [] | |
for subset in subsets_list: | |
current_split_list = [] | |
for s in w_splits: | |
if (subset == s.split("_")[0]): | |
current_split_list.append(s) | |
if current_split_list: | |
splits_list.append(current_split_list) | |
splits_list = [random_sample(s_list, 1)[0] for s_list in splits_list] | |
ds_list = [ | |
load_dataset(path=self.context_ds_name, name=split, streaming=True, split='all') | |
for split in splits_list | |
] | |
datasets = ds_list[0] | |
if len(ds_list) > 1: | |
datasets = interleave_datasets(ds_list, probabilities=None) | |
return datasets | |
def findContexts( | |
self, | |
sample: str, | |
word: str | |
) -> Dict[str,str]: | |
sample = sample['text'].strip() | |
context = "" | |
m = re.search(r'\b{}\b'.format(word), sample) | |
if m: | |
init = m.span()[0] | |
end = init+len(word) | |
context = sample[:init]+"<u><b>"+word+"</b></u>"+sample[end:] | |
return {'context':context} | |
def getSubsetsInfo( | |
self, | |
word: str | |
) -> Tuple: | |
err = self.errorChecking(word) | |
if err: | |
raise Exception(err) | |
total_freq = self.vocab.getFreq(word) | |
subsets_name_list = list(self.vocab.getSubsets(word).keys()) | |
subsets_freq_list = list(self.vocab.getSubsets(word).values()) | |
# Create subset frequency dict to subset_freq component | |
subsets_info = { | |
s_name + f" ({s_freq})": s_freq/total_freq | |
for s_name, s_freq in zip(subsets_name_list, subsets_freq_list) | |
} | |
subsets_origin_info = dict(sorted(subsets_info.items(), key=lambda x: x[1], reverse=True)) | |
subsets_info = self.Label.compute(subsets_origin_info) | |
return subsets_info, subsets_origin_info | |
def getContexts( | |
self, | |
word: str, | |
n_context: int, | |
ds | |
) -> List[Tuple]: | |
err = self.errorChecking(word) | |
if err: | |
raise Exception(err) | |
ds_w_contexts = ds.map(lambda sample: self.findContexts(sample, word)) | |
only_contexts = ds_w_contexts.filter(lambda sample: sample['context'] != "") | |
shuffle_contexts = only_contexts.shuffle(buffer_size=10) | |
list_of_dict = list(shuffle_contexts.take(n_context)) | |
list_of_contexts = [ | |
(i, dic['context'], dic['subset']) | |
for i,dic in enumerate(list_of_dict) | |
] | |
return list_of_contexts |