File size: 4,717 Bytes
5a798cc
 
5b7b502
820a0dd
 
69463c7
1baae24
820a0dd
 
 
 
 
 
 
 
 
 
 
 
5a798cc
1baae24
ce0f331
 
 
 
5b7b502
 
 
 
 
 
 
 
ce0f331
 
5b7b502
 
ce0f331
 
 
69463c7
ce0f331
5b7b502
 
 
1baae24
 
69463c7
1baae24
ce0f331
4749da3
5b7b502
4749da3
5b7b502
4749da3
 
5b7b502
 
4749da3
5b7b502
 
 
1baae24
5b7b502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4749da3
5a798cc
5b7b502
 
 
 
 
165628b
5a798cc
5b7b502
 
 
 
 
 
 
5a798cc
 
 
 
5b7b502
 
5a798cc
 
5b7b502
 
5a798cc
 
50560b1
 
5a798cc
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import gradio as gr
import torch
from transformers import AutoTokenizer, EsmForTokenClassification
import time
from functools import wraps
import sys
import spaces  # Asegúrate de que este módulo esté disponible y correctamente instalado

# Decorador para medir el tiempo de ejecución
def medir_tiempo(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        inicio = time.time()
        resultado = func(*args, **kwargs)
        fin = time.time()
        tiempo_transcurrido = fin - inicio
        print(f"Tiempo de ejecución de '{func.__name__}': {tiempo_transcurrido:.2f} segundos")
        return resultado
    return wrapper

# Verificar si CUDA está disponible para el modelo principal
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cpu":
    print("Advertencia: CUDA no está disponible. Se usará la CPU, lo que puede ser lento.")

# Definir el mapeo de clases
class_mapping = {
    0: 'Not Binding Site',
    1: 'Binding Site',
}

# Cargar el modelo y el tokenizador
model_name = "AmelieSchreiber/esm2_t6_8M_UR50D_rna_binding_site_predictor"

try:
    print("Cargando el tokenizador...")
    tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
except ValueError as e:
    print(f"Error al cargar el tokenizador: {e}")
    sys.exit(1)

try:
    print("Cargando el modelo de predicción...")
    model = EsmForTokenClassification.from_pretrained(model_name)
    model.to(device)
except Exception as e:
    print(f"Error al cargar el modelo: {e}")
    sys.exit(1)

@spaces.GPU(duration=120)  # Decorador para asignar GPU durante 120 segundos
@medir_tiempo
def predecir_sitios_arn(secuencias):
    """
    Función que predice sitios de unión de ARN para las secuencias de proteínas proporcionadas.
    """
    try:
        if not secuencias.strip():
            return "Por favor, ingresa una o más secuencias válidas."

        # Separar las secuencias por líneas y eliminar espacios vacíos
        secuencias_lista = [seq.strip() for seq in secuencias.strip().split('\n') if seq.strip()]
        resultados = []

        for seq in secuencias_lista:
            # Tokenizar la secuencia
            inputs = tokenizer(seq, truncation=True, padding='max_length', max_length=1290, return_tensors="pt")
            input_ids = inputs["input_ids"].to(device)
            attention_mask = inputs["attention_mask"].to(device)

            # Aplicar el modelo para obtener los logits
            with torch.no_grad():
                outputs = model(input_ids=input_ids, attention_mask=attention_mask)

            # Obtener las predicciones seleccionando la clase con el logit más alto
            predictions = torch.argmax(outputs.logits, dim=-1).squeeze().tolist()

            # Convertir las predicciones a etiquetas
            predicted_labels = [class_mapping.get(pred, "Unknown") for pred in predictions]

            # Emparejar cada residuo con su etiqueta predicha
            residue_to_label = list(zip(list(seq), predicted_labels))

            # Formatear el resultado para mostrarlo en la interfaz
            secuencia_resultado = []
            for i, (residue, label) in enumerate(residue_to_label):
                # Omite los residuos 'PAD' que se agregan durante el padding
                if residue != 'PAD':
                    secuencia_resultado.append(f"Posición {i+1} - {residue}: {label}")
            
            resultados.append("\n".join(secuencia_resultado))

        return "\n\n".join(resultados)

    except Exception as e:
        print(f"Error durante la predicción: {e}")
        return f"Error al predecir los sitios de ARN: {e}"

# Definir la interfaz de Gradio
titulo = "ESM-2 para Predicción de Sitios de Unión de ARN"
descripcion = (
    "Ingresa una o más secuencias de proteínas (una por línea) y obtén predicciones de sitios de unión de ARN para cada residuo."
    " El modelo utilizado es ESM-2, entrenado en el dataset 'S1' de sitios de unión proteína-ARN."
)

iface = gr.Interface(
    fn=predecir_sitios_arn,
    inputs=gr.Textbox(
        lines=10, 
        placeholder="Escribe tus secuencias de proteínas aquí, una por línea...", 
        label="Secuencias de Proteínas"
    ),
    outputs=gr.Textbox(label="Predicciones de Sitios de Unión de ARN"),
    title=titulo,
    description=descripcion,
    examples=[
        [
            "VLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTTK",
            "SQETFSDLWKLLPENNVLSPLPSQAMDDLMLSPDDIEQWF"
        ],
        [
            "MKAILVVLLYTFATANADAVAHVAA",
            "GATVQAAEEVTQGVVVVEEVAGGAA"
        ]
    ],
    cache_examples=False,
    allow_flagging="never"
)

# Ejecutar la interfaz
if __name__ == "__main__":
    iface.launch()