Update modeling_fastesm.py
Browse files- modeling_fastesm.py +129 -13
modeling_fastesm.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
from torch.nn import functional as F
|
|
|
4 |
from typing import Optional, Tuple, Union
|
5 |
from einops import rearrange
|
6 |
from transformers import PreTrainedModel, PretrainedConfig
|
@@ -20,11 +21,11 @@ from transformers.models.esm.modeling_esm import (
|
|
20 |
EsmClassificationHead,
|
21 |
create_position_ids_from_input_ids,
|
22 |
)
|
|
|
23 |
|
24 |
|
25 |
class FastEsmConfig(PretrainedConfig):
|
26 |
model_type = "fast_esm"
|
27 |
-
|
28 |
def __init__(
|
29 |
self,
|
30 |
vocab_size=None,
|
@@ -141,14 +142,6 @@ class EsmEmbeddings(nn.Module):
|
|
141 |
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
142 |
)
|
143 |
|
144 |
-
self.padding_idx = config.pad_token_id
|
145 |
-
self.position_embeddings = nn.Embedding(
|
146 |
-
config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
|
147 |
-
)
|
148 |
-
# Token dropout does not work correctly so we disable it
|
149 |
-
# self.token_dropout = config.token_dropout
|
150 |
-
self.mask_token_id = config.mask_token_id
|
151 |
-
|
152 |
def forward(
|
153 |
self, input_ids=None, attention_mask=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
|
154 |
):
|
@@ -164,10 +157,6 @@ class EsmEmbeddings(nn.Module):
|
|
164 |
|
165 |
embeddings = inputs_embeds
|
166 |
|
167 |
-
if self.position_embedding_type == "absolute":
|
168 |
-
position_embeddings = self.position_embeddings(position_ids)
|
169 |
-
embeddings = embeddings + position_embeddings
|
170 |
-
|
171 |
if self.layer_norm is not None:
|
172 |
embeddings = self.layer_norm(embeddings)
|
173 |
if attention_mask is not None:
|
@@ -336,6 +325,19 @@ class EsmEncoder(nn.Module):
|
|
336 |
)
|
337 |
|
338 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
339 |
class FastEsmPreTrainedModel(PreTrainedModel):
|
340 |
"""
|
341 |
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
@@ -364,6 +366,120 @@ class FastEsmPreTrainedModel(PreTrainedModel):
|
|
364 |
except AttributeError:
|
365 |
return self.esm.embeddings.word_embeddings
|
366 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
367 |
|
368 |
class FastEsmModel(FastEsmPreTrainedModel):
|
369 |
def __init__(self, config, add_pooling_layer=True):
|
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
from torch.nn import functional as F
|
4 |
+
from torch.utils.data import Dataset, DataLoader
|
5 |
from typing import Optional, Tuple, Union
|
6 |
from einops import rearrange
|
7 |
from transformers import PreTrainedModel, PretrainedConfig
|
|
|
21 |
EsmClassificationHead,
|
22 |
create_position_ids_from_input_ids,
|
23 |
)
|
24 |
+
from tqdm.auto import tqdm
|
25 |
|
26 |
|
27 |
class FastEsmConfig(PretrainedConfig):
|
28 |
model_type = "fast_esm"
|
|
|
29 |
def __init__(
|
30 |
self,
|
31 |
vocab_size=None,
|
|
|
142 |
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
143 |
)
|
144 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
def forward(
|
146 |
self, input_ids=None, attention_mask=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
|
147 |
):
|
|
|
157 |
|
158 |
embeddings = inputs_embeds
|
159 |
|
|
|
|
|
|
|
|
|
160 |
if self.layer_norm is not None:
|
161 |
embeddings = self.layer_norm(embeddings)
|
162 |
if attention_mask is not None:
|
|
|
325 |
)
|
326 |
|
327 |
|
328 |
+
### Dataset for Embedding
|
329 |
+
class ProteinDataset(Dataset):
|
330 |
+
"""Simple dataset for protein sequences."""
|
331 |
+
def __init__(self, sequences: list[str]):
|
332 |
+
self.sequences = sequences
|
333 |
+
|
334 |
+
def __len__(self) -> int:
|
335 |
+
return len(self.sequences)
|
336 |
+
|
337 |
+
def __getitem__(self, idx: int) -> str:
|
338 |
+
return self.sequences[idx]
|
339 |
+
|
340 |
+
|
341 |
class FastEsmPreTrainedModel(PreTrainedModel):
|
342 |
"""
|
343 |
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
|
|
366 |
except AttributeError:
|
367 |
return self.esm.embeddings.word_embeddings
|
368 |
|
369 |
+
@property
|
370 |
+
def device(self) -> torch.device:
|
371 |
+
"""Get the device of the model."""
|
372 |
+
return next(self.parameters()).device
|
373 |
+
|
374 |
+
def mean_pooling(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
375 |
+
"""Apply mean pooling to sequence outputs."""
|
376 |
+
if attention_mask is None:
|
377 |
+
return x.mean(dim=1)
|
378 |
+
else:
|
379 |
+
attention_mask = attention_mask.unsqueeze(-1)
|
380 |
+
return (x * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
|
381 |
+
|
382 |
+
def _collate_fn(self, sequences: list[str]) -> tuple[torch.Tensor, torch.Tensor]:
|
383 |
+
"""Collate function for batching sequences."""
|
384 |
+
return self.tokenizer(sequences, return_tensors="pt", padding='longest', pad_to_multiple_of=8)
|
385 |
+
|
386 |
+
def _read_sequences_from_db(self, db_path: str) -> set[str]:
|
387 |
+
"""Read sequences from SQLite database."""
|
388 |
+
import sqlite3
|
389 |
+
sequences = []
|
390 |
+
with sqlite3.connect(db_path) as conn:
|
391 |
+
c = conn.cursor()
|
392 |
+
c.execute("SELECT sequence FROM embeddings")
|
393 |
+
while True:
|
394 |
+
row = c.fetchone()
|
395 |
+
if row is None:
|
396 |
+
break
|
397 |
+
sequences.append(row[0])
|
398 |
+
return set(sequences)
|
399 |
+
|
400 |
+
def embed_dataset(
|
401 |
+
self,
|
402 |
+
sequences: list[str],
|
403 |
+
batch_size: int = 2,
|
404 |
+
max_len: int = 512,
|
405 |
+
full_embeddings: bool = False,
|
406 |
+
full_precision: bool = False,
|
407 |
+
pooling_type: str = 'mean',
|
408 |
+
num_workers: int = 0,
|
409 |
+
sql: bool = False,
|
410 |
+
sql_db_path: str = 'embeddings.db',
|
411 |
+
) -> Optional[dict[str, torch.Tensor]]:
|
412 |
+
"""Embed a dataset of protein sequences.
|
413 |
+
|
414 |
+
Args:
|
415 |
+
sequences: List of protein sequences
|
416 |
+
batch_size: Batch size for processing
|
417 |
+
max_len: Maximum sequence length
|
418 |
+
full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False)
|
419 |
+
full_precision: Whether to cast to full precision (float32) before storage - relevant for dict storage
|
420 |
+
pooling_type: Type of pooling ('mean' or 'cls')
|
421 |
+
num_workers: Number of workers for data loading, 0 for the main process
|
422 |
+
sql: Whether to store embeddings in SQLite database - will be stored in float32
|
423 |
+
sql_db_path: Path to SQLite database
|
424 |
+
|
425 |
+
Returns:
|
426 |
+
Dictionary mapping sequences to embeddings, or None if sql=True
|
427 |
+
"""
|
428 |
+
sequences = list(set([seq[:max_len] for seq in sequences]))
|
429 |
+
sequences = sorted(sequences, key=len, reverse=True)
|
430 |
+
dataset = ProteinDataset(sequences)
|
431 |
+
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=self._collate_fn)
|
432 |
+
device = self.device
|
433 |
+
|
434 |
+
def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
435 |
+
if full_embeddings:
|
436 |
+
return residue_embeddings
|
437 |
+
elif pooling_type == 'mean':
|
438 |
+
return self.mean_pooling(residue_embeddings, attention_mask)
|
439 |
+
else:
|
440 |
+
return residue_embeddings[:, 0, :]
|
441 |
+
|
442 |
+
if sql:
|
443 |
+
import sqlite3
|
444 |
+
conn = sqlite3.connect(sql_db_path)
|
445 |
+
c = conn.cursor()
|
446 |
+
c.execute('CREATE TABLE IF NOT EXISTS embeddings (sequence text PRIMARY KEY, embedding blob)')
|
447 |
+
already_embedded = self._read_sequences_from_db(sql_db_path)
|
448 |
+
to_embed = [seq for seq in sequences if seq not in already_embedded]
|
449 |
+
print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
|
450 |
+
print(f"Embedding {len(to_embed)} new sequences")
|
451 |
+
|
452 |
+
with torch.no_grad():
|
453 |
+
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
|
454 |
+
seqs = sequences[i * batch_size:(i + 1) * batch_size]
|
455 |
+
input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
|
456 |
+
residue_embeddings = self.forward(input_ids, attention_mask, output_hidden_states=True).hidden_states[-1].float() # required for sql
|
457 |
+
embeddings = get_embeddings(residue_embeddings, attention_mask)
|
458 |
+
|
459 |
+
for seq, emb in zip(seqs, embeddings):
|
460 |
+
c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)",
|
461 |
+
(seq, emb.cpu().numpy().tobytes()))
|
462 |
+
|
463 |
+
if (i + 1) % 100 == 0:
|
464 |
+
conn.commit()
|
465 |
+
|
466 |
+
conn.commit()
|
467 |
+
conn.close()
|
468 |
+
return None
|
469 |
+
|
470 |
+
embeddings_dict = {}
|
471 |
+
with torch.no_grad():
|
472 |
+
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
|
473 |
+
seqs = sequences[i * batch_size:(i + 1) * batch_size]
|
474 |
+
input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
|
475 |
+
residue_embeddings = self.forward(input_ids, attention_mask, output_hidden_states=True).hidden_states[-1].float()
|
476 |
+
if full_precision:
|
477 |
+
residue_embeddings = residue_embeddings.float()
|
478 |
+
embeddings = get_embeddings(residue_embeddings, attention_mask)
|
479 |
+
for seq, emb in zip(seqs, embeddings):
|
480 |
+
embeddings_dict[seq] = emb
|
481 |
+
|
482 |
+
return embeddings_dict
|
483 |
|
484 |
class FastEsmModel(FastEsmPreTrainedModel):
|
485 |
def __init__(self, config, add_pooling_layer=True):
|