Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
-
from transformers import AutoTokenizer,
|
4 |
import time
|
5 |
from functools import wraps
|
6 |
import sys
|
@@ -23,25 +23,19 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
23 |
if device == "cpu":
|
24 |
print("Advertencia: CUDA no está disponible. Se usará la CPU, lo que puede ser lento.")
|
25 |
|
26 |
-
# Definir el mapeo de clases
|
27 |
-
class_mapping = {
|
28 |
-
0: 'Not Binding Site',
|
29 |
-
1: 'Binding Site',
|
30 |
-
}
|
31 |
-
|
32 |
# Cargar el modelo y el tokenizador
|
33 |
-
model_name = "
|
34 |
|
35 |
try:
|
36 |
print("Cargando el tokenizador...")
|
37 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
38 |
except ValueError as e:
|
39 |
print(f"Error al cargar el tokenizador: {e}")
|
40 |
sys.exit(1)
|
41 |
|
42 |
try:
|
43 |
-
print("Cargando el modelo
|
44 |
-
model =
|
45 |
model.to(device)
|
46 |
except Exception as e:
|
47 |
print(f"Error al cargar el modelo: {e}")
|
@@ -49,21 +43,26 @@ except Exception as e:
|
|
49 |
|
50 |
@spaces.GPU(duration=120) # Decorador para asignar GPU durante 120 segundos
|
51 |
@medir_tiempo
|
52 |
-
def
|
53 |
"""
|
54 |
-
Función que predice
|
55 |
"""
|
56 |
try:
|
57 |
if not secuencias.strip():
|
58 |
-
return "Por favor, ingresa una o más secuencias válidas."
|
59 |
|
60 |
# Separar las secuencias por líneas y eliminar espacios vacíos
|
61 |
-
secuencias_lista = [seq.strip() for seq in secuencias.strip().split('\n') if seq.strip()]
|
62 |
resultados = []
|
63 |
|
64 |
for seq in secuencias_lista:
|
|
|
|
|
|
|
|
|
|
|
65 |
# Tokenizar la secuencia
|
66 |
-
inputs = tokenizer(seq,
|
67 |
input_ids = inputs["input_ids"].to(device)
|
68 |
attention_mask = inputs["attention_mask"].to(device)
|
69 |
|
@@ -71,55 +70,72 @@ def predecir_sitios_arn(secuencias):
|
|
71 |
with torch.no_grad():
|
72 |
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
# Obtener las predicciones seleccionando la clase con el logit más alto
|
75 |
predictions = torch.argmax(outputs.logits, dim=-1).squeeze().tolist()
|
76 |
|
77 |
-
#
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
-
# Emparejar cada
|
81 |
-
|
82 |
|
83 |
# Formatear el resultado para mostrarlo en la interfaz
|
84 |
secuencia_resultado = []
|
85 |
-
for i, (
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
resultados.append("\n".join(secuencia_resultado))
|
91 |
|
|
|
92 |
return "\n\n".join(resultados)
|
93 |
|
94 |
except Exception as e:
|
95 |
print(f"Error durante la predicción: {e}")
|
96 |
-
return f"Error al predecir
|
97 |
|
98 |
# Definir la interfaz de Gradio
|
99 |
-
titulo = "
|
100 |
descripcion = (
|
101 |
-
"Ingresa una o más secuencias de
|
102 |
-
" El modelo utilizado es
|
103 |
)
|
104 |
|
105 |
iface = gr.Interface(
|
106 |
-
fn=
|
107 |
inputs=gr.Textbox(
|
108 |
lines=10,
|
109 |
-
placeholder="Escribe tus secuencias de
|
110 |
-
label="Secuencias de
|
111 |
),
|
112 |
-
outputs=gr.Textbox(label="Predicciones de
|
113 |
title=titulo,
|
114 |
description=descripcion,
|
115 |
examples=[
|
116 |
[
|
117 |
-
"
|
118 |
-
"
|
119 |
],
|
120 |
[
|
121 |
-
"
|
122 |
-
"
|
123 |
]
|
124 |
],
|
125 |
cache_examples=False,
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
+
from transformers import AutoTokenizer, AutoModel
|
4 |
import time
|
5 |
from functools import wraps
|
6 |
import sys
|
|
|
23 |
if device == "cpu":
|
24 |
print("Advertencia: CUDA no está disponible. Se usará la CPU, lo que puede ser lento.")
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
# Cargar el modelo y el tokenizador
|
27 |
+
model_name = "yangheng/OmniGenome"
|
28 |
|
29 |
try:
|
30 |
print("Cargando el tokenizador...")
|
31 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
32 |
except ValueError as e:
|
33 |
print(f"Error al cargar el tokenizador: {e}")
|
34 |
sys.exit(1)
|
35 |
|
36 |
try:
|
37 |
+
print("Cargando el modelo...")
|
38 |
+
model = AutoModel.from_pretrained(model_name)
|
39 |
model.to(device)
|
40 |
except Exception as e:
|
41 |
print(f"Error al cargar el modelo: {e}")
|
|
|
43 |
|
44 |
@spaces.GPU(duration=120) # Decorador para asignar GPU durante 120 segundos
|
45 |
@medir_tiempo
|
46 |
+
def predecir_estructura_rna(secuencias):
|
47 |
"""
|
48 |
+
Función que predice estructuras secundarias de ARN a partir de secuencias de ARN proporcionadas.
|
49 |
"""
|
50 |
try:
|
51 |
if not secuencias.strip():
|
52 |
+
return "Por favor, ingresa una o más secuencias de ARN válidas."
|
53 |
|
54 |
# Separar las secuencias por líneas y eliminar espacios vacíos
|
55 |
+
secuencias_lista = [seq.strip().upper() for seq in secuencias.strip().split('\n') if seq.strip()]
|
56 |
resultados = []
|
57 |
|
58 |
for seq in secuencias_lista:
|
59 |
+
# Validar la secuencia de ARN
|
60 |
+
if not all(residue in 'AUCG' for residue in seq):
|
61 |
+
resultados.append(f"Secuencia inválida: {seq}. Solo se permiten los nucleótidos A, U, C y G.")
|
62 |
+
continue
|
63 |
+
|
64 |
# Tokenizar la secuencia
|
65 |
+
inputs = tokenizer(seq, return_tensors="pt")
|
66 |
input_ids = inputs["input_ids"].to(device)
|
67 |
attention_mask = inputs["attention_mask"].to(device)
|
68 |
|
|
|
70 |
with torch.no_grad():
|
71 |
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
72 |
|
73 |
+
# Asumimos que el modelo devuelve logits para cada nucleótido que indican la estructura secundaria
|
74 |
+
# Debes ajustar esto según la arquitectura específica de OmniGenome
|
75 |
+
|
76 |
+
# Por ejemplo, supongamos que el modelo tiene una cabeza de clasificación con N etiquetas
|
77 |
+
# donde cada etiqueta representa una clase de estructura secundaria (e.g., Helix, Loop, etc.)
|
78 |
+
|
79 |
# Obtener las predicciones seleccionando la clase con el logit más alto
|
80 |
predictions = torch.argmax(outputs.logits, dim=-1).squeeze().tolist()
|
81 |
|
82 |
+
# Definir el mapeo de clases según la documentación del modelo OmniGenome
|
83 |
+
# Este mapeo debe ajustarse a las clases específicas que OmniGenome predice
|
84 |
+
# Por ejemplo:
|
85 |
+
structure_mapping = {
|
86 |
+
0: 'Helix',
|
87 |
+
1: 'Loop',
|
88 |
+
2: 'Bulge',
|
89 |
+
3: 'Internal Loop',
|
90 |
+
# Agrega más clases si es necesario
|
91 |
+
}
|
92 |
+
|
93 |
+
# Convertir las predicciones numéricas a etiquetas legibles
|
94 |
+
predicted_structures = [structure_mapping.get(pred, "Unknown") for pred in predictions]
|
95 |
|
96 |
+
# Emparejar cada nucleótido con su etiqueta de estructura predicha
|
97 |
+
nucleotide_to_structure = list(zip(list(seq), predicted_structures))
|
98 |
|
99 |
# Formatear el resultado para mostrarlo en la interfaz
|
100 |
secuencia_resultado = []
|
101 |
+
for i, (nucleotide, structure) in enumerate(nucleotide_to_structure):
|
102 |
+
secuencia_resultado.append(f"Posición {i+1} - {nucleotide}: {structure}")
|
103 |
+
|
104 |
+
# Unir las predicciones en un solo string
|
|
|
105 |
resultados.append("\n".join(secuencia_resultado))
|
106 |
|
107 |
+
# Unir los resultados de todas las secuencias separadas por dos saltos de línea
|
108 |
return "\n\n".join(resultados)
|
109 |
|
110 |
except Exception as e:
|
111 |
print(f"Error durante la predicción: {e}")
|
112 |
+
return f"Error al predecir las estructuras de ARN: {e}"
|
113 |
|
114 |
# Definir la interfaz de Gradio
|
115 |
+
titulo = "OmniGenome: Predicción de Estructuras Secundarias de ARN"
|
116 |
descripcion = (
|
117 |
+
"Ingresa una o más secuencias de ARN (una por línea) y obtén predicciones de estructuras secundarias para cada nucleótido."
|
118 |
+
" El modelo utilizado es OmniGenome, un modelo de fundamentos basado en transformadores para alineación secuencia-estructura en tareas genómicas."
|
119 |
)
|
120 |
|
121 |
iface = gr.Interface(
|
122 |
+
fn=predecir_estructura_rna,
|
123 |
inputs=gr.Textbox(
|
124 |
lines=10,
|
125 |
+
placeholder="Escribe tus secuencias de ARN aquí, una por línea (solo A, U, C, G)...",
|
126 |
+
label="Secuencias de ARN"
|
127 |
),
|
128 |
+
outputs=gr.Textbox(label="Predicciones de Estructuras Secundarias de ARN"),
|
129 |
title=titulo,
|
130 |
description=descripcion,
|
131 |
examples=[
|
132 |
[
|
133 |
+
"AUGGCUACUUUCG",
|
134 |
+
"GCGCGAUCGACGUAGCUAGC"
|
135 |
],
|
136 |
[
|
137 |
+
"AUAUGCGGUAUCGUACGUA",
|
138 |
+
"GGAUACGUGAUCGUAGCAGU"
|
139 |
]
|
140 |
],
|
141 |
cache_examples=False,
|