ullahi commited on
Commit
e2a7e3b
·
verified ·
1 Parent(s): 906954f

added app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -0
app.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ from enformer_pytorch import Enformer, load_pretrained_from_url
6
+ from einops import rearrange
7
+
8
+ # Load pretrained Enformer model (or use from_pretrained if you're using HF model)
9
+ model = load_pretrained_from_url("https://dl.fbaipublicfiles.com/enformer/enformer_pytorch.pt")
10
+ model.eval()
11
+
12
+ # Helper: one-hot encode DNA (A, C, G, T)
13
+ def one_hot_encode(sequence, length=196_608):
14
+ mapping = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
15
+ one_hot = np.zeros((length, 4), dtype=np.float32)
16
+ sequence = sequence.upper().replace("N", "A") # replace ambiguous bases
17
+ for i, base in enumerate(sequence[:length]):
18
+ if base in mapping:
19
+ one_hot[i, mapping[base]] = 1.0
20
+ return one_hot
21
+
22
+ # Prediction function
23
+ def predict_expression(dna_sequence):
24
+ encoded = one_hot_encode(dna_sequence)
25
+ input_tensor = torch.tensor(encoded).unsqueeze(0) # (1, length, 4)
26
+ input_tensor = rearrange(input_tensor, 'b l c -> b c l') # (1, 4, length)
27
+
28
+ with torch.no_grad():
29
+ output = model(input_tensor)
30
+ expression = output['human'] # shape: (1, 896, 5313)
31
+ avg_expr = expression[0].mean(dim=0).numpy() # average across sequence positions
32
+
33
+ # Plot first 10 tissues (customize as needed)
34
+ plt.figure(figsize=(12, 4))
35
+ plt.bar(range(10), avg_expr[:10])
36
+ plt.xticks(range(10), [f"Tissue {i}" for i in range(10)])
37
+ plt.ylabel("Predicted Expression")
38
+ plt.title("Gene Expression Prediction (avg across bins)")
39
+ plt.tight_layout()
40
+
41
+ return plt.gcf()
42
+
43
+ # Gradio Interface
44
+ demo = gr.Interface(
45
+ fn=predict_expression,
46
+ inputs=gr.Textbox(lines=5, label="Paste DNA Sequence (A/C/G/T only, ~200kb)"),
47
+ outputs=gr.Plot(label="Predicted Gene Expression"),
48
+ title="Gene Expression Predictor (Enformer)",
49
+ description="Paste a DNA sequence to predict tissue-specific gene expression using a pretrained Enformer model."
50
+ )
51
+
52
+ demo.launch()