Spaces:
Paused
Paused
import streamlit as st | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
# Verificar GPU al inicio | |
def check_gpu(): | |
if torch.cuda.is_available(): | |
gpu_info = { | |
"GPU Disponible": True, | |
"Nombre GPU": torch.cuda.get_device_name(0), | |
"Memoria Total (GB)": round(torch.cuda.get_device_properties(0).total_memory/1e9, 2), | |
"CUDA Version": torch.version.cuda | |
} | |
return gpu_info | |
return {"GPU Disponible": False} | |
# Configurar autenticaci贸n | |
def setup_auth(): | |
if 'HUGGING_FACE_TOKEN' in st.secrets: | |
login(st.secrets['HUGGING_FACE_TOKEN']) | |
return True | |
else: | |
st.error("No se encontr贸 el token de Hugging Face en los secrets") | |
st.stop() | |
return False | |
class LlamaDemo: | |
def __init__(self): | |
self.model_name = "meta-llama/Llama-2-7b-chat-hf" | |
self._model = None | |
self._tokenizer = None | |
def model(self): | |
if self._model is None: | |
self._model = AutoModelForCausalLM.from_pretrained( | |
self.model_name, | |
torch_dtype=torch.float16, # Usar float16 para optimizar memoria | |
device_map="auto", | |
load_in_8bit=True # Cuantizaci贸n 8-bit para optimizar memoria | |
) | |
return self._model | |
def tokenizer(self): | |
if self._tokenizer is None: | |
self._tokenizer = AutoTokenizer.from_pretrained( | |
self.model_name, | |
trust_remote_code=True | |
) | |
return self._tokenizer | |
def generate_response(self, prompt: str, max_new_tokens: int = 512) -> str: | |
formatted_prompt = f"[INST] {prompt} [/INST]" | |
inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.model.device) | |
with torch.no_grad(): | |
outputs = self.model.generate( | |
**inputs, | |
max_new_tokens=max_new_tokens, | |
num_return_sequences=1, | |
temperature=0.7, | |
do_sample=True, | |
top_p=0.9, | |
pad_token_id=self.tokenizer.eos_token_id | |
) | |
# Liberar memoria GPU despu茅s de generar | |
torch.cuda.empty_cache() | |
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return response.split("[/INST]")[-1].strip() | |
def main(): | |
st.set_page_config( | |
page_title="Llama 2 Chat Demo", | |
page_icon="馃", | |
layout="wide" | |
) | |
st.title("馃 Llama 2 Chat Demo") | |
# Mostrar informaci贸n de GPU | |
gpu_info = check_gpu() | |
with st.expander("馃捇 GPU Info", expanded=False): | |
for key, value in gpu_info.items(): | |
st.write(f"{key}: {value}") | |
# Initialize model | |
if 'llama' not in st.session_state: | |
with st.spinner("Loading Llama 2... This might take a few minutes..."): | |
st.session_state.llama = LlamaDemo() | |
if 'chat_history' not in st.session_state: | |
st.session_state.chat_history = [] | |
# Chat interface | |
with st.container(): | |
for message in st.session_state.chat_history: | |
with st.chat_message(message["role"]): | |
st.write(message["content"]) | |
if prompt := st.chat_input("What would you like to discuss?"): | |
st.session_state.chat_history.append({ | |
"role": "user", | |
"content": prompt | |
}) | |
with st.chat_message("user"): | |
st.write(prompt) | |
with st.chat_message("assistant"): | |
with st.spinner("Thinking..."): | |
try: | |
response = st.session_state.llama.generate_response(prompt) | |
st.write(response) | |
st.session_state.chat_history.append({ | |
"role": "assistant", | |
"content": response | |
}) | |
except Exception as e: | |
st.error(f"Error: {str(e)}") | |
with st.sidebar: | |
st.markdown(""" | |
### Memory Management | |
To optimize GPU usage and costs: | |
- Model runs in 8-bit precision | |
- Memory is cleared after each generation | |
- Space sleeps after inactivity | |
""") | |
if st.button("Clear Chat History"): | |
st.session_state.chat_history = [] | |
st.experimental_rerun() | |
if __name__ == "__main__": | |
main() |