Spaces:
Sleeping
Sleeping
user
commited on
Commit
·
97426bb
1
Parent(s):
becd78e
Fix response generation and handle unused token errors
Browse files- app.py +35 -38
- requirements.txt +0 -1
app.py
CHANGED
@@ -43,16 +43,12 @@ MODEL_COMBINATIONS = {
|
|
43 |
}
|
44 |
|
45 |
@st.cache_resource
|
46 |
-
def load_models(
|
47 |
try:
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
embedding_model = AutoModel.from_pretrained(embedding_model_name)
|
53 |
-
generation_tokenizer = AutoTokenizer.from_pretrained(generation_model_name)
|
54 |
-
generation_model = AutoModelForCausalLM.from_pretrained(generation_model_name)
|
55 |
-
|
56 |
return embedding_tokenizer, embedding_model, generation_tokenizer, generation_model
|
57 |
except Exception as e:
|
58 |
st.error(f"Error loading models: {str(e)}")
|
@@ -99,27 +95,38 @@ def generate_response(query, embedding_tokenizer, generation_tokenizer, generati
|
|
99 |
prompt = f"As the Muse of A.R. Ammons, respond to this query: {query}\nContext: {context}\nMuse:"
|
100 |
|
101 |
input_ids = generation_tokenizer.encode(prompt, return_tensors="pt")
|
102 |
-
output = generation_model.generate(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
response = generation_tokenizer.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
104 |
|
105 |
muse_response = response.split("Muse:")[-1].strip()
|
|
|
|
|
|
|
|
|
|
|
106 |
return muse_response
|
107 |
|
108 |
-
def save_data(chunks, embeddings, index
|
109 |
-
|
110 |
-
with open(f'data/chunks_{model_combination}.pkl', 'wb') as f:
|
111 |
pickle.dump(chunks, f)
|
112 |
-
np.save(
|
113 |
-
faiss.write_index(index,
|
114 |
-
|
115 |
-
def load_data(
|
116 |
-
if os.path.exists(
|
117 |
-
|
118 |
-
os.path.exists(f'data/faiss_index_{model_combination}.bin'):
|
119 |
-
with open(f'data/chunks_{model_combination}.pkl', 'rb') as f:
|
120 |
chunks = pickle.load(f)
|
121 |
-
embeddings = np.load(
|
122 |
-
index = faiss.read_index(
|
123 |
return chunks, embeddings, index
|
124 |
return None, None, None
|
125 |
|
@@ -174,22 +181,12 @@ st.info(f"Potential time saved compared to slowest option: {MODEL_COMBINATIONS[s
|
|
174 |
if st.button("Load Selected Models"):
|
175 |
with st.spinner("Loading models and data..."):
|
176 |
embedding_tokenizer, embedding_model, generation_tokenizer, generation_model = load_models(st.session_state.model_combination)
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
# If data doesn't exist, process it and save
|
182 |
-
if chunks is None or embeddings is None or index is None:
|
183 |
-
chunks = load_and_process_text('ammons_muse.txt')
|
184 |
-
embeddings = create_embeddings(chunks, embedding_model)
|
185 |
-
index = create_faiss_index(embeddings)
|
186 |
-
save_data(chunks, embeddings, index, st.session_state.model_combination)
|
187 |
|
188 |
st.session_state.models_loaded = True
|
189 |
-
st.
|
190 |
-
st.session_state.embeddings = embeddings
|
191 |
-
st.session_state.index = index
|
192 |
-
st.success("Models and data loaded successfully!")
|
193 |
|
194 |
if 'models_loaded' not in st.session_state or not st.session_state.models_loaded:
|
195 |
st.warning("Please load the models before chatting.")
|
@@ -211,7 +208,7 @@ if prompt := st.chat_input("What would you like to ask the Muse?"):
|
|
211 |
|
212 |
with st.spinner("The Muse is contemplating..."):
|
213 |
try:
|
214 |
-
response = generate_response(prompt,
|
215 |
except Exception as e:
|
216 |
response = f"I apologize, but I encountered an error: {str(e)}"
|
217 |
|
|
|
43 |
}
|
44 |
|
45 |
@st.cache_resource
|
46 |
+
def load_models(model_combination):
|
47 |
try:
|
48 |
+
embedding_tokenizer = AutoTokenizer.from_pretrained(MODEL_COMBINATIONS[model_combination]['embedding'])
|
49 |
+
embedding_model = AutoModel.from_pretrained(MODEL_COMBINATIONS[model_combination]['embedding'])
|
50 |
+
generation_tokenizer = AutoTokenizer.from_pretrained(MODEL_COMBINATIONS[model_combination]['generation'])
|
51 |
+
generation_model = AutoModelForCausalLM.from_pretrained(MODEL_COMBINATIONS[model_combination]['generation'])
|
|
|
|
|
|
|
|
|
52 |
return embedding_tokenizer, embedding_model, generation_tokenizer, generation_model
|
53 |
except Exception as e:
|
54 |
st.error(f"Error loading models: {str(e)}")
|
|
|
95 |
prompt = f"As the Muse of A.R. Ammons, respond to this query: {query}\nContext: {context}\nMuse:"
|
96 |
|
97 |
input_ids = generation_tokenizer.encode(prompt, return_tensors="pt")
|
98 |
+
output = generation_model.generate(
|
99 |
+
input_ids,
|
100 |
+
max_new_tokens=100,
|
101 |
+
num_return_sequences=1,
|
102 |
+
temperature=0.7,
|
103 |
+
do_sample=True,
|
104 |
+
top_k=50,
|
105 |
+
top_p=0.95,
|
106 |
+
no_repeat_ngram_size=2
|
107 |
+
)
|
108 |
response = generation_tokenizer.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
109 |
|
110 |
muse_response = response.split("Muse:")[-1].strip()
|
111 |
+
|
112 |
+
# Check if the response contains unused tokens
|
113 |
+
if "[unused" in muse_response:
|
114 |
+
muse_response = "I apologize, but I'm having trouble formulating a response. Let me try again with a simpler message: Hello! As the Muse of A.R. Ammons, I'm here to inspire and discuss poetry. How may I assist you today?"
|
115 |
+
|
116 |
return muse_response
|
117 |
|
118 |
+
def save_data(chunks, embeddings, index):
|
119 |
+
with open('chunks.pkl', 'wb') as f:
|
|
|
120 |
pickle.dump(chunks, f)
|
121 |
+
np.save('embeddings.npy', embeddings)
|
122 |
+
faiss.write_index(index, 'faiss_index.bin')
|
123 |
+
|
124 |
+
def load_data():
|
125 |
+
if os.path.exists('chunks.pkl') and os.path.exists('embeddings.npy') and os.path.exists('faiss_index.bin'):
|
126 |
+
with open('chunks.pkl', 'rb') as f:
|
|
|
|
|
127 |
chunks = pickle.load(f)
|
128 |
+
embeddings = np.load('embeddings.npy')
|
129 |
+
index = faiss.read_index('faiss_index.bin')
|
130 |
return chunks, embeddings, index
|
131 |
return None, None, None
|
132 |
|
|
|
181 |
if st.button("Load Selected Models"):
|
182 |
with st.spinner("Loading models and data..."):
|
183 |
embedding_tokenizer, embedding_model, generation_tokenizer, generation_model = load_models(st.session_state.model_combination)
|
184 |
+
chunks = load_and_process_text('ammons_muse.txt')
|
185 |
+
embeddings = create_embeddings(chunks, embedding_model)
|
186 |
+
index = create_faiss_index(embeddings)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
|
188 |
st.session_state.models_loaded = True
|
189 |
+
st.success("Models loaded successfully!")
|
|
|
|
|
|
|
190 |
|
191 |
if 'models_loaded' not in st.session_state or not st.session_state.models_loaded:
|
192 |
st.warning("Please load the models before chatting.")
|
|
|
208 |
|
209 |
with st.spinner("The Muse is contemplating..."):
|
210 |
try:
|
211 |
+
response = generate_response(prompt, tokenizer, generation_model, embedding_model, index, chunks)
|
212 |
except Exception as e:
|
213 |
response = f"I apologize, but I encountered an error: {str(e)}"
|
214 |
|
requirements.txt
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
|
2 |
torch
|
3 |
transformers
|
4 |
sentence-transformers
|
|
|
|
|
1 |
torch
|
2 |
transformers
|
3 |
sentence-transformers
|