user commited on
Commit
4e8606b
·
1 Parent(s): b3b4e83
Files changed (1) hide show
  1. app.py +90 -182
app.py CHANGED
@@ -42,194 +42,102 @@ MODEL_COMBINATIONS = {
42
  }
43
  }
44
 
45
- @st.cache_resource
46
- def load_models(model_combination):
47
  try:
48
- embedding_tokenizer = AutoTokenizer.from_pretrained(MODEL_COMBINATIONS[model_combination]['embedding'])
49
- embedding_model = AutoModel.from_pretrained(MODEL_COMBINATIONS[model_combination]['embedding'])
50
- generation_tokenizer = AutoTokenizer.from_pretrained(MODEL_COMBINATIONS[model_combination]['generation'])
51
- generation_model = AutoModelForCausalLM.from_pretrained(MODEL_COMBINATIONS[model_combination]['generation'])
52
- return embedding_tokenizer, embedding_model, generation_tokenizer, generation_model
53
  except Exception as e:
54
- st.error(f"Error loading models: {str(e)}")
55
- return None, None, None, None
56
 
57
- @st.cache_data
58
- def load_and_process_text(file_path):
59
  try:
60
- with open(file_path, 'r', encoding='utf-8') as file:
61
- text = file.read()
62
- chunks = [text[i:i+512] for i in range(0, len(text), 512)]
63
- return chunks
64
  except Exception as e:
65
- st.error(f"Error loading text file: {str(e)}")
66
- return []
67
-
68
- @st.cache_data
69
- def create_embeddings(chunks, _embedding_model):
70
- if isinstance(_embedding_model, str):
71
- tokenizer = AutoTokenizer.from_pretrained(_embedding_model)
72
- model = AutoModel.from_pretrained(_embedding_model)
73
- else:
74
- # Assume _embedding_model is already a model instance
75
- model = _embedding_model
76
- tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path)
77
-
78
- embeddings = []
79
- for chunk in chunks:
80
- inputs = tokenizer(chunk, return_tensors="pt", padding=True, truncation=True, max_length=512)
81
- with torch.no_grad():
82
- outputs = model(**inputs)
83
- embeddings.append(outputs.last_hidden_state.mean(dim=1).squeeze().numpy())
84
-
85
- return np.vstack(embeddings)
86
 
87
  @st.cache_resource
88
- def create_faiss_index(embeddings):
89
- index = faiss.IndexFlatL2(embeddings.shape[1])
90
- index.add(embeddings)
91
- return index
92
-
93
- def generate_response(query, embedding_tokenizer, generation_tokenizer, generation_model, embedding_model, index, chunks):
94
- inputs = embedding_tokenizer(query, return_tensors="pt", padding=True, truncation=True, max_length=512)
95
- with torch.no_grad():
96
- outputs = embedding_model(**inputs)
97
- query_embedding = outputs.last_hidden_state.mean(dim=1).squeeze().numpy()
98
-
99
- k = 3
100
- _, I = index.search(query_embedding.reshape(1, -1), k)
101
-
102
- context = " ".join([chunks[i] for i in I[0]])
103
-
104
- prompt = f"As the Muse of A.R. Ammons, respond to this query: {query}\nContext: {context}\nMuse:"
105
-
106
- input_ids = generation_tokenizer.encode(prompt, return_tensors="pt")
107
- output = generation_model.generate(
108
- input_ids,
109
- max_new_tokens=100,
110
- num_return_sequences=1,
111
- temperature=0.7,
112
- do_sample=True,
113
- top_k=50,
114
- top_p=0.95,
115
- no_repeat_ngram_size=2
116
- )
117
- response = generation_tokenizer.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
118
-
119
- muse_response = response.split("Muse:")[-1].strip()
120
-
121
- # Check if the response contains unused tokens
122
- if "[unused" in muse_response:
123
- muse_response = "I apologize, but I'm having trouble formulating a response. Let me try again with a simpler message: Hello! As the Muse of A.R. Ammons, I'm here to inspire and discuss poetry. How may I assist you today?"
124
-
125
- return muse_response
126
-
127
- def save_data(chunks, embeddings, index):
128
- with open('chunks.pkl', 'wb') as f:
129
- pickle.dump(chunks, f)
130
- np.save('embeddings.npy', embeddings)
131
- faiss.write_index(index, 'faiss_index.bin')
132
-
133
- def load_data():
134
- if os.path.exists('chunks.pkl') and os.path.exists('embeddings.npy') and os.path.exists('faiss_index.bin'):
135
  with open('chunks.pkl', 'rb') as f:
