import streamlit as st from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig # Agregada esta importación import torch from huggingface_hub import login import os ################################################################## def setup_llama3_auth(): """Configurar autenticación para Llama 3""" if 'HUGGING_FACE_TOKEN_3' in st.secrets: token = st.secrets['HUGGING_FACE_TOKEN_3'] login(token) return True else: st.error("No se encontró el token de Llama 3 en los secrets") st.stop() return False class Llama3Demo: def __init__(self): setup_llama3_auth() self.model_name = "meta-llama/Llama-3.2-3B-Instruct" self._model = None self._tokenizer = None # Configuración de cuantización self.quantization_config = BitsAndBytesConfig( load_in_8bit=True, bnb_4bit_compute_dtype=torch.float16 ) @property def model(self): if self._model is None: try: self._model = AutoModelForCausalLM.from_pretrained( self.model_name, torch_dtype=torch.float16, device_map="auto", quantization_config=self.quantization_config, # Nueva forma de configurar cuantización token=st.secrets['HUGGING_FACE_TOKEN_3'] # Actualizado de use_auth_token a token ) except Exception as e: st.error(f"Error cargando el modelo: {str(e)}") raise e return self._model @property def tokenizer(self): if self._tokenizer is None: try: self._tokenizer = AutoTokenizer.from_pretrained( self.model_name, token=st.secrets['HUGGING_FACE_TOKEN_3'] # Actualizado de use_auth_token a token ) except Exception as e: st.error(f"Error cargando el tokenizer: {str(e)}") raise e return self._tokenizer ################################################################## def generate_response(self, prompt: str, max_new_tokens: int = 512) -> str: formatted_prompt = f"""<|system|>You are a helpful AI assistant. <|user|>{prompt} <|assistant|>""" inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.model.device) # Asegurar que tenemos un pad_token_id válido if self.tokenizer.pad_token_id is None: self.tokenizer.pad_token_id = self.tokenizer.eos_token_id with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=max_new_tokens, num_return_sequences=1, temperature=0.7, do_sample=True, top_p=0.9, pad_token_id=self.tokenizer.pad_token_id # Explícitamente establecer pad_token_id ) torch.cuda.empty_cache() response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) return response.split("<|assistant|>")[-1].strip() ################################################################## def main(): st.set_page_config(page_title="Llama 3.2 Chat", page_icon="🦙") st.title("🦙 Llama 3.2 Chat") # Verificar configuración with st.expander("🔧 Status", expanded=True): try: token_status = setup_llama3_auth() st.write("Token Llama 3:", "✅" if token_status else "❌") if torch.cuda.is_available(): st.write("GPU:", torch.cuda.get_device_name(0)) st.write("Memoria GPU:", f"{torch.cuda.get_device_properties(0).total_memory/1e9:.1f} GB") else: st.warning("GPU no disponible") except Exception as e: st.error(f"Error en configuración: {str(e)}") # Sidebar con controles de generación with st.sidebar: st.markdown("### Parámetros de Generación") generation_params = { 'temperature': st.slider( "Temperatura (creatividad vs precisión)", min_value=0.1, max_value=1.0, value=0.6, step=0.1, help="Valores más bajos = respuestas más precisas" ), 'max_new_tokens': st.slider( "Longitud máxima", min_value=64, max_value=1024, value=512, step=64, help="Longitud máxima de la respuesta" ), 'top_p': st.slider( "Top-p (núcleo de probabilidad)", min_value=0.1, max_value=1.0, value=0.85, step=0.05 ) } with st.expander("Parámetros Avanzados"): generation_params.update({ 'repetition_penalty': st.slider( "Penalización por repetición", min_value=1.0, max_value=2.0, value=1.2, step=0.1 ), 'top_k': st.slider( "Top-k tokens", min_value=1, max_value=100, value=50, step=1 ) }) st.markdown(""" ### Guía de Parámetros - **Temperatura**: Menor = más preciso, Mayor = más creativo - **Top-p**: Control sobre la variabilidad de respuestas - **Longitud**: Ajustar según necesidad de detalle """) if st.button("Limpiar Chat"): st.session_state.messages = [] st.experimental_rerun() # Inicializar el modelo if 'llama' not in st.session_state: with st.spinner("Inicializando Llama 3.2... esto puede tomar unos minutos..."): try: st.session_state.llama = Llama3Demo() except Exception as e: st.error("Error inicializando el modelo") st.stop() # Gestión del historial de chat if 'messages' not in st.session_state: st.session_state.messages = [] # Mostrar historial for message in st.session_state.messages: with st.chat_message(message["role"]): st.markdown(message["content"]) # Interface de chat if prompt := st.chat_input("Escribe tu mensaje aquí"): st.session_state.messages.append({"role": "user", "content": prompt}) with st.chat_message("user"): st.markdown(prompt) with st.chat_message("assistant"): try: response = st.session_state.llama.generate_response(prompt, **generation_params) st.markdown(response) st.session_state.messages.append({"role": "assistant", "content": response}) except Exception as e: st.error(f"Error generando respuesta: {str(e)}") if __name__ == "__main__": main()