Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM | |
import faiss | |
import numpy as np | |
import os | |
import pickle | |
import warnings | |
warnings.filterwarnings("ignore", category=FutureWarning, module="transformers") | |
# Model combinations with speed ratings and estimated time savings | |
MODEL_COMBINATIONS = { | |
"Fastest (30 seconds)": { | |
"embedding": "sentence-transformers/all-MiniLM-L6-v2", | |
"generation": "distilgpt2", | |
"free": True, | |
"time_saved": "2.5 minutes" | |
}, | |
"Balanced (1 minute)": { | |
"embedding": "sentence-transformers/all-MiniLM-L12-v2", | |
"generation": "facebook/opt-350m", | |
"free": True, | |
"time_saved": "2 minutes" | |
}, | |
"High Quality (2 minutes)": { | |
"embedding": "sentence-transformers/all-mpnet-base-v2", | |
"generation": "gpt2", | |
"free": True, | |
"time_saved": "1 minute" | |
}, | |
"Premium Speed (15 seconds)": { | |
"embedding": "sentence-transformers/all-MiniLM-L6-v2", | |
"generation": "microsoft/phi-1_5", | |
"free": False, | |
"time_saved": "2.75 minutes" | |
}, | |
"Premium Quality (1.5 minutes)": { | |
"embedding": "openai-embedding-ada-002", | |
"generation": "meta-llama/Llama-2-7b-chat-hf", | |
"free": False, | |
"time_saved": "1.5 minutes" | |
} | |
} | |
def load_model(model_name): | |
try: | |
return AutoModel.from_pretrained(model_name) | |
except Exception as e: | |
st.error(f"Error loading model {model_name}: {str(e)}") | |
return None | |
def load_tokenizer(model_name): | |
try: | |
return AutoTokenizer.from_pretrained(model_name) | |
except Exception as e: | |
st.error(f"Error loading tokenizer for {model_name}: {str(e)}") | |
return None | |
def load_embedding_model(model_name): | |
return load_model(model_name) | |
def load_generation_model(model_name): | |
try: | |
return AutoModelForCausalLM.from_pretrained(model_name) | |
except Exception as e: | |
st.error(f"Error loading generation model {model_name}: {str(e)}") | |
return None | |
def load_index_and_chunks(): | |
try: | |
with open('faiss_index.pkl', 'rb') as f: | |
index = pickle.load(f) | |
with open('chunks.pkl', 'rb') as f: | |
chunks = pickle.load(f) | |
return index, chunks | |
except Exception as e: | |
st.error(f"Error loading index and chunks: {str(e)}") | |
return None, None | |
def generate_response(prompt, embedding_tokenizer, generation_tokenizer, generation_model, embedding_model, index, chunks): | |
try: | |
# Embed the prompt | |
prompt_embedding = embedding_model(embedding_tokenizer(prompt, return_tensors='pt')['input_ids']).last_hidden_state.mean(dim=1).detach().numpy() | |
# Search for similar chunks | |
D, I = index.search(prompt_embedding, k=5) | |
context = " ".join([chunks[i] for i in I[0]]) | |
# Generate response | |
input_text = f"Context: {context}\n\nQuestion: {prompt}\n\nAnswer:" | |
input_ids = generation_tokenizer(input_text, return_tensors="pt").input_ids | |
output = generation_model.generate(input_ids, max_length=150, num_return_sequences=1, no_repeat_ngram_size=2) | |
response = generation_tokenizer.decode(output[0], skip_special_tokens=True) | |
return response | |
except Exception as e: | |
st.error(f"Error generating response: {str(e)}") | |
return "I apologize, but I encountered an error while generating a response." | |
def main(): | |
st.title("Your Muse Chat App") | |
# Load models and data | |
selected_combo = st.selectbox("Choose a model combination:", list(MODEL_COMBINATIONS.keys())) | |
combo = MODEL_COMBINATIONS[selected_combo] | |
embedding_model = load_embedding_model(combo['embedding']) | |
generation_model = load_generation_model(combo['generation']) | |
embedding_tokenizer = load_tokenizer(combo['embedding']) | |
generation_tokenizer = load_tokenizer(combo['generation']) | |
index, chunks = load_index_and_chunks() | |
if not all([embedding_model, generation_model, embedding_tokenizer, generation_tokenizer, index, chunks]): | |
st.error("Some components failed to load. Please check the errors above.") | |
return | |
# Initialize chat history | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
# Display chat messages | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
# Chat input | |
if prompt := st.chat_input("What would you like to ask the Muse?"): | |
st.chat_message("user").markdown(prompt) | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
with st.spinner("The Muse is contemplating..."): | |
response = generate_response(prompt, embedding_tokenizer, generation_tokenizer, generation_model, embedding_model, index, chunks) | |
with st.chat_message("assistant"): | |
st.markdown(response) | |
st.session_state.messages.append({"role": "assistant", "content": response}) | |
if __name__ == "__main__": | |
main() |