136
  chunks = pickle.load(f)
137
- embeddings = np.load('embeddings.npy')
138
- index = faiss.read_index('faiss_index.bin')
139
- return chunks, embeddings, index
140
- return None, None, None
141
-
142
- # Streamlit UI
143
- st.set_page_config(page_title="A.R. Ammons' Muse Chatbot", page_icon="🎭")
144
-
145
- st.title("A.R. Ammons' Muse Chatbot 🎭")
146
- st.markdown("""
147
- <style>
148
- .big-font {
149
- font-size:20px !important;
150
- font-weight: bold;
151
- }
152
- </style>
153
- """, unsafe_allow_html=True)
154
- st.markdown('<p class="big-font">Chat with the Muse of A.R. Ammons. Ask questions or discuss poetry!</p>', unsafe_allow_html=True)
155
-
156
- # Model selection
157
- if 'model_combination' not in st.session_state:
158
- st.session_state.model_combination = "Fastest (30 seconds)"
159
-
160
- # Create a list of model options, with non-free models at the end
161
- free_models = [k for k, v in MODEL_COMBINATIONS.items() if v['free']]
162
- non_free_models = [k for k, v in MODEL_COMBINATIONS.items() if not v['free']]
163
- all_models = free_models + non_free_models
164
-
165
- # Custom CSS to grey out non-free options
166
- st.markdown("""
167
- <style>
168
- .stSelectbox div[role="option"][aria-selected="false"]:nth-last-child(-n+2) {
169
- color: grey !important;
170
- }
171
- </style>
172
- """, unsafe_allow_html=True)
173
-
174
- selected_model = st.selectbox(
175
- "Choose a model combination:",
176
- all_models,
177
- index=all_models.index(st.session_state.model_combination),
178
- format_func=lambda x: f"{x} {'(Not Free)' if not MODEL_COMBINATIONS[x]['free'] else ''}"
179
- )
180
-
181
- # Prevent selection of non-free models
182
- if not MODEL_COMBINATIONS[selected_model]['free']:
183
- st.warning("Premium models are not available in the free version.")
184
- st.stop()
185
-
186
- st.session_state.model_combination = selected_model
187
-
188
- st.info(f"Potential time saved compared to slowest option: {MODEL_COMBINATIONS[selected_model]['time_saved']}")
189
-
190
- if st.button("Load Selected Models"):
191
- with st.spinner("Loading models and data..."):
192
- embedding_tokenizer, embedding_model, generation_tokenizer, generation_model = load_models(st.session_state.model_combination)
193
- chunks = load_and_process_text('ammons_muse.txt')
194
- embeddings = create_embeddings(chunks, embedding_model)
195
- index = create_faiss_index(embeddings)
196
-
197
- st.session_state.models_loaded = True
198
- st.success("Models loaded successfully!")
199
-
200
- if 'models_loaded' not in st.session_state or not st.session_state.models_loaded:
201
- st.warning("Please load the models before chatting.")
202
- st.stop()
203
-
204
- # Initialize chat history
205
- if 'messages' not in st.session_state:
206
- st.session_state.messages = []
207
-
208
- # Display chat messages from history on app rerun
209
- for message in st.session_state.messages:
210
- with st.chat_message(message["role"]):
211
- st.markdown(message["content"])
212
-
213
- # React to user input
214
- if prompt := st.chat_input("What would you like to ask the Muse?"):
215
- st.chat_message("user").markdown(prompt)
216
- st.session_state.messages.append({"role": "user", "content": prompt})
217
-
218
- with st.spinner("The Muse is contemplating..."):
219
- try:
220
- response = generate_response(prompt, tokenizer, generation_model, embedding_model, index, chunks)
221
- except Exception as e:
222
- response = f"I apologize, but I encountered an error: {str(e)}"
223
-
224
- with st.chat_message("assistant"):
225
- st.markdown(response)
226
- st.session_state.messages.append({"role": "assistant", "content": response})
227
-
228
- # Add a button to clear chat history
229
- if st.button("Clear Chat History"):
230
- st.session_state.messages = []
231
- st.experimental_rerun()
232
-
233
- # Add a footer
234
- st.markdown("---")
235
- st.markdown("*Powered by the spirit of A.R. Ammons and the magic of AI*")
 
