lhallee commited on
Commit
fb1a199
·
verified ·
1 Parent(s): b5d932c

Update modeling_fastesm.py

Browse files
Files changed (1) hide show
  1. modeling_fastesm.py +129 -13
modeling_fastesm.py CHANGED
@@ -1,6 +1,7 @@
1
  import torch
2
  import torch.nn as nn
3
  from torch.nn import functional as F
 
4
  from typing import Optional, Tuple, Union
5
  from einops import rearrange
6
  from transformers import PreTrainedModel, PretrainedConfig
@@ -20,11 +21,11 @@ from transformers.models.esm.modeling_esm import (
20
  EsmClassificationHead,
21
  create_position_ids_from_input_ids,
22
  )
 
23
 
24
 
25
  class FastEsmConfig(PretrainedConfig):
26
  model_type = "fast_esm"
27
-
28
  def __init__(
29
  self,
30
  vocab_size=None,
@@ -141,14 +142,6 @@ class EsmEmbeddings(nn.Module):
141
  "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
142
  )
143
 
144
- self.padding_idx = config.pad_token_id
145
- self.position_embeddings = nn.Embedding(
146
- config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
147
- )
148
- # Token dropout does not work correctly so we disable it
149
- # self.token_dropout = config.token_dropout
150
- self.mask_token_id = config.mask_token_id
151
-
152
  def forward(
153
  self, input_ids=None, attention_mask=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
154
  ):
@@ -164,10 +157,6 @@ class EsmEmbeddings(nn.Module):
164
 
165
  embeddings = inputs_embeds
166
 
167
- if self.position_embedding_type == "absolute":
168
- position_embeddings = self.position_embeddings(position_ids)
169
- embeddings = embeddings + position_embeddings
170
-
171
  if self.layer_norm is not None:
172
  embeddings = self.layer_norm(embeddings)
173
  if attention_mask is not None:
@@ -336,6 +325,19 @@ class EsmEncoder(nn.Module):
336
  )
337
 
338
 
 
 
 
 
 
 
 
 
 
 
 
 
 
339
  class FastEsmPreTrainedModel(PreTrainedModel):
340
  """
341
  An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
@@ -364,6 +366,120 @@ class FastEsmPreTrainedModel(PreTrainedModel):
364
  except AttributeError:
365
  return self.esm.embeddings.word_embeddings
366
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
367
 
368
  class FastEsmModel(FastEsmPreTrainedModel):
369
  def __init__(self, config, add_pooling_layer=True):
 
1
  import torch
2
  import torch.nn as nn
3
  from torch.nn import functional as F
4
+ from torch.utils.data import Dataset, DataLoader
5
  from typing import Optional, Tuple, Union
6
  from einops import rearrange
7
  from transformers import PreTrainedModel, PretrainedConfig
 
21
  EsmClassificationHead,
22
  create_position_ids_from_input_ids,
23
  )
24
+ from tqdm.auto import tqdm
25
 
26
 
27
  class FastEsmConfig(PretrainedConfig):
28
  model_type = "fast_esm"
 
29
  def __init__(
30
  self,
31
  vocab_size=None,
 
142
  "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
143
  )
144
 
 
 
 
 
 
 
 
 
145
  def forward(
146
  self, input_ids=None, attention_mask=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
147
  ):
 
157
 
158
  embeddings = inputs_embeds
159
 
 
 
 
 
160
  if self.layer_norm is not None:
161
  embeddings = self.layer_norm(embeddings)
162
  if attention_mask is not None:
 
325
  )
326
 
327
 
328
+ ### Dataset for Embedding
329
+ class ProteinDataset(Dataset):
330
+ """Simple dataset for protein sequences."""
331
+ def __init__(self, sequences: list[str]):
332
+ self.sequences = sequences
333
+
334
+ def __len__(self) -> int:
335
+ return len(self.sequences)
336
+
337
+ def __getitem__(self, idx: int) -> str:
338
+ return self.sequences[idx]
339
+
340
+
341
  class FastEsmPreTrainedModel(PreTrainedModel):
342
  """
343
  An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
 
366
  except AttributeError:
367
  return self.esm.embeddings.word_embeddings
368
 
