import streamlit as st from transformers import AutoTokenizer, AutoModelForCausalLM import torch # Verificar GPU al inicio def check_gpu(): if torch.cuda.is_available(): gpu_info = { "GPU Disponible": True, "Nombre GPU": torch.cuda.get_device_name(0), "Memoria Total (GB)": round(torch.cuda.get_device_properties(0).total_memory/1e9, 2), "CUDA Version": torch.version.cuda } return gpu_info return {"GPU Disponible": False} # Configurar autenticación def setup_auth(): if 'HUGGING_FACE_TOKEN' in st.secrets: login(st.secrets['HUGGING_FACE_TOKEN']) return True else: st.error("No se encontró el token de Hugging Face en los secrets") st.stop() return False class LlamaDemo: def __init__(self): self.model_name = "meta-llama/Llama-2-7b-chat-hf" self._model = None self._tokenizer = None @property def model(self): if self._model is None: self._model = AutoModelForCausalLM.from_pretrained( self.model_name, torch_dtype=torch.float16, # Usar float16 para optimizar memoria device_map="auto", load_in_8bit=True # Cuantización 8-bit para optimizar memoria ) return self._model @property def tokenizer(self): if self._tokenizer is None: self._tokenizer = AutoTokenizer.from_pretrained( self.model_name, trust_remote_code=True ) return self._tokenizer def generate_response(self, prompt: str, max_new_tokens: int = 512) -> str: formatted_prompt = f"[INST] {prompt} [/INST]" inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.model.device) with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=max_new_tokens, num_return_sequences=1, temperature=0.7, do_sample=True, top_p=0.9, pad_token_id=self.tokenizer.eos_token_id ) # Liberar memoria GPU después de generar torch.cuda.empty_cache() response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) return response.split("[/INST]")[-1].strip() def main(): st.set_page_config( page_title="Llama 2 Chat Demo", page_icon="🦙", layout="wide" ) st.title("🦙 Llama 2 Chat Demo") # Mostrar información de GPU gpu_info = check_gpu() with st.expander("💻 GPU Info", expanded=False): for key, value in gpu_info.items(): st.write(f"{key}: {value}") # Initialize model if 'llama' not in st.session_state: with st.spinner("Loading Llama 2... This might take a few minutes..."): st.session_state.llama = LlamaDemo() if 'chat_history' not in st.session_state: st.session_state.chat_history = [] # Chat interface with st.container(): for message in st.session_state.chat_history: with st.chat_message(message["role"]): st.write(message["content"]) if prompt := st.chat_input("What would you like to discuss?"): st.session_state.chat_history.append({ "role": "user", "content": prompt }) with st.chat_message("user"): st.write(prompt) with st.chat_message("assistant"): with st.spinner("Thinking..."): try: response = st.session_state.llama.generate_response(prompt) st.write(response) st.session_state.chat_history.append({ "role": "assistant", "content": response }) except Exception as e: st.error(f"Error: {str(e)}") with st.sidebar: st.markdown(""" ### Memory Management To optimize GPU usage and costs: - Model runs in 8-bit precision - Memory is cleared after each generation - Space sleeps after inactivity """) if st.button("Clear Chat History"): st.session_state.chat_history = [] st.experimental_rerun() if __name__ == "__main__": main()