File size: 3,148 Bytes
f995cde
55ca2dd
f995cde
 
 
 
9428cb3
f995cde
55ca2dd
f995cde
 
55ca2dd
 
 
 
 
f995cde
fa0a856
 
f995cde
55ca2dd
f995cde
 
55ca2dd
 
fa0a856
55ca2dd
 
 
 
 
 
 
 
 
f995cde
55ca2dd
 
f995cde
 
 
55ca2dd
f995cde
 
 
 
55ca2dd
f995cde
55ca2dd
f995cde
55ca2dd
fa0a856
f995cde
 
 
 
 
 
 
55ca2dd
 
f995cde
 
 
 
 
 
 
 
 
 
 
fa0a856
55ca2dd
 
 
 
 
 
 
 
 
f995cde
 
fa0a856
 
55ca2dd
fa0a856
55ca2dd
fa0a856
 
f995cde
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch

class LlamaDemo:
    def __init__(self):
        self.model_name = "meta-llama/Llama-2-70b-chat"
        # Initialize in lazy loading fashion
        self._pipe = None
        
    @property
    def pipe(self):
        if self._pipe is None:
            self._pipe = pipeline(
                "text-generation",
                model=self.model_name,
                torch_dtype=torch.float16,
                device_map="auto",
                trust_remote_code=True
            )
        return self._pipe

    def generate_response(self, prompt: str, max_length: int = 512) -> str:
        # Format prompt for Llama 2 chat
        formatted_prompt = f"[INST] {prompt} [/INST]"
        
        # Generate response using pipeline
        response = self.pipe(
            formatted_prompt,
            max_new_tokens=max_length,
            num_return_sequences=1,
            temperature=0.7,
            do_sample=True,
            top_p=0.9
        )[0]['generated_text']
        
        # Extract response after the instruction tag
        return response.split("[/INST]")[-1].strip()

def main():
    st.set_page_config(
        page_title="Llama 2 Chat Demo",
        page_icon="🦙",
        layout="wide"
    )
    
    st.title("🦙 Llama 2 Chat Demo")
    
    # Initialize model
    if 'llama' not in st.session_state:
        with st.spinner("Loading Llama 2... This might take a few minutes..."):
            st.session_state.llama = LlamaDemo()
    
    if 'chat_history' not in st.session_state:
        st.session_state.chat_history = []
        
    # Chat interface
    with st.container():
        for message in st.session_state.chat_history:
            with st.chat_message(message["role"]):
                st.write(message["content"])
    
        if prompt := st.chat_input("What would you like to discuss?"):
            st.session_state.chat_history.append({
                "role": "user",
                "content": prompt
            })
            
            with st.chat_message("user"):
                st.write(prompt)
            
            with st.chat_message("assistant"):
                with st.spinner("Thinking..."):
                    try:
                        response = st.session_state.llama.generate_response(prompt)
                        st.write(response)
                        st.session_state.chat_history.append({
                            "role": "assistant",
                            "content": response
                        })
                    except Exception as e:
                        st.error(f"Error: {str(e)}")
    
    with st.sidebar:
        st.markdown("""
        ### About
        This demo uses Llama-2-70B-chat, a large language model from Meta.
        
        The model runs with automatic device mapping and mixed precision for optimal performance.
        """)
        
        if st.button("Clear Chat History"):
            st.session_state.chat_history = []
            st.experimental_rerun()

if __name__ == "__main__":
    main()