Spaces:
Sleeping
Sleeping
File size: 6,374 Bytes
dbaa71b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
import pytest
from obsei.analyzer.classification_analyzer import ClassificationAnalyzerConfig
from obsei.payload import TextPayload
from obsei.postprocessor.inference_aggregator import InferenceAggregatorConfig
from obsei.postprocessor.inference_aggregator_function import (
ClassificationAverageScore,
ClassificationMaxCategories,
)
from obsei.preprocessor.text_splitter import TextSplitterConfig
GOOD_TEXT = """If anyone is interested... these are our hosts. I can’t recommend them enough, Abc & Pbc.
The unit is just lovely, you go to sleep & wake up to this incredible place, and you have use of a Weber grill and a ridiculously indulgent hot-tub under the stars"""
BAD_TEXT = """I had the worst experience ever with XYZ in Egypt. Bad Cars, asking to pay in cash, do not have enough fuel, do not open AC, wait far away from my location until the trip is cancelled, call and ask about the destination then cancel, and more. Worst service."""
MIXED_TEXT = """I am mixed"""
TEXTS = [GOOD_TEXT, BAD_TEXT, MIXED_TEXT]
BUY_INTENT = """I am interested in this style of PGN-ES-D-6150 /Direct drive energy saving servo motor price and in doing business with you. Could you please send me the quotation"""
SELL_INTENT = """Black full body massage chair for sale."""
BUY_SELL_TEXTS = [BUY_INTENT, SELL_INTENT]
def test_zero_shot_analyzer(zero_shot_analyzer):
labels = ["facility", "food", "comfortable", "positive", "negative"]
source_responses = [
TextPayload(processed_text=text, source_name="sample") for text in TEXTS
]
analyzer_responses = zero_shot_analyzer.analyze_input(
source_response_list=source_responses,
analyzer_config=ClassificationAnalyzerConfig(labels=labels),
)
assert len(analyzer_responses) == len(TEXTS)
for analyzer_response in analyzer_responses:
assert len(analyzer_response.segmented_data["classifier_data"]) == len(labels)
assert "positive" in analyzer_response.segmented_data["classifier_data"]
assert "negative" in analyzer_response.segmented_data["classifier_data"]
@pytest.mark.parametrize(
"label_map, expected", [
(None, ["LABEL_1", "LABEL_0"]),
({"LABEL_1": "Buy", "LABEL_0": "Sell"}, ["Buy", "Sell"])
]
)
def test_text_classification_analyzer(text_classification_analyzer, label_map, expected):
source_responses = [
TextPayload(processed_text=text, source_name="sample")
for text in BUY_SELL_TEXTS
]
analyzer_responses = text_classification_analyzer.analyze_input(
source_response_list=source_responses,
analyzer_config=ClassificationAnalyzerConfig(
label_map=label_map,
),
)
assert len(analyzer_responses) == len(BUY_SELL_TEXTS)
for analyzer_response in analyzer_responses:
assert analyzer_response.segmented_data["classifier_data"] is not None
assert analyzer_response.segmented_data["classifier_data"].keys() <= set(expected)
@pytest.mark.parametrize(
"aggregate_function", [ClassificationAverageScore(), ClassificationMaxCategories()]
)
def test_classification_analyzer_with_splitter_aggregator(
aggregate_function, zero_shot_analyzer
):
labels = ["facility", "food", "comfortable", "positive", "negative"]
source_responses = [
TextPayload(processed_text=text, source_name="sample") for text in TEXTS
]
analyzer_responses = zero_shot_analyzer.analyze_input(
source_response_list=source_responses,
analyzer_config=ClassificationAnalyzerConfig(
labels=labels,
use_splitter_and_aggregator=True,
splitter_config=TextSplitterConfig(max_split_length=50),
aggregator_config=InferenceAggregatorConfig(
aggregate_function=aggregate_function
),
),
)
assert len(analyzer_responses) == len(TEXTS)
for analyzer_response in analyzer_responses:
assert "aggregator_data" in analyzer_response.segmented_data
def test_vader_analyzer(vader_analyzer):
source_responses = [
TextPayload(processed_text=text, source_name="sample") for text in TEXTS
]
analyzer_responses = vader_analyzer.analyze_input(
source_response_list=source_responses
)
assert len(analyzer_responses) == len(TEXTS)
for analyzer_response in analyzer_responses:
assert len(analyzer_response.segmented_data["classifier_data"]) == 2
assert "positive" in analyzer_response.segmented_data["classifier_data"]
assert "negative" in analyzer_response.segmented_data["classifier_data"]
def test_trf_ner_analyzer(trf_ner_analyzer):
source_responses = [
TextPayload(
processed_text="My name is Sam and I live in Berlin, Germany.",
source_name="sample",
)
]
analyzer_responses = trf_ner_analyzer.analyze_input(
source_response_list=source_responses,
)
assert len(analyzer_responses) == 1
entities = analyzer_responses[0].segmented_data["ner_data"]
matched_count = 0
for entity in entities:
if entity["word"] == "Sam" and entity["entity_group"] == "PER":
matched_count = matched_count + 1
elif entity["word"] == "Berlin" and entity["entity_group"] == "LOC":
matched_count = matched_count + 1
elif entity["word"] == "Germany" and entity["entity_group"] == "LOC":
matched_count = matched_count + 1
assert matched_count == 3
def test_spacy_ner_analyzer(spacy_ner_analyzer):
source_responses = [
TextPayload(
processed_text="My name is Sam and I live in Berlin, Germany.",
source_name="sample",
)
]
analyzer_responses = spacy_ner_analyzer.analyze_input(
source_response_list=source_responses,
)
assert len(analyzer_responses) == 1
entities = analyzer_responses[0].segmented_data["ner_data"]
matched_count = 0
for entity in entities:
if entity["word"] == "Sam" and entity["entity_group"] == "PERSON":
matched_count = matched_count + 1
elif entity["word"] == "Berlin" and entity["entity_group"] == "GPE":
matched_count = matched_count + 1
elif entity["word"] == "Germany" and entity["entity_group"] == "GPE":
matched_count = matched_count + 1
assert matched_count == 3
|