rag2 / app.py
user
bug fix
4e8606b
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
import faiss
import numpy as np
import os
import pickle
import warnings
warnings.filterwarnings("ignore", category=FutureWarning, module="transformers")
# Model combinations with speed ratings and estimated time savings
MODEL_COMBINATIONS = {
"Fastest (30 seconds)": {
"embedding": "sentence-transformers/all-MiniLM-L6-v2",
"generation": "distilgpt2",
"free": True,
"time_saved": "2.5 minutes"
},
"Balanced (1 minute)": {
"embedding": "sentence-transformers/all-MiniLM-L12-v2",
"generation": "facebook/opt-350m",
"free": True,
"time_saved": "2 minutes"
},
"High Quality (2 minutes)": {
"embedding": "sentence-transformers/all-mpnet-base-v2",
"generation": "gpt2",
"free": True,
"time_saved": "1 minute"
},
"Premium Speed (15 seconds)": {
"embedding": "sentence-transformers/all-MiniLM-L6-v2",
"generation": "microsoft/phi-1_5",
"free": False,
"time_saved": "2.75 minutes"
},
"Premium Quality (1.5 minutes)": {
"embedding": "openai-embedding-ada-002",
"generation": "meta-llama/Llama-2-7b-chat-hf",
"free": False,
"time_saved": "1.5 minutes"
}
}
def load_model(model_name):
try:
return AutoModel.from_pretrained(model_name)
except Exception as e:
st.error(f"Error loading model {model_name}: {str(e)}")
return None
def load_tokenizer(model_name):
try:
return AutoTokenizer.from_pretrained(model_name)
except Exception as e:
st.error(f"Error loading tokenizer for {model_name}: {str(e)}")
return None
@st.cache_resource
def load_embedding_model(model_name):
return load_model(model_name)
@st.cache_resource
def load_generation_model(model_name):
try:
return AutoModelForCausalLM.from_pretrained(model_name)
except Exception as e:
st.error(f"Error loading generation model {model_name}: {str(e)}")
return None
def load_index_and_chunks():
try:
with open('faiss_index.pkl', 'rb') as f:
index = pickle.load(f)
with open('chunks.pkl', 'rb') as f:
chunks = pickle.load(f)
return index, chunks
except Exception as e:
st.error(f"Error loading index and chunks: {str(e)}")
return None, None
def generate_response(prompt, embedding_tokenizer, generation_tokenizer, generation_model, embedding_model, index, chunks):
try:
# Embed the prompt
prompt_embedding = embedding_model(embedding_tokenizer(prompt, return_tensors='pt')['input_ids']).last_hidden_state.mean(dim=1).detach().numpy()
# Search for similar chunks
D, I = index.search(prompt_embedding, k=5)
context = " ".join([chunks[i] for i in I[0]])
# Generate response
input_text = f"Context: {context}\n\nQuestion: {prompt}\n\nAnswer:"
input_ids = generation_tokenizer(input_text, return_tensors="pt").input_ids
output = generation_model.generate(input_ids, max_length=150, num_return_sequences=1, no_repeat_ngram_size=2)
response = generation_tokenizer.decode(output[0], skip_special_tokens=True)
return response
except Exception as e:
st.error(f"Error generating response: {str(e)}")
return "I apologize, but I encountered an error while generating a response."
def main():
st.title("Your Muse Chat App")
# Load models and data
selected_combo = st.selectbox("Choose a model combination:", list(MODEL_COMBINATIONS.keys()))
combo = MODEL_COMBINATIONS[selected_combo]
embedding_model = load_embedding_model(combo['embedding'])
generation_model = load_generation_model(combo['generation'])
embedding_tokenizer = load_tokenizer(combo['embedding'])
generation_tokenizer = load_tokenizer(combo['generation'])
index, chunks = load_index_and_chunks()
if not all([embedding_model, generation_model, embedding_tokenizer, generation_tokenizer, index, chunks]):
st.error("Some components failed to load. Please check the errors above.")
return
# Initialize chat history
if "messages" not in st.session_state:
st.session_state.messages = []
# Display chat messages
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Chat input
if prompt := st.chat_input("What would you like to ask the Muse?"):
st.chat_message("user").markdown(prompt)
st.session_state.messages.append({"role": "user", "content": prompt})
with st.spinner("The Muse is contemplating..."):
response = generate_response(prompt, embedding_tokenizer, generation_tokenizer, generation_model, embedding_model, index, chunks)
with st.chat_message("assistant"):
st.markdown(response)
st.session_state.messages.append({"role": "assistant", "content": response})
if __name__ == "__main__":
main()