user commited on
Commit
fe293b8
·
1 Parent(s): ecb7b4d

model selection with speed ratings, time savings

Browse files
Files changed (1) hide show
  1. app.py +83 -14
app.py CHANGED
@@ -8,16 +8,51 @@ import pickle
8
  import warnings
9
  warnings.filterwarnings("ignore", category=FutureWarning, module="transformers")
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  @st.cache_resource
12
- def load_models():
13
  try:
14
- tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
15
- embedding_model = AutoModel.from_pretrained("distilbert-base-uncased")
16
- generation_model = AutoModelForCausalLM.from_pretrained("gpt2")
17
- return tokenizer, embedding_model, generation_model
 
18
  except Exception as e:
19
  st.error(f"Error loading models: {str(e)}")
20
- return None, None, None
21
 
22
  @st.cache_data
23
  def load_and_process_text(file_path):
@@ -95,18 +130,52 @@ st.markdown("""
95
  """, unsafe_allow_html=True)
96
  st.markdown('<p class="big-font">Chat with the Muse of A.R. Ammons. Ask questions or discuss poetry!</p>', unsafe_allow_html=True)
97
 
98
- # Load models and data
99
- with st.spinner("Loading models and data..."):
100
- tokenizer, embedding_model, generation_model = load_models()
101
- chunks, embeddings, index = load_data()
102
- if chunks is None or embeddings is None or index is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  chunks = load_and_process_text('ammons_muse.txt')
104
  embeddings = create_embeddings(chunks, embedding_model)
105
  index = create_faiss_index(embeddings)
106
- save_data(chunks, embeddings, index)
 
 
107
 
108
- if tokenizer is None or embedding_model is None or generation_model is None or not chunks:
109
- st.error("Failed to load necessary components. Please try again later.")
110
  st.stop()
111
 
112
  # Initialize chat history
 
8
  import warnings
9
  warnings.filterwarnings("ignore", category=FutureWarning, module="transformers")
10
 
11
+ # Model combinations with speed ratings and estimated time savings
12
+ MODEL_COMBINATIONS = {
13
+ "Fastest (30 seconds)": {
14
+ "embedding": "sentence-transformers/all-MiniLM-L6-v2",
15
+ "generation": "distilgpt2",
16
+ "free": True,
17
+ "time_saved": "2.5 minutes"
18
+ },
19
+ "Balanced (1 minute)": {
20
+ "embedding": "sentence-transformers/all-MiniLM-L12-v2",
21
+ "generation": "facebook/opt-350m",
22
+ "free": True,
23
+ "time_saved": "2 minutes"
24
+ },
25
+ "High Quality (2 minutes)": {
26
+ "embedding": "sentence-transformers/all-mpnet-base-v2",
27
+ "generation": "gpt2",
28
+ "free": True,
29
+ "time_saved": "1 minute"
30
+ },
31
+ "Premium Speed (15 seconds)": {
32
+ "embedding": "sentence-transformers/all-MiniLM-L6-v2",
33
+ "generation": "microsoft/phi-1_5",
34
+ "free": False,
35
+ "time_saved": "2.75 minutes"
36
+ },
37
+ "Premium Quality (1.5 minutes)": {
38
+ "embedding": "openai-embedding-ada-002",
39
+ "generation": "meta-llama/Llama-2-7b-chat-hf",
40
+ "free": False,
41
+ "time_saved": "1.5 minutes"
42
+ }
43
+ }
44
+
45
  @st.cache_resource
46
+ def load_models(model_combination):
47
  try:
48
+ embedding_tokenizer = AutoTokenizer.from_pretrained(MODEL_COMBINATIONS[model_combination]['embedding'])
49
+ embedding_model = AutoModel.from_pretrained(MODEL_COMBINATIONS[model_combination]['embedding'])
50
+ generation_tokenizer = AutoTokenizer.from_pretrained(MODEL_COMBINATIONS[model_combination]['generation'])
51
+ generation_model = AutoModelForCausalLM.from_pretrained(MODEL_COMBINATIONS[model_combination]['generation'])
52
+ return embedding_tokenizer, embedding_model, generation_tokenizer, generation_model
53
  except Exception as e:
54
  st.error(f"Error loading models: {str(e)}")
55
+ return None, None, None, None
56
 
57
  @st.cache_data
58
  def load_and_process_text(file_path):
 
130
  """, unsafe_allow_html=True)
131
  st.markdown('<p class="big-font">Chat with the Muse of A.R. Ammons. Ask questions or discuss poetry!</p>', unsafe_allow_html=True)
132
 
133
+ # Model selection
134
+ if 'model_combination' not in st.session_state:
135
+ st.session_state.model_combination = "Fastest (30 seconds)"
136
+
137
+ # Create a list of model options, with non-free models at the end
138
+ free_models = [k for k, v in MODEL_COMBINATIONS.items() if v['free']]
139
+ non_free_models = [k for k, v in MODEL_COMBINATIONS.items() if not v['free']]
140
+ all_models = free_models + non_free_models
141
+
142
+ # Custom CSS to grey out non-free options
143
+ st.markdown("""
144
+ <style>
145
+ .stSelectbox div[role="option"][aria-selected="false"]:nth-last-child(-n+2) {
146
+ color: grey !important;
147
+ }
148
+ </style>
149
+ """, unsafe_allow_html=True)
150
+
151
+ selected_model = st.selectbox(
152
+ "Choose a model combination:",
153
+ all_models,
154
+ index=all_models.index(st.session_state.model_combination),
155
+ format_func=lambda x: f"{x} {'(Not Free)' if not MODEL_COMBINATIONS[x]['free'] else ''}"
156
+ )
157
+
158
+ # Prevent selection of non-free models
159
+ if not MODEL_COMBINATIONS[selected_model]['free']:
160
+ st.warning("Premium models are not available in the free version.")
161
+ st.stop()
162
+
163
+ st.session_state.model_combination = selected_model
164
+
165
+ st.info(f"Potential time saved compared to slowest option: {MODEL_COMBINATIONS[selected_model]['time_saved']}")
166
+
167
+ if st.button("Load Selected Models"):
168
+ with st.spinner("Loading models and data..."):
169
+ embedding_tokenizer, embedding_model, generation_tokenizer, generation_model = load_models(st.session_state.model_combination)
170
  chunks = load_and_process_text('ammons_muse.txt')
171
  embeddings = create_embeddings(chunks, embedding_model)
172
  index = create_faiss_index(embeddings)
173
+
174
+ st.session_state.models_loaded = True
175
+ st.success("Models loaded successfully!")
176
 
177
+ if 'models_loaded' not in st.session_state or not st.session_state.models_loaded:
178
+ st.warning("Please load the models before chatting.")
179
  st.stop()
180
 
181
  # Initialize chat history