import streamlit as st import torch from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM import faiss import numpy as np import os import pickle import warnings warnings.filterwarnings("ignore", category=FutureWarning, module="transformers") # Model combinations with speed ratings and estimated time savings MODEL_COMBINATIONS = { "Fastest (30 seconds)": { "embedding": "sentence-transformers/all-MiniLM-L6-v2", "generation": "distilgpt2", "free": True, "time_saved": "2.5 minutes" }, "Balanced (1 minute)": { "embedding": "sentence-transformers/all-MiniLM-L12-v2", "generation": "facebook/opt-350m", "free": True, "time_saved": "2 minutes" }, "High Quality (2 minutes)": { "embedding": "sentence-transformers/all-mpnet-base-v2", "generation": "gpt2", "free": True, "time_saved": "1 minute" }, "Premium Speed (15 seconds)": { "embedding": "sentence-transformers/all-MiniLM-L6-v2", "generation": "microsoft/phi-1_5", "free": False, "time_saved": "2.75 minutes" }, "Premium Quality (1.5 minutes)": { "embedding": "openai-embedding-ada-002", "generation": "meta-llama/Llama-2-7b-chat-hf", "free": False, "time_saved": "1.5 minutes" } } @st.cache_resource def load_models(model_combination): try: embedding_tokenizer = AutoTokenizer.from_pretrained(MODEL_COMBINATIONS[model_combination]['embedding']) embedding_model = AutoModel.from_pretrained(MODEL_COMBINATIONS[model_combination]['embedding']) generation_tokenizer = AutoTokenizer.from_pretrained(MODEL_COMBINATIONS[model_combination]['generation']) generation_model = AutoModelForCausalLM.from_pretrained(MODEL_COMBINATIONS[model_combination]['generation']) return embedding_tokenizer, embedding_model, generation_tokenizer, generation_model except Exception as e: st.error(f"Error loading models: {str(e)}") return None, None, None, None @st.cache_data def load_and_process_text(file_path): try: with open(file_path, 'r', encoding='utf-8') as file: text = file.read() chunks = [text[i:i+512] for i in range(0, len(text), 512)] return chunks except Exception as e: st.error(f"Error loading text file: {str(e)}") return [] @st.cache_data def create_embeddings(chunks, _embedding_model): embeddings = [] for chunk in chunks: inputs = tokenizer(chunk, return_tensors="pt", padding=True, truncation=True, max_length=512) with torch.no_grad(): outputs = _embedding_model(**inputs) embeddings.append(outputs.last_hidden_state.mean(dim=1).squeeze().numpy()) return np.array(embeddings) @st.cache_resource def create_faiss_index(embeddings): index = faiss.IndexFlatL2(embeddings.shape[1]) index.add(embeddings) return index def generate_response(query, tokenizer, generation_model, embedding_model, index, chunks): inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True, max_length=512) with torch.no_grad(): outputs = embedding_model(**inputs) query_embedding = outputs.last_hidden_state.mean(dim=1).squeeze().numpy() k = 3 _, I = index.search(query_embedding.reshape(1, -1), k) context = " ".join([chunks[i] for i in I[0]]) prompt = f"As the Muse of A.R. Ammons, respond to this query: {query}\nContext: {context}\nMuse:" input_ids = tokenizer.encode(prompt, return_tensors="pt") output = generation_model.generate(input_ids, max_new_tokens=100, num_return_sequences=1, temperature=0.7) response = tokenizer.decode(output[0], skip_special_tokens=True) muse_response = response.split("Muse:")[-1].strip() return muse_response def save_data(chunks, embeddings, index): with open('chunks.pkl', 'wb') as f: pickle.dump(chunks, f) np.save('embeddings.npy', embeddings) faiss.write_index(index, 'faiss_index.bin') def load_data(): if os.path.exists('chunks.pkl') and os.path.exists('embeddings.npy') and os.path.exists('faiss_index.bin'): with open('chunks.pkl', 'rb') as f: chunks = pickle.load(f) embeddings = np.load('embeddings.npy') index = faiss.read_index('faiss_index.bin') return chunks, embeddings, index return None, None, None # Streamlit UI st.set_page_config(page_title="A.R. Ammons' Muse Chatbot", page_icon="🎭") st.title("A.R. Ammons' Muse Chatbot 🎭") st.markdown(""" """, unsafe_allow_html=True) st.markdown('

Chat with the Muse of A.R. Ammons. Ask questions or discuss poetry!

', unsafe_allow_html=True) # Model selection if 'model_combination' not in st.session_state: st.session_state.model_combination = "Fastest (30 seconds)" # Create a list of model options, with non-free models at the end free_models = [k for k, v in MODEL_COMBINATIONS.items() if v['free']] non_free_models = [k for k, v in MODEL_COMBINATIONS.items() if not v['free']] all_models = free_models + non_free_models # Custom CSS to grey out non-free options st.markdown(""" """, unsafe_allow_html=True) selected_model = st.selectbox( "Choose a model combination:", all_models, index=all_models.index(st.session_state.model_combination), format_func=lambda x: f"{x} {'(Not Free)' if not MODEL_COMBINATIONS[x]['free'] else ''}" ) # Prevent selection of non-free models if not MODEL_COMBINATIONS[selected_model]['free']: st.warning("Premium models are not available in the free version.") st.stop() st.session_state.model_combination = selected_model st.info(f"Potential time saved compared to slowest option: {MODEL_COMBINATIONS[selected_model]['time_saved']}") if st.button("Load Selected Models"): with st.spinner("Loading models and data..."): embedding_tokenizer, embedding_model, generation_tokenizer, generation_model = load_models(st.session_state.model_combination) chunks = load_and_process_text('ammons_muse.txt') embeddings = create_embeddings(chunks, embedding_model) index = create_faiss_index(embeddings) st.session_state.models_loaded = True st.success("Models loaded successfully!") if 'models_loaded' not in st.session_state or not st.session_state.models_loaded: st.warning("Please load the models before chatting.") st.stop() # Initialize chat history if 'messages' not in st.session_state: st.session_state.messages = [] # Display chat messages from history on app rerun for message in st.session_state.messages: with st.chat_message(message["role"]): st.markdown(message["content"]) # React to user input if prompt := st.chat_input("What would you like to ask the Muse?"): st.chat_message("user").markdown(prompt) st.session_state.messages.append({"role": "user", "content": prompt}) with st.spinner("The Muse is contemplating..."): try: response = generate_response(prompt, tokenizer, generation_model, embedding_model, index, chunks) except Exception as e: response = f"I apologize, but I encountered an error: {str(e)}" with st.chat_message("assistant"): st.markdown(response) st.session_state.messages.append({"role": "assistant", "content": response}) # Add a button to clear chat history if st.button("Clear Chat History"): st.session_state.messages = [] st.experimental_rerun() # Add a footer st.markdown("---") st.markdown("*Powered by the spirit of A.R. Ammons and the magic of AI*")