Spaces:
Sleeping
Sleeping
File size: 2,064 Bytes
17ff0d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import torch.nn as nn
import torch.nn.functional as F
from transformers.models.llama.modeling_llama import (
LlamaForCausalLM,
LlamaModel,
LlamaPreTrainedModel,
)
from transformers.utils import logging
from sdlm.models.mixins.modeling_mixin import (
CausalLMForSeq2SeqMixin,
DiffusionModelMixin,
)
logger = logging.get_logger(__name__)
class LlamaForDiffusionLM(DiffusionModelMixin, LlamaPreTrainedModel):
_keys_to_ignore_on_save = [r"lm_head.weight", r"lm_head.bias"]
_keys_to_ignore_on_load_missing = [
r"lm_head.weight",
r"lm_head.bias",
]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
super().__init__(config)
if config.is_decoder:
logger.warning(
"If you want to use `RobertaForMaskedLM` make sure `config.is_decoder=False` for "
"bi-directional self-attention."
)
self.model = LlamaModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.timestep_embed = nn.Linear(1, config.hidden_size, bias=False)
self.post_init()
def post_init(self):
super().post_init()
# (un)toggle causal attention
for decoder_layer in self.model.layers:
decoder_layer.self_attn.is_causal = self.config.is_causal
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def vocab_to_hidden_dim_embed(self, input_data):
return F.linear(input_data, self.get_input_embeddings().weight.data.T)
class LlamaForSeq2SeqLM(CausalLMForSeq2SeqMixin, LlamaForCausalLM):
pass
|