import os import tempfile import torch from mergekit.io import TensorWriter class TestTensorWriter: def test_safetensors(self): with tempfile.TemporaryDirectory() as d: writer = TensorWriter(d, safe_serialization=True) writer.save_tensor("steve", torch.randn(4)) writer.finalize() assert os.path.exists(os.path.join(d, "model-00001-of-00001.safetensors")) assert os.path.exists(os.path.join(d, "model.safetensors.index.json")) def test_pickle(self): with tempfile.TemporaryDirectory() as d: writer = TensorWriter(d, safe_serialization=False) writer.save_tensor("timothan", torch.randn(4)) writer.finalize() assert os.path.exists(os.path.join(d, "pytorch_model-00001-of-00001.bin")) assert os.path.exists(os.path.join(d, "pytorch_model.bin.index.json")) def test_duplicate_tensor(self): with tempfile.TemporaryDirectory() as d: writer = TensorWriter(d, safe_serialization=True) jim = torch.randn(4) writer.save_tensor("jim", jim) writer.save_tensor("jimbo", jim) writer.finalize() assert os.path.exists(os.path.join(d, "model-00001-of-00001.safetensors")) assert os.path.exists(os.path.join(d, "model.safetensors.index.json"))