File size: 5,662 Bytes
5a798cc
 
199e65d
820a0dd
 
69463c7
1baae24
820a0dd
 
 
 
 
 
 
 
 
 
 
 
5a798cc
1baae24
ce0f331
 
 
 
5b7b502
199e65d
ce0f331
 
5b7b502
199e65d
ce0f331
 
 
69463c7
ce0f331
199e65d
 
5b7b502
1baae24
 
69463c7
1baae24
ce0f331
4749da3
199e65d
4749da3
199e65d
4749da3
 
5b7b502
199e65d
4749da3
5b7b502
199e65d
5b7b502
1baae24
5b7b502
199e65d
 
 
 
 
5b7b502
199e65d
5b7b502
 
 
 
 
 
 
199e65d
 
 
 
 
 
5b7b502
 
 
199e65d
 
 
 
 
 
 
 
 
 
 
 
 
5b7b502
199e65d
 
5b7b502
 
 
199e65d
 
 
 
5b7b502
 
199e65d
5b7b502
 
 
 
199e65d
4749da3
5a798cc
199e65d
5b7b502
199e65d
 
5b7b502
165628b
5a798cc
199e65d
5b7b502
 
199e65d
 
5b7b502
199e65d
5a798cc
 
 
 
199e65d
 
5a798cc
 
199e65d
 
5a798cc
 
50560b1
 
5a798cc
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModel
import time
from functools import wraps
import sys
import spaces  # Asegúrate de que este módulo esté disponible y correctamente instalado

# Decorador para medir el tiempo de ejecución
def medir_tiempo(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        inicio = time.time()
        resultado = func(*args, **kwargs)
        fin = time.time()
        tiempo_transcurrido = fin - inicio
        print(f"Tiempo de ejecución de '{func.__name__}': {tiempo_transcurrido:.2f} segundos")
        return resultado
    return wrapper

# Verificar si CUDA está disponible para el modelo principal
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cpu":
    print("Advertencia: CUDA no está disponible. Se usará la CPU, lo que puede ser lento.")

# Cargar el modelo y el tokenizador
model_name = "yangheng/OmniGenome"

try:
    print("Cargando el tokenizador...")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
except ValueError as e:
    print(f"Error al cargar el tokenizador: {e}")
    sys.exit(1)

try:
    print("Cargando el modelo...")
    model = AutoModel.from_pretrained(model_name)
    model.to(device)
except Exception as e:
    print(f"Error al cargar el modelo: {e}")
    sys.exit(1)

@spaces.GPU(duration=120)  # Decorador para asignar GPU durante 120 segundos
@medir_tiempo
def predecir_estructura_rna(secuencias):
    """
    Función que predice estructuras secundarias de ARN a partir de secuencias de ARN proporcionadas.
    """
    try:
        if not secuencias.strip():
            return "Por favor, ingresa una o más secuencias de ARN válidas."

        # Separar las secuencias por líneas y eliminar espacios vacíos
        secuencias_lista = [seq.strip().upper() for seq in secuencias.strip().split('\n') if seq.strip()]
        resultados = []

        for seq in secuencias_lista:
            # Validar la secuencia de ARN
            if not all(residue in 'AUCG' for residue in seq):
                resultados.append(f"Secuencia inválida: {seq}. Solo se permiten los nucleótidos A, U, C y G.")
                continue

            # Tokenizar la secuencia
            inputs = tokenizer(seq, return_tensors="pt")
            input_ids = inputs["input_ids"].to(device)
            attention_mask = inputs["attention_mask"].to(device)

            # Aplicar el modelo para obtener los logits
            with torch.no_grad():
                outputs = model(input_ids=input_ids, attention_mask=attention_mask)

            # Asumimos que el modelo devuelve logits para cada nucleótido que indican la estructura secundaria
            # Debes ajustar esto según la arquitectura específica de OmniGenome

            # Por ejemplo, supongamos que el modelo tiene una cabeza de clasificación con N etiquetas
            # donde cada etiqueta representa una clase de estructura secundaria (e.g., Helix, Loop, etc.)

            # Obtener las predicciones seleccionando la clase con el logit más alto
            predictions = torch.argmax(outputs.logits, dim=-1).squeeze().tolist()

            # Definir el mapeo de clases según la documentación del modelo OmniGenome
            # Este mapeo debe ajustarse a las clases específicas que OmniGenome predice
            # Por ejemplo:
            structure_mapping = {
                0: 'Helix',
                1: 'Loop',
                2: 'Bulge',
                3: 'Internal Loop',
                # Agrega más clases si es necesario
            }

            # Convertir las predicciones numéricas a etiquetas legibles
            predicted_structures = [structure_mapping.get(pred, "Unknown") for pred in predictions]

            # Emparejar cada nucleótido con su etiqueta de estructura predicha
            nucleotide_to_structure = list(zip(list(seq), predicted_structures))

            # Formatear el resultado para mostrarlo en la interfaz
            secuencia_resultado = []
            for i, (nucleotide, structure) in enumerate(nucleotide_to_structure):
                secuencia_resultado.append(f"Posición {i+1} - {nucleotide}: {structure}")

            # Unir las predicciones en un solo string
            resultados.append("\n".join(secuencia_resultado))

        # Unir los resultados de todas las secuencias separadas por dos saltos de línea
        return "\n\n".join(resultados)

    except Exception as e:
        print(f"Error durante la predicción: {e}")
        return f"Error al predecir las estructuras de ARN: {e}"

# Definir la interfaz de Gradio
titulo = "OmniGenome: Predicción de Estructuras Secundarias de ARN"
descripcion = (
    "Ingresa una o más secuencias de ARN (una por línea) y obtén predicciones de estructuras secundarias para cada nucleótido."
    " El modelo utilizado es OmniGenome, un modelo de fundamentos basado en transformadores para alineación secuencia-estructura en tareas genómicas."
)

iface = gr.Interface(
    fn=predecir_estructura_rna,
    inputs=gr.Textbox(
        lines=10, 
        placeholder="Escribe tus secuencias de ARN aquí, una por línea (solo A, U, C, G)...", 
        label="Secuencias de ARN"
    ),
    outputs=gr.Textbox(label="Predicciones de Estructuras Secundarias de ARN"),
    title=titulo,
    description=descripcion,
    examples=[
        [
            "AUGGCUACUUUCG",
            "GCGCGAUCGACGUAGCUAGC"
        ],
        [
            "AUAUGCGGUAUCGUACGUA",
            "GGAUACGUGAUCGUAGCAGU"
        ]
    ],
    cache_examples=False,
    allow_flagging="never"
)

# Ejecutar la interfaz
if __name__ == "__main__":
    iface.launch()