qinfeng722's picture
Upload 322 files
5caedb4 verified
import unittest
from unittest.mock import MagicMock
import torch
from llm_studio.src.augmentations.nlp_aug import BaseNLPAug
class TestBaseNLPAug(unittest.TestCase):
def setUp(self):
self.cfg = MagicMock()
self.cfg.tokenizer._tokenizer_mask_token_id = 1337
def test_init(self):
aug = BaseNLPAug(self.cfg)
self.assertEqual(aug.cfg, self.cfg)
def test_forward_no_augmentation(self):
aug = BaseNLPAug(self.cfg)
self.cfg.augmentation.token_mask_probability = 0.0
batch = {
"input_ids": torch.tensor(
[
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
]
),
"attention_mask": torch.tensor(
[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]
),
"labels": torch.tensor(
[
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
]
),
}
result = aug.forward(batch.copy())
self.assertTrue(torch.equal(result["input_ids"], batch["input_ids"]))
self.assertTrue(torch.equal(result["attention_mask"], batch["attention_mask"]))
self.assertTrue(torch.equal(result["labels"], batch["labels"]))
def test_forward_with_augmentation(self):
aug = BaseNLPAug(self.cfg)
self.cfg.augmentation.token_mask_probability = 0.5
torch.manual_seed(42) # For reproducibility
batch = {
"input_ids": torch.tensor(
[
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
]
),
"attention_mask": torch.tensor(
[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]
),
"labels": torch.tensor(
[
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
]
),
}
result = aug.forward(batch.copy())
# Check that some tokens have been masked
self.assertFalse(torch.equal(result["input_ids"], batch["input_ids"]))
# Check that masked tokens are replaced with mask token ID
mask = result["input_ids"] == self.cfg.tokenizer._tokenizer_mask_token_id
self.assertTrue(mask.any())
# Check that attention mask is updated for masked tokens
self.assertTrue(
torch.equal(result["attention_mask"][mask], torch.zeros(mask.sum()))
)
# Check that labels are updated to -100 for masked tokens
self.assertTrue(
torch.equal(result["labels"][mask], torch.ones(mask.sum()) * -100)
)