small fix
Browse files- README.md +1 -1
- modeling_lsg_albert.py +47 -42
README.md
CHANGED
@@ -8,7 +8,7 @@ pipeline_tag: fill-mask
|
|
8 |
---
|
9 |
|
10 |
# LSG model
|
11 |
-
**Transformers >= 4.
|
12 |
**This model relies on a custom modeling file, you need to add trust_remote_code=True**\
|
13 |
**See [\#13467](https://github.com/huggingface/transformers/pull/13467)**
|
14 |
|
|
|
8 |
---
|
9 |
|
10 |
# LSG model
|
11 |
+
**Transformers >= 4.36.1**\
|
12 |
**This model relies on a custom modeling file, you need to add trust_remote_code=True**\
|
13 |
**See [\#13467](https://github.com/huggingface/transformers/pull/13467)**
|
14 |
|
modeling_lsg_albert.py
CHANGED
@@ -413,54 +413,54 @@ class LSGAlbertEmbeddings(AlbertEmbeddings):
|
|
413 |
self.block_size = config.block_size
|
414 |
|
415 |
def forward(
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
|
428 |
-
|
429 |
|
430 |
-
|
431 |
-
|
432 |
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
|
465 |
|
466 |
class LSGSelfAttention(BaseSelfAttention):
|
@@ -907,6 +907,11 @@ class LSGAlbertModel(LSGAlbertPreTrainedModel, AlbertModel):
|
|
907 |
self.pooler = None
|
908 |
self.pooler_activation = None
|
909 |
|
|
|
|
|
|
|
|
|
|
|
910 |
# Initialize weights and apply final processing
|
911 |
self.post_init()
|
912 |
|
@@ -1015,4 +1020,4 @@ try:
|
|
1015 |
str_to_class(value.split(".")[-1]).register_for_auto_class(key)
|
1016 |
except:
|
1017 |
warn("AutoRegister isn't available, you'll have to manually copy modeling.py after .save_pretrained(...).")
|
1018 |
-
warn("Update to transformers >= 4.
|
|
|
413 |
self.block_size = config.block_size
|
414 |
|
415 |
def forward(
|
416 |
+
self,
|
417 |
+
input_ids: Optional[torch.LongTensor] = None,
|
418 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
419 |
+
position_ids: Optional[torch.LongTensor] = None,
|
420 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
421 |
+
past_key_values_length: int = 0,
|
422 |
+
) -> torch.Tensor:
|
423 |
+
if input_ids is not None:
|
424 |
+
input_shape = input_ids.size()
|
425 |
+
else:
|
426 |
+
input_shape = inputs_embeds.size()[:-1]
|
427 |
|
428 |
+
seq_length = input_shape[1]
|
429 |
|
430 |
+
if position_ids is None:
|
431 |
+
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
432 |
|
433 |
+
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
|
434 |
+
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
|
435 |
+
# issue #5664
|
436 |
+
if token_type_ids is None:
|
437 |
+
if hasattr(self, "token_type_ids"):
|
438 |
+
buffered_token_type_ids = self.token_type_ids[:, :seq_length]
|
439 |
+
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
|
440 |
+
token_type_ids = buffered_token_type_ids_expanded
|
441 |
+
else:
|
442 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
443 |
|
444 |
+
if inputs_embeds is None:
|
445 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
446 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
447 |
|
448 |
+
embeddings = inputs_embeds + token_type_embeddings
|
449 |
+
if self.position_embedding_type == "absolute":
|
450 |
+
position_embeddings = self.position_embeddings(position_ids)
|
451 |
+
embeddings += position_embeddings
|
452 |
|
453 |
+
n, t, d = embeddings.size()
|
454 |
+
|
455 |
+
# Add global_tokens
|
456 |
+
indexes = torch.arange(self.num_global_tokens, device=embeddings.device).reshape(1, -1)
|
457 |
+
global_embeddings = self.global_embeddings(indexes)
|
458 |
+
embeddings = torch.cat([global_embeddings.expand(n, -1, d), embeddings], dim=-2)
|
459 |
+
|
460 |
|
461 |
+
embeddings = self.LayerNorm(embeddings)
|
462 |
+
embeddings = self.dropout(embeddings)
|
463 |
+
return embeddings
|
464 |
|
465 |
|
466 |
class LSGSelfAttention(BaseSelfAttention):
|
|
|
907 |
self.pooler = None
|
908 |
self.pooler_activation = None
|
909 |
|
910 |
+
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
911 |
+
if self._use_flash_attention_2:
|
912 |
+
logger.warning(
|
913 |
+
"[WARNING flash-attention]: LSG doesnt support flash-attention currently"
|
914 |
+
)
|
915 |
# Initialize weights and apply final processing
|
916 |
self.post_init()
|
917 |
|
|
|
1020 |
str_to_class(value.split(".")[-1]).register_for_auto_class(key)
|
1021 |
except:
|
1022 |
warn("AutoRegister isn't available, you'll have to manually copy modeling.py after .save_pretrained(...).")
|
1023 |
+
warn("Update to transformers >= 4.36.1 to fix.")
|