File size: 5,169 Bytes
576b273
 
 
b300879
 
a4614bf
 
ecb7b4d
 
b300879
fe293b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e8606b
576b273
4e8606b
576b273
4e8606b
 
b300879
4e8606b
576b273
4e8606b
576b273
4e8606b
 
 
 
 
 
576b273
 
4e8606b
 
 
 
 
 
 
 
 
 
 
97426bb
a4614bf
4e8606b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
import faiss
import numpy as np
import os
import pickle
import warnings
warnings.filterwarnings("ignore", category=FutureWarning, module="transformers")

# Model combinations with speed ratings and estimated time savings
MODEL_COMBINATIONS = {
    "Fastest (30 seconds)": {
        "embedding": "sentence-transformers/all-MiniLM-L6-v2",
        "generation": "distilgpt2",
        "free": True,
        "time_saved": "2.5 minutes"
    },
    "Balanced (1 minute)": {
        "embedding": "sentence-transformers/all-MiniLM-L12-v2",
        "generation": "facebook/opt-350m",
        "free": True,
        "time_saved": "2 minutes"
    },
    "High Quality (2 minutes)": {
        "embedding": "sentence-transformers/all-mpnet-base-v2",
        "generation": "gpt2",
        "free": True,
        "time_saved": "1 minute"
    },
    "Premium Speed (15 seconds)": {
        "embedding": "sentence-transformers/all-MiniLM-L6-v2",
        "generation": "microsoft/phi-1_5",
        "free": False,
        "time_saved": "2.75 minutes"
    },
    "Premium Quality (1.5 minutes)": {
        "embedding": "openai-embedding-ada-002",
        "generation": "meta-llama/Llama-2-7b-chat-hf",
        "free": False,
        "time_saved": "1.5 minutes"
    }
}

def load_model(model_name):
    try:
        return AutoModel.from_pretrained(model_name)
    except Exception as e:
        st.error(f"Error loading model {model_name}: {str(e)}")
        return None

def load_tokenizer(model_name):
    try:
        return AutoTokenizer.from_pretrained(model_name)
    except Exception as e:
        st.error(f"Error loading tokenizer for {model_name}: {str(e)}")
        return None

@st.cache_resource
def load_embedding_model(model_name):
    return load_model(model_name)

@st.cache_resource
def load_generation_model(model_name):
    try:
        return AutoModelForCausalLM.from_pretrained(model_name)
    except Exception as e:
        st.error(f"Error loading generation model {model_name}: {str(e)}")
        return None

def load_index_and_chunks():
    try:
        with open('faiss_index.pkl', 'rb') as f:
            index = pickle.load(f)
        with open('chunks.pkl', 'rb') as f:
            chunks = pickle.load(f)
        return index, chunks
    except Exception as e:
        st.error(f"Error loading index and chunks: {str(e)}")
        return None, None

def generate_response(prompt, embedding_tokenizer, generation_tokenizer, generation_model, embedding_model, index, chunks):
    try:
        # Embed the prompt
        prompt_embedding = embedding_model(embedding_tokenizer(prompt, return_tensors='pt')['input_ids']).last_hidden_state.mean(dim=1).detach().numpy()
        
        # Search for similar chunks
        D, I = index.search(prompt_embedding, k=5)
        context = " ".join([chunks[i] for i in I[0]])
        
        # Generate response
        input_text = f"Context: {context}\n\nQuestion: {prompt}\n\nAnswer:"
        input_ids = generation_tokenizer(input_text, return_tensors="pt").input_ids
        
        output = generation_model.generate(input_ids, max_length=150, num_return_sequences=1, no_repeat_ngram_size=2)
        response = generation_tokenizer.decode(output[0], skip_special_tokens=True)
        
        return response
    except Exception as e:
        st.error(f"Error generating response: {str(e)}")
        return "I apologize, but I encountered an error while generating a response."

def main():
    st.title("Your Muse Chat App")

    # Load models and data
    selected_combo = st.selectbox("Choose a model combination:", list(MODEL_COMBINATIONS.keys()))
    combo = MODEL_COMBINATIONS[selected_combo]

    embedding_model = load_embedding_model(combo['embedding'])
    generation_model = load_generation_model(combo['generation'])
    embedding_tokenizer = load_tokenizer(combo['embedding'])
    generation_tokenizer = load_tokenizer(combo['generation'])

    index, chunks = load_index_and_chunks()

    if not all([embedding_model, generation_model, embedding_tokenizer, generation_tokenizer, index, chunks]):
        st.error("Some components failed to load. Please check the errors above.")
        return

    # Initialize chat history
    if "messages" not in st.session_state:
        st.session_state.messages = []

    # Display chat messages
    for message in st.session_state.messages:
        with st.chat_message(message["role"]):
            st.markdown(message["content"])

    # Chat input
    if prompt := st.chat_input("What would you like to ask the Muse?"):
        st.chat_message("user").markdown(prompt)
        st.session_state.messages.append({"role": "user", "content": prompt})

        with st.spinner("The Muse is contemplating..."):
            response = generate_response(prompt, embedding_tokenizer, generation_tokenizer, generation_model, embedding_model, index, chunks)

        with st.chat_message("assistant"):
            st.markdown(response)
        st.session_state.messages.append({"role": "assistant", "content": response})

if __name__ == "__main__":
    main()