|
import torch |
|
import unittest |
|
import math |
|
import re |
|
import warnings |
|
from dataclasses import asdict, dataclass, field |
|
from enum import Enum |
|
from typing import List, Optional, Tuple, Union |
|
import itertools |
|
import copy |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from transformers.pytorch_utils import Conv1D |
|
from .gating import GATING_TO_MODEL_MAPPING |
|
|
|
from ..import_utils import is_bnb_4bit_available, is_bnb_available |
|
from ..utils import ( |
|
COMMON_LAYERS_PATTERN, |
|
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, |
|
ModulesToSaveWrapper, |
|
PeftConfig, |
|
PeftType, |
|
_freeze_adapter, |
|
_get_submodules, |
|
transpose, |
|
) |
|
|
|
if is_bnb_available(): |
|
import bitsandbytes as bnb |
|
from moelora import * |
|
|
|
class TestMoELoRA(unittest.TestCase): |
|
def setUp(self): |
|
self.model = MoELoRA() |
|
|
|
def test_forward_no_adapters(self): |
|
x = torch.randn(10, 20, 30) |
|
output = self.model.forward(x) |
|
self.assertEqual(output.shape, (10, 20, 30)) |
|
|
|
def test_forward_with_adapters(self): |
|
x = torch.randn(10, 20, 30) |
|
self.model.active_adapter = 'adapter1' |
|
output = self.model.forward(x) |
|
self.assertEqual(output.shape, (10, 20, 30)) |
|
|
|
def test_forward_with_global_user_embeds(self): |
|
x = torch.randn(10, 20, 30) |
|
self.model.active_adapter = 'adapter1' |
|
self.model.global_user_embeds = [torch.randn(10, 30)] |
|
output = self.model.forward(x) |
|
self.assertEqual(output.shape, (10, 20, 30)) |
|
|
|
if __name__ == '__main__': |
|
unittest.main()import torch |
|
import unittest |
|
|
|
class TestMoELoRA(unittest.TestCase): |
|
def setUp(self): |
|
self.model = MoELoRA() |
|
|
|
def test_forward_no_adapters(self): |
|
x = torch.randn(10, 20, 30) |
|
output = self.model.forward(x) |
|
self.assertEqual(output.shape, (10, 20, 30)) |
|
|
|
def test_forward_with_adapters(self): |
|
x = torch.randn(10, 20, 30) |
|
self.model.active_adapter = 'adapter1' |
|
output = self.model.forward(x) |
|
self.assertEqual(output.shape, (10, 20, 30)) |
|
|
|
def test_forward_with_global_user_embeds(self): |
|
x = torch.randn(10, 20, 30) |
|
self.model.active_adapter = 'adapter1' |
|
self.model.global_user_embeds = [torch.randn(10, 30)] |
|
output = self.model.forward(x) |
|
self.assertEqual(output.shape, (10, 20, 30)) |
|
|
|
def test_forward_with_global_user_embeds_exception(self): |
|
x = torch.randn(10, 20, 30) |
|
self.model.active_adapter = 'adapter1' |
|
self.model.global_user_embeds = [torch.randn(5, 30)] |
|
output = self.model.forward(x) |
|
self.assertEqual(output.shape, (10, 20, 30)) |
|
|
|
def test_forward_no_global_user_embeds(self): |
|
x = torch.randn(10, 20, 30) |
|
self.model.active_adapter = 'adapter1' |
|
self.model.global_user_embeds = [] |
|
output = self.model.forward(x) |
|
self.assertEqual(output.shape, (10, 20, 30)) |
|
|
|
if __name__ == '__main__': |
|
unittest.main() |