Spaces:
Paused
Paused
import streamlit as st | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
from huggingface_hub import login | |
import os | |
def init_huggingface(): | |
"""Initialize Hugging Face authentication either from secrets or user input""" | |
if 'HUGGING_FACE_TOKEN' not in st.session_state: | |
# First try to get from environment variable | |
token = os.getenv('HUGGINGFACE_TOKEN') | |
# If not in environment, check streamlit secrets | |
if not token and 'huggingface_token' in st.secrets: | |
token = st.secrets['huggingface_token'] | |
# If still not found, ask user | |
if not token: | |
token = st.text_input('Enter your Hugging Face token:', type='password') | |
if not token: | |
st.warning('Please enter your Hugging Face token to proceed') | |
st.stop() | |
st.session_state['HUGGING_FACE_TOKEN'] = token | |
# Login to Hugging Face | |
login(st.session_state['HUGGING_FACE_TOKEN']) | |
return True | |
class LlamaDemo: | |
def __init__(self): | |
self.model_name = "meta-llama/Llama-2-70b-chat-hf" | |
self._model = None | |
self._tokenizer = None | |
def model(self): | |
if self._model is None: | |
self._model = AutoModelForCausalLM.from_pretrained( | |
self.model_name, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
trust_remote_code=True, | |
load_in_8bit=True # Para optimizar memoria | |
) | |
return self._model | |
def tokenizer(self): | |
if self._tokenizer is None: | |
self._tokenizer = AutoTokenizer.from_pretrained( | |
self.model_name, | |
trust_remote_code=True | |
) | |
return self._tokenizer | |
def generate_response(self, prompt: str, max_new_tokens: int = 512) -> str: | |
# Format prompt for Llama 2 chat | |
formatted_prompt = f"[INST] {prompt} [/INST]" | |
inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.model.device) | |
with torch.no_grad(): | |
outputs = self.model.generate( | |
**inputs, | |
max_new_tokens=max_new_tokens, | |
num_return_sequences=1, | |
temperature=0.7, | |
do_sample=True, | |
top_p=0.9, | |
pad_token_id=self.tokenizer.eos_token_id | |
) | |
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return response.split("[/INST]")[-1].strip() | |
def main(): | |
st.set_page_config( | |
page_title="Llama 2 Demo", | |
page_icon="π¦", | |
layout="wide" | |
) | |
st.title("π¦ Llama 2 Chat Demo") | |
# Initialize Hugging Face authentication | |
if init_huggingface(): | |
st.success("Successfully authenticated with Hugging Face!") | |
# Initialize model | |
if 'llama' not in st.session_state: | |
with st.spinner("Loading Llama 2... This might take a few minutes..."): | |
st.session_state.llama = LlamaDemo() | |
if 'chat_history' not in st.session_state: | |
st.session_state.chat_history = [] | |
# Chat interface | |
with st.container(): | |
for message in st.session_state.chat_history: | |
with st.chat_message(message["role"]): | |
st.write(message["content"]) | |
if prompt := st.chat_input("What would you like to discuss?"): | |
st.session_state.chat_history.append({ | |
"role": "user", | |
"content": prompt | |
}) | |
with st.chat_message("user"): | |
st.write(prompt) | |
with st.chat_message("assistant"): | |
with st.spinner("Thinking..."): | |
try: | |
response = st.session_state.llama.generate_response(prompt) | |
st.write(response) | |
st.session_state.chat_history.append({ | |
"role": "assistant", | |
"content": response | |
}) | |
except Exception as e: | |
st.error(f"Error: {str(e)}") | |
with st.sidebar: | |
if st.button("Clear Chat History"): | |
st.session_state.chat_history = [] | |
st.experimental_rerun() | |
if __name__ == "__main__": | |
main() |