BioRAG / app.py
C2MV's picture
Update app.py
69463c7 verified
raw
history blame
4.41 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import time
from functools import wraps
import sys
# Intentar importar 'spaces' para usar el decorador GPU si está disponible
try:
import spaces
except ImportError:
# Si 'spaces' no está disponible, definir un decorador vacío
def GPU(duration):
def decorator(func):
return func
return decorator
spaces = type('spaces', (), {'GPU': GPU})
# Decorador para medir el tiempo de ejecución
def medir_tiempo(func):
@wraps(func)
def wrapper(*args, **kwargs):
inicio = time.time()
resultado = func(*args, **kwargs)
fin = time.time()
tiempo_transcurrido = fin - inicio
print(f"Tiempo de ejecución de '{func.__name__}': {tiempo_transcurrido:.2f} segundos")
return resultado
return wrapper
# Configurar el dispositivo (GPU si está disponible)
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cpu":
print("Advertencia: CUDA no está disponible. Se usará la CPU, lo que puede ser lento.")
# Ruta local al tokenizador
tokenizer_path = "tokenizer_bpe_1024"
# Cargar el tokenizador desde el directorio local
try:
print(f"Cargando el tokenizador desde el directorio local '{tokenizer_path}'...")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
except ValueError as e:
print(f"Error al cargar el tokenizador: {e}")
sys.exit(1)
except Exception as e:
print(f"Error inesperado al cargar el tokenizador: {e}")
sys.exit(1)
# Ruta al modelo local
model_path = "model.pt.recombined"
# Cargar el modelo desde el archivo local
try:
print(f"Cargando el modelo GenerRNA desde '{model_path}'...")
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float16 if device == "cuda" else torch.float32
).to(device)
model.eval()
print("Modelo GenerRNA cargado exitosamente.")
except FileNotFoundError:
print(f"Error: El archivo del modelo '{model_path}' no se encontró.")
sys.exit(1)
except Exception as e:
print(f"Error al cargar el modelo GenerRNA: {e}")
sys.exit(1)
@spaces.GPU(duration=120) # Decorador para asignar GPU durante 120 segundos
@medir_tiempo
def generar_rna_sequence(prompt, max_length=256):
"""
Función que genera una secuencia de RNA a partir de una secuencia inicial dada.
"""
try:
if not prompt.strip():
return "Por favor, ingresa una secuencia de inicio válida."
# Tokenizar la entrada
inputs = tokenizer.encode(prompt, return_tensors="pt").to(device)
# Generar la secuencia
with torch.no_grad():
outputs = model.generate(
inputs,
max_length=max_length,
num_return_sequences=1,
no_repeat_ngram_size=2,
temperature=0.7,
top_k=50,
top_p=0.95,
do_sample=True
)
# Decodificar la secuencia generada
generated_sequence = tokenizer.decode(outputs[0], skip_special_tokens=True)
return generated_sequence
except Exception as e:
print(f"Error durante la generación de secuencia: {e}")
return f"Error al generar la secuencia: {e}"
# Definir la interfaz de Gradio
titulo = "GenerRNA - Generador de Secuencias de RNA"
descripcion = (
"GenerRNA es un modelo generativo de RNA basado en una arquitectura Transformer. "
"Ingresa una secuencia inicial opcional y define la longitud máxima para generar nuevas secuencias de RNA."
)
iface = gr.Interface(
fn=generar_rna_sequence,
inputs=[
gr.Textbox(
lines=5,
placeholder="Ingresa una secuencia de RNA inicial (opcional)...",
label="Secuencia Inicial"
),
gr.Slider(
minimum=50,
maximum=1000,
step=50,
value=256,
label="Longitud Máxima de la Secuencia"
)
],
outputs=gr.Textbox(label="Secuencia de RNA Generada"),
title=titulo,
description=descripcion,
examples=[
[
"AUGGCUACGUAUCGACGUA"
],
[
"GCUAUGCUAGCUAGCUGAC"
]
],
cache_examples=False,
allow_flagging="never"
)
# Ejecutar la interfaz
if __name__ == "__main__":
iface.launch()