import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# Verificar si CUDA está disponible
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cpu":
    print("Advertencia: CUDA no está disponible. Se usará la CPU, lo que puede ser lento.")

# Cargar el tokenizador y el modelo desde HuggingFace
model_name = "dmis-lab/selfbiorag_7b"
print("Cargando el tokenizador y el modelo desde HuggingFace...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name, 
    torch_dtype=torch.float16 if device == "cuda" else torch.float32
).to(device)

def generar_respuesta(consulta):
    """
    Función que genera una respuesta a partir de una consulta dada.
    """
    # Tokenizar la consulta
    inputs = tokenizer.encode(consulta, return_tensors="pt").to(device)
    
    # Configurar los parámetros de generación
    generation_kwargs = {
        "max_new_tokens": 200,
        "temperature": 0.0,
        "top_p": 1.0,
        "do_sample": False,
        "skip_special_tokens": True
    }
    
    # Generar la respuesta
    with torch.no_grad():
        outputs = model.generate(**inputs, **generation_kwargs)
    
    # Decodificar la respuesta
    respuesta = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return respuesta

# Definir la interfaz de Gradio
titulo = "Generador de Respuestas con SelfBioRAG 7B"
descripcion = "Ingresa una consulta y el modelo generará una respuesta basada en el contenido proporcionado."

iface = gr.Interface(
    fn=generar_respuesta,
    inputs=gr.inputs.Textbox(lines=5, placeholder="Escribe tu consulta aquí..."),
    outputs=gr.outputs.Textbox(),
    title=titulo,
    description=descripcion,
    examples=[
        [
            "Clasifica el siguiente informe de radiología según la parte del cuerpo a la que se refiere (por ejemplo, pecho, abdomen, cerebro, etc.): Los discos intervertebrales en L4-L5 y L5-S1 muestran signos de degeneración con leve abultamiento que comprime la raíz nerviosa adyacente."
        ],
        [
            "Resume los puntos clave sobre el papel de las mutaciones en los genes BRCA1 y BRCA2 en el aumento del riesgo de cáncer de mama."
        ]
    ],
    cache_examples=False
)

# Ejecutar la interfaz
if __name__ == "__main__":
    iface.launch()