AIdeaText commited on
Commit
f995cde
1 Parent(s): 2cf350a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -0
app.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import torch
4
+ from typing import List, Dict
5
+ import time
6
+
7
+ class LlamaDemo:
8
+ def __init__(self):
9
+ self.model_name = "meta-llama/Llama-2-7b-chat-hf"
10
+ # Initialize in lazy loading fashion
11
+ self._model = None
12
+ self._tokenizer = None
13
+
14
+ @property
15
+ def model(self):
16
+ if self._model is None:
17
+ self._model = AutoModelForCausalLM.from_pretrained(
18
+ self.model_name,
19
+ torch_dtype=torch.float16,
20
+ device_map="auto"
21
+ )
22
+ return self._model
23
+
24
+ @property
25
+ def tokenizer(self):
26
+ if self._tokenizer is None:
27
+ self._tokenizer = AutoTokenizer.from_pretrained(self.model_name)
28
+ return self._tokenizer
29
+
30
+ def generate_response(self, prompt: str, max_length: int = 512) -> str:
31
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
32
+
33
+ # Generate response
34
+ with torch.no_grad():
35
+ outputs = self.model.generate(
36
+ **inputs,
37
+ max_length=max_length,
38
+ num_return_sequences=1,
39
+ temperature=0.7,
40
+ do_sample=True,
41
+ pad_token_id=self.tokenizer.eos_token_id
42
+ )
43
+
44
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
45
+ return response.replace(prompt, "").strip()
46
+
47
+ def main():
48
+ st.set_page_config(
49
+ page_title="Llama 3.1 Demo",
50
+ page_icon="🦙",
51
+ layout="wide"
52
+ )
53
+
54
+ st.title("🦙 Llama 3.1 Demo")
55
+
56
+ # Initialize session state
57
+ if 'llama' not in st.session_state:
58
+ st.session_state.llama = LlamaDemo()
59
+
60
+ if 'chat_history' not in st.session_state:
61
+ st.session_state.chat_history = []
62
+
63
+ # Chat interface
64
+ with st.container():
65
+ # Display chat history
66
+ for message in st.session_state.chat_history:
67
+ role = message["role"]
68
+ content = message["content"]
69
+
70
+ with st.chat_message(role):
71
+ st.write(content)
72
+
73
+ # Input for new message
74
+ if prompt := st.chat_input("What would you like to discuss?"):
75
+ # Add user message to chat history
76
+ st.session_state.chat_history.append({
77
+ "role": "user",
78
+ "content": prompt
79
+ })
80
+
81
+ with st.chat_message("user"):
82
+ st.write(prompt)
83
+
84
+ # Show assistant response
85
+ with st.chat_message("assistant"):
86
+ message_placeholder = st.empty()
87
+
88
+ with st.spinner("Generating response..."):
89
+ response = st.session_state.llama.generate_response(prompt)
90
+ message_placeholder.write(response)
91
+
92
+ # Add assistant response to chat history
93
+ st.session_state.chat_history.append({
94
+ "role": "assistant",
95
+ "content": response
96
+ })
97
+
98
+ # Sidebar with settings
99
+ with st.sidebar:
100
+ st.header("Settings")
101
+ max_length = st.slider("Maximum response length", 64, 1024, 512)
102
+
103
+ if st.button("Clear Chat History"):
104
+ st.session_state.chat_history = []
105
+ st.experimental_rerun()
106
+
107
+ if __name__ == "__main__":
108
+ main()