42
  }
43
  }
44
 
45
+ def load_model(model_name):
 
46
  try:
47
+ return AutoModel.from_pretrained(model_name)
 
 
 
 
48
  except Exception as e:
49
+ st.error(f"Error loading model {model_name}: {str(e)}")
50
+ return None
51
 
52
+ def load_tokenizer(model_name):
 
53
  try:
54
+ return AutoTokenizer.from_pretrained(model_name)
 
 
 
55
  except Exception as e:
56
+ st.error(f"Error loading tokenizer for {model_name}: {str(e)}")
57
+ return None
58
+
59
+ @st.cache_resource
60
+ def load_embedding_model(model_name):
61
+ return load_model(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  @st.cache_resource
64
+ def load_generation_model(model_name):
65
+ try:
66
+ return AutoModelForCausalLM.from_pretrained(model_name)
67
+ except Exception as e:
68
+ st.error(f"Error loading generation model {model_name}: {str(e)}")
69
+ return None
70
+
71
+ def load_index_and_chunks():
72
+ try:
73
+ with open('faiss_index.pkl', 'rb') as f:
74
+ index = pickle.load(f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  with open('chunks.pkl', 'rb') as f:
76
  chunks = pickle.load(f)
77
+ return index, chunks
78
+ except Exception as e:
79
+ st.error(f"Error loading index and chunks: {str(e)}")
80
+ return None, None
81
+
82
+ def generate_response(prompt, embedding_tokenizer, generation_tokenizer, generation_model, embedding_model, index, chunks):
83
+ try:
84
+ # Embed the prompt
85
+ prompt_embedding = embedding_model(embedding_tokenizer(prompt, return_tensors='pt')['input_ids']).last_hidden_state.mean(dim=1).detach().numpy()
86
+
87
+ # Search for similar chunks
88
+ D, I = index.search(prompt_embedding, k=5)
89
+ context = " ".join([chunks[i] for i in I[0]])
90
+
91
+ # Generate response
92
+ input_text = f"Context: {context}\n\nQuestion: {prompt}\n\nAnswer:"
93
+ input_ids = generation_tokenizer(input_text, return_tensors="pt").input_ids
94
+
95
+ output = generation_model.generate(input_ids, max_length=150, num_return_sequences=1, no_repeat_ngram_size=2)
96
+ response = generation_tokenizer.decode(output[0], skip_special_tokens=True)
97
+
98
+ return response
99
+ except Exception as e:
100
+ st.error(f"Error generating response: {str(e)}")
101
+ return "I apologize, but I encountered an error while generating a response."
102
+
103
+ def main():
104
+ st.title("Your Muse Chat App")
105
+
106
+ # Load models and data
107
+ selected_combo = st.selectbox("Choose a model combination:", list(MODEL_COMBINATIONS.keys()))
108
+ combo = MODEL_COMBINATIONS[selected_combo]
109
+
110
+ embedding_model = load_embedding_model(combo['embedding'])
111
+ generation_model = load_generation_model(combo['generation'])
112
+ embedding_tokenizer = load_tokenizer(combo['embedding'])
113
+ generation_tokenizer = load_tokenizer(combo['generation'])
114
+
115
+ index, chunks = load_index_and_chunks()
116
+
117
+ if not all([embedding_model, generation_model, embedding_tokenizer, generation_tokenizer, index, chunks]):
118
+ st.error("Some components failed to load. Please check the errors above.")
119
+ return
120
+
121
+ # Initialize chat history
122
+ if "messages" not in st.session_state:
123
+ st.session_state.messages = []
124
+
125
+ # Display chat messages
126
+ for message in st.session_state.messages:
127
+ with st.chat_message(message["role"]):
128
+ st.markdown(message["content"])
129
+
130
+ # Chat input
131
+ if prompt := st.chat_input("What would you like to ask the Muse?"):
132
+ st.chat_message("user").markdown(prompt)
133
+ st.session_state.messages.append({"role": "user", "content": prompt})
134
+
135
+ with st.spinner("The Muse is contemplating..."):
136
+ response = generate_response(prompt, embedding_tokenizer, generation_tokenizer, generation_model, embedding_model, index, chunks)
137
+
138
+ with st.chat_message("assistant"):
139
+ st.markdown(response)
140
+ st.session_state.messages.append({"role": "assistant", "content": response})
141
+
142
+ if __name__ == "__main__":
143
+ main()