import gradio as gr import torch from transformers import AutoTokenizer, EsmForTokenClassification import time from functools import wraps import sys import spaces # Asegúrate de que este módulo esté disponible y correctamente instalado # Decorador para medir el tiempo de ejecución def medir_tiempo(func): @wraps(func) def wrapper(*args, **kwargs): inicio = time.time() resultado = func(*args, **kwargs) fin = time.time() tiempo_transcurrido = fin - inicio print(f"Tiempo de ejecución de '{func.__name__}': {tiempo_transcurrido:.2f} segundos") return resultado return wrapper # Verificar si CUDA está disponible para el modelo principal device = "cuda" if torch.cuda.is_available() else "cpu" if device == "cpu": print("Advertencia: CUDA no está disponible. Se usará la CPU, lo que puede ser lento.") # Definir el mapeo de clases class_mapping = { 0: 'Not Binding Site', 1: 'Binding Site', } # Cargar el modelo y el tokenizador model_name = "AmelieSchreiber/esm2_t6_8M_UR50D_rna_binding_site_predictor" try: print("Cargando el tokenizador...") tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D") except ValueError as e: print(f"Error al cargar el tokenizador: {e}") sys.exit(1) try: print("Cargando el modelo de predicción...") model = EsmForTokenClassification.from_pretrained(model_name) model.to(device) except Exception as e: print(f"Error al cargar el modelo: {e}") sys.exit(1) @spaces.GPU(duration=120) # Decorador para asignar GPU durante 120 segundos @medir_tiempo def predecir_sitios_arn(secuencias): """ Función que predice sitios de unión de ARN para las secuencias de proteínas proporcionadas. """ try: if not secuencias.strip(): return "Por favor, ingresa una o más secuencias válidas." # Separar las secuencias por líneas y eliminar espacios vacíos secuencias_lista = [seq.strip() for seq in secuencias.strip().split('\n') if seq.strip()] resultados = [] for seq in secuencias_lista: # Tokenizar la secuencia inputs = tokenizer(seq, truncation=True, padding='max_length', max_length=1290, return_tensors="pt") input_ids = inputs["input_ids"].to(device) attention_mask = inputs["attention_mask"].to(device) # Aplicar el modelo para obtener los logits with torch.no_grad(): outputs = model(input_ids=input_ids, attention_mask=attention_mask) # Obtener las predicciones seleccionando la clase con el logit más alto predictions = torch.argmax(outputs.logits, dim=-1).squeeze().tolist() # Convertir las predicciones a etiquetas predicted_labels = [class_mapping.get(pred, "Unknown") for pred in predictions] # Emparejar cada residuo con su etiqueta predicha residue_to_label = list(zip(list(seq), predicted_labels)) # Formatear el resultado para mostrarlo en la interfaz secuencia_resultado = [] for i, (residue, label) in enumerate(residue_to_label): # Omite los residuos 'PAD' que se agregan durante el padding if residue != 'PAD': secuencia_resultado.append(f"Posición {i+1} - {residue}: {label}") resultados.append("\n".join(secuencia_resultado)) return "\n\n".join(resultados) except Exception as e: print(f"Error durante la predicción: {e}") return f"Error al predecir los sitios de ARN: {e}" # Definir la interfaz de Gradio titulo = "ESM-2 para Predicción de Sitios de Unión de ARN" descripcion = ( "Ingresa una o más secuencias de proteínas (una por línea) y obtén predicciones de sitios de unión de ARN para cada residuo." " El modelo utilizado es ESM-2, entrenado en el dataset 'S1' de sitios de unión proteína-ARN." ) iface = gr.Interface( fn=predecir_sitios_arn, inputs=gr.Textbox( lines=10, placeholder="Escribe tus secuencias de proteínas aquí, una por línea...", label="Secuencias de Proteínas" ), outputs=gr.Textbox(label="Predicciones de Sitios de Unión de ARN"), title=titulo, description=descripcion, examples=[ [ "VLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTTK", "SQETFSDLWKLLPENNVLSPLPSQAMDDLMLSPDDIEQWF" ], [ "MKAILVVLLYTFATANADAVAHVAA", "GATVQAAEEVTQGVVVVEEVAGGAA" ] ], cache_examples=False, allow_flagging="never" ) # Ejecutar la interfaz if __name__ == "__main__": iface.launch()