Spaces:
Paused
Paused
import streamlit as st | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
import torch | |
class LlamaDemo: | |
def __init__(self): | |
self.model_name = "meta-llama/Llama-2-70b-chat" | |
# Initialize in lazy loading fashion | |
self._pipe = None | |
def pipe(self): | |
if self._pipe is None: | |
self._pipe = pipeline( | |
"text-generation", | |
model=self.model_name, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
trust_remote_code=True | |
) | |
return self._pipe | |
def generate_response(self, prompt: str, max_length: int = 512) -> str: | |
# Format prompt for Llama 2 chat | |
formatted_prompt = f"[INST] {prompt} [/INST]" | |
# Generate response using pipeline | |
response = self.pipe( | |
formatted_prompt, | |
max_new_tokens=max_length, | |
num_return_sequences=1, | |
temperature=0.7, | |
do_sample=True, | |
top_p=0.9 | |
)[0]['generated_text'] | |
# Extract response after the instruction tag | |
return response.split("[/INST]")[-1].strip() | |
def main(): | |
st.set_page_config( | |
page_title="Llama 2 Chat Demo", | |
page_icon="π¦", | |
layout="wide" | |
) | |
st.title("π¦ Llama 2 Chat Demo") | |
# Initialize model | |
if 'llama' not in st.session_state: | |
with st.spinner("Loading Llama 2... This might take a few minutes..."): | |
st.session_state.llama = LlamaDemo() | |
if 'chat_history' not in st.session_state: | |
st.session_state.chat_history = [] | |
# Chat interface | |
with st.container(): | |
for message in st.session_state.chat_history: | |
with st.chat_message(message["role"]): | |
st.write(message["content"]) | |
if prompt := st.chat_input("What would you like to discuss?"): | |
st.session_state.chat_history.append({ | |
"role": "user", | |
"content": prompt | |
}) | |
with st.chat_message("user"): | |
st.write(prompt) | |
with st.chat_message("assistant"): | |
with st.spinner("Thinking..."): | |
try: | |
response = st.session_state.llama.generate_response(prompt) | |
st.write(response) | |
st.session_state.chat_history.append({ | |
"role": "assistant", | |
"content": response | |
}) | |
except Exception as e: | |
st.error(f"Error: {str(e)}") | |
with st.sidebar: | |
st.markdown(""" | |
### About | |
This demo uses Llama-2-70B-chat, a large language model from Meta. | |
The model runs with automatic device mapping and mixed precision for optimal performance. | |
""") | |
if st.button("Clear Chat History"): | |
st.session_state.chat_history = [] | |
st.experimental_rerun() | |
if __name__ == "__main__": | |
main() |