Spaces:
Paused
Paused
import streamlit as st | |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig # Agregada esta importación | |
import torch | |
from huggingface_hub import login | |
import os | |
################################################################## | |
def setup_llama3_auth(): | |
"""Configurar autenticación para Llama 3""" | |
if 'HUGGING_FACE_TOKEN_3' in st.secrets: | |
token = st.secrets['HUGGING_FACE_TOKEN_3'] | |
login(token) | |
return True | |
else: | |
st.error("No se encontró el token de Llama 3 en los secrets") | |
st.stop() | |
return False | |
class Llama3Demo: | |
def __init__(self): | |
setup_llama3_auth() | |
self.model_name = "meta-llama/Llama-3.2-3B-Instruct" | |
self._model = None | |
self._tokenizer = None | |
# Configuración de cuantización | |
self.quantization_config = BitsAndBytesConfig( | |
load_in_8bit=True, | |
bnb_4bit_compute_dtype=torch.float16 | |
) | |
def model(self): | |
if self._model is None: | |
try: | |
self._model = AutoModelForCausalLM.from_pretrained( | |
self.model_name, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
quantization_config=self.quantization_config, # Nueva forma de configurar cuantización | |
token=st.secrets['HUGGING_FACE_TOKEN_3'] # Actualizado de use_auth_token a token | |
) | |
except Exception as e: | |
st.error(f"Error cargando el modelo: {str(e)}") | |
raise e | |
return self._model | |
def tokenizer(self): | |
if self._tokenizer is None: | |
try: | |
self._tokenizer = AutoTokenizer.from_pretrained( | |
self.model_name, | |
token=st.secrets['HUGGING_FACE_TOKEN_3'] # Actualizado de use_auth_token a token | |
) | |
except Exception as e: | |
st.error(f"Error cargando el tokenizer: {str(e)}") | |
raise e | |
return self._tokenizer | |
################################################################## | |
def generate_response(self, prompt: str, max_new_tokens: int = 512) -> str: | |
formatted_prompt = f"""<|system|>You are a helpful AI assistant.</s> | |
<|user|>{prompt}</s> | |
<|assistant|>""" | |
inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.model.device) | |
# Asegurar que tenemos un pad_token_id válido | |
if self.tokenizer.pad_token_id is None: | |
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id | |
with torch.no_grad(): | |
outputs = self.model.generate( | |
**inputs, | |
max_new_tokens=max_new_tokens, | |
num_return_sequences=1, | |
temperature=0.7, | |
do_sample=True, | |
top_p=0.9, | |
pad_token_id=self.tokenizer.pad_token_id # Explícitamente establecer pad_token_id | |
) | |
torch.cuda.empty_cache() | |
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return response.split("<|assistant|>")[-1].strip() | |
################################################################## | |
def main(): | |
st.set_page_config(page_title="Llama 3.2 Chat", page_icon="🦙") | |
st.title("🦙 Llama 3.2 Chat") | |
# Verificar configuración | |
with st.expander("🔧 Status", expanded=True): | |
try: | |
token_status = setup_llama3_auth() | |
st.write("Token Llama 3:", "✅" if token_status else "❌") | |
if torch.cuda.is_available(): | |
st.write("GPU:", torch.cuda.get_device_name(0)) | |
st.write("Memoria GPU:", f"{torch.cuda.get_device_properties(0).total_memory/1e9:.1f} GB") | |
else: | |
st.warning("GPU no disponible") | |
except Exception as e: | |
st.error(f"Error en configuración: {str(e)}") | |
# Sidebar con controles de generación | |
with st.sidebar: | |
st.markdown("### Parámetros de Generación") | |
generation_params = { | |
'temperature': st.slider( | |
"Temperatura (creatividad vs precisión)", | |
min_value=0.1, | |
max_value=1.0, | |
value=0.6, | |
step=0.1, | |
help="Valores más bajos = respuestas más precisas" | |
), | |
'max_new_tokens': st.slider( | |
"Longitud máxima", | |
min_value=64, | |
max_value=1024, | |
value=512, | |
step=64, | |
help="Longitud máxima de la respuesta" | |
), | |
'top_p': st.slider( | |
"Top-p (núcleo de probabilidad)", | |
min_value=0.1, | |
max_value=1.0, | |
value=0.85, | |
step=0.05 | |
) | |
} | |
with st.expander("Parámetros Avanzados"): | |
generation_params.update({ | |
'repetition_penalty': st.slider( | |
"Penalización por repetición", | |
min_value=1.0, | |
max_value=2.0, | |
value=1.2, | |
step=0.1 | |
), | |
'top_k': st.slider( | |
"Top-k tokens", | |
min_value=1, | |
max_value=100, | |
value=50, | |
step=1 | |
) | |
}) | |
st.markdown(""" | |
### Guía de Parámetros | |
- **Temperatura**: Menor = más preciso, Mayor = más creativo | |
- **Top-p**: Control sobre la variabilidad de respuestas | |
- **Longitud**: Ajustar según necesidad de detalle | |
""") | |
if st.button("Limpiar Chat"): | |
st.session_state.messages = [] | |
st.experimental_rerun() | |
# Inicializar el modelo | |
if 'llama' not in st.session_state: | |
with st.spinner("Inicializando Llama 3.2... esto puede tomar unos minutos..."): | |
try: | |
st.session_state.llama = Llama3Demo() | |
except Exception as e: | |
st.error("Error inicializando el modelo") | |
st.stop() | |
# Gestión del historial de chat | |
if 'messages' not in st.session_state: | |
st.session_state.messages = [] | |
# Mostrar historial | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
# Interface de chat | |
if prompt := st.chat_input("Escribe tu mensaje aquí"): | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
with st.chat_message("user"): | |
st.markdown(prompt) | |
with st.chat_message("assistant"): | |
try: | |
response = st.session_state.llama.generate_response(prompt, **generation_params) | |
st.markdown(response) | |
st.session_state.messages.append({"role": "assistant", "content": response}) | |
except Exception as e: | |
st.error(f"Error generando respuesta: {str(e)}") | |
if __name__ == "__main__": | |
main() | |