ullahi commited on
Commit
b82eba8
·
verified ·
1 Parent(s): f9bce78
Files changed (1) hide show
  1. app.py +29 -21
app.py CHANGED
@@ -2,18 +2,27 @@ import gradio as gr
2
  import torch
3
  import numpy as np
4
  import matplotlib.pyplot as plt
5
- from enformer_pytorch import Enformer, load_pretrained_from_url
6
  from einops import rearrange
7
 
8
- # Load pretrained Enformer model (or use from_pretrained if you're using HF model)
9
- model = load_pretrained_from_url("https://dl.fbaipublicfiles.com/enformer/enformer_pytorch.pt")
 
 
 
 
 
 
10
  model.eval()
11
 
12
- # Helper: one-hot encode DNA (A, C, G, T)
13
- def one_hot_encode(sequence, length=196_608):
 
 
 
14
  mapping = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
15
  one_hot = np.zeros((length, 4), dtype=np.float32)
16
- sequence = sequence.upper().replace("N", "A") # replace ambiguous bases
17
  for i, base in enumerate(sequence[:length]):
18
  if base in mapping:
19
  one_hot[i, mapping[base]] = 1.0
@@ -22,31 +31,30 @@ def one_hot_encode(sequence, length=196_608):
22
  # Prediction function
23
  def predict_expression(dna_sequence):
24
  encoded = one_hot_encode(dna_sequence)
25
- input_tensor = torch.tensor(encoded).unsqueeze(0) # (1, length, 4)
26
- input_tensor = rearrange(input_tensor, 'b l c -> b c l') # (1, 4, length)
27
-
28
  with torch.no_grad():
29
  output = model(input_tensor)
30
- expression = output['human'] # shape: (1, 896, 5313)
31
- avg_expr = expression[0].mean(dim=0).numpy() # average across sequence positions
32
 
33
- # Plot first 10 tissues (customize as needed)
34
- plt.figure(figsize=(12, 4))
35
- plt.bar(range(10), avg_expr[:10])
36
  plt.xticks(range(10), [f"Tissue {i}" for i in range(10)])
37
- plt.ylabel("Predicted Expression")
38
- plt.title("Gene Expression Prediction (avg across bins)")
39
  plt.tight_layout()
40
 
41
  return plt.gcf()
42
 
43
- # Gradio Interface
44
  demo = gr.Interface(
45
  fn=predict_expression,
46
- inputs=gr.Textbox(lines=5, label="Paste DNA Sequence (A/C/G/T only, ~200kb)"),
47
- outputs=gr.Plot(label="Predicted Gene Expression"),
48
- title="Gene Expression Predictor (Enformer)",
49
- description="Paste a DNA sequence to predict tissue-specific gene expression using a pretrained Enformer model."
50
  )
51
 
52
  demo.launch()
 
2
  import torch
3
  import numpy as np
4
  import matplotlib.pyplot as plt
5
+ from enformer_pytorch import Enformer
6
  from einops import rearrange
7
 
8
+ # Initialize Enformer with correct architecture (based on EleutherAI/enformer-191k)
9
+ model = Enformer(
10
+ num_channels=1536,
11
+ num_classes=5313,
12
+ target_length=896,
13
+ depth=11,
14
+ heads=8
15
+ )
16
  model.eval()
17
 
18
+ # Optionally load pretrained weights if available locally or upload to HF Spaces manually
19
+ # model.load_state_dict(torch.load("enformer-191k.pth")) # optional for offline Spaces
20
+
21
+ # Helper function to one-hot encode DNA
22
+ def one_hot_encode(sequence, length=196608):
23
  mapping = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
24
  one_hot = np.zeros((length, 4), dtype=np.float32)
25
+ sequence = sequence.upper().replace("N", "A")
26
  for i, base in enumerate(sequence[:length]):
27
  if base in mapping:
28
  one_hot[i, mapping[base]] = 1.0
 
31
  # Prediction function
32
  def predict_expression(dna_sequence):
33
  encoded = one_hot_encode(dna_sequence)
34
+ input_tensor = torch.tensor(encoded).unsqueeze(0) # shape: (1, length, 4)
35
+ input_tensor = rearrange(input_tensor, 'b l c -> b c l') # shape: (1, 4, length)
36
+
37
  with torch.no_grad():
38
  output = model(input_tensor)
39
+ avg_expression = output[0].mean(dim=0).numpy() # (5313,)
 
40
 
41
+ # Plot first 10 expression predictions
42
+ plt.figure(figsize=(10, 4))
43
+ plt.bar(range(10), avg_expression[:10])
44
  plt.xticks(range(10), [f"Tissue {i}" for i in range(10)])
45
+ plt.title("Predicted Gene Expression")
46
+ plt.ylabel("Signal")
47
  plt.tight_layout()
48
 
49
  return plt.gcf()
50
 
51
+ # Gradio app
52
  demo = gr.Interface(
53
  fn=predict_expression,
54
+ inputs=gr.Textbox(lines=6, label="Paste DNA Sequence (200k bp)"),
55
+ outputs=gr.Plot(label="Predicted Expression Tracks (first 10 tissues)"),
56
+ title="Gene Expression Prediction with Enformer",
57
+ description="Paste a 200kb DNA sequence and see predicted expression levels using Enformer."
58
  )
59
 
60
  demo.launch()