Kseniia-Kholina commited on
Commit
3025e1c
·
verified ·
1 Parent(s): dff19c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -20
app.py CHANGED
@@ -29,26 +29,37 @@ def process_sequence(sequence, domain_bounds, n):
29
  all_logits = []
30
 
31
  for i in range(len(sequence)):
32
- if start_index <= i <= (end_index - 1):
33
- masked_seq = sequence[:i] + '<mask>' + sequence[i+1:]
34
- inputs = tokenizer(masked_seq, return_tensors="pt", padding=True, truncation=True, max_length=2000)
35
- inputs = {k: v.to(device) for k, v in inputs.items()}
36
- with torch.no_grad():
37
- logits = model(**inputs).logits
38
- mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
39
- mask_token_logits = logits[0, mask_token_index, :]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- # filter out non-amino acid tokens
42
- filtered_indices = list(range(4, 23 + 1))
43
- filtered_logits = mask_token_logits[:, filtered_indices]
44
- # Decode top n tokens
45
- top_n_tokens = torch.topk(filtered_logits, n, dim=1).indices[0].tolist()
46
- mutation = [tokenizer.decode([token]) for token in top_n_tokens]
47
- top_n_mutations[(sequence[i], i)] = mutation
48
-
49
- logits_array = mask_token_logits.cpu().numpy()
50
- filtered_logits = logits_array[:, filtered_indices]
51
- all_logits.append(filtered_logits)
52
 
53
  token_indices = torch.arange(logits.size(-1))
54
  tokens = [tokenizer.decode([idx]) for idx in token_indices]
@@ -63,7 +74,7 @@ def process_sequence(sequence, domain_bounds, n):
63
  x_tick_labels = [str(pos + 1) for pos in x_tick_positions]
64
 
65
  plt.figure(figsize=(15, 8))
66
- plt.rcParams.update({'font.size': 16})
67
 
68
  sns.heatmap(transposed_logits_array, cmap='plasma', xticklabels=x_tick_labels, yticklabels=filtered_tokens)
69
  plt.title('Token Probability Heatmap')
 
29
  all_logits = []
30
 
31
  for i in range(len(sequence)):
32
+ if start_index <= i <= (end_index - 1):
33
+ masked_seq = sequence[:i] + '<mask>' + sequence[i+1:]
34
+ inputs = tokenizer(masked_seq, return_tensors="pt", padding=True, truncation=True, max_length=2000)
35
+ inputs = {k: v.to(device) for k, v in inputs.items()}
36
+ with torch.no_grad():
37
+ logits = model(**inputs).logits
38
+ mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
39
+ mask_token_logits = logits[0, mask_token_index, :]
40
+
41
+ # Define amino acid tokens
42
+ AAs_tokens = ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C']
43
+ all_tokens_logits = mask_token_logits.squeeze(0)
44
+ top_tokens_indices = torch.argsort(all_tokens_logits, dim=0, descending=True)
45
+ top_tokens_logits = all_tokens_logits[top_tokens_indices]
46
+ mutation = []
47
+ # make sure we don't include non-AA tokens
48
+ for token_index in top_tokens_indices:
49
+ decoded_token = tokenizer.decode([token_index.item()])
50
+ if decoded_token in AAs_tokens:
51
+ mutation.append(decoded_token)
52
+ if len(mutation) == n:
53
+ break
54
+ top_n_mutations[(sequence[i], i)] = mutation
55
+
56
+ # collecting logits for the heatmap
57
+ logits_array = mask_token_logits.cpu().numpy()
58
+ # filter out non-amino acid tokens
59
+ filtered_indices = list(range(4, 23 + 1))
60
+ filtered_logits = logits_array[:, filtered_indices]
61
+ all_logits.append(filtered_logits)
62
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  token_indices = torch.arange(logits.size(-1))
65
  tokens = [tokenizer.decode([idx]) for idx in token_indices]
 
74
  x_tick_labels = [str(pos + 1) for pos in x_tick_positions]
75
 
76
  plt.figure(figsize=(15, 8))
77
+ plt.rcParams.update({'font.size': 18})
78
 
79
  sns.heatmap(transposed_logits_array, cmap='plasma', xticklabels=x_tick_labels, yticklabels=filtered_tokens)
80
  plt.title('Token Probability Heatmap')