File size: 1,593 Bytes
217780a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import re

from m4.models.vbloom.configuration_vbloom import VBloomConfig
from m4.models.vbloom.modeling_vbloom import VBloomForCausalLM
from m4.models.vgpt2.configuration_vgpt2 import VGPT2Config
from m4.models.vgpt2.modeling_vgpt2 import VGPT2LMHeadModel
from m4.models.vgpt_neo.configuration_vgpt_neo import VGPTNeoConfig
from m4.models.vgpt_neo.modeling_vgpt_neo import VGPTNeoForCausalLM
from m4.models.vllama.configuration_vllama import VLlamaConfig
from m4.models.vllama.modeling_vllama import VLlamaForCausalLM
from m4.models.vopt.configuration_vopt import VOPTConfig
from m4.models.vopt.modeling_vopt import VOPTForCausalLM
from m4.models.vt5.configuration_vt5 import VT5Config
from m4.models.vt5.modeling_vt5 import VT5ForConditionalGeneration


model_name2classes = {
    r"bloom|bigscience-small-testing": [VBloomConfig, VBloomForCausalLM],
    r"gpt-neo|gptneo": [VGPTNeoConfig, VGPTNeoForCausalLM],
    r"gpt2": [VGPT2Config, VGPT2LMHeadModel],
    r"opt": [VOPTConfig, VOPTForCausalLM],
    r"t5": [VT5Config, VT5ForConditionalGeneration],
    r"llama": [VLlamaConfig, VLlamaForCausalLM],
}


def model_name_to_classes(model_name_or_path):
    """returns config_class, model_class for a given model name or path"""

    model_name_lowcase = model_name_or_path.lower()
    for rx, classes in model_name2classes.items():
        if re.search(rx, model_name_lowcase):
            return classes
    else:
        raise ValueError(
            f"Unknown type of backbone LM. Got {model_name_or_path}, supported regexes:"
            f" {list(model_name2classes.keys())}."
        )