|
import gradio as gr |
|
import torch |
|
from transformers import AutoTokenizer, EsmForTokenClassification |
|
import time |
|
from functools import wraps |
|
import sys |
|
import spaces |
|
|
|
|
|
def medir_tiempo(func): |
|
@wraps(func) |
|
def wrapper(*args, **kwargs): |
|
inicio = time.time() |
|
resultado = func(*args, **kwargs) |
|
fin = time.time() |
|
tiempo_transcurrido = fin - inicio |
|
print(f"Tiempo de ejecución de '{func.__name__}': {tiempo_transcurrido:.2f} segundos") |
|
return resultado |
|
return wrapper |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
if device == "cpu": |
|
print("Advertencia: CUDA no está disponible. Se usará la CPU, lo que puede ser lento.") |
|
|
|
|
|
class_mapping = { |
|
0: 'Not Binding Site', |
|
1: 'Binding Site', |
|
} |
|
|
|
|
|
model_name = "AmelieSchreiber/esm2_t6_8M_UR50D_rna_binding_site_predictor" |
|
|
|
try: |
|
print("Cargando el tokenizador...") |
|
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D") |
|
except ValueError as e: |
|
print(f"Error al cargar el tokenizador: {e}") |
|
sys.exit(1) |
|
|
|
try: |
|
print("Cargando el modelo de predicción...") |
|
model = EsmForTokenClassification.from_pretrained(model_name) |
|
model.to(device) |
|
except Exception as e: |
|
print(f"Error al cargar el modelo: {e}") |
|
sys.exit(1) |
|
|
|
@spaces.GPU(duration=120) |
|
@medir_tiempo |
|
def predecir_sitios_arn(secuencias): |
|
""" |
|
Función que predice sitios de unión de ARN para las secuencias de proteínas proporcionadas. |
|
""" |
|
try: |
|
if not secuencias.strip(): |
|
return "Por favor, ingresa una o más secuencias válidas." |
|
|
|
|
|
secuencias_lista = [seq.strip() for seq in secuencias.strip().split('\n') if seq.strip()] |
|
resultados = [] |
|
|
|
for seq in secuencias_lista: |
|
|
|
inputs = tokenizer(seq, truncation=True, padding='max_length', max_length=1290, return_tensors="pt") |
|
input_ids = inputs["input_ids"].to(device) |
|
attention_mask = inputs["attention_mask"].to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(input_ids=input_ids, attention_mask=attention_mask) |
|
|
|
|
|
predictions = torch.argmax(outputs.logits, dim=-1).squeeze().tolist() |
|
|
|
|
|
predicted_labels = [class_mapping.get(pred, "Unknown") for pred in predictions] |
|
|
|
|
|
residue_to_label = list(zip(list(seq), predicted_labels)) |
|
|
|
|
|
secuencia_resultado = [] |
|
for i, (residue, label) in enumerate(residue_to_label): |
|
|
|
if residue != 'PAD': |
|
secuencia_resultado.append(f"Posición {i+1} - {residue}: {label}") |
|
|
|
resultados.append("\n".join(secuencia_resultado)) |
|
|
|
return "\n\n".join(resultados) |
|
|
|
except Exception as e: |
|
print(f"Error durante la predicción: {e}") |
|
return f"Error al predecir los sitios de ARN: {e}" |
|
|
|
|
|
titulo = "ESM-2 para Predicción de Sitios de Unión de ARN" |
|
descripcion = ( |
|
"Ingresa una o más secuencias de proteínas (una por línea) y obtén predicciones de sitios de unión de ARN para cada residuo." |
|
" El modelo utilizado es ESM-2, entrenado en el dataset 'S1' de sitios de unión proteína-ARN." |
|
) |
|
|
|
iface = gr.Interface( |
|
fn=predecir_sitios_arn, |
|
inputs=gr.Textbox( |
|
lines=10, |
|
placeholder="Escribe tus secuencias de proteínas aquí, una por línea...", |
|
label="Secuencias de Proteínas" |
|
), |
|
outputs=gr.Textbox(label="Predicciones de Sitios de Unión de ARN"), |
|
title=titulo, |
|
description=descripcion, |
|
examples=[ |
|
[ |
|
"VLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTTK", |
|
"SQETFSDLWKLLPENNVLSPLPSQAMDDLMLSPDDIEQWF" |
|
], |
|
[ |
|
"MKAILVVLLYTFATANADAVAHVAA", |
|
"GATVQAAEEVTQGVVVVEEVAGGAA" |
|
] |
|
], |
|
cache_examples=False, |
|
allow_flagging="never" |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
iface.launch() |
|
|