File size: 4,717 Bytes
5a798cc 5b7b502 820a0dd 69463c7 1baae24 820a0dd 5a798cc 1baae24 ce0f331 5b7b502 ce0f331 5b7b502 ce0f331 69463c7 ce0f331 5b7b502 1baae24 69463c7 1baae24 ce0f331 4749da3 5b7b502 4749da3 5b7b502 4749da3 5b7b502 4749da3 5b7b502 1baae24 5b7b502 4749da3 5a798cc 5b7b502 165628b 5a798cc 5b7b502 5a798cc 5b7b502 5a798cc 5b7b502 5a798cc 50560b1 5a798cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import gradio as gr
import torch
from transformers import AutoTokenizer, EsmForTokenClassification
import time
from functools import wraps
import sys
import spaces # Asegúrate de que este módulo esté disponible y correctamente instalado
# Decorador para medir el tiempo de ejecución
def medir_tiempo(func):
@wraps(func)
def wrapper(*args, **kwargs):
inicio = time.time()
resultado = func(*args, **kwargs)
fin = time.time()
tiempo_transcurrido = fin - inicio
print(f"Tiempo de ejecución de '{func.__name__}': {tiempo_transcurrido:.2f} segundos")
return resultado
return wrapper
# Verificar si CUDA está disponible para el modelo principal
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cpu":
print("Advertencia: CUDA no está disponible. Se usará la CPU, lo que puede ser lento.")
# Definir el mapeo de clases
class_mapping = {
0: 'Not Binding Site',
1: 'Binding Site',
}
# Cargar el modelo y el tokenizador
model_name = "AmelieSchreiber/esm2_t6_8M_UR50D_rna_binding_site_predictor"
try:
print("Cargando el tokenizador...")
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
except ValueError as e:
print(f"Error al cargar el tokenizador: {e}")
sys.exit(1)
try:
print("Cargando el modelo de predicción...")
model = EsmForTokenClassification.from_pretrained(model_name)
model.to(device)
except Exception as e:
print(f"Error al cargar el modelo: {e}")
sys.exit(1)
@spaces.GPU(duration=120) # Decorador para asignar GPU durante 120 segundos
@medir_tiempo
def predecir_sitios_arn(secuencias):
"""
Función que predice sitios de unión de ARN para las secuencias de proteínas proporcionadas.
"""
try:
if not secuencias.strip():
return "Por favor, ingresa una o más secuencias válidas."
# Separar las secuencias por líneas y eliminar espacios vacíos
secuencias_lista = [seq.strip() for seq in secuencias.strip().split('\n') if seq.strip()]
resultados = []
for seq in secuencias_lista:
# Tokenizar la secuencia
inputs = tokenizer(seq, truncation=True, padding='max_length', max_length=1290, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
# Aplicar el modelo para obtener los logits
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
# Obtener las predicciones seleccionando la clase con el logit más alto
predictions = torch.argmax(outputs.logits, dim=-1).squeeze().tolist()
# Convertir las predicciones a etiquetas
predicted_labels = [class_mapping.get(pred, "Unknown") for pred in predictions]
# Emparejar cada residuo con su etiqueta predicha
residue_to_label = list(zip(list(seq), predicted_labels))
# Formatear el resultado para mostrarlo en la interfaz
secuencia_resultado = []
for i, (residue, label) in enumerate(residue_to_label):
# Omite los residuos 'PAD' que se agregan durante el padding
if residue != 'PAD':
secuencia_resultado.append(f"Posición {i+1} - {residue}: {label}")
resultados.append("\n".join(secuencia_resultado))
return "\n\n".join(resultados)
except Exception as e:
print(f"Error durante la predicción: {e}")
return f"Error al predecir los sitios de ARN: {e}"
# Definir la interfaz de Gradio
titulo = "ESM-2 para Predicción de Sitios de Unión de ARN"
descripcion = (
"Ingresa una o más secuencias de proteínas (una por línea) y obtén predicciones de sitios de unión de ARN para cada residuo."
" El modelo utilizado es ESM-2, entrenado en el dataset 'S1' de sitios de unión proteína-ARN."
)
iface = gr.Interface(
fn=predecir_sitios_arn,
inputs=gr.Textbox(
lines=10,
placeholder="Escribe tus secuencias de proteínas aquí, una por línea...",
label="Secuencias de Proteínas"
),
outputs=gr.Textbox(label="Predicciones de Sitios de Unión de ARN"),
title=titulo,
description=descripcion,
examples=[
[
"VLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTTK",
"SQETFSDLWKLLPENNVLSPLPSQAMDDLMLSPDDIEQWF"
],
[
"MKAILVVLLYTFATANADAVAHVAA",
"GATVQAAEEVTQGVVVVEEVAGGAA"
]
],
cache_examples=False,
allow_flagging="never"
)
# Ejecutar la interfaz
if __name__ == "__main__":
iface.launch()
|