TestOneLlama / app.py
AIdeaText's picture
Update app.py
5d15c3d verified
raw
history blame
4.6 kB
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# Verificar GPU al inicio
def check_gpu():
if torch.cuda.is_available():
gpu_info = {
"GPU Disponible": True,
"Nombre GPU": torch.cuda.get_device_name(0),
"Memoria Total (GB)": round(torch.cuda.get_device_properties(0).total_memory/1e9, 2),
"CUDA Version": torch.version.cuda
}
return gpu_info
return {"GPU Disponible": False}
# Configurar autenticaci贸n
def setup_auth():
if 'HUGGING_FACE_TOKEN' in st.secrets:
login(st.secrets['HUGGING_FACE_TOKEN'])
return True
else:
st.error("No se encontr贸 el token de Hugging Face en los secrets")
st.stop()
return False
class LlamaDemo:
def __init__(self):
self.model_name = "meta-llama/Llama-2-7b-chat-hf"
self._model = None
self._tokenizer = None
@property
def model(self):
if self._model is None:
self._model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.float16, # Usar float16 para optimizar memoria
device_map="auto",
load_in_8bit=True # Cuantizaci贸n 8-bit para optimizar memoria
)
return self._model
@property
def tokenizer(self):
if self._tokenizer is None:
self._tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
trust_remote_code=True
)
return self._tokenizer
def generate_response(self, prompt: str, max_new_tokens: int = 512) -> str:
formatted_prompt = f"[INST] {prompt} [/INST]"
inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.model.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
num_return_sequences=1,
temperature=0.7,
do_sample=True,
top_p=0.9,
pad_token_id=self.tokenizer.eos_token_id
)
# Liberar memoria GPU despu茅s de generar
torch.cuda.empty_cache()
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return response.split("[/INST]")[-1].strip()
def main():
st.set_page_config(
page_title="Llama 2 Chat Demo",
page_icon="馃",
layout="wide"
)
st.title("馃 Llama 2 Chat Demo")
# Mostrar informaci贸n de GPU
gpu_info = check_gpu()
with st.expander("馃捇 GPU Info", expanded=False):
for key, value in gpu_info.items():
st.write(f"{key}: {value}")
# Initialize model
if 'llama' not in st.session_state:
with st.spinner("Loading Llama 2... This might take a few minutes..."):
st.session_state.llama = LlamaDemo()
if 'chat_history' not in st.session_state:
st.session_state.chat_history = []
# Chat interface
with st.container():
for message in st.session_state.chat_history:
with st.chat_message(message["role"]):
st.write(message["content"])
if prompt := st.chat_input("What would you like to discuss?"):
st.session_state.chat_history.append({
"role": "user",
"content": prompt
})
with st.chat_message("user"):
st.write(prompt)
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
try:
response = st.session_state.llama.generate_response(prompt)
st.write(response)
st.session_state.chat_history.append({
"role": "assistant",
"content": response
})
except Exception as e:
st.error(f"Error: {str(e)}")
with st.sidebar:
st.markdown("""
### Memory Management
To optimize GPU usage and costs:
- Model runs in 8-bit precision
- Memory is cleared after each generation
- Space sleeps after inactivity
""")
if st.button("Clear Chat History"):
st.session_state.chat_history = []
st.experimental_rerun()
if __name__ == "__main__":
main()