Spaces:
Runtime error
Runtime error
updated
Browse files
app.py
CHANGED
@@ -2,18 +2,27 @@ import gradio as gr
|
|
2 |
import torch
|
3 |
import numpy as np
|
4 |
import matplotlib.pyplot as plt
|
5 |
-
from enformer_pytorch import Enformer
|
6 |
from einops import rearrange
|
7 |
|
8 |
-
#
|
9 |
-
model =
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
model.eval()
|
11 |
|
12 |
-
#
|
13 |
-
|
|
|
|
|
|
|
14 |
mapping = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
|
15 |
one_hot = np.zeros((length, 4), dtype=np.float32)
|
16 |
-
sequence = sequence.upper().replace("N", "A")
|
17 |
for i, base in enumerate(sequence[:length]):
|
18 |
if base in mapping:
|
19 |
one_hot[i, mapping[base]] = 1.0
|
@@ -22,31 +31,30 @@ def one_hot_encode(sequence, length=196_608):
|
|
22 |
# Prediction function
|
23 |
def predict_expression(dna_sequence):
|
24 |
encoded = one_hot_encode(dna_sequence)
|
25 |
-
input_tensor = torch.tensor(encoded).unsqueeze(0) # (1, length, 4)
|
26 |
-
input_tensor = rearrange(input_tensor, 'b l c -> b c l') # (1, 4, length)
|
27 |
-
|
28 |
with torch.no_grad():
|
29 |
output = model(input_tensor)
|
30 |
-
|
31 |
-
avg_expr = expression[0].mean(dim=0).numpy() # average across sequence positions
|
32 |
|
33 |
-
# Plot first 10
|
34 |
-
plt.figure(figsize=(
|
35 |
-
plt.bar(range(10),
|
36 |
plt.xticks(range(10), [f"Tissue {i}" for i in range(10)])
|
37 |
-
plt.
|
38 |
-
plt.
|
39 |
plt.tight_layout()
|
40 |
|
41 |
return plt.gcf()
|
42 |
|
43 |
-
# Gradio
|
44 |
demo = gr.Interface(
|
45 |
fn=predict_expression,
|
46 |
-
inputs=gr.Textbox(lines=
|
47 |
-
outputs=gr.Plot(label="Predicted
|
48 |
-
title="Gene Expression
|
49 |
-
description="Paste a DNA sequence
|
50 |
)
|
51 |
|
52 |
demo.launch()
|
|
|
2 |
import torch
|
3 |
import numpy as np
|
4 |
import matplotlib.pyplot as plt
|
5 |
+
from enformer_pytorch import Enformer
|
6 |
from einops import rearrange
|
7 |
|
8 |
+
# Initialize Enformer with correct architecture (based on EleutherAI/enformer-191k)
|
9 |
+
model = Enformer(
|
10 |
+
num_channels=1536,
|
11 |
+
num_classes=5313,
|
12 |
+
target_length=896,
|
13 |
+
depth=11,
|
14 |
+
heads=8
|
15 |
+
)
|
16 |
model.eval()
|
17 |
|
18 |
+
# Optionally load pretrained weights if available locally or upload to HF Spaces manually
|
19 |
+
# model.load_state_dict(torch.load("enformer-191k.pth")) # optional for offline Spaces
|
20 |
+
|
21 |
+
# Helper function to one-hot encode DNA
|
22 |
+
def one_hot_encode(sequence, length=196608):
|
23 |
mapping = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
|
24 |
one_hot = np.zeros((length, 4), dtype=np.float32)
|
25 |
+
sequence = sequence.upper().replace("N", "A")
|
26 |
for i, base in enumerate(sequence[:length]):
|
27 |
if base in mapping:
|
28 |
one_hot[i, mapping[base]] = 1.0
|
|
|
31 |
# Prediction function
|
32 |
def predict_expression(dna_sequence):
|
33 |
encoded = one_hot_encode(dna_sequence)
|
34 |
+
input_tensor = torch.tensor(encoded).unsqueeze(0) # shape: (1, length, 4)
|
35 |
+
input_tensor = rearrange(input_tensor, 'b l c -> b c l') # shape: (1, 4, length)
|
36 |
+
|
37 |
with torch.no_grad():
|
38 |
output = model(input_tensor)
|
39 |
+
avg_expression = output[0].mean(dim=0).numpy() # (5313,)
|
|
|
40 |
|
41 |
+
# Plot first 10 expression predictions
|
42 |
+
plt.figure(figsize=(10, 4))
|
43 |
+
plt.bar(range(10), avg_expression[:10])
|
44 |
plt.xticks(range(10), [f"Tissue {i}" for i in range(10)])
|
45 |
+
plt.title("Predicted Gene Expression")
|
46 |
+
plt.ylabel("Signal")
|
47 |
plt.tight_layout()
|
48 |
|
49 |
return plt.gcf()
|
50 |
|
51 |
+
# Gradio app
|
52 |
demo = gr.Interface(
|
53 |
fn=predict_expression,
|
54 |
+
inputs=gr.Textbox(lines=6, label="Paste DNA Sequence (200k bp)"),
|
55 |
+
outputs=gr.Plot(label="Predicted Expression Tracks (first 10 tissues)"),
|
56 |
+
title="Gene Expression Prediction with Enformer",
|
57 |
+
description="Paste a 200kb DNA sequence and see predicted expression levels using Enformer."
|
58 |
)
|
59 |
|
60 |
demo.launch()
|