AIdeaText commited on
Commit
fa0a856
β€’
1 Parent(s): ce284a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -12
app.py CHANGED
@@ -1,12 +1,13 @@
1
  import streamlit as st
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
  from typing import List, Dict
5
  import time
6
 
7
  class LlamaDemo:
8
  def __init__(self):
9
- self.model_name = "meta-llama/Llama-2-7b-chat-hf"
 
10
  # Initialize in lazy loading fashion
11
  self._model = None
12
  self._tokenizer = None
@@ -17,24 +18,31 @@ class LlamaDemo:
17
  self._model = AutoModelForCausalLM.from_pretrained(
18
  self.model_name,
19
  torch_dtype=torch.float16,
20
- device_map="auto"
 
21
  )
22
  return self._model
23
 
24
  @property
25
  def tokenizer(self):
26
  if self._tokenizer is None:
27
- self._tokenizer = AutoTokenizer.from_pretrained(self.model_name)
 
 
 
28
  return self._tokenizer
29
 
30
  def generate_response(self, prompt: str, max_length: int = 512) -> str:
31
- inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
 
 
 
32
 
33
  # Generate response
34
  with torch.no_grad():
35
  outputs = self.model.generate(
36
  **inputs,
37
- max_length=max_length,
38
  num_return_sequences=1,
39
  temperature=0.7,
40
  do_sample=True,
@@ -42,20 +50,23 @@ class LlamaDemo:
42
  )
43
 
44
  response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
45
- return response.replace(prompt, "").strip()
 
 
46
 
47
  def main():
48
  st.set_page_config(
49
- page_title="Llama 3.1 Demo",
50
  page_icon="πŸ¦™",
51
  layout="wide"
52
  )
53
 
54
- st.title("πŸ¦™ Llama 3.1 Demo")
55
 
56
  # Initialize session state
57
  if 'llama' not in st.session_state:
58
- st.session_state.llama = LlamaDemo()
 
59
 
60
  if 'chat_history' not in st.session_state:
61
  st.session_state.chat_history = []
@@ -85,7 +96,7 @@ def main():
85
  with st.chat_message("assistant"):
86
  message_placeholder = st.empty()
87
 
88
- with st.spinner("Generating response..."):
89
  response = st.session_state.llama.generate_response(prompt)
90
  message_placeholder.write(response)
91
 
@@ -95,11 +106,20 @@ def main():
95
  "content": response
96
  })
97
 
98
- # Sidebar with settings
99
  with st.sidebar:
100
  st.header("Settings")
101
  max_length = st.slider("Maximum response length", 64, 1024, 512)
102
 
 
 
 
 
 
 
 
 
 
103
  if st.button("Clear Chat History"):
104
  st.session_state.chat_history = []
105
  st.experimental_rerun()
 
1
  import streamlit as st
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
  from typing import List, Dict
5
  import time
6
 
7
  class LlamaDemo:
8
  def __init__(self):
9
+ # Using TinyLlama, which is open source and doesn't require authentication
10
+ self.model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
11
  # Initialize in lazy loading fashion
12
  self._model = None
13
  self._tokenizer = None
 
18
  self._model = AutoModelForCausalLM.from_pretrained(
19
  self.model_name,
20
  torch_dtype=torch.float16,
21
+ device_map="auto",
22
+ trust_remote_code=True
23
  )
24
  return self._model
25
 
26
  @property
27
  def tokenizer(self):
28
  if self._tokenizer is None:
29
+ self._tokenizer = AutoTokenizer.from_pretrained(
30
+ self.model_name,
31
+ trust_remote_code=True
32
+ )
33
  return self._tokenizer
34
 
35
  def generate_response(self, prompt: str, max_length: int = 512) -> str:
36
+ # Format the prompt according to TinyLlama's chat template
37
+ chat_prompt = f"<|system|>You are a helpful AI assistant.</s><|user|>{prompt}</s><|assistant|>"
38
+
39
+ inputs = self.tokenizer(chat_prompt, return_tensors="pt").to(self.model.device)
40
 
41
  # Generate response
42
  with torch.no_grad():
43
  outputs = self.model.generate(
44
  **inputs,
45
+ max_new_tokens=max_length,
46
  num_return_sequences=1,
47
  temperature=0.7,
48
  do_sample=True,
 
50
  )
51
 
52
  response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
53
+ # Remove the prompt from the response
54
+ response = response.split("<|assistant|>")[-1].strip()
55
+ return response
56
 
57
  def main():
58
  st.set_page_config(
59
+ page_title="Open Source Llama Demo",
60
  page_icon="πŸ¦™",
61
  layout="wide"
62
  )
63
 
64
+ st.title("πŸ¦™ Open Source Llama Demo")
65
 
66
  # Initialize session state
67
  if 'llama' not in st.session_state:
68
+ with st.spinner("Loading model... This might take a few minutes..."):
69
+ st.session_state.llama = LlamaDemo()
70
 
71
  if 'chat_history' not in st.session_state:
72
  st.session_state.chat_history = []
 
96
  with st.chat_message("assistant"):
97
  message_placeholder = st.empty()
98
 
99
+ with st.spinner("Thinking..."):
100
  response = st.session_state.llama.generate_response(prompt)
101
  message_placeholder.write(response)
102
 
 
106
  "content": response
107
  })
108
 
109
+ # Sidebar with settings and info
110
  with st.sidebar:
111
  st.header("Settings")
112
  max_length = st.slider("Maximum response length", 64, 1024, 512)
113
 
114
+ st.markdown("---")
115
+ st.markdown("""
116
+ ### About
117
+ This demo uses TinyLlama, an open source language model that's smaller but
118
+ still capable. It's perfect for demonstrations and testing.
119
+
120
+ The model is loaded locally and doesn't require any authentication or API keys.
121
+ """)
122
+
123
  if st.button("Clear Chat History"):
124
  st.session_state.chat_history = []
125
  st.experimental_rerun()