Spaces:
Paused
Paused
File size: 4,592 Bytes
f995cde a16e1cf f995cde 1aaec00 a16e1cf f995cde 5d15c3d 6f5bc8b 5d15c3d f995cde a7492f8 1aaec00 f995cde 1aaec00 a16e1cf fa0a856 a16e1cf 1aaec00 fa0a856 f995cde 1aaec00 f995cde 1aaec00 55ca2dd fa0a856 1aaec00 f995cde a16e1cf 1aaec00 55ca2dd f995cde a16e1cf f995cde 55ca2dd f995cde a16e1cf 1aaec00 55ca2dd f995cde 55ca2dd fa0a856 f995cde 55ca2dd f995cde fa0a856 55ca2dd f995cde a16e1cf f995cde |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# Verificar GPU al inicio
def check_gpu():
if torch.cuda.is_available():
gpu_info = {
"GPU Disponible": True,
"Nombre GPU": torch.cuda.get_device_name(0),
"Memoria Total (GB)": round(torch.cuda.get_device_properties(0).total_memory/1e9, 2),
"CUDA Version": torch.version.cuda
}
return gpu_info
return {"GPU Disponible": False}
# Configurar autenticaci贸n
def setup_auth():
if 'HUGGINGFACE_TOKEN' in st.secrets:
login(st.secrets['HUGGINGFACE_TOKEN'])
return True
else:
st.error("No se encontr贸 el token de Hugging Face en los secrets")
st.stop()
return False
class LlamaDemo:
def __init__(self):
self.model_name = "meta-llama/Llama-2-7b-chat"
self._model = None
self._tokenizer = None
@property
def model(self):
if self._model is None:
self._model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.float16, # Usar float16 para optimizar memoria
device_map="auto",
load_in_8bit=True # Cuantizaci贸n 8-bit para optimizar memoria
)
return self._model
@property
def tokenizer(self):
if self._tokenizer is None:
self._tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
trust_remote_code=True
)
return self._tokenizer
def generate_response(self, prompt: str, max_new_tokens: int = 512) -> str:
formatted_prompt = f"[INST] {prompt} [/INST]"
inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.model.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
num_return_sequences=1,
temperature=0.7,
do_sample=True,
top_p=0.9,
pad_token_id=self.tokenizer.eos_token_id
)
# Liberar memoria GPU despu茅s de generar
torch.cuda.empty_cache()
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return response.split("[/INST]")[-1].strip()
def main():
st.set_page_config(
page_title="Llama 2 Chat Demo",
page_icon="馃",
layout="wide"
)
st.title("馃 Llama 2 Chat Demo")
# Mostrar informaci贸n de GPU
gpu_info = check_gpu()
with st.expander("馃捇 GPU Info", expanded=False):
for key, value in gpu_info.items():
st.write(f"{key}: {value}")
# Initialize model
if 'llama' not in st.session_state:
with st.spinner("Loading Llama 2... This might take a few minutes..."):
st.session_state.llama = LlamaDemo()
if 'chat_history' not in st.session_state:
st.session_state.chat_history = []
# Chat interface
with st.container():
for message in st.session_state.chat_history:
with st.chat_message(message["role"]):
st.write(message["content"])
if prompt := st.chat_input("What would you like to discuss?"):
st.session_state.chat_history.append({
"role": "user",
"content": prompt
})
with st.chat_message("user"):
st.write(prompt)
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
try:
response = st.session_state.llama.generate_response(prompt)
st.write(response)
st.session_state.chat_history.append({
"role": "assistant",
"content": response
})
except Exception as e:
st.error(f"Error: {str(e)}")
with st.sidebar:
st.markdown("""
### Memory Management
To optimize GPU usage and costs:
- Model runs in 8-bit precision
- Memory is cleared after each generation
- Space sleeps after inactivity
""")
if st.button("Clear Chat History"):
st.session_state.chat_history = []
st.experimental_rerun()
if __name__ == "__main__":
main() |