BioRAG / app.py
C2MV's picture
Update app.py
5b7b502 verified
raw
history blame
4.72 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, EsmForTokenClassification
import time
from functools import wraps
import sys
import spaces # Asegúrate de que este módulo esté disponible y correctamente instalado
# Decorador para medir el tiempo de ejecución
def medir_tiempo(func):
@wraps(func)
def wrapper(*args, **kwargs):
inicio = time.time()
resultado = func(*args, **kwargs)
fin = time.time()
tiempo_transcurrido = fin - inicio
print(f"Tiempo de ejecución de '{func.__name__}': {tiempo_transcurrido:.2f} segundos")
return resultado
return wrapper
# Verificar si CUDA está disponible para el modelo principal
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cpu":
print("Advertencia: CUDA no está disponible. Se usará la CPU, lo que puede ser lento.")
# Definir el mapeo de clases
class_mapping = {
0: 'Not Binding Site',
1: 'Binding Site',
}
# Cargar el modelo y el tokenizador
model_name = "AmelieSchreiber/esm2_t6_8M_UR50D_rna_binding_site_predictor"
try:
print("Cargando el tokenizador...")
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
except ValueError as e:
print(f"Error al cargar el tokenizador: {e}")
sys.exit(1)
try:
print("Cargando el modelo de predicción...")
model = EsmForTokenClassification.from_pretrained(model_name)
model.to(device)
except Exception as e:
print(f"Error al cargar el modelo: {e}")
sys.exit(1)
@spaces.GPU(duration=120) # Decorador para asignar GPU durante 120 segundos
@medir_tiempo
def predecir_sitios_arn(secuencias):
"""
Función que predice sitios de unión de ARN para las secuencias de proteínas proporcionadas.
"""
try:
if not secuencias.strip():
return "Por favor, ingresa una o más secuencias válidas."
# Separar las secuencias por líneas y eliminar espacios vacíos
secuencias_lista = [seq.strip() for seq in secuencias.strip().split('\n') if seq.strip()]
resultados = []
for seq in secuencias_lista:
# Tokenizar la secuencia
inputs = tokenizer(seq, truncation=True, padding='max_length', max_length=1290, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
# Aplicar el modelo para obtener los logits
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
# Obtener las predicciones seleccionando la clase con el logit más alto
predictions = torch.argmax(outputs.logits, dim=-1).squeeze().tolist()
# Convertir las predicciones a etiquetas
predicted_labels = [class_mapping.get(pred, "Unknown") for pred in predictions]
# Emparejar cada residuo con su etiqueta predicha
residue_to_label = list(zip(list(seq), predicted_labels))
# Formatear el resultado para mostrarlo en la interfaz
secuencia_resultado = []
for i, (residue, label) in enumerate(residue_to_label):
# Omite los residuos 'PAD' que se agregan durante el padding
if residue != 'PAD':
secuencia_resultado.append(f"Posición {i+1} - {residue}: {label}")
resultados.append("\n".join(secuencia_resultado))
return "\n\n".join(resultados)
except Exception as e:
print(f"Error durante la predicción: {e}")
return f"Error al predecir los sitios de ARN: {e}"
# Definir la interfaz de Gradio
titulo = "ESM-2 para Predicción de Sitios de Unión de ARN"
descripcion = (
"Ingresa una o más secuencias de proteínas (una por línea) y obtén predicciones de sitios de unión de ARN para cada residuo."
" El modelo utilizado es ESM-2, entrenado en el dataset 'S1' de sitios de unión proteína-ARN."
)
iface = gr.Interface(
fn=predecir_sitios_arn,
inputs=gr.Textbox(
lines=10,
placeholder="Escribe tus secuencias de proteínas aquí, una por línea...",
label="Secuencias de Proteínas"
),
outputs=gr.Textbox(label="Predicciones de Sitios de Unión de ARN"),
title=titulo,
description=descripcion,
examples=[
[
"VLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTTK",
"SQETFSDLWKLLPENNVLSPLPSQAMDDLMLSPDDIEQWF"
],
[
"MKAILVVLLYTFATANADAVAHVAA",
"GATVQAAEEVTQGVVVVEEVAGGAA"
]
],
cache_examples=False,
allow_flagging="never"
)
# Ejecutar la interfaz
if __name__ == "__main__":
iface.launch()