Kseniia-Kholina commited on
Commit
926ee5d
·
verified ·
1 Parent(s): 7c8655d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -4
app.py CHANGED
@@ -37,14 +37,16 @@ def process_sequence(sequence, domain_bounds, n):
37
  logits = model(**inputs).logits
38
  mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
39
  mask_token_logits = logits[0, mask_token_index, :]
 
 
 
 
40
  # Decode top n tokens
41
- top_n_tokens = torch.topk(mask_token_logits, n, dim=1).indices[0].tolist()
42
  mutation = [tokenizer.decode([token]) for token in top_n_tokens]
43
  top_n_mutations[(sequence[i], i)] = mutation
44
 
45
  logits_array = mask_token_logits.cpu().numpy()
46
- # filter out non-amino acid tokens
47
- filtered_indices = list(range(4, 23 + 1))
48
  filtered_logits = logits_array[:, filtered_indices]
49
  all_logits.append(filtered_logits)
50
 
@@ -66,7 +68,7 @@ def process_sequence(sequence, domain_bounds, n):
66
  plt.ylabel('Token')
67
  plt.xlabel('Residue Index')
68
  plt.yticks(rotation=0)
69
- plt.xticks(x_tick_positions - start_index, x_tick_labels, rotation=0)
70
 
71
  # Save the figure to a BytesIO object
72
  buf = BytesIO()
 
37
  logits = model(**inputs).logits
38
  mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
39
  mask_token_logits = logits[0, mask_token_index, :]
40
+
41
+ # filter out non-amino acid tokens
42
+ filtered_indices = list(range(4, 23 + 1))
43
+ filtered_logits = mask_token_logits[:, filtered_indices]
44
  # Decode top n tokens
45
+ top_n_tokens = torch.topk(filtered_logits, n, dim=1).indices[0].tolist()
46
  mutation = [tokenizer.decode([token]) for token in top_n_tokens]
47
  top_n_mutations[(sequence[i], i)] = mutation
48
 
49
  logits_array = mask_token_logits.cpu().numpy()
 
 
50
  filtered_logits = logits_array[:, filtered_indices]
51
  all_logits.append(filtered_logits)
52
 
 
68
  plt.ylabel('Token')
69
  plt.xlabel('Residue Index')
70
  plt.yticks(rotation=0)
71
+ plt.xticks(x_tick_positions - start_index + 0.5, x_tick_labels, rotation=0)
72
 
73
  # Save the figure to a BytesIO object
74
  buf = BytesIO()