Spaces:
Runtime error
Runtime error
import unittest | |
from dataclasses import dataclass, is_dataclass | |
import torch | |
from torch.utils.data import DataLoader, Dataset | |
from transformers import AutoTokenizer | |
from trlx.pipeline import MiniBatchIterator | |
from trlx.pipeline.offline_pipeline import ( | |
ILQLRolloutStorage, | |
ILQLSeq2SeqRolloutStorage, | |
PromptPipeline, | |
) | |
class DataclassBatch: | |
query_tensors: torch.Tensor | |
response_tensors: torch.Tensor | |
logprobs: torch.Tensor | |
values: torch.Tensor | |
rewards: torch.Tensor | |
class DummyDataset(Dataset, DataclassBatch): | |
def __init__(self, num_samples): | |
self.query_tensors = torch.randn(num_samples, 64) | |
self.response_tensors = torch.randn(num_samples, 64) | |
self.logprobs = torch.randn(num_samples, 1) | |
self.values = torch.randn(num_samples, 1) | |
self.rewards = torch.randn(num_samples, 1) | |
def __len__(self): | |
return len(self.query_tensors) | |
def __getitem__(self, idx) -> DataclassBatch: | |
return DataclassBatch( | |
query_tensors=self.query_tensors[idx], | |
response_tensors=self.response_tensors[idx], | |
logprobs=self.logprobs[idx], | |
values=self.values[idx], | |
rewards=self.rewards[idx], | |
) | |
def collate_fn(batch): | |
return DataclassBatch( | |
query_tensors=torch.stack([sample.query_tensors for sample in batch]), | |
response_tensors=torch.stack([sample.response_tensors for sample in batch]), | |
logprobs=torch.stack([sample.logprobs for sample in batch]), | |
values=torch.stack([sample.values for sample in batch]), | |
rewards=torch.stack([sample.rewards for sample in batch]), | |
) | |
class BaseTestMiniBatchIterator(unittest.TestCase): | |
def check_mini_batch(self, mb, expected_mini_batch_size): | |
if is_dataclass(mb): | |
mb = mb.__dict__ | |
for key, value in mb.items(): | |
self.assertEqual(value.size(0), expected_mini_batch_size) | |
class TestMiniBatchDL(BaseTestMiniBatchIterator): | |
def test_batch(self): | |
batch = DataclassBatch( | |
torch.tensor([1]), torch.tensor([2]), torch.tensor([3]), torch.tensor([4]), torch.tensor([5]) | |
) | |
self.assertTrue(is_dataclass(batch)) | |
self.assertTrue(all(isinstance(v, torch.Tensor) for v in batch.__dict__.values())) | |
def test_minibatch_iterator(self): | |
# Create Dummy Dataset and DataLoader | |
dummy_dataset = DummyDataset(32) | |
dummy_dataloader = DataLoader(dummy_dataset, batch_size=8, shuffle=True, num_workers=0, collate_fn=collate_fn) | |
iterator = MiniBatchIterator(dummy_dataloader, mb_size=4, num_mb=2) | |
for minibatches in iterator: | |
for minibatch in minibatches: | |
self.assertIsInstance(minibatch, DataclassBatch) | |
self.assertTrue(all(isinstance(v, torch.Tensor) for v in minibatch.__dict__.values())) | |
self.check_mini_batch(minibatch, 4) | |
def test_minibatch_iterator_with_undivisible_mbsize(self): | |
# Create Dummy Dataset and DataLoader | |
dummy_dataset = DummyDataset(32) | |
dummy_dataloader = DataLoader(dummy_dataset, batch_size=8, shuffle=True, num_workers=0, collate_fn=collate_fn) | |
iterator = MiniBatchIterator(dummy_dataloader, mb_size=3, num_mb=3) | |
for minibatches in iterator: | |
for minibatch in minibatches[:-1]: | |
self.assertIsInstance(minibatch, DataclassBatch) | |
self.assertTrue(all(isinstance(v, torch.Tensor) for v in minibatch.__dict__.values())) | |
self.check_mini_batch(minibatch, 3) | |
# last minibatch has only 2 samples | |
minibatch = minibatches[-1] | |
self.assertIsInstance(minibatch, DataclassBatch) | |
self.assertTrue(all(isinstance(v, torch.Tensor) for v in minibatch.__dict__.values())) | |
self.check_mini_batch(minibatch, 2) | |
def test_minibatch_iterator_with_remainder(self): | |
# Create Dummy Dataset and DataLoader | |
dummy_dataset = DummyDataset(36) | |
dummy_dataloader = DataLoader(dummy_dataset, batch_size=8, shuffle=True, num_workers=0, collate_fn=collate_fn) | |
iterator = MiniBatchIterator(dummy_dataloader, mb_size=2, num_mb=4) | |
for i in range(4): | |
minibatches = next(iterator) | |
for minibatch in minibatches[:-1]: | |
self.assertIsInstance(minibatch, DataclassBatch) | |
self.assertTrue(all(isinstance(v, torch.Tensor) for v in minibatch.__dict__.values())) | |
self.check_mini_batch(minibatch, 2) | |
# last iteration has only 2 minibatches | |
minibatches = next(iterator) | |
self.assertEqual(len(minibatches), 2) | |
for minibatch in minibatches: | |
self.assertIsInstance(minibatch, DataclassBatch) | |
self.assertTrue(all(isinstance(v, torch.Tensor) for v in minibatch.__dict__.values())) | |
self.check_mini_batch(minibatch, 2) | |
def test_minibatch_iterator_with_smaller_dataset(self): | |
# Create Dummy Dataset and DataLoader with size smaller than batch size | |
dummy_dataset = DummyDataset(6) | |
dummy_dataloader = DataLoader(dummy_dataset, batch_size=8, shuffle=True, num_workers=0, collate_fn=collate_fn) | |
iterator = MiniBatchIterator(dummy_dataloader, mb_size=2, num_mb=4) | |
minibatches = next(iterator) | |
for minibatch in minibatches: | |
self.assertIsInstance(minibatch, DataclassBatch) | |
self.assertTrue(all(isinstance(v, torch.Tensor) for v in minibatch.__dict__.values())) | |
with self.assertRaises(StopIteration): | |
minibatches = next(iterator) | |
def test_minibatch_content(self): | |
dummy_dataset = DummyDataset(32) | |
dummy_dataloader = DataLoader(dummy_dataset, batch_size=8, shuffle=False, num_workers=0, collate_fn=collate_fn) | |
iterator = MiniBatchIterator(dummy_dataloader, mb_size=4, num_mb=2) | |
idx = 0 | |
for minibatches in iterator: | |
for minibatch in minibatches: | |
for key in minibatch.__dict__.keys(): | |
original_data = getattr(dummy_dataset, key) | |
start_idx = idx * minibatch.__dict__[key].size(0) | |
end_idx = start_idx + minibatch.__dict__[key].size(0) | |
expected_data = original_data[start_idx:end_idx] | |
# Check if the tensor content in the minibatch is consistent with the original dataset | |
self.assertTrue(torch.all(torch.eq(minibatch.__dict__[key], expected_data))) | |
idx += 1 | |
# Test if the iterator covered all the samples in the dataset | |
self.assertEqual(idx * iterator.mb_size, len(dummy_dataset)) | |
class TestMiniBatchIteratorWithPromptPipeline(BaseTestMiniBatchIterator): | |
def test_minibatch_iterator_with_prompt_pipeline(self): | |
# Load tokenizer | |
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") | |
# Create prompts | |
prompts = ["This is a test prompt."] * 32 | |
prompt_pipeline = PromptPipeline(prompts, max_prompt_length=20, tokenizer=tokenizer) | |
prompt_dataloader = prompt_pipeline.create_loader(batch_size=8, shuffle=True) | |
iterator = MiniBatchIterator(prompt_dataloader, mb_size=4, num_mb=2) | |
for minibatches in iterator: | |
for minibatch in minibatches: | |
self.assertTrue("input_ids" in minibatch) | |
self.assertTrue("attention_mask" in minibatch) | |
self.assertTrue(isinstance(minibatch["input_ids"], torch.Tensor)) | |
self.assertTrue(isinstance(minibatch["attention_mask"], torch.Tensor)) | |
self.check_mini_batch(minibatch, 4) | |
class TestMiniBatchIteratorWithILQLRollouts(BaseTestMiniBatchIterator): | |
def create_dummy_tensors(self, num_samples): | |
input_ids = torch.randint(0, 100, (num_samples, 10)) | |
attention_mask = torch.randint(0, 2, (num_samples, 10)) | |
rewards = torch.randn(num_samples, 1) | |
states_ixs = torch.randint(0, 100, (num_samples, 1)) | |
actions_ixs = torch.randint(0, 100, (num_samples, 1)) | |
dones = torch.randint(0, 2, (num_samples, 1), dtype=torch.bool) | |
return input_ids, attention_mask, rewards, states_ixs, actions_ixs, dones | |
def test_minibatch_iterator_with_ilql_rollout_storage(self): | |
# Create dummy data | |
input_ids, attention_mask, rewards, states_ixs, actions_ixs, dones = self.create_dummy_tensors(32) | |
# Create ILQLRolloutStorage instance | |
ilql_rollout_storage = ILQLRolloutStorage(input_ids, attention_mask, rewards, states_ixs, actions_ixs, dones) | |
ilql_dataloader = ilql_rollout_storage.create_loader(batch_size=8) | |
iterator = MiniBatchIterator(ilql_dataloader, mb_size=4, num_mb=2) | |
for minibatches in iterator: | |
self.assertEqual(len(minibatches), 2) | |
for minibatch in minibatches: | |
self.check_mini_batch(minibatch, expected_mini_batch_size=4) | |
def test_minibatch_iterator_with_ilql_seq2seq_rollout_storage(self): | |
# Create dummy data | |
input_ids, attention_mask, rewards, states_ixs, actions_ixs, dones = self.create_dummy_tensors(32) | |
decoder_input_ids = torch.randint(0, 100, (32, 10)) | |
# Create ILQLSeq2SeqRolloutStorage instance | |
ilql_seq2seq_rollout_storage = ILQLSeq2SeqRolloutStorage( | |
input_ids, attention_mask, decoder_input_ids, rewards, states_ixs, actions_ixs, dones | |
) | |
ilql_seq2seq_dataloader = ilql_seq2seq_rollout_storage.create_loader(batch_size=8) | |
iterator = MiniBatchIterator(ilql_seq2seq_dataloader, mb_size=4, num_mb=2) | |
for minibatches in iterator: | |
self.assertEqual(len(minibatches), 2) | |
for minibatch in minibatches: | |
self.check_mini_batch(minibatch, expected_mini_batch_size=4) | |
if __name__ == "__main__": | |
unittest.main() | |