|
|
|
|
|
|
|
|
|
|
|
|
|
from examples.speech_recognition.criterions.cross_entropy_acc import ( |
|
CrossEntropyWithAccCriterion, |
|
) |
|
|
|
from .asr_test_base import CrossEntropyCriterionTestBase |
|
|
|
|
|
class CrossEntropyWithAccCriterionTest(CrossEntropyCriterionTestBase): |
|
def setUp(self): |
|
self.criterion_cls = CrossEntropyWithAccCriterion |
|
super().setUp() |
|
|
|
def test_cross_entropy_all_correct(self): |
|
sample = self.get_test_sample(correct=True, soft_target=False, aggregate=False) |
|
loss, sample_size, logging_output = self.criterion( |
|
self.model, sample, "sum", log_probs=True |
|
) |
|
assert logging_output["correct"] == 20 |
|
assert logging_output["total"] == 20 |
|
assert logging_output["sample_size"] == 20 |
|
assert logging_output["ntokens"] == 20 |
|
|
|
def test_cross_entropy_all_wrong(self): |
|
sample = self.get_test_sample(correct=False, soft_target=False, aggregate=False) |
|
loss, sample_size, logging_output = self.criterion( |
|
self.model, sample, "sum", log_probs=True |
|
) |
|
assert logging_output["correct"] == 0 |
|
assert logging_output["total"] == 20 |
|
assert logging_output["sample_size"] == 20 |
|
assert logging_output["ntokens"] == 20 |
|
|