File size: 4,362 Bytes
f995cde
fa0a856
f995cde
 
 
 
 
 
fa0a856
 
f995cde
 
 
 
 
 
 
 
 
 
fa0a856
 
f995cde
 
 
 
 
 
fa0a856
 
 
 
f995cde
 
 
fa0a856
 
 
 
f995cde
 
 
 
 
fa0a856
f995cde
 
 
 
 
 
 
fa0a856
 
 
f995cde
 
 
fa0a856
f995cde
 
 
 
fa0a856
f995cde
 
 
fa0a856
 
f995cde
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa0a856
f995cde
 
 
 
 
 
 
 
 
fa0a856
f995cde
 
 
 
fa0a856
 
 
 
 
 
 
 
 
f995cde
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from typing import List, Dict
import time

class LlamaDemo:
    def __init__(self):
        # Using TinyLlama, which is open source and doesn't require authentication
        self.model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
        # Initialize in lazy loading fashion
        self._model = None
        self._tokenizer = None
        
    @property
    def model(self):
        if self._model is None:
            self._model = AutoModelForCausalLM.from_pretrained(
                self.model_name,
                torch_dtype=torch.float16,
                device_map="auto",
                trust_remote_code=True
            )
        return self._model
    
    @property
    def tokenizer(self):
        if self._tokenizer is None:
            self._tokenizer = AutoTokenizer.from_pretrained(
                self.model_name,
                trust_remote_code=True
            )
        return self._tokenizer

    def generate_response(self, prompt: str, max_length: int = 512) -> str:
        # Format the prompt according to TinyLlama's chat template
        chat_prompt = f"<|system|>You are a helpful AI assistant.</s><|user|>{prompt}</s><|assistant|>"
        
        inputs = self.tokenizer(chat_prompt, return_tensors="pt").to(self.model.device)
        
        # Generate response
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=max_length,
                num_return_sequences=1,
                temperature=0.7,
                do_sample=True,
                pad_token_id=self.tokenizer.eos_token_id
            )
        
        response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        # Remove the prompt from the response
        response = response.split("<|assistant|>")[-1].strip()
        return response

def main():
    st.set_page_config(
        page_title="Open Source Llama Demo",
        page_icon="🦙",
        layout="wide"
    )
    
    st.title("🦙 Open Source Llama Demo")
    
    # Initialize session state
    if 'llama' not in st.session_state:
        with st.spinner("Loading model... This might take a few minutes..."):
            st.session_state.llama = LlamaDemo()
    
    if 'chat_history' not in st.session_state:
        st.session_state.chat_history = []
        
    # Chat interface
    with st.container():
        # Display chat history
        for message in st.session_state.chat_history:
            role = message["role"]
            content = message["content"]
            
            with st.chat_message(role):
                st.write(content)
    
        # Input for new message
        if prompt := st.chat_input("What would you like to discuss?"):
            # Add user message to chat history
            st.session_state.chat_history.append({
                "role": "user",
                "content": prompt
            })
            
            with st.chat_message("user"):
                st.write(prompt)
            
            # Show assistant response
            with st.chat_message("assistant"):
                message_placeholder = st.empty()
                
                with st.spinner("Thinking..."):
                    response = st.session_state.llama.generate_response(prompt)
                    message_placeholder.write(response)
                    
                # Add assistant response to chat history
                st.session_state.chat_history.append({
                    "role": "assistant",
                    "content": response
                })
    
    # Sidebar with settings and info
    with st.sidebar:
        st.header("Settings")
        max_length = st.slider("Maximum response length", 64, 1024, 512)
        
        st.markdown("---")
        st.markdown("""
        ### About
        This demo uses TinyLlama, an open source language model that's smaller but 
        still capable. It's perfect for demonstrations and testing.
        
        The model is loaded locally and doesn't require any authentication or API keys.
        """)
        
        if st.button("Clear Chat History"):
            st.session_state.chat_history = []
            st.experimental_rerun()

if __name__ == "__main__":
    main()