import streamlit as st from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline import torch class LlamaDemo: def __init__(self): self.model_name = "meta-llama/Llama-2-70b-chat" # Initialize in lazy loading fashion self._pipe = None @property def pipe(self): if self._pipe is None: self._pipe = pipeline( "text-generation", model=self.model_name, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True ) return self._pipe def generate_response(self, prompt: str, max_length: int = 512) -> str: # Format prompt for Llama 2 chat formatted_prompt = f"[INST] {prompt} [/INST]" # Generate response using pipeline response = self.pipe( formatted_prompt, max_new_tokens=max_length, num_return_sequences=1, temperature=0.7, do_sample=True, top_p=0.9 )[0]['generated_text'] # Extract response after the instruction tag return response.split("[/INST]")[-1].strip() def main(): st.set_page_config( page_title="Llama 2 Chat Demo", page_icon="🦙", layout="wide" ) st.title("🦙 Llama 2 Chat Demo") # Initialize model if 'llama' not in st.session_state: with st.spinner("Loading Llama 2... This might take a few minutes..."): st.session_state.llama = LlamaDemo() if 'chat_history' not in st.session_state: st.session_state.chat_history = [] # Chat interface with st.container(): for message in st.session_state.chat_history: with st.chat_message(message["role"]): st.write(message["content"]) if prompt := st.chat_input("What would you like to discuss?"): st.session_state.chat_history.append({ "role": "user", "content": prompt }) with st.chat_message("user"): st.write(prompt) with st.chat_message("assistant"): with st.spinner("Thinking..."): try: response = st.session_state.llama.generate_response(prompt) st.write(response) st.session_state.chat_history.append({ "role": "assistant", "content": response }) except Exception as e: st.error(f"Error: {str(e)}") with st.sidebar: st.markdown(""" ### About This demo uses Llama-2-70B-chat, a large language model from Meta. The model runs with automatic device mapping and mixed precision for optimal performance. """) if st.button("Clear Chat History"): st.session_state.chat_history = [] st.experimental_rerun() if __name__ == "__main__": main()