import streamlit as st import torch from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM import faiss import numpy as np import os import pickle import warnings warnings.filterwarnings("ignore", category=FutureWarning, module="transformers") # Model combinations with speed ratings and estimated time savings MODEL_COMBINATIONS = { "Fastest (30 seconds)": { "embedding": "sentence-transformers/all-MiniLM-L6-v2", "generation": "distilgpt2", "free": True, "time_saved": "2.5 minutes" }, "Balanced (1 minute)": { "embedding": "sentence-transformers/all-MiniLM-L12-v2", "generation": "facebook/opt-350m", "free": True, "time_saved": "2 minutes" }, "High Quality (2 minutes)": { "embedding": "sentence-transformers/all-mpnet-base-v2", "generation": "gpt2", "free": True, "time_saved": "1 minute" }, "Premium Speed (15 seconds)": { "embedding": "sentence-transformers/all-MiniLM-L6-v2", "generation": "microsoft/phi-1_5", "free": False, "time_saved": "2.75 minutes" }, "Premium Quality (1.5 minutes)": { "embedding": "openai-embedding-ada-002", "generation": "meta-llama/Llama-2-7b-chat-hf", "free": False, "time_saved": "1.5 minutes" } } def load_model(model_name): try: return AutoModel.from_pretrained(model_name) except Exception as e: st.error(f"Error loading model {model_name}: {str(e)}") return None def load_tokenizer(model_name): try: return AutoTokenizer.from_pretrained(model_name) except Exception as e: st.error(f"Error loading tokenizer for {model_name}: {str(e)}") return None @st.cache_resource def load_embedding_model(model_name): return load_model(model_name) @st.cache_resource def load_generation_model(model_name): try: return AutoModelForCausalLM.from_pretrained(model_name) except Exception as e: st.error(f"Error loading generation model {model_name}: {str(e)}") return None def load_index_and_chunks(): try: with open('faiss_index.pkl', 'rb') as f: index = pickle.load(f) with open('chunks.pkl', 'rb') as f: chunks = pickle.load(f) return index, chunks except Exception as e: st.error(f"Error loading index and chunks: {str(e)}") return None, None def generate_response(prompt, embedding_tokenizer, generation_tokenizer, generation_model, embedding_model, index, chunks): try: # Embed the prompt prompt_embedding = embedding_model(embedding_tokenizer(prompt, return_tensors='pt')['input_ids']).last_hidden_state.mean(dim=1).detach().numpy() # Search for similar chunks D, I = index.search(prompt_embedding, k=5) context = " ".join([chunks[i] for i in I[0]]) # Generate response input_text = f"Context: {context}\n\nQuestion: {prompt}\n\nAnswer:" input_ids = generation_tokenizer(input_text, return_tensors="pt").input_ids output = generation_model.generate(input_ids, max_length=150, num_return_sequences=1, no_repeat_ngram_size=2) response = generation_tokenizer.decode(output[0], skip_special_tokens=True) return response except Exception as e: st.error(f"Error generating response: {str(e)}") return "I apologize, but I encountered an error while generating a response." def main(): st.title("Your Muse Chat App") # Load models and data selected_combo = st.selectbox("Choose a model combination:", list(MODEL_COMBINATIONS.keys())) combo = MODEL_COMBINATIONS[selected_combo] embedding_model = load_embedding_model(combo['embedding']) generation_model = load_generation_model(combo['generation']) embedding_tokenizer = load_tokenizer(combo['embedding']) generation_tokenizer = load_tokenizer(combo['generation']) index, chunks = load_index_and_chunks() if not all([embedding_model, generation_model, embedding_tokenizer, generation_tokenizer, index, chunks]): st.error("Some components failed to load. Please check the errors above.") return # Initialize chat history if "messages" not in st.session_state: st.session_state.messages = [] # Display chat messages for message in st.session_state.messages: with st.chat_message(message["role"]): st.markdown(message["content"]) # Chat input if prompt := st.chat_input("What would you like to ask the Muse?"): st.chat_message("user").markdown(prompt) st.session_state.messages.append({"role": "user", "content": prompt}) with st.spinner("The Muse is contemplating..."): response = generate_response(prompt, embedding_tokenizer, generation_tokenizer, generation_model, embedding_model, index, chunks) with st.chat_message("assistant"): st.markdown(response) st.session_state.messages.append({"role": "assistant", "content": response}) if __name__ == "__main__": main()