Spaces:
Sleeping
Sleeping
File size: 5,169 Bytes
576b273 b300879 a4614bf ecb7b4d b300879 fe293b8 4e8606b 576b273 4e8606b 576b273 4e8606b b300879 4e8606b 576b273 4e8606b 576b273 4e8606b 576b273 4e8606b 97426bb a4614bf 4e8606b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
import faiss
import numpy as np
import os
import pickle
import warnings
warnings.filterwarnings("ignore", category=FutureWarning, module="transformers")
# Model combinations with speed ratings and estimated time savings
MODEL_COMBINATIONS = {
"Fastest (30 seconds)": {
"embedding": "sentence-transformers/all-MiniLM-L6-v2",
"generation": "distilgpt2",
"free": True,
"time_saved": "2.5 minutes"
},
"Balanced (1 minute)": {
"embedding": "sentence-transformers/all-MiniLM-L12-v2",
"generation": "facebook/opt-350m",
"free": True,
"time_saved": "2 minutes"
},
"High Quality (2 minutes)": {
"embedding": "sentence-transformers/all-mpnet-base-v2",
"generation": "gpt2",
"free": True,
"time_saved": "1 minute"
},
"Premium Speed (15 seconds)": {
"embedding": "sentence-transformers/all-MiniLM-L6-v2",
"generation": "microsoft/phi-1_5",
"free": False,
"time_saved": "2.75 minutes"
},
"Premium Quality (1.5 minutes)": {
"embedding": "openai-embedding-ada-002",
"generation": "meta-llama/Llama-2-7b-chat-hf",
"free": False,
"time_saved": "1.5 minutes"
}
}
def load_model(model_name):
try:
return AutoModel.from_pretrained(model_name)
except Exception as e:
st.error(f"Error loading model {model_name}: {str(e)}")
return None
def load_tokenizer(model_name):
try:
return AutoTokenizer.from_pretrained(model_name)
except Exception as e:
st.error(f"Error loading tokenizer for {model_name}: {str(e)}")
return None
@st.cache_resource
def load_embedding_model(model_name):
return load_model(model_name)
@st.cache_resource
def load_generation_model(model_name):
try:
return AutoModelForCausalLM.from_pretrained(model_name)
except Exception as e:
st.error(f"Error loading generation model {model_name}: {str(e)}")
return None
def load_index_and_chunks():
try:
with open('faiss_index.pkl', 'rb') as f:
index = pickle.load(f)
with open('chunks.pkl', 'rb') as f:
chunks = pickle.load(f)
return index, chunks
except Exception as e:
st.error(f"Error loading index and chunks: {str(e)}")
return None, None
def generate_response(prompt, embedding_tokenizer, generation_tokenizer, generation_model, embedding_model, index, chunks):
try:
# Embed the prompt
prompt_embedding = embedding_model(embedding_tokenizer(prompt, return_tensors='pt')['input_ids']).last_hidden_state.mean(dim=1).detach().numpy()
# Search for similar chunks
D, I = index.search(prompt_embedding, k=5)
context = " ".join([chunks[i] for i in I[0]])
# Generate response
input_text = f"Context: {context}\n\nQuestion: {prompt}\n\nAnswer:"
input_ids = generation_tokenizer(input_text, return_tensors="pt").input_ids
output = generation_model.generate(input_ids, max_length=150, num_return_sequences=1, no_repeat_ngram_size=2)
response = generation_tokenizer.decode(output[0], skip_special_tokens=True)
return response
except Exception as e:
st.error(f"Error generating response: {str(e)}")
return "I apologize, but I encountered an error while generating a response."
def main():
st.title("Your Muse Chat App")
# Load models and data
selected_combo = st.selectbox("Choose a model combination:", list(MODEL_COMBINATIONS.keys()))
combo = MODEL_COMBINATIONS[selected_combo]
embedding_model = load_embedding_model(combo['embedding'])
generation_model = load_generation_model(combo['generation'])
embedding_tokenizer = load_tokenizer(combo['embedding'])
generation_tokenizer = load_tokenizer(combo['generation'])
index, chunks = load_index_and_chunks()
if not all([embedding_model, generation_model, embedding_tokenizer, generation_tokenizer, index, chunks]):
st.error("Some components failed to load. Please check the errors above.")
return
# Initialize chat history
if "messages" not in st.session_state:
st.session_state.messages = []
# Display chat messages
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Chat input
if prompt := st.chat_input("What would you like to ask the Muse?"):
st.chat_message("user").markdown(prompt)
st.session_state.messages.append({"role": "user", "content": prompt})
with st.spinner("The Muse is contemplating..."):
response = generate_response(prompt, embedding_tokenizer, generation_tokenizer, generation_model, embedding_model, index, chunks)
with st.chat_message("assistant"):
st.markdown(response)
st.session_state.messages.append({"role": "assistant", "content": response})
if __name__ == "__main__":
main() |