BioRAG / app.py
C2MV's picture
Update app.py
820a0dd verified
raw
history blame
2.84 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import time
from functools import wraps
# Decorador para medir el tiempo de ejecución
def medir_tiempo(func):
@wraps(func)
def wrapper(*args, **kwargs):
inicio = time.time()
resultado = func(*args, **kwargs)
fin = time.time()
tiempo_transcurrido = fin - inicio
print(f"Tiempo de ejecución de '{func.__name__}': {tiempo_transcurrido:.2f} segundos")
return resultado
return wrapper
# Verificar si CUDA está disponible
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cpu":
print("Advertencia: CUDA no está disponible. Se usará la CPU, lo que puede ser lento.")
# Cargar el tokenizador y el modelo desde HuggingFace
model_name = "dmis-lab/selfbiorag_7b"
print("Cargando el tokenizador y el modelo desde HuggingFace...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if device == "cuda" else torch.float32
).to(device)
@medir_tiempo
def generar_respuesta(consulta):
"""
Función que genera una respuesta a partir de una consulta dada.
"""
# Tokenizar la consulta
inputs = tokenizer.encode(consulta, return_tensors="pt").to(device)
# Configurar los parámetros de generación
generation_kwargs = {
"max_new_tokens": 100, # Ajustado a 100
"temperature": 0.0,
"top_p": 1.0,
"do_sample": False,
"skip_special_tokens": True
}
# Generar la respuesta
with torch.no_grad():
outputs = model.generate(**inputs, **generation_kwargs)
# Decodificar la respuesta
respuesta = tokenizer.decode(outputs[0], skip_special_tokens=True)
return respuesta
# Definir la interfaz de Gradio
titulo = "Generador de Respuestas con SelfBioRAG 7B"
descripcion = "Ingresa una consulta y el modelo generará una respuesta basada en el contenido proporcionado."
iface = gr.Interface(
fn=generar_respuesta,
inputs=gr.inputs.Textbox(lines=5, placeholder="Escribe tu consulta aquí..."),
outputs=gr.outputs.Textbox(),
title=titulo,
description=descripcion,
examples=[
[
"Clasifica el siguiente informe de radiología según la parte del cuerpo a la que se refiere (por ejemplo, pecho, abdomen, cerebro, etc.): Los discos intervertebrales en L4-L5 y L5-S1 muestran signos de degeneración con leve abultamiento que comprime la raíz nerviosa adyacente."
],
[
"Resume los puntos clave sobre el papel de las mutaciones en los genes BRCA1 y BRCA2 en el aumento del riesgo de cáncer de mama."
]
],
cache_examples=False
)
# Ejecutar la interfaz
if __name__ == "__main__":
iface.launch()