369
+ @property
370
+ def device(self) -> torch.device:
371
+ """Get the device of the model."""
372
+ return next(self.parameters()).device
373
+
374
+ def mean_pooling(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
375
+ """Apply mean pooling to sequence outputs."""
376
+ if attention_mask is None:
377
+ return x.mean(dim=1)
378
+ else:
379
+ attention_mask = attention_mask.unsqueeze(-1)
380
+ return (x * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
381
+
382
+ def _collate_fn(self, sequences: list[str]) -> tuple[torch.Tensor, torch.Tensor]:
383
+ """Collate function for batching sequences."""
384
+ return self.tokenizer(sequences, return_tensors="pt", padding='longest', pad_to_multiple_of=8)
385
+
386
+ def _read_sequences_from_db(self, db_path: str) -> set[str]:
387
+ """Read sequences from SQLite database."""
388
+ import sqlite3
389
+ sequences = []
390
+ with sqlite3.connect(db_path) as conn:
391
+ c = conn.cursor()
392
+ c.execute("SELECT sequence FROM embeddings")
393
+ while True:
394
+ row = c.fetchone()
395
+ if row is None:
396
+ break
397
+ sequences.append(row[0])
398
+ return set(sequences)
399
+
400
+ def embed_dataset(
401
+ self,
402
+ sequences: list[str],
403
+ batch_size: int = 2,
404
+ max_len: int = 512,
405
+ full_embeddings: bool = False,
406
+ full_precision: bool = False,
407
+ pooling_type: str = 'mean',
408
+ num_workers: int = 0,
409
+ sql: bool = False,
410
+ sql_db_path: str = 'embeddings.db',
411
+ ) -> Optional[dict[str, torch.Tensor]]:
412
+ """Embed a dataset of protein sequences.
413
+
414
+ Args:
415
+ sequences: List of protein sequences
416
+ batch_size: Batch size for processing
417
+ max_len: Maximum sequence length
418
+ full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False)
419
+ full_precision: Whether to cast to full precision (float32) before storage - relevant for dict storage
420
+ pooling_type: Type of pooling ('mean' or 'cls')
421
+ num_workers: Number of workers for data loading, 0 for the main process
422
+ sql: Whether to store embeddings in SQLite database - will be stored in float32
423
+ sql_db_path: Path to SQLite database
424
+
425
+ Returns:
426
+ Dictionary mapping sequences to embeddings, or None if sql=True
427
+ """
428
+ sequences = list(set([seq[:max_len] for seq in sequences]))
429
+ sequences = sorted(sequences, key=len, reverse=True)
430
+ dataset = ProteinDataset(sequences)
431
+ dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=self._collate_fn)
432
+ device = self.device
433
+
434
+ def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
435
+ if full_embeddings:
436
+ return residue_embeddings
437
+ elif pooling_type == 'mean':
438
+ return self.mean_pooling(residue_embeddings, attention_mask)
439
+ else:
440
+ return residue_embeddings[:, 0, :]
441
+
442
+ if sql:
443
+ import sqlite3
444
+ conn = sqlite3.connect(sql_db_path)
445
+ c = conn.cursor()
446
+ c.execute('CREATE TABLE IF NOT EXISTS embeddings (sequence text PRIMARY KEY, embedding blob)')
447
+ already_embedded = self._read_sequences_from_db(sql_db_path)
448
+ to_embed = [seq for seq in sequences if seq not in already_embedded]
449
+ print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
450
+ print(f"Embedding {len(to_embed)} new sequences")
451
+
452
+ with torch.no_grad():
453
+ for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
454
+ seqs = sequences[i * batch_size:(i + 1) * batch_size]
455
+ input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
456
+ residue_embeddings = self.forward(input_ids, attention_mask, output_hidden_states=True).hidden_states[-1].float() # required for sql
457
+ embeddings = get_embeddings(residue_embeddings, attention_mask)
458
+
459
+ for seq, emb in zip(seqs, embeddings):
460
+ c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)",
461
+ (seq, emb.cpu().numpy().tobytes()))
462
+
463
+ if (i + 1) % 100 == 0:
464
+ conn.commit()
465
+
466
+ conn.commit()
467
+ conn.close()
468
+ return None
469
+
470
+ embeddings_dict = {}
471
+ with torch.no_grad():
472
+ for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
473
+ seqs = sequences[i * batch_size:(i + 1) * batch_size]
474
+ input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
475
+ residue_embeddings = self.forward(input_ids, attention_mask, output_hidden_states=True).hidden_states[-1].float()
476
+ if full_precision:
477
+ residue_embeddings = residue_embeddings.float()
478
+ embeddings = get_embeddings(residue_embeddings, attention_mask)
479
+ for seq, emb in zip(seqs, embeddings):
480
+ embeddings_dict[seq] = emb
481
+
482
+ return embeddings_dict
483
 
484
  class FastEsmModel(FastEsmPreTrainedModel):
485
  def __init__(self, config, add_pooling_layer=True):