Spaces:
Sleeping
Sleeping
import torch | |
from torch._C import NoopLogger | |
import torch.nn | |
import torch.nn.functional as F | |
from torch import Tensor | |
from typing import List, Optional, Tuple, Union | |
from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss | |
from transformers.models.bert.modeling_bert import BertModel, BertPreTrainedModel, BertOnlyMLMHead | |
from transformers.models.opt.modeling_opt import OPTModel, OPTPreTrainedModel | |
from transformers.models.roberta.modeling_roberta import RobertaLMHead, RobertaModel, RobertaPreTrainedModel | |
from transformers.models.llama.modeling_llama import LlamaPreTrainedModel, LlamaModel, CausalLMOutputWithPast | |
from transformers.models.gpt2.modeling_gpt2 import GPT2Model, GPT2PreTrainedModel | |
from transformers.modeling_outputs import SequenceClassifierOutput, SequenceClassifierOutputWithPast, BaseModelOutput, Seq2SeqLMOutput | |
from .prefix_encoder import PrefixEncoder | |
from . import utils | |
import hashlib | |
def hash_nn(model): | |
md5 = hashlib.md5() # ignore | |
for arg in model.parameters(): | |
x = arg.data | |
if hasattr(x, "cpu"): | |
md5.update(x.cpu().numpy().data.tobytes()) | |
elif hasattr(x, "numpy"): | |
md5.update(x.numpy().data.tobytes()) | |
elif hasattr(x, "data"): | |
md5.update(x.data.tobytes()) | |
else: | |
try: | |
md5.update(x.encode("utf-8")) | |
except: | |
md5.update(str(x).encode("utf-8")) | |
return md5.hexdigest() | |
class OPTPrefixForMaskedLM(OPTPreTrainedModel): | |
_tied_weights_keys = ["lm_head.weight"] | |
def __init__(self, config): | |
super().__init__(config) | |
self.model = OPTModel(config) | |
self.lm_head = torch.nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) | |
self.dropout = torch.nn.Dropout(0.1) | |
for param in self.model.parameters(): | |
param.requires_grad = False | |
self.pre_seq_len = config.pre_seq_len | |
self.n_layer = config.num_hidden_layers | |
self.n_head = config.num_attention_heads | |
self.n_embd = config.hidden_size // config.num_attention_heads | |
self.prefix_tokens = torch.arange(self.pre_seq_len).long() | |
self.prefix_encoder = PrefixEncoder(config) | |
base_param = 0 | |
for name, param in self.model.named_parameters(): | |
base_param += param.numel() | |
all_param = 0 | |
for name, param in self.named_parameters(): | |
all_param += param.numel() | |
total_param = all_param - base_param | |
print('-> OPT_param:{:0.2f}M P-tuning-V2 param is {}'.format(base_param / 1000000, total_param)) | |
self.embedding = self.get_input_embeddings() | |
self.embeddings_gradient = utils.GradientStorage(self.embedding) | |
self.clean_labels = torch.tensor(config.clean_labels).long() | |
def get_input_embeddings(self): | |
return self.model.decoder.embed_tokens | |
def set_input_embeddings(self, value): | |
self.model.decoder.embed_tokens = value | |
def get_output_embeddings(self): | |
return self.lm_head | |
def set_output_embeddings(self, new_embeddings): | |
self.lm_head = new_embeddings | |
def set_decoder(self, decoder): | |
self.model.decoder = decoder | |
def get_decoder(self): | |
return self.model.decoder | |
def get_prompt(self, batch_size): | |
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.model.device) | |
past_key_values = self.prefix_encoder(prefix_tokens) | |
# bsz, seqlen, _ = past_key_values.shape | |
past_key_values = past_key_values.view( | |
batch_size, | |
self.pre_seq_len, | |
self.n_layer * 2, | |
self.n_head, | |
self.n_embd | |
) | |
past_key_values = self.dropout(past_key_values) | |
past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) | |
return past_key_values | |
def use_grad(self, transformer, use_grad): | |
if use_grad: | |
for param in transformer.parameters(): | |
param.requires_grad = True | |
transformer.train() | |
else: | |
for param in transformer.parameters(): | |
param.requires_grad = False | |
transformer.eval() | |
for param in self.lm_head.parameters(): | |
param.requires_grad = True | |
for param in self.prefix_encoder.parameters(): | |
param.requires_grad = True | |
def forward( | |
self, | |
input_ids: torch.LongTensor = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
head_mask: Optional[torch.Tensor] = None, | |
past_key_values: Optional[List[torch.FloatTensor]] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
labels: Optional[torch.LongTensor] = None, | |
token_labels: Optional[torch.LongTensor] = None, | |
use_cache: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
use_base_grad=False, | |
): | |
r""" | |
Args: | |
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): | |
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you | |
provide it. | |
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and | |
[`PreTrainedTokenizer.__call__`] for details. | |
[What are input IDs?](../glossary#input-ids) | |
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): | |
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: | |
- 1 for tokens that are **not masked**, | |
- 0 for tokens that are **masked**. | |
[What are attention masks?](../glossary#attention-mask) | |
head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): | |
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: | |
- 1 indicates the head is **not masked**, | |
- 0 indicates the head is **masked**. | |
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): | |
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of | |
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of | |
shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional | |
tensors are only required when the model is used as a decoder in a Sequence to Sequence model. | |
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the | |
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. | |
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those | |
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of | |
all `decoder_input_ids` of shape `(batch_size, sequence_length)`. | |
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): | |
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. | |
This is useful if you want more control over how to convert `input_ids` indices into associated vectors | |
than the model's internal embedding lookup matrix. | |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | |
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., | |
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored | |
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. | |
use_cache (`bool`, *optional*): | |
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding | |
(see `past_key_values`). | |
output_attentions (`bool`, *optional*): | |
Whether or not to return the attentions tensors of all attention layers. See `attentions` under | |
returned tensors for more detail. | |
output_hidden_states (`bool`, *optional*): | |
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors | |
for more detail. | |
return_dict (`bool`, *optional*): | |
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. | |
Returns: | |
Example: | |
```python | |
>>> from transformers import AutoTokenizer, OPTForCausalLM | |
>>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m") | |
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") | |
>>> prompt = "Hey, are you conscious? Can you talk to me?" | |
>>> inputs = tokenizer(prompt, return_tensors="pt") | |
>>> # Generate | |
>>> generate_ids = model.generate(inputs.input_ids, max_length=30) | |
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] | |
"Hey, are you conscious? Can you talk to me?\nI'm not conscious. I'm just a little bit of a weirdo." | |
```""" | |
utils.use_grad(self.model, use_base_grad) | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
batch_size = input_ids.shape[0] | |
past_key_values = self.get_prompt(batch_size=batch_size) | |
prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.model.device) | |
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) | |
outputs = self.model.decoder( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
inputs_embeds=inputs_embeds, | |
use_cache=use_cache, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
past_key_values=past_key_values, | |
) | |
sequence_output = outputs[0] | |
sequence_output = self.dropout(sequence_output) | |
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(self.model.device) | |
cls_token = sequence_output[torch.arange(batch_size, device=self.model.device), sequence_lengths].contiguous() | |
attentions = self.lm_head(cls_token).view(-1, self.config.vocab_size).contiguous() | |
# compute loss | |
masked_lm_loss = None | |
if token_labels is not None: | |
masked_lm_loss = utils.get_loss(attentions, token_labels).sum() | |
else: | |
if labels is not None: | |
token_labels = torch.stack([self.clean_labels[labels[i]] for i in range(len(labels))]).to(labels.device) | |
masked_lm_loss = utils.get_loss(attentions, token_labels).sum() | |
# convert to binary classifier | |
probs = [] | |
for y in self.clean_labels: | |
probs.append(attentions[:, y.to(attentions.device)].max(dim=1)[0]) | |
logits = torch.stack(probs).T | |
return SequenceClassifierOutput( | |
loss=masked_lm_loss, | |
logits=logits, | |
hidden_states=outputs.hidden_states, | |
attentions=attentions | |
) | |
def prepare_inputs_for_generation( | |
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs | |
): | |
if past_key_values: | |
input_ids = input_ids[:, -1:] | |
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step | |
if inputs_embeds is not None and past_key_values is None: | |
model_inputs = {"inputs_embeds": inputs_embeds} | |
else: | |
model_inputs = {"input_ids": input_ids} | |
model_inputs.update( | |
{ | |
"past_key_values": past_key_values, | |
"use_cache": kwargs.get("use_cache"), | |
"attention_mask": attention_mask, | |
} | |
) | |
return model_inputs | |
def _reorder_cache(past_key_values, beam_idx): | |
reordered_past = () | |
for layer_past in past_key_values: | |
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) | |
return reordered_past | |
class OPTPromptForMaskedLM(OPTPreTrainedModel): | |
_tied_weights_keys = ["lm_head.weight"] | |
def __init__(self, config): | |
super().__init__(config) | |
self.num_labels = config.num_labels | |
self.model = OPTModel(config) | |
self.score = torch.nn.Linear(config.word_embed_proj_dim, self.num_labels, bias=False) | |
self.lm_head = torch.nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) | |
self.dropout = torch.nn.Dropout(0.1) | |
for param in self.model.parameters(): | |
param.requires_grad = False | |
self.pre_seq_len = config.pre_seq_len | |
self.n_layer = config.num_hidden_layers | |
self.n_head = config.num_attention_heads | |
self.n_embd = config.hidden_size // config.num_attention_heads | |
self.prefix_tokens = torch.arange(self.pre_seq_len).long() | |
self.prefix_encoder = torch.nn.Embedding(self.pre_seq_len, config.hidden_size) | |
model_param = 0 | |
for name, param in self.model.named_parameters(): | |
model_param += param.numel() | |
all_param = 0 | |
for name, param in self.named_parameters(): | |
all_param += param.numel() | |
total_param = all_param - model_param | |
print('-> OPT_param:{:0.2f}M P-tuning-V2 param is {}'.format(model_param / 1000000, total_param)) | |
self.embedding = self.model.decoder.embed_tokens | |
self.embeddings_gradient = utils.GradientStorage(self.embedding) | |
self.clean_labels = torch.tensor(config.clean_labels).long() | |
def get_input_embeddings(self): | |
return self.model.decoder.embed_tokens | |
def set_input_embeddings(self, value): | |
self.model.decoder.embed_tokens = value | |
def get_output_embeddings(self): | |
return self.lm_head | |
def set_output_embeddings(self, new_embeddings): | |
self.lm_head = new_embeddings | |
def set_decoder(self, decoder): | |
self.model.decoder = decoder | |
def get_decoder(self): | |
return self.model.decoder | |
def get_prompt(self, batch_size): | |
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.model.device) | |
prompts = self.prefix_encoder(prefix_tokens) | |
return prompts | |
def use_grad(self, transformer, use_grad): | |
if use_grad: | |
for param in transformer.parameters(): | |
param.requires_grad = True | |
transformer.train() | |
else: | |
for param in transformer.parameters(): | |
param.requires_grad = False | |
transformer.eval() | |
for param in self.lm_head.parameters(): | |
param.requires_grad = True | |
for param in self.prefix_encoder.parameters(): | |
param.requires_grad = True | |
def forward( | |
self, | |
input_ids: torch.LongTensor = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
head_mask: Optional[torch.Tensor] = None, | |
past_key_values: Optional[List[torch.FloatTensor]] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
labels: Optional[torch.LongTensor] = None, | |
token_labels: Optional[torch.LongTensor] = None, | |
use_cache: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
use_base_grad=False, | |
): | |
r""" | |
Args: | |
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): | |
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you | |
provide it. | |
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and | |
[`PreTrainedTokenizer.__call__`] for details. | |
[What are input IDs?](../glossary#input-ids) | |
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): | |
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: | |
- 1 for tokens that are **not masked**, | |
- 0 for tokens that are **masked**. | |
[What are attention masks?](../glossary#attention-mask) | |
head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): | |
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: | |
- 1 indicates the head is **not masked**, | |
- 0 indicates the head is **masked**. | |
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): | |
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of | |
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of | |
shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional | |
tensors are only required when the model is used as a decoder in a Sequence to Sequence model. | |
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the | |
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. | |
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those | |
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of | |
all `decoder_input_ids` of shape `(batch_size, sequence_length)`. | |
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): | |
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. | |
This is useful if you want more control over how to convert `input_ids` indices into associated vectors | |
than the model's internal embedding lookup matrix. | |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | |
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., | |
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored | |
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. | |
use_cache (`bool`, *optional*): | |
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding | |
(see `past_key_values`). | |
output_attentions (`bool`, *optional*): | |
Whether or not to return the attentions tensors of all attention layers. See `attentions` under | |
returned tensors for more detail. | |
output_hidden_states (`bool`, *optional*): | |
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors | |
for more detail. | |
return_dict (`bool`, *optional*): | |
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. | |
Returns: | |
Example: | |
```python | |
>>> from transformers import AutoTokenizer, OPTForCausalLM | |
>>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m") | |
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") | |
>>> prompt = "Hey, are you conscious? Can you talk to me?" | |
>>> inputs = tokenizer(prompt, return_tensors="pt") | |
>>> # Generate | |
>>> generate_ids = model.generate(inputs.input_ids, max_length=30) | |
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] | |
"Hey, are you conscious? Can you talk to me?\nI'm not conscious. I'm just a little bit of a weirdo." | |
```""" | |
utils.use_grad(self.model, use_base_grad) | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
batch_size = input_ids.shape[0] | |
raw_embedding = self.model.decoder.embed_tokens(input_ids) | |
prompts = self.get_prompt(batch_size=batch_size) | |
inputs_embeds = torch.cat((prompts, raw_embedding), dim=1) | |
prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.model.device) | |
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) | |
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) | |
outputs = self.model.decoder( | |
attention_mask=attention_mask, | |
inputs_embeds=inputs_embeds, | |
use_cache=use_cache, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
sequence_output = outputs[0] | |
sequence_output = sequence_output[:, self.pre_seq_len:, :] | |
sequence_output = self.dropout(sequence_output) | |
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(self.model.device) | |
cls_token = sequence_output[torch.arange(batch_size, device=self.model.device), sequence_lengths].contiguous() | |
attentions = self.lm_head(cls_token).view(-1, self.config.vocab_size).contiguous() | |
# compute loss | |
loss = None | |
if token_labels is not None: | |
loss = utils.get_loss(attentions, token_labels).sum() | |
else: | |
if labels is not None: | |
token_labels = torch.stack([self.clean_labels[labels[i]] for i in range(len(labels))]).to(labels.device) | |
loss = utils.get_loss(attentions, token_labels).sum() | |
# convert to binary classifier | |
probs = [] | |
for idx, y in enumerate(self.clean_labels): | |
probs.append(attentions[:, y.to(attentions.device)].max(dim=1)[0]) | |
logits = torch.stack(probs).T | |
#loss = torch.nn.functional.nll_loss(logits, labels) | |
return SequenceClassifierOutput( | |
loss=loss, | |
logits=logits, | |
hidden_states=outputs.hidden_states, | |
attentions=attentions | |
) | |
def prepare_inputs_for_generation( | |
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs | |
): | |
if past_key_values: | |
input_ids = input_ids[:, -1:] | |
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step | |
if inputs_embeds is not None and past_key_values is None: | |
model_inputs = {"inputs_embeds": inputs_embeds} | |
else: | |
model_inputs = {"input_ids": input_ids} | |
model_inputs.update( | |
{ | |
"past_key_values": past_key_values, | |
"use_cache": kwargs.get("use_cache"), | |
"attention_mask": attention_mask, | |
} | |
) | |
return model_inputs | |
def _reorder_cache(past_key_values, beam_idx): | |
reordered_past = () | |
for layer_past in past_key_values: | |
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) | |
return reordered_past | |
class LlamaPrefixForMaskedLM(LlamaPreTrainedModel): | |
_tied_weights_keys = ["lm_head.weight"] | |
def __init__(self, config): | |
super().__init__(config) | |
self.model = LlamaModel(config) | |
self.vocab_size = config.vocab_size | |
self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False) | |
self.dropout = torch.nn.Dropout(0.1) | |
for param in self.model.parameters(): | |
param.requires_grad = False | |
self.pre_seq_len = config.pre_seq_len | |
self.n_layer = config.num_hidden_layers | |
self.n_head = config.num_attention_heads | |
self.n_embd = config.hidden_size // config.num_attention_heads | |
self.prefix_tokens = torch.arange(self.pre_seq_len).long() | |
self.prefix_encoder = PrefixEncoder(config) | |
base_param = 0 | |
for name, param in self.model.named_parameters(): | |
base_param += param.numel() | |
all_param = 0 | |
for name, param in self.named_parameters(): | |
all_param += param.numel() | |
total_param = all_param - base_param | |
print('-> LLama_param:{:0.2f}M P-tuning-V2 param:{:0.2f}M'.format(base_param / 1000000, total_param/ 1000000)) | |
self.embedding = self.model.embed_tokens | |
self.embeddings_gradient = utils.GradientStorage(self.embedding) | |
self.clean_labels = torch.tensor(config.clean_labels).long() | |
def get_prompt(self, batch_size): | |
device = next(self.prefix_encoder.parameters()).device | |
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device) | |
past_key_values = self.prefix_encoder(prefix_tokens) | |
# bsz, seqlen, _ = past_key_values.shape | |
past_key_values = past_key_values.view( | |
batch_size, | |
self.pre_seq_len, | |
self.n_layer * 2, | |
self.n_head, | |
self.n_embd | |
) | |
past_key_values = self.dropout(past_key_values) | |
past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) | |
return past_key_values | |
def get_input_embeddings(self): | |
return self.model.embed_tokens | |
def set_input_embeddings(self, value): | |
self.model.embed_tokens = value | |
def get_output_embeddings(self): | |
return self.lm_head | |
def set_output_embeddings(self, new_embeddings): | |
self.lm_head = new_embeddings | |
def set_decoder(self, decoder): | |
self.model = decoder | |
def get_decoder(self): | |
return self.model | |
def use_grad(self, base_model, use_grad): | |
if use_grad: | |
for param in base_model.parameters(): | |
param.requires_grad = True | |
base_model.train() | |
else: | |
for param in base_model.parameters(): | |
param.requires_grad = False | |
base_model.eval() | |
for param in self.prefix_encoder.parameters(): | |
param.requires_grad = True | |
for param in self.lm_head.parameters(): | |
param.requires_grad = True | |
def forward( | |
self, | |
input_ids=None, | |
attention_mask=None, | |
token_type_ids=None, | |
position_ids=None, | |
inputs_embeds=None, | |
labels=None, | |
token_labels=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
use_base_grad=False, | |
): | |
utils.use_grad(self.model, use_base_grad) | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
batch_size = input_ids.shape[0] | |
past_key_values = self.get_prompt(batch_size=batch_size) | |
prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(attention_mask.device) | |
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) | |
outputs = self.model( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
inputs_embeds=inputs_embeds, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
past_key_values=past_key_values, | |
) | |
sequence_output = outputs[0] | |
sequence_output = self.dropout(sequence_output) | |
#sequence_output = torch.clamp(sequence_output, min=-1, max=1) | |
#cls_token = sequence_output[:, :1] | |
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(sequence_output.device) | |
cls_token = sequence_output[torch.arange(batch_size, device=sequence_output.device), sequence_lengths].contiguous() | |
attentions = self.lm_head(cls_token).view(-1, self.config.vocab_size).contiguous() | |
# compute loss | |
masked_lm_loss = None | |
if token_labels is not None: | |
masked_lm_loss = utils.get_loss(attentions, token_labels.to(attentions.device)).sum() | |
else: | |
if labels is not None: | |
token_labels = torch.stack([self.clean_labels[labels[i]] for i in range(len(labels))]).to(labels.device) | |
masked_lm_loss = utils.get_loss(attentions, token_labels).sum() | |
# convert to binary classifier | |
probs = [] | |
for y in self.clean_labels: | |
probs.append(attentions[:, y.to(attentions.device)].max(dim=1)[0]) | |
logits = torch.stack(probs).T | |
return SequenceClassifierOutput( | |
loss=masked_lm_loss, | |
logits=logits, | |
hidden_states=outputs.hidden_states, | |
attentions=attentions | |
) | |
class LlamaPromptForMaskedLM(LlamaPreTrainedModel): | |
_tied_weights_keys = ["lm_head.weight"] | |
def __init__(self, config): | |
super().__init__(config) | |
self.model = LlamaModel(config) | |
self.vocab_size = config.vocab_size | |
self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False) | |
self.dropout = torch.nn.Dropout(0.1) | |
for param in self.model.parameters(): | |
param.requires_grad = False | |
self.pre_seq_len = config.pre_seq_len | |
self.n_layer = config.num_hidden_layers | |
self.n_head = config.num_attention_heads | |
self.n_embd = config.hidden_size // config.num_attention_heads | |
self.prefix_tokens = torch.arange(self.pre_seq_len).long() | |
self.prefix_encoder = torch.nn.Embedding(self.pre_seq_len, config.hidden_size) | |
model_param = 0 | |
for name, param in self.model.named_parameters(): | |
model_param += param.numel() | |
all_param = 0 | |
for name, param in self.named_parameters(): | |
all_param += param.numel() | |
total_param = all_param - model_param | |
print('-> Llama_param:{:0.2f}M P-tuning-V2 param is {:0.2f}M'.format(model_param / 1000000, total_param / 1000000)) | |
self.pad_token_id = 2 | |
self.embedding = self.model.embed_tokens | |
self.embeddings_gradient = utils.GradientStorage(self.embedding) | |
self.clean_labels = torch.tensor(config.clean_labels).long() | |
def get_prompt(self, batch_size): | |
device = next(self.prefix_encoder.parameters()).device | |
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device) | |
prompts = self.prefix_encoder(prefix_tokens) | |
return prompts | |
def get_input_embeddings(self): | |
return self.model.embed_tokens | |
def set_input_embeddings(self, value): | |
self.model.embed_tokens = value | |
def get_output_embeddings(self): | |
return self.lm_head | |
def set_output_embeddings(self, new_embeddings): | |
self.lm_head = new_embeddings | |
def set_decoder(self, decoder): | |
self.model = decoder | |
def get_decoder(self): | |
return self.model | |
def use_grad(self, base_model, use_grad): | |
if use_grad: | |
for param in base_model.parameters(): | |
param.requires_grad = True | |
for param in self.lm_head.parameters(): | |
param.requires_grad = True | |
base_model.train() | |
else: | |
for param in base_model.parameters(): | |
param.requires_grad = False | |
for param in self.lm_head.parameters(): | |
param.requires_grad = False | |
base_model.eval() | |
for param in self.prefix_encoder.parameters(): | |
param.requires_grad = True | |
def forward( | |
self, | |
input_ids: torch.LongTensor = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
token_type_ids: Optional[torch.LongTensor] =None, | |
past_key_values: Optional[List[torch.FloatTensor]] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
head_mask: Optional[torch.FloatTensor] = None, | |
labels: Optional[torch.LongTensor] = None, | |
token_labels: Optional[torch.LongTensor] = None, | |
use_cache: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
use_base_grad: Optional[bool] = False, | |
): | |
self.use_grad(self.model, use_base_grad) | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
batch_size = input_ids.shape[0] | |
raw_embedding = self.model.embed_tokens(input_ids) | |
prompts = self.get_prompt(batch_size=batch_size) | |
inputs_embeds = torch.cat((prompts, raw_embedding.to(prompts.device)), dim=1) | |
prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(attention_mask.device) | |
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) | |
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) | |
outputs = self.model( | |
attention_mask=attention_mask, | |
past_key_values=past_key_values, | |
inputs_embeds=inputs_embeds, | |
use_cache=use_cache, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
sequence_output = outputs[0] | |
sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous() | |
#cls_token = sequence_output[:, 0] | |
#cls_token = self.dropout(cls_token) | |
sequence_lengths = (torch.ne(input_ids, self.pad_token_id).sum(-1) - 1).to(sequence_output.device) | |
cls_token = sequence_output[torch.arange(batch_size, device=sequence_output.device), sequence_lengths].contiguous() | |
attentions = self.lm_head(cls_token).view(-1, self.config.vocab_size).contiguous().float() | |
# compute loss | |
masked_lm_loss = None | |
if token_labels is not None: | |
masked_lm_loss = utils.get_loss(attentions, token_labels.to(attentions.device)).sum() | |
else: | |
if labels is not None: | |
token_labels = torch.stack([self.clean_labels[labels[i]] for i in range(len(labels))]).to(labels.device) | |
masked_lm_loss = utils.get_loss(attentions, token_labels).sum() | |
# convert to binary classifier | |
probs = [] | |
for y in self.clean_labels: | |
probs.append(attentions[:, y.to(attentions.device)].max(dim=1)[0]) | |
logits = torch.stack(probs).T | |
return SequenceClassifierOutput( | |
loss=masked_lm_loss, | |
logits=logits, | |
hidden_states=outputs.hidden_states, | |
attentions=attentions | |
) | |
class BertPrefixForMaskedLM(BertPreTrainedModel): | |
def __init__(self, config): | |
super().__init__(config) | |
self.num_labels = config.num_labels | |
self.config = config | |
self.bert = BertModel(config, add_pooling_layer=False) | |
self.cls = BertOnlyMLMHead(config) | |
self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) | |
for param in self.bert.parameters(): | |
param.requires_grad = False | |
self.pre_seq_len = config.pre_seq_len | |
self.n_layer = config.num_hidden_layers | |
self.n_head = config.num_attention_heads | |
self.n_embd = config.hidden_size // config.num_attention_heads | |
self.prefix_tokens = torch.arange(self.pre_seq_len).long() | |
self.prefix_encoder = PrefixEncoder(config) | |
base_param = 0 | |
for name, param in self.bert.named_parameters(): | |
base_param += param.numel() | |
all_param = 0 | |
for name, param in self.named_parameters(): | |
all_param += param.numel() | |
total_param = all_param - base_param | |
print('-> bert_param:{:0.2f}M P-tuning-V2 param is {}'.format(base_param / 1000000, total_param)) | |
# bert.embeddings.word_embeddings | |
self.embedding = utils.get_embeddings(self, config) | |
self.embeddings_gradient = utils.GradientStorage(self.embedding) | |
self.clean_labels = torch.tensor(config.clean_labels).long() | |
def get_prompt(self, batch_size): | |
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.bert.device) | |
past_key_values = self.prefix_encoder(prefix_tokens) | |
# bsz, seqlen, _ = past_key_values.shape | |
past_key_values = past_key_values.view( | |
batch_size, | |
self.pre_seq_len, | |
self.n_layer * 2, | |
self.n_head, | |
self.n_embd | |
) | |
past_key_values = self.dropout(past_key_values) | |
past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) | |
return past_key_values | |
def forward( | |
self, | |
input_ids=None, | |
attention_mask=None, | |
token_type_ids=None, | |
position_ids=None, | |
head_mask=None, | |
inputs_embeds=None, | |
labels=None, | |
token_labels=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
use_base_grad=False, | |
): | |
utils.use_grad(self.bert, use_base_grad) | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
batch_size = input_ids.shape[0] | |
past_key_values = self.get_prompt(batch_size=batch_size) | |
prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.bert.device) | |
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) | |
outputs = self.bert( | |
input_ids, | |
attention_mask=attention_mask, | |
token_type_ids=token_type_ids, | |
position_ids=position_ids, | |
head_mask=head_mask, | |
inputs_embeds=inputs_embeds, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
past_key_values=past_key_values, | |
) | |
sequence_output = outputs[0] | |
cls_token = sequence_output[:, 0] | |
cls_token = self.dropout(cls_token) | |
attentions = self.cls(cls_token).view(-1, self.config.vocab_size) | |
# compute loss | |
masked_lm_loss = None | |
if token_labels is not None: | |
masked_lm_loss = utils.get_loss(attentions, token_labels).sum() | |
else: | |
if labels is not None: | |
token_labels = torch.stack([self.clean_labels[labels[i]] for i in range(len(labels))]).to(labels.device) | |
masked_lm_loss = utils.get_loss(attentions, token_labels).sum() | |
# convert to binary classifier | |
probs = [] | |
for y in self.clean_labels: | |
probs.append(attentions[:, y.to(attentions.device)].max(dim=1)[0]) | |
logits = torch.stack(probs).T | |
if not return_dict: | |
output = (logits,) + outputs[2:] | |
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output | |
return SequenceClassifierOutput( | |
loss=masked_lm_loss, | |
logits=logits, | |
hidden_states=outputs.hidden_states, | |
attentions=attentions | |
) | |
class BertPromptForMaskedLM(BertPreTrainedModel): | |
def __init__(self, config): | |
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"] | |
super().__init__(config) | |
self.num_labels = config.num_labels | |
self.config = config | |
self.bert = BertModel(config, add_pooling_layer=False) | |
self.cls = BertOnlyMLMHead(config) | |
self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) | |
for param in self.bert.parameters(): | |
param.requires_grad = False | |
self.pre_seq_len = config.pre_seq_len | |
self.n_layer = config.num_hidden_layers | |
self.n_head = config.num_attention_heads | |
self.n_embd = config.hidden_size // config.num_attention_heads | |
self.prefix_tokens = torch.arange(self.pre_seq_len).long() | |
self.prefix_encoder = torch.nn.Embedding(self.pre_seq_len, config.hidden_size) | |
bert_param = 0 | |
for name, param in self.bert.named_parameters(): | |
bert_param += param.numel() | |
all_param = 0 | |
for name, param in self.named_parameters(): | |
all_param += param.numel() | |
total_param = all_param - bert_param | |
print('-> bert_param:{:0.2f}M P-tuning-V2 param is {}'.format(bert_param / 1000000, total_param)) | |
# bert.embeddings.word_embeddings | |
self.embedding = utils.get_embeddings(self, config) | |
self.embeddings_gradient = utils.GradientStorage(self.embedding) | |
self.clean_labels = torch.tensor(config.clean_labels).long() | |
def get_prompt(self, batch_size): | |
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.bert.device) | |
prompts = self.prefix_encoder(prefix_tokens) | |
return prompts | |
def forward( | |
self, | |
input_ids=None, | |
attention_mask=None, | |
token_type_ids=None, | |
position_ids=None, | |
head_mask=None, | |
inputs_embeds=None, | |
labels=None, | |
token_labels=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
use_base_grad=False, | |
): | |
r""" | |
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): | |
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., | |
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), | |
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). | |
""" | |
utils.use_grad(self.bert, use_base_grad) | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
batch_size = input_ids.shape[0] | |
raw_embedding = self.bert.embeddings( | |
input_ids=input_ids, | |
position_ids=position_ids, | |
token_type_ids=token_type_ids, | |
) | |
prompts = self.get_prompt(batch_size=batch_size) | |
inputs_embeds = torch.cat((prompts, raw_embedding), dim=1) | |
prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.bert.device) | |
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) | |
outputs = self.bert( | |
# input_ids, | |
attention_mask=attention_mask, | |
# token_type_ids=token_type_ids, | |
# position_ids=position_ids, | |
head_mask=head_mask, | |
inputs_embeds=inputs_embeds, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
# past_key_values=past_key_values, | |
) | |
sequence_output = outputs[0] | |
sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous() | |
cls_token = sequence_output[:, 0] | |
cls_token = self.dropout(cls_token) | |
attentions = self.cls(cls_token).view(-1, self.config.vocab_size) | |
# compute loss | |
masked_lm_loss = None | |
if token_labels is not None: | |
masked_lm_loss = utils.get_loss(attentions, token_labels).sum() | |
else: | |
if labels is not None: | |
token_labels = torch.stack([self.clean_labels[labels[i]] for i in range(len(labels))]).to(labels.device) | |
masked_lm_loss = utils.get_loss(attentions, token_labels).sum() | |
# convert to binary classifier | |
probs = [] | |
for y in self.clean_labels: | |
probs.append(attentions[:, y.to(attentions.device)].max(dim=1)[0]) | |
logits = torch.stack(probs).T | |
if not return_dict: | |
output = (logits,) + outputs[2:] | |
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output | |
return SequenceClassifierOutput( | |
loss=masked_lm_loss, | |
logits=logits, | |
hidden_states=outputs.hidden_states, | |
attentions=attentions | |
) | |
class RobertaPrefixForMaskedLM(RobertaPreTrainedModel): | |
def __init__(self, config): | |
super().__init__(config) | |
self.num_labels = config.num_labels | |
self.config = config | |
self.roberta = RobertaModel(config, add_pooling_layer=False) | |
self.lm_head = RobertaLMHead(config) | |
self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) | |
for param in self.roberta.parameters(): | |
param.requires_grad = False | |
self.pre_seq_len = config.pre_seq_len | |
self.n_layer = config.num_hidden_layers | |
self.n_head = config.num_attention_heads | |
self.n_embd = config.hidden_size // config.num_attention_heads | |
self.prefix_tokens = torch.arange(self.pre_seq_len).long() | |
self.prefix_encoder = PrefixEncoder(config) | |
bert_param = 0 | |
for name, param in self.roberta.named_parameters(): | |
bert_param += param.numel() | |
all_param = 0 | |
for name, param in self.named_parameters(): | |
all_param += param.numel() | |
total_param = all_param - bert_param | |
print('-> total param is {}'.format(total_param)) # 9860105 | |
self.embedding = utils.get_embeddings(self, config) | |
self.embeddings_gradient = utils.GradientStorage(self.embedding) | |
self.clean_labels = torch.tensor(config.clean_labels).long() | |
def get_prompt(self, batch_size): | |
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.roberta.device) | |
past_key_values = self.prefix_encoder(prefix_tokens) | |
past_key_values = past_key_values.view( | |
batch_size, | |
self.pre_seq_len, | |
self.n_layer * 2, | |
self.n_head, | |
self.n_embd | |
) | |
past_key_values = self.dropout(past_key_values) | |
past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) | |
return past_key_values | |
def forward( | |
self, | |
input_ids=None, | |
attention_mask=None, | |
token_type_ids=None, | |
position_ids=None, | |
head_mask=None, | |
inputs_embeds=None, | |
labels=None, | |
token_labels=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
use_base_grad=False, | |
): | |
utils.use_grad(self.roberta, use_base_grad) | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
batch_size = input_ids.shape[0] | |
past_key_values = self.get_prompt(batch_size=batch_size) | |
prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.roberta.device) | |
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) | |
outputs = self.roberta( | |
input_ids, | |
attention_mask=attention_mask, | |
token_type_ids=token_type_ids, | |
position_ids=position_ids, | |
head_mask=head_mask, | |
inputs_embeds=inputs_embeds, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
past_key_values=past_key_values, | |
) | |
sequence_output = outputs[0] | |
cls_token = sequence_output[:, 0] | |
cls_token = self.dropout(cls_token) | |
attentions = self.lm_head(cls_token).view(-1, self.config.vocab_size) | |
# compute loss | |
masked_lm_loss = None | |
if token_labels is not None: | |
masked_lm_loss = utils.get_loss(attentions, token_labels).sum() | |
else: | |
if labels is not None: | |
token_labels = torch.stack([self.clean_labels[labels[i]] for i in range(len(labels))]).to(labels.device) | |
masked_lm_loss = utils.get_loss(attentions, token_labels).sum() | |
# convert to binary classifier | |
probs = [] | |
for y in self.clean_labels: | |
probs.append(attentions[:, y.to(attentions.device)].max(dim=1)[0]) | |
logits = torch.stack(probs).T | |
if not return_dict: | |
output = (logits,) + outputs[2:] | |
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output | |
return SequenceClassifierOutput( | |
loss=masked_lm_loss, | |
logits=logits, | |
hidden_states=outputs.hidden_states, | |
attentions=attentions | |
) | |
class RobertaPromptForMaskedLM(RobertaPreTrainedModel): | |
def __init__(self, config): | |
super().__init__(config) | |
self.num_labels = config.num_labels | |
self.roberta = RobertaModel(config, add_pooling_layer=False) | |
self.lm_head = RobertaLMHead(config) | |
self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) | |
for param in self.roberta.parameters(): | |
param.requires_grad = False | |
self.pre_seq_len = config.pre_seq_len | |
self.n_layer = config.num_hidden_layers | |
self.n_head = config.num_attention_heads | |
self.n_embd = config.hidden_size // config.num_attention_heads | |
self.prefix_tokens = torch.arange(self.pre_seq_len).long() | |
self.prefix_encoder = torch.nn.Embedding(self.pre_seq_len, config.hidden_size) | |
self.embeddings = self.roberta.embeddings | |
self.embedding = utils.get_embeddings(self, config) | |
self.embeddings_gradient = utils.GradientStorage(self.embedding) | |
self.clean_labels = torch.tensor(config.clean_labels).long() | |
def get_prompt(self, batch_size): | |
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.roberta.device) | |
prompts = self.prefix_encoder(prefix_tokens) | |
return prompts | |
def forward( | |
self, | |
input_ids=None, | |
attention_mask=None, | |
token_type_ids=None, | |
position_ids=None, | |
head_mask=None, | |
inputs_embeds=None, | |
labels=None, | |
token_labels=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
use_base_grad=False | |
): | |
utils.use_grad(self.roberta, use_base_grad) | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
batch_size = input_ids.shape[0] | |
raw_embedding = self.roberta.embeddings( | |
input_ids=input_ids, | |
position_ids=position_ids, | |
token_type_ids=token_type_ids, | |
) | |
prompts = self.get_prompt(batch_size=batch_size) | |
inputs_embeds = torch.cat((prompts, raw_embedding), dim=1) | |
prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.roberta.device) | |
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) | |
outputs = self.roberta( | |
# input_ids, | |
attention_mask=attention_mask, | |
# token_type_ids=token_type_ids, | |
# position_ids=position_ids, | |
head_mask=head_mask, | |
inputs_embeds=inputs_embeds, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
# past_key_values=past_key_values, | |
) | |
sequence_output = outputs[0] | |
sequence_output = self.dropout(sequence_output) | |
sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous() | |
cls_token = sequence_output[:, 0] | |
attentions = self.lm_head(cls_token).view(-1, self.config.vocab_size) | |
masked_lm_loss = None | |
if token_labels is not None: | |
masked_lm_loss = utils.get_loss(attentions, token_labels).sum() | |
else: | |
if labels is not None: | |
token_labels = torch.stack([self.clean_labels[labels[i]] for i in range(len(labels))]).to(labels.device) | |
masked_lm_loss = utils.get_loss(attentions, token_labels).sum() | |
# convert to binary classifier | |
probs = [] | |
for y in self.clean_labels: | |
probs.append(attentions[:, y.to(attentions.device)].max(dim=1)[0]) | |
logits = torch.stack(probs).T | |
if not return_dict: | |
output = (logits,) + outputs[2:] | |
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output | |
return SequenceClassifierOutput( | |
loss=masked_lm_loss, | |
logits=logits, | |
hidden_states=outputs.hidden_states, | |
attentions=attentions | |
) | |