import os import gradio as gr import torch import torch._dynamo from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer from threading import Thread import spaces # Desactivar TorchDynamo para evitar errores de compilación torch._dynamo.config.suppress_errors = True torch._dynamo.disable() # Configuración MODEL_ID = "somosnlp-hackathon-2025/iberotales-gemma-3-1b-it-es" MAX_MAX_NEW_TOKENS = 4096 DEFAULT_MAX_NEW_TOKENS = 2048 MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "2048")) # System prompt personalizado DEFAULT_SYSTEM_MESSAGE = """Resuelve el siguiente problema. Primero, piensa en voz alta qué debes hacer, paso por paso y de forma resumida, entre y . Luego, da la respuesta final entre y . No escribas nada fuera de ese formato.""" # Variables globales model = None tokenizer = None def load_model(): """Cargar modelo y tokenizador""" global model, tokenizer if torch.cuda.is_available(): print(f"Cargando modelo: {MODEL_ID}") try: tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.float32, device_map="auto", trust_remote_code=True, ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token print("¡Modelo cargado exitosamente!") return True except Exception as e: print(f"Error al cargar el modelo: {e}") return False else: print("CUDA no disponible") return False # Cargar modelo al iniciar model_loaded = load_model() @spaces.GPU def generate( message: str, history: list, system_message: str, max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS, temperature: float = 0.7, top_p: float = 0.95, top_k: int = 50, repetition_penalty: float = 1.2, ): """Generar historia con streaming""" global model, tokenizer if model is None or tokenizer is None: yield "Error: Modelo no disponible. Por favor, reinicia la aplicación." return conversation = [] if system_message: conversation.append({"role": "system", "content": system_message}) for msg in history: if isinstance(msg, dict) and "role" in msg and "content" in msg: conversation.append(msg) conversation.append({"role": "user", "content": message}) try: input_ids = tokenizer.apply_chat_template( conversation, return_tensors="pt", add_generation_prompt=True, ) if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] gr.Warning(f"Conversación recortada a {MAX_INPUT_TOKEN_LENGTH} tokens.") input_ids = input_ids.to(model.device) attention_mask = torch.ones_like(input_ids, device=model.device) streamer = TextIteratorStreamer( tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=True ) generate_kwargs = { "input_ids": input_ids, "attention_mask": attention_mask, "streamer": streamer, "max_new_tokens": max_new_tokens, "do_sample": True, "top_p": top_p, "top_k": top_k, "temperature": temperature, "repetition_penalty": repetition_penalty, "pad_token_id": tokenizer.eos_token_id, "eos_token_id": tokenizer.eos_token_id, } generation_thread = Thread(target=model.generate, kwargs=generate_kwargs) generation_thread.start() outputs = [] try: for new_text in streamer: outputs.append(new_text) yield "".join(outputs) except Exception as e: yield f"Error durante la generación: {str(e)}" finally: generation_thread.join(timeout=1) except Exception as e: yield f"Error: {str(e)}" # Crear interfaz de chat demo = gr.ChatInterface( fn=generate, title="Iberotales: Mitos y Leyendas Iberoamericanas", description="Genera historias y personajes basados en el patrimonio cultural de Iberoamérica usando GRPO.", chatbot=gr.Chatbot( height=600, show_copy_button=True, ), textbox=gr.Textbox( placeholder="Escribe una historia o personaje que quieras generar...", scale=7 ), additional_inputs=[ gr.Textbox( value=DEFAULT_SYSTEM_MESSAGE, label="Mensaje del sistema (formato estructurado requerido)" ), gr.Slider( label="Máximo de tokens", minimum=100, maximum=MAX_MAX_NEW_TOKENS, step=50, value=DEFAULT_MAX_NEW_TOKENS, ), gr.Slider( label="Temperatura", minimum=0.1, maximum=2.0, step=0.1, value=0.7, ), gr.Slider( label="Top-p", minimum=0.1, maximum=1.0, step=0.05, value=0.95, ), gr.Slider( label="Top-k", minimum=1, maximum=100, step=1, value=50, ), gr.Slider( label="Penalización por repetición", minimum=1.0, maximum=2.0, step=0.05, value=1.2, ), ], examples=[ ["Crea una historia corta sobre el Pombero, un personaje de la mitología guaraní."], ["Genera un personaje basado en la leyenda del Cadejo."], ["Inventa una narrativa en torno al Nahual en un entorno contemporáneo."], ], cache_examples=False, ) if __name__ == "__main__": if model_loaded: print("Lanzando aplicación Gradio...") demo.launch( share=False, show_error=True ) else: print("Error al cargar el modelo. No se puede iniciar la aplicación.")