edia_lmodels_en / modules /module_connection.py
LMartinezEXEX's picture
Minimal code refractor of connector classes
524b9ae
from abc import ABC
from modules.module_rankSents import RankSents
from modules.module_crowsPairs import CrowsPairs
from typing import List, Tuple
class Connector(ABC):
def parse_word(
self,
word: str
) -> str:
return word.lower().strip()
def parse_words(
self,
array_in_string: str
) -> List[str]:
words = array_in_string.strip()
if not words:
return []
words = [
self.parse_word(word)
for word in words.split(',') if word.strip() != ''
]
return words
def process_error(
self,
err: str
) -> str:
if err:
err = "<center><h3>" + err + "</h3></center>"
return err
class PhraseBiasExplorerConnector(Connector):
def __init__(
self,
**kwargs
) -> None:
language_model = kwargs.get('language_model', None)
lang = kwargs.get('lang', None)
if language_model is None or lang is None:
raise KeyError
self.phrase_bias_explorer = RankSents(
language_model=language_model,
lang=lang
)
def rank_sentence_options(
self,
sent: str,
word_list: str,
banned_word_list: str,
useArticles: bool,
usePrepositions: bool,
useConjunctions: bool
) -> Tuple:
sent = " ".join(sent.strip().replace("*"," * ").split())
err = self.phrase_bias_explorer.errorChecking(sent)
if err:
return self.process_error(err), "", ""
word_list = self.parse_words(word_list)
banned_word_list = self.parse_words(banned_word_list)
all_plls_scores = self.phrase_bias_explorer.rank(
sent,
word_list,
banned_word_list,
useArticles,
usePrepositions,
useConjunctions
)
all_plls_scores = self.phrase_bias_explorer.Label.compute(all_plls_scores)
return self.process_error(err), all_plls_scores, ""
class CrowsPairsExplorerConnector(Connector):
def __init__(
self,
**kwargs
) -> None:
language_model = kwargs.get('language_model', None)
if language_model is None:
raise KeyError
self.crows_pairs_explorer = CrowsPairs(
language_model=language_model
)
def compare_sentences(
self,
sent0: str,
sent1: str,
sent2: str,
sent3: str,
sent4: str,
sent5: str
) -> Tuple:
sent_list = [sent0, sent1, sent2, sent3, sent4, sent5]
err = self.crows_pairs_explorer.errorChecking(
sent_list
)
if err:
return self.process_error(err), "", ""
all_plls_scores = self.crows_pairs_explorer.rank(
sent_list
)
all_plls_scores = self.crows_pairs_explorer.Label.compute(all_plls_scores)
return self.process_error(err), all_plls_scores, ""