calahealthgpt / fastchat /model /model_adapter.py
alexshengzhili's picture
Upload folder using huggingface_hub
e72aedf
"""Model adapter registration."""
import math
import sys
from typing import List, Optional
import warnings
if sys.version_info >= (3, 9):
from functools import cache
else:
from functools import lru_cache as cache
import accelerate
import psutil
import torch
from transformers import (
AutoConfig,
AutoModel,
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoTokenizer,
LlamaTokenizer,
LlamaForCausalLM,
T5Tokenizer,
)
from fastchat.modules.gptq import GptqConfig, load_gptq_quantized
from fastchat.conversation import Conversation, get_conv_template
from fastchat.model.compression import load_compress_model
from fastchat.model.model_chatglm import generate_stream_chatglm
from fastchat.model.model_codet5p import generate_stream_codet5p
from fastchat.model.model_falcon import generate_stream_falcon
from fastchat.model.monkey_patch_non_inplace import (
replace_llama_attn_with_non_inplace_operations,
)
from fastchat.utils import get_gpu_memory
class BaseModelAdapter:
"""The base and the default model adapter."""
use_fast_tokenizer = True
def match(self, model_path: str):
return True
def load_model(self, model_path: str, from_pretrained_kwargs: dict):
revision = from_pretrained_kwargs.get("revision", "main")
tokenizer = AutoTokenizer.from_pretrained(
model_path,
use_fast=self.use_fast_tokenizer,
revision=revision,
)
model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs
)
return model, tokenizer
def load_compress_model(self, model_path, device, torch_dtype, revision="main"):
return load_compress_model(
model_path,
device,
torch_dtype,
use_fast=self.use_fast_tokenizer,
revision=revision,
)
def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("one_shot")
# A global registry for all model adapters
# TODO (lmzheng): make it a priority queue.
model_adapters: List[BaseModelAdapter] = []
def register_model_adapter(cls):
"""Register a model adapter."""
model_adapters.append(cls())
@cache
def get_model_adapter(model_path: str) -> BaseModelAdapter:
"""Get a model adapter for a model_path."""
for adapter in model_adapters:
if adapter.match(model_path):
return adapter
raise ValueError(f"No valid model adapter for {model_path}")
def raise_warning_for_incompatible_cpu_offloading_configuration(
device: str, load_8bit: bool, cpu_offloading: bool
):
if cpu_offloading:
if not load_8bit:
warnings.warn(
"The cpu-offloading feature can only be used while also using 8-bit-quantization.\n"
"Use '--load-8bit' to enable 8-bit-quantization\n"
"Continuing without cpu-offloading enabled\n"
)
return False
if not "linux" in sys.platform:
warnings.warn(
"CPU-offloading is only supported on linux-systems due to the limited compatability with the bitsandbytes-package\n"
"Continuing without cpu-offloading enabled\n"
)
return False
if device != "cuda":
warnings.warn(
"CPU-offloading is only enabled when using CUDA-devices\n"
"Continuing without cpu-offloading enabled\n"
)
return False
return cpu_offloading
def load_model(
model_path: str,
device: str,
num_gpus: int,
max_gpu_memory: Optional[str] = None,
load_8bit: bool = False,
cpu_offloading: bool = False,
gptq_config: Optional[GptqConfig] = None,
revision: str = "main",
debug: bool = False,
):
"""Load a model from Hugging Face."""
# get model adapter
adapter = get_model_adapter(model_path)
# Handle device mapping
cpu_offloading = raise_warning_for_incompatible_cpu_offloading_configuration(
device, load_8bit, cpu_offloading
)
if device == "cpu":
kwargs = {"torch_dtype": torch.float32}
elif device == "cuda":
kwargs = {"torch_dtype": torch.float16}
if num_gpus != 1:
kwargs["device_map"] = "auto"
if max_gpu_memory is None:
kwargs[
"device_map"
] = "sequential" # This is important for not the same VRAM sizes
available_gpu_memory = get_gpu_memory(num_gpus)
kwargs["max_memory"] = {
i: str(int(available_gpu_memory[i] * 0.85)) + "GiB"
for i in range(num_gpus)
}
else:
kwargs["max_memory"] = {i: max_gpu_memory for i in range(num_gpus)}
elif device == "mps":
kwargs = {"torch_dtype": torch.float16}
# Avoid bugs in mps backend by not using in-place operations.
replace_llama_attn_with_non_inplace_operations()
elif device == "xpu":
kwargs = {"torch_dtype": torch.bfloat16}
# Try to load ipex, while it looks unused, it links into torch for xpu support
try:
import intel_extension_for_pytorch as ipex
except ImportError:
warnings.warn(
"Intel Extension for PyTorch is not installed, but is required for xpu inference."
)
else:
raise ValueError(f"Invalid device: {device}")
if cpu_offloading:
# raises an error on incompatible platforms
from transformers import BitsAndBytesConfig
if "max_memory" in kwargs:
kwargs["max_memory"]["cpu"] = (
str(math.floor(psutil.virtual_memory().available / 2**20)) + "Mib"
)
kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_8bit_fp32_cpu_offload=cpu_offloading
)
kwargs["load_in_8bit"] = load_8bit
elif load_8bit:
if num_gpus != 1:
warnings.warn(
"8-bit quantization is not supported for multi-gpu inference."
)
else:
return adapter.load_compress_model(
model_path=model_path,
device=device,
torch_dtype=kwargs["torch_dtype"],
revision=revision,
)
elif gptq_config and gptq_config.wbits < 16:
model, tokenizer = load_gptq_quantized(model_path, gptq_config)
if num_gpus != 1:
device_map = accelerate.infer_auto_device_map(
model,
max_memory=kwargs["max_memory"],
no_split_module_classes=["LlamaDecoderLayer"],
)
model = accelerate.dispatch_model(
model, device_map=device_map, offload_buffers=True
)
else:
model.to(device)
return model, tokenizer
kwargs["revision"] = revision
# Load model
adapter = get_model_adapter(model_path)
model, tokenizer = adapter.load_model(model_path, kwargs)
if (device == "cuda" and num_gpus == 1 and not cpu_offloading) or device == "mps":
model.to(device)
elif device == "xpu":
model.eval()
model = model.to("xpu")
model = torch.xpu.optimize(model, dtype=torch.bfloat16, inplace=True)
if debug:
print(model)
return model, tokenizer
def get_conversation_template(model_path: str) -> Conversation:
"""Get the default conversation template."""
adapter = get_model_adapter(model_path)
return adapter.get_default_conv_template(model_path)
def get_generate_stream_function(model: torch.nn.Module, model_path: str):
"""Get the generate_stream function for inference."""
from fastchat.serve.inference import generate_stream
model_type = str(type(model)).lower()
is_chatglm = "chatglm" in model_type
is_falcon = "rwforcausallm" in model_type
is_codet5p = "codet5p" in model_type
if is_chatglm:
return generate_stream_chatglm
elif is_falcon:
return generate_stream_falcon
elif is_codet5p:
return generate_stream_codet5p
else:
return generate_stream
def add_model_args(parser):
parser.add_argument(
"--model-path",
type=str,
default="lmsys/vicuna-7b-v1.3",
help="The path to the weights. This can be a local folder or a Hugging Face repo ID.",
)
parser.add_argument(
"--revision",
type=str,
default="main",
help="Hugging Face Hub model revision identifier",
)
parser.add_argument(
"--device",
type=str,
choices=["cpu", "cuda", "mps", "xpu"],
default="cuda",
help="The device type",
)
parser.add_argument(
"--gpus",
type=str,
default=None,
help="A single GPU like 1 or multiple GPUs like 0,2",
)
parser.add_argument("--num-gpus", type=int, default=1)
parser.add_argument(
"--max-gpu-memory",
type=str,
help="The maximum memory per gpu. Use a string like '13Gib'",
)
parser.add_argument(
"--load-8bit", action="store_true", help="Use 8-bit quantization"
)
parser.add_argument(
"--cpu-offloading",
action="store_true",
help="Only when using 8-bit quantization: Offload excess weights to the CPU that don't fit on the GPU",
)
parser.add_argument(
"--gptq-ckpt",
type=str,
default=None,
help="Load quantized model. The path to the local GPTQ checkpoint.",
)
parser.add_argument(
"--gptq-wbits",
type=int,
default=16,
choices=[2, 3, 4, 8, 16],
help="#bits to use for quantization",
)
parser.add_argument(
"--gptq-groupsize",
type=int,
default=-1,
help="Groupsize to use for quantization; default uses full row.",
)
parser.add_argument(
"--gptq-act-order",
action="store_true",
help="Whether to apply the activation order GPTQ heuristic",
)
def remove_parent_directory_name(model_path):
"""Remove parent directory name."""
if model_path[-1] == "/":
model_path = model_path[:-1]
return model_path.split("/")[-1]
class PeftModelAdapter:
"""Loads any "peft" model and it's base model."""
def match(self, model_path: str):
"""Accepts any model path with "peft" in the name"""
return "peft" in model_path
def load_model(self, model_path: str, from_pretrained_kwargs: dict):
"""Loads the base model then the (peft) adapter weights"""
from peft import PeftConfig, PeftModel
config = PeftConfig.from_pretrained(model_path)
base_model_path = config.base_model_name_or_path
if "peft" in base_model_path:
raise ValueError(
f"PeftModelAdapter cannot load a base model with 'peft' in the name: {config.base_model_name_or_path}"
)
base_adapter = get_model_adapter(base_model_path)
base_model, tokenizer = base_adapter.load_model(
base_model_path, from_pretrained_kwargs
)
model = PeftModel.from_pretrained(base_model, model_path)
return model, tokenizer
def get_default_conv_template(self, model_path: str) -> Conversation:
"""Uses the conv template of the base model"""
from peft import PeftConfig, PeftModel
config = PeftConfig.from_pretrained(model_path)
if "peft" in config.base_model_name_or_path:
raise ValueError(
f"PeftModelAdapter cannot load a base model with 'peft' in the name: {config.base_model_name_or_path}"
)
base_model_path = config.base_model_name_or_path
base_adapter = get_model_adapter(base_model_path)
return base_adapter.get_default_conv_template(config.base_model_name_or_path)
class VicunaAdapter(BaseModelAdapter):
"Model adapater for Vicuna models (e.g., lmsys/vicuna-7b-v1.3)" ""
use_fast_tokenizer = False
def match(self, model_path: str):
return "vicuna" in model_path
def load_model(self, model_path: str, from_pretrained_kwargs: dict):
revision = from_pretrained_kwargs.get("revision", "main")
tokenizer = AutoTokenizer.from_pretrained(
model_path, use_fast=self.use_fast_tokenizer, revision=revision
)
model = AutoModelForCausalLM.from_pretrained(
model_path,
low_cpu_mem_usage=True,
**from_pretrained_kwargs,
)
self.raise_warning_for_old_weights(model)
return model, tokenizer
def get_default_conv_template(self, model_path: str) -> Conversation:
if "v0" in remove_parent_directory_name(model_path):
return get_conv_template("one_shot")
return get_conv_template("vicuna_v1.1")
def raise_warning_for_old_weights(self, model):
if isinstance(model, LlamaForCausalLM) and model.model.vocab_size > 32000:
warnings.warn(
"\nYou are probably using the old Vicuna-v0 model, "
"which will generate unexpected results with the "
"current fastchat.\nYou can try one of the following methods:\n"
"1. Upgrade your weights to the new Vicuna-v1.3: https://github.com/lm-sys/FastChat#vicuna-weights.\n"
"2. Use the old conversation template by `python3 -m fastchat.serve.cli --model-path /path/to/vicuna-v0 --conv-template conv_one_shot`\n"
"3. Downgrade fschat to fschat==0.1.10 (Not recommonded).\n"
)
class LongChatAdapter(BaseModelAdapter):
"Model adapater for LongChat models (e.g., lmsys/longchat-7b-16k)."
use_fast_tokenizer = False
def match(self, model_path: str):
return "longchat" in model_path
def load_model(self, model_path: str, from_pretrained_kwargs: dict):
revision = from_pretrained_kwargs.get("revision", "main")
config = AutoConfig.from_pretrained(model_path, revision=revision)
# Apply monkey patch, TODO(Dacheng): Add flash attention support
from fastchat.model.llama_condense_monkey_patch import (
replace_llama_with_condense,
)
replace_llama_with_condense(config.rope_condense_ratio)
tokenizer = AutoTokenizer.from_pretrained(
model_path, use_fast=self.use_fast_tokenizer, revision=revision
)
model = AutoModelForCausalLM.from_pretrained(
model_path,
low_cpu_mem_usage=True,
**from_pretrained_kwargs,
)
return model, tokenizer
def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("vicuna_v1.1")
class CodeT5pAdapter(BaseModelAdapter):
"""The model adapter for Salesforce/codet5p-6b"""
def match(self, model_path: str):
return "codet5p" in model_path
def load_model(self, model_path: str, from_pretrained_kwargs: dict):
revision = from_pretrained_kwargs.get("revision", "main")
tokenizer = AutoTokenizer.from_pretrained(model_path, revision=revision)
model = AutoModelForSeq2SeqLM.from_pretrained(
model_path,
low_cpu_mem_usage=True,
trust_remote_code=True,
**from_pretrained_kwargs,
)
return model, tokenizer
class T5Adapter(BaseModelAdapter):
"""The model adapter for lmsys/fastchat-t5-3b-v1.0"""
def match(self, model_path: str):
return "t5" in model_path
def load_model(self, model_path: str, from_pretrained_kwargs: dict):
revision = from_pretrained_kwargs.get("revision", "main")
tokenizer = T5Tokenizer.from_pretrained(model_path, revision=revision)
model = AutoModelForSeq2SeqLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs
)
return model, tokenizer
class KoalaAdapter(BaseModelAdapter):
"""The model adapter for koala"""
use_fast_tokenizer = False
def match(self, model_path: str):
return "koala" in model_path
def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("koala_v1")
class AlpacaAdapter(BaseModelAdapter):
"""The model adapter for alpaca"""
use_fast_tokenizer = False
def match(self, model_path: str):
return "alpaca" in model_path.lower()
def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("alpaca")
class ChatGLMAdapter(BaseModelAdapter):
"""The model adapter for THUDM/chatglm-6b, THUDM/chatglm2-6b"""
def match(self, model_path: str):
return "chatglm" in model_path.lower()
def load_model(self, model_path: str, from_pretrained_kwargs: dict):
revision = from_pretrained_kwargs.get("revision", "main")
tokenizer = AutoTokenizer.from_pretrained(
model_path, trust_remote_code=True, revision=revision
)
model = AutoModel.from_pretrained(
model_path, trust_remote_code=True, **from_pretrained_kwargs
)
return model, tokenizer
def get_default_conv_template(self, model_path: str) -> Conversation:
model_path = model_path.lower()
if "chatglm2" in model_path:
return get_conv_template("chatglm2")
return get_conv_template("chatglm")
class DollyV2Adapter(BaseModelAdapter):
"""The model adapter for databricks/dolly-v2-12b"""
def match(self, model_path: str):
return "dolly-v2" in model_path
def load_model(self, model_path: str, from_pretrained_kwargs: dict):
revision = from_pretrained_kwargs.get("revision", "main")
tokenizer = AutoTokenizer.from_pretrained(model_path, revision=revision)
model = AutoModelForCausalLM.from_pretrained(
model_path,
low_cpu_mem_usage=True,
**from_pretrained_kwargs,
)
# 50277 means "### End"
tokenizer.eos_token_id = 50277
model.config.eos_token_id = tokenizer.eos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
return model, tokenizer
def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("dolly_v2")
class OasstPythiaAdapter(BaseModelAdapter):
"""The model adapter for OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5"""
def match(self, model_path: str):
return "oasst" in model_path and "pythia" in model_path
def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("oasst_pythia")
def load_model(self, model_path: str, from_pretrained_kwargs: dict):
model, tokenizer = super().load_model(model_path, from_pretrained_kwargs)
model.config.eos_token_id = tokenizer.eos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
return model, tokenizer
class OasstLLaMAAdapter(BaseModelAdapter):
"""The model adapter for OpenAssistant/oasst-sft-7-llama-30b"""
use_fast_tokenizer = False
def match(self, model_path: str):
if "OpenAssistant-SFT-7-Llama-30B-HF" in model_path:
return True
return "oasst" in model_path and "pythia" not in model_path
def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("oasst_llama")
class PythiaAdapter(BaseModelAdapter):
"""The model adapter for any EleutherAI/pythia model"""
def match(self, model_path: str):
return "pythia" in model_path
def load_model(self, model_path: str, from_pretrained_kwargs: dict):
model, tokenizer = super().load_model(model_path, from_pretrained_kwargs)
model.config.eos_token_id = tokenizer.eos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
return model, tokenizer
class StableLMAdapter(BaseModelAdapter):
"""The model adapter for StabilityAI/stablelm-tuned-alpha-7b"""
def match(self, model_path: str):
return "stablelm" in model_path
def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("stablelm")
class MPTAdapter(BaseModelAdapter):
"""The model adapter for MPT series (mosaicml/mpt-7b-chat, mosaicml/mpt-30b-chat)"""
def match(self, model_path: str):
return "mpt" in model_path
def load_model(self, model_path: str, from_pretrained_kwargs: dict):
revision = from_pretrained_kwargs.get("revision", "main")
model = AutoModelForCausalLM.from_pretrained(
model_path,
low_cpu_mem_usage=True,
trust_remote_code=True,
max_seq_len=8192,
**from_pretrained_kwargs,
)
tokenizer = AutoTokenizer.from_pretrained(
model_path, trust_remote_code=True, revision=revision
)
model.config.eos_token_id = tokenizer.eos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
return model, tokenizer
def get_default_conv_template(self, model_path: str) -> Conversation:
if "mpt-7b-chat" in model_path:
return get_conv_template("mpt-7b-chat")
elif "mpt-30b-chat" in model_path:
return get_conv_template("mpt-30b-chat")
elif "mpt-30b-instruct" in model_path:
return get_conv_template("mpt-30b-instruct")
else:
raise ValueError(f"Unknown MPT model: {model_path}")
class BaizeAdapter(BaseModelAdapter):
"""The model adapter for project-baize/baize-v2-7b"""
use_fast_tokenizer = False
def match(self, model_path: str):
return "baize" in model_path
def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("baize")
class RwkvAdapter(BaseModelAdapter):
"""The model adapter for BlinkDL/RWKV-4-Raven"""
def match(self, model_path: str):
return "RWKV-4" in model_path
def load_model(self, model_path: str, from_pretrained_kwargs: dict):
from fastchat.model.rwkv_model import RwkvModel
model = RwkvModel(model_path)
revision = from_pretrained_kwargs.get("revision", "main")
tokenizer = AutoTokenizer.from_pretrained(
"EleutherAI/pythia-160m", revision=revision
)
return model, tokenizer
def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("rwkv")
class OpenBuddyAdapter(BaseModelAdapter):
"""The model adapter for OpenBuddy/openbuddy-7b-v1.1-bf16-enc"""
use_fast_tokenizer = False
def match(self, model_path: str):
return "openbuddy" in model_path
def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("openbuddy")
class PhoenixAdapter(BaseModelAdapter):
"""The model adapter for FreedomIntelligence/phoenix-inst-chat-7b"""
def match(self, model_path: str):
return "phoenix" in model_path
def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("phoenix")
class ChatGPTAdapter(BaseModelAdapter):
"""The model adapter for ChatGPT"""
def match(self, model_path: str):
return model_path in ("gpt-3.5-turbo", "gpt-4", "gpt-3.5-turbo-16k")
def load_model(self, model_path: str, from_pretrained_kwargs: dict):
raise NotImplementedError()
def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("chatgpt")
class ClaudeAdapter(BaseModelAdapter):
"""The model adapter for Claude"""
def match(self, model_path: str):
return model_path in ["claude-v1", "claude-instant-v1"]
def load_model(self, model_path: str, from_pretrained_kwargs: dict):
raise NotImplementedError()
def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("claude")
class BardAdapter(BaseModelAdapter):
"""The model adapter for Bard"""
def match(self, model_path: str):
return model_path == "bard"
def load_model(self, model_path: str, from_pretrained_kwargs: dict):
raise NotImplementedError()
def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("bard")
class PaLM2Adapter(BaseModelAdapter):
"""The model adapter for PaLM2"""
def match(self, model_path: str):
return model_path == "palm-2"
def load_model(self, model_path: str, from_pretrained_kwargs: dict):
raise NotImplementedError()
def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("bard")
class BiLLaAdapter(BaseModelAdapter):
"""The model adapter for Neutralzz/BiLLa-7B-SFT"""
def match(self, model_path: str):
return "billa" in model_path.lower()
def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("billa")
class RedPajamaINCITEAdapter(BaseModelAdapter):
"""The model adapter for togethercomputer/RedPajama-INCITE-7B-Chat"""
def match(self, model_path: str):
return "redpajama-incite" in model_path.lower()
def load_model(self, model_path: str, from_pretrained_kwargs: dict):
revision = from_pretrained_kwargs.get("revision", "main")
tokenizer = AutoTokenizer.from_pretrained(model_path, revision=revision)
model = AutoModelForCausalLM.from_pretrained(
model_path,
low_cpu_mem_usage=True,
**from_pretrained_kwargs,
)
return model, tokenizer
def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("redpajama-incite")
class H2OGPTAdapter(BaseModelAdapter):
"""The model adapter for h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b"""
use_fast_tokenizer = False
def match(self, model_path: str):
return "h2ogpt" in model_path.lower()
def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("h2ogpt")
class RobinAdapter(BaseModelAdapter):
"""The model adapter for LMFlow/Full-Robin-7b-v2"""
use_fast_tokenizer = False
def match(self, model_path: str):
return "Robin" in model_path
def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("Robin")
class SnoozyAdapter(BaseModelAdapter):
"""The model adapter for nomic-ai/gpt4all-13b-snoozy"""
use_fast_tokenizer = False
def match(self, model_path: str):
return "gpt4all" in model_path and "snoozy" in model_path
def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("snoozy")
class WizardLMAdapter(BaseModelAdapter):
"""The model adapter for WizardLM/WizardLM-13B-V1.0"""
use_fast_tokenizer = False
def match(self, model_path: str):
return "wizardlm" in model_path.lower()
def get_default_conv_template(self, model_path: str) -> Conversation:
model_path = model_path.lower()
if "13b" in model_path or "30b" in model_path:
return get_conv_template("vicuna_v1.1")
else:
# TODO: use the recommended template for 7B
# (https://huggingface.co/WizardLM/WizardLM-13B-V1.0)
return get_conv_template("one_shot")
class ManticoreAdapter(BaseModelAdapter):
"""The model adapter for openaccess-ai-collective/manticore-13b-chat-pyg"""
use_fast_tokenizer = False
def match(self, model_path: str):
return "manticore" in model_path.lower()
def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("manticore")
class GuanacoAdapter(BaseModelAdapter):
"""The model adapter for timdettmers/guanaco-33b-merged"""
use_fast_tokenizer = False
def match(self, model_path: str):
return "guanaco" in model_path
def load_model(self, model_path: str, from_pretrained_kwargs: dict):
revision = from_pretrained_kwargs.get("revision", "main")
tokenizer = AutoTokenizer.from_pretrained(
model_path, use_fast=self.use_fast_tokenizer, revision=revision
)
model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs
)
# Fix a bug in tokenizer config
tokenizer.eos_token_id = model.config.eos_token_id
return model, tokenizer
def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("zero_shot")
class ChangGPTAdapter(BaseModelAdapter):
"""The model adapter for lcw99/polyglot-ko-12.8b-chang-instruct-chat"""
def match(self, model_path: str):
return "polyglot" in model_path and "chang" in model_path
def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("polyglot_changgpt")
class CamelAdapter(BaseModelAdapter):
"""The model adapter for camel-ai/CAMEL-13B-Combined-Data"""
use_fast_tokenizer = False
def match(self, model_path: str):
return "camel" in model_path
def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("vicuna_v1.1")
class TuluAdapter(BaseModelAdapter):
"""The model adapter for allenai/tulu-30b"""
use_fast_tokenizer = False
def match(self, model_path: str):
return "tulu" in model_path
def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("tulu")
class FalconAdapter(BaseModelAdapter):
"""The model adapter for tiiuae/falcon-40b."""
def match(self, model_path: str):
return "falcon" in model_path.lower()
def load_model(self, model_path: str, from_pretrained_kwargs: dict):
revision = from_pretrained_kwargs.get("revision", "main")
# Strongly suggest using bf16, which is recommended by the author of Falcon
tokenizer = AutoTokenizer.from_pretrained(model_path, revision=revision)
model = AutoModelForCausalLM.from_pretrained(
model_path,
low_cpu_mem_usage=True,
trust_remote_code=True,
**from_pretrained_kwargs,
)
# In Falcon tokenizer config and special config there is not any pad token
# Setting `pad_token_id` to 9, which corresponds to special token '>>SUFFIX<<'
tokenizer.pad_token_id = 9
return model, tokenizer
def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("falcon")
class TigerBotAdapter(BaseModelAdapter):
"""The model adapter for TigerResearch/tigerbot-7b-sft"""
def match(self, model_path: str):
return "tigerbot" in model_path.lower()
def load_model(self, model_path: str, from_pretrained_kwargs: dict):
revision = from_pretrained_kwargs.get("revision", "main")
tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=True,
revision=revision,
)
model = AutoModelForCausalLM.from_pretrained(
model_path,
trust_remote_code=True,
low_cpu_mem_usage=True,
**from_pretrained_kwargs,
)
return model, tokenizer
def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("tigerbot")
class BaichuanAdapter(BaseModelAdapter):
"""The model adapter for baichuan-inc/baichuan-7B"""
def match(self, model_path: str):
return "baichuan" in model_path
def load_model(self, model_path: str, from_pretrained_kwargs: dict):
revision = from_pretrained_kwargs.get("revision", "main")
tokenizer = AutoTokenizer.from_pretrained(
model_path, trust_remote_code=True, revision=revision
)
model = AutoModelForCausalLM.from_pretrained(
model_path,
trust_remote_code=True,
low_cpu_mem_usage=True,
**from_pretrained_kwargs,
)
return model, tokenizer
def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("one_shot")
class XGenAdapter(BaseModelAdapter):
"""The model adapter for Salesforce/xgen-7b"""
def match(self, model_path: str):
return "xgen" in model_path
def load_model(self, model_path: str, from_pretrained_kwargs: dict):
revision = from_pretrained_kwargs.get("revision", "main")
model = AutoModelForCausalLM.from_pretrained(
model_path,
low_cpu_mem_usage=True,
trust_remote_code=True,
**from_pretrained_kwargs,
)
tokenizer = AutoTokenizer.from_pretrained(
model_path, trust_remote_code=True, revision=revision
)
model.config.eos_token_id = 50256
return model, tokenizer
def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("xgen")
# Note: the registration order matters.
# The one registered earlier has a higher matching priority.
register_model_adapter(PeftModelAdapter)
register_model_adapter(VicunaAdapter)
register_model_adapter(LongChatAdapter)
register_model_adapter(CodeT5pAdapter)
register_model_adapter(T5Adapter)
register_model_adapter(KoalaAdapter)
register_model_adapter(AlpacaAdapter)
register_model_adapter(ChatGLMAdapter)
register_model_adapter(DollyV2Adapter)
register_model_adapter(OasstPythiaAdapter)
register_model_adapter(OasstLLaMAAdapter)
register_model_adapter(StableLMAdapter)
register_model_adapter(BaizeAdapter)
register_model_adapter(RwkvAdapter)
register_model_adapter(OpenBuddyAdapter)
register_model_adapter(PhoenixAdapter)
register_model_adapter(BardAdapter)
register_model_adapter(PaLM2Adapter)
register_model_adapter(ChatGPTAdapter)
register_model_adapter(ClaudeAdapter)
register_model_adapter(MPTAdapter)
register_model_adapter(BiLLaAdapter)
register_model_adapter(RedPajamaINCITEAdapter)
register_model_adapter(H2OGPTAdapter)
register_model_adapter(RobinAdapter)
register_model_adapter(SnoozyAdapter)
register_model_adapter(WizardLMAdapter)
register_model_adapter(ManticoreAdapter)
register_model_adapter(GuanacoAdapter)
register_model_adapter(CamelAdapter)
register_model_adapter(ChangGPTAdapter)
register_model_adapter(TuluAdapter)
register_model_adapter(FalconAdapter)
register_model_adapter(TigerBotAdapter)
register_model_adapter(BaichuanAdapter)
register_model_adapter(XGenAdapter)
register_model_adapter(PythiaAdapter)
# After all adapters, try the default base adapter.
register_model_adapter(BaseModelAdapter)