TestOneLlama / app.py
AIdeaText's picture
Update app.py
1aaec00 verified
raw
history blame
4.46 kB
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from huggingface_hub import login
import os
def init_huggingface():
"""Initialize Hugging Face authentication either from secrets or user input"""
if 'HUGGING_FACE_TOKEN' not in st.session_state:
# First try to get from environment variable
token = os.getenv('HUGGINGFACE_TOKEN')
# If not in environment, check streamlit secrets
if not token and 'huggingface_token' in st.secrets:
token = st.secrets['huggingface_token']
# If still not found, ask user
if not token:
token = st.text_input('Enter your Hugging Face token:', type='password')
if not token:
st.warning('Please enter your Hugging Face token to proceed')
st.stop()
st.session_state['HUGGING_FACE_TOKEN'] = token
# Login to Hugging Face
login(st.session_state['HUGGING_FACE_TOKEN'])
return True
class LlamaDemo:
def __init__(self):
self.model_name = "meta-llama/Llama-2-70b-chat-hf"
self._model = None
self._tokenizer = None
@property
def model(self):
if self._model is None:
self._model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
load_in_8bit=True # Para optimizar memoria
)
return self._model
@property
def tokenizer(self):
if self._tokenizer is None:
self._tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
trust_remote_code=True
)
return self._tokenizer
def generate_response(self, prompt: str, max_new_tokens: int = 512) -> str:
# Format prompt for Llama 2 chat
formatted_prompt = f"[INST] {prompt} [/INST]"
inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.model.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
num_return_sequences=1,
temperature=0.7,
do_sample=True,
top_p=0.9,
pad_token_id=self.tokenizer.eos_token_id
)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return response.split("[/INST]")[-1].strip()
def main():
st.set_page_config(
page_title="Llama 2 Demo",
page_icon="πŸ¦™",
layout="wide"
)
st.title("πŸ¦™ Llama 2 Chat Demo")
# Initialize Hugging Face authentication
if init_huggingface():
st.success("Successfully authenticated with Hugging Face!")
# Initialize model
if 'llama' not in st.session_state:
with st.spinner("Loading Llama 2... This might take a few minutes..."):
st.session_state.llama = LlamaDemo()
if 'chat_history' not in st.session_state:
st.session_state.chat_history = []
# Chat interface
with st.container():
for message in st.session_state.chat_history:
with st.chat_message(message["role"]):
st.write(message["content"])
if prompt := st.chat_input("What would you like to discuss?"):
st.session_state.chat_history.append({
"role": "user",
"content": prompt
})
with st.chat_message("user"):
st.write(prompt)
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
try:
response = st.session_state.llama.generate_response(prompt)
st.write(response)
st.session_state.chat_history.append({
"role": "assistant",
"content": response
})
except Exception as e:
st.error(f"Error: {str(e)}")
with st.sidebar:
if st.button("Clear Chat History"):
st.session_state.chat_history = []
st.experimental_rerun()
if __name__ == "__main__":
main()