gmastrapas's picture
feat: remove adapter_mask from interface
d779277
import re
import warnings
from typing import Dict, Optional, Union
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModel, PretrainedConfig
from transformers.modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
BaseModelOutputWithPoolingAndCrossAttentions,
)
_HF_ARCH_DICT = {
# https://huggingface.co/docs/transformers/model_doc/roberta#roberta
'roberta': {
'config_names': {
'context_length': 'max_position_embeddings',
'vocab_size': 'vocab_size',
'width': 'hidden_size',
'heads': 'num_attention_heads',
'layers': 'num_hidden_layers',
'layer_attr': 'layer',
'token_embeddings_attr': 'embeddings',
},
'pooler': 'mean_pooler',
},
# https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
'xlm-roberta': {
'config_names': {
'context_length': 'max_position_embeddings',
'vocab_size': 'vocab_size',
'width': 'hidden_size',
'heads': 'num_attention_heads',
'layers': 'num_hidden_layers',
'layer_attr': 'layer',
'token_embeddings_attr': 'embeddings',
},
'pooler': 'mean_pooler',
},
# https://huggingface.co/docs/transformers/model_doc/bert
'bert': {
'config_names': {
'context_length': 'max_position_embeddings',
'vocab_size': 'vocab_size',
'width': 'hidden_size',
'heads': 'num_attention_heads',
'layers': 'num_hidden_layers',
},
'pooler': 'cls_pooler',
},
}
_POOLERS = {}
def _camel2snake(s):
return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()
def register_pooler(cls):
"""Decorator registering pooler class"""
_POOLERS[_camel2snake(cls.__name__)] = cls
return cls
@register_pooler
class MeanPooler(nn.Module):
@staticmethod
def forward(x: BaseModelOutput, attention_mask: torch.Tensor):
masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
@register_pooler
class MaxPooler(nn.Module):
@staticmethod
def forward(x: BaseModelOutput, attention_mask: torch.Tensor):
masked_output = x.last_hidden_state.masked_fill(
attention_mask.unsqueeze(-1), -torch.inf
)
return masked_output.max(1).values
@register_pooler
class ClsPooler(nn.Module):
def __init__(self, use_pooler_output: bool = True):
super().__init__()
self.cls_token_position = 0
self.use_pooler_output = use_pooler_output
def forward(self, x: BaseModelOutput, _: torch.Tensor):
if (
self.use_pooler_output
and isinstance(
x,
(
BaseModelOutputWithPooling,
BaseModelOutputWithPoolingAndCrossAttentions,
),
)
and (x.pooler_output is not None)
):
return x.pooler_output
return x.last_hidden_state[:, self.cls_token_position, :]
class HFTextEncoder(nn.Module):
output_tokens: torch.jit.Final[bool]
def __init__(
self,
model_name_or_path: str,
output_dim: int,
config: PretrainedConfig = None,
pooler_type: str = None,
proj_type: str = None,
proj_bias: bool = False,
pretrained: bool = True,
output_tokens: bool = False,
trust_remote_code: bool = False,
revision: Optional[str] = None,
code_revision: Optional[str] = None,
default_instruction_task: Optional[str] = None,
default_lora_task: Optional[str] = None,
model_config_kwargs: Optional[Dict] = None,
):
super().__init__()
self.output_tokens = output_tokens
self.output_dim = output_dim
model_config_kwargs = model_config_kwargs or {}
if config is None:
if pretrained:
self.transformer = AutoModel.from_pretrained(
model_name_or_path,
trust_remote_code=trust_remote_code,
revision=revision,
add_pooling_layer=False,
code_revision=code_revision,
**model_config_kwargs,
)
self.config = self.transformer.config
else:
self.config = AutoConfig.from_pretrained(
model_name_or_path,
trust_remote_code=trust_remote_code,
code_revision=code_revision,
)
self.config.update(model_config_kwargs)
self.transformer = AutoModel.from_config(
self.config,
trust_remote_code=trust_remote_code,
add_pooling_layer=False,
code_revision=code_revision,
)
if (
hasattr(self.config, 'is_encoder_decoder')
and self.config.is_encoder_decoder
):
self.transformer = self.transformer.encoder
else:
self.config = config
self.config.update(model_config_kwargs)
self.transformer = AutoModel.from_config(
self.config,
trust_remote_code=trust_remote_code,
revision=revision,
code_revision=code_revision,
)
self.vocab_size = getattr(self.config, 'vocab_size', 0)
self.context_length = getattr(self.config, 'max_position_embeddings', 0)
pooler_type = pooler_type or _HF_ARCH_DICT[self.config.model_type]['pooler']
self.pooler = _POOLERS[pooler_type]()
d_model = getattr(
self.config, _HF_ARCH_DICT[self.config.model_type]['config_names']['width']
)
if (d_model == output_dim) and (proj_type is None): # do we always need a proj?
self.proj = nn.Identity()
elif (d_model != output_dim) or proj_type == 'linear':
self.proj = nn.Linear(d_model, output_dim, bias=proj_bias)
elif proj_type == 'mlp':
hidden_size = (d_model + output_dim) // 2
self.proj = nn.Sequential(
nn.Linear(d_model, hidden_size, bias=proj_bias),
nn.GELU(),
nn.Linear(hidden_size, output_dim, bias=proj_bias),
)
self._task_instructions = {}
self._lora_adaptation_map = {}
self._supports_task_instructions = False
self._supports_lora = False
if (
hasattr(self.transformer, '_adaptation_map')
and len(self.transformer._adaptation_map) > 0
):
self._lora_adaptation_map = self.transformer._adaptation_map
self._supports_lora = True
if (
hasattr(self.transformer, '_task_instructions')
and len(self.transformer._task_instructions) > 0
):
self._task_instructions = self.transformer._task_instructions
self._supports_task_instructions = True
self._default_instruction_task = None
self._default_lora_task = None
self._default_instruction = None
self._default_loraid = None
if default_instruction_task is not None:
self._default_instruction_task = default_instruction_task
self._default_instruction = self.get_instruction_from_task(
default_instruction_task
)
if default_lora_task is not None:
self._default_lora_task = default_lora_task
self._default_loraid = self.get_loraid_from_task(default_lora_task)
@property
def supports_task_instructions(self) -> bool:
return self._supports_task_instructions
@property
def supports_lora(self) -> bool:
return self._supports_lora
@property
def task_instructions(self) -> Dict[str, str]:
return self._task_instructions
@property
def lora_adaptation_map(self) -> Dict[str, int]:
return self._lora_adaptation_map
@property
def default_instruction(self) -> Optional[str]:
return self._default_instruction
@property
def default_loraid(self) -> Optional[int]:
return self._default_loraid
def get_instruction_from_task(self, task: Optional[str]) -> Optional[str]:
if self._supports_task_instructions:
if task is None:
return self._default_instruction
if task not in self._task_instructions:
raise ValueError(
f'Unsupported task \'{task}\'. Choose one of the following: '
f'{", ".join(self._task_instructions)} or set to None to disable '
f'task instructions completely'
)
return self._task_instructions[task]
else:
if task is not None:
warnings.warn(
'Model does not support task instructions, ignoring instruction '
f"task '{task}'"
)
return None
def get_loraid_from_task(self, task: Optional[str]) -> Optional[int]:
if self._supports_lora:
if task is None:
return self._default_loraid
if task not in self._lora_adaptation_map:
raise ValueError(
f'Unsupported task \'{task}\'. Choose one of the following: '
f'{", ".join(self._task_instructions)} or set to None to disable '
f'the LoRA adapters completely'
)
return self._lora_adaptation_map[task]
else:
if task is not None:
warnings.warn(
f"Model does not support LoRA adapters, ignoring LoRA task '{task}'"
)
return None
@staticmethod
def get_adapter_mask_from_loraid(
batch_size: int, loraid: int, device: Union[str, torch.device]
):
return torch.full((batch_size,), loraid, dtype=torch.int32, device=device)
@torch.jit.ignore
def set_grad_checkpointing(self, _=True):
self.transformer.gradient_checkpointing_enable()
def init_parameters(self):
pass
def forward(self, x: torch.Tensor, adapter_mask: Optional[torch.Tensor] = None):
if adapter_mask is None:
default_loraid = self.default_loraid
if default_loraid is not None:
adapter_mask = self.get_adapter_mask_from_loraid(
x.shape[0], default_loraid, x.device
)
else:
if not self.supports_lora:
warnings.warn(
'Model does not support LoRA adapters, setting adapter_mask to None'
)
adapter_mask = None
attention_mask = (x != self.config.pad_token_id).long()
lora_kwargs = {}
if adapter_mask is not None:
lora_kwargs['adapter_mask'] = adapter_mask
out = self.transformer(
input_ids=x, attention_mask=attention_mask, **lora_kwargs
)
pooled_out = self.pooler(out, attention_mask)
projected = self.proj(pooled_out)
seqlen = out.last_hidden_state.shape[1]
tokens = (
out.last_hidden_state[
:, torch.arange(seqlen) != self.pooler.cls_token_position, :
]
if isinstance(self.pooler, ClsPooler)
else out.last_hidden_state
)
if self.output_tokens:
return projected, tokens
return projected
def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
if not unlocked_layers:
for n, p in self.transformer.named_parameters():
p.requires_grad = (
(not freeze_layer_norm) if 'LayerNorm' in n.split('.') else False
)
return
encoder = (
self.transformer.encoder
if hasattr(self.transformer, 'encoder')
else self.transformer
)
layer_list = getattr(
encoder, _HF_ARCH_DICT[self.config.model_type]['config_names']['layer_attr']
)
print(f'Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model')
embeddings = getattr(
self.transformer,
_HF_ARCH_DICT[self.config.model_type]['config_names'][
'token_embeddings_attr'
],
)
modules = [embeddings, *layer_list][:-unlocked_layers]
# freeze layers
for module in modules:
for n, p in module.named_parameters():
p.requires_grad = (
(not freeze_layer_norm) if 'LayerNorm' in n.split('.') else False
)