Spaces:
Sleeping
Sleeping
user
commited on
Commit
·
a4614bf
1
Parent(s):
cc19159
Fix max_length error and implement data persistence
Browse files
app.py
CHANGED
@@ -3,6 +3,8 @@ import torch
|
|
3 |
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
|
4 |
import faiss
|
5 |
import numpy as np
|
|
|
|
|
6 |
|
7 |
@st.cache_resource
|
8 |
def load_models():
|
@@ -56,12 +58,27 @@ def generate_response(query, tokenizer, generation_model, embedding_model, index
|
|
56 |
prompt = f"As the Muse of A.R. Ammons, respond to this query: {query}\nContext: {context}\nMuse:"
|
57 |
|
58 |
input_ids = tokenizer.encode(prompt, return_tensors="pt")
|
59 |
-
output = generation_model.generate(input_ids,
|
60 |
response = tokenizer.decode(output[0], skip_special_tokens=True)
|
61 |
|
62 |
muse_response = response.split("Muse:")[-1].strip()
|
63 |
return muse_response
|
64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
# Streamlit UI
|
66 |
st.set_page_config(page_title="A.R. Ammons' Muse Chatbot", page_icon="🎭")
|
67 |
|
@@ -79,9 +96,12 @@ st.markdown('<p class="big-font">Chat with the Muse of A.R. Ammons. Ask question
|
|
79 |
# Load models and data
|
80 |
with st.spinner("Loading models and data..."):
|
81 |
tokenizer, embedding_model, generation_model = load_models()
|
82 |
-
chunks =
|
83 |
-
embeddings
|
84 |
-
|
|
|
|
|
|
|
85 |
|
86 |
if tokenizer is None or embedding_model is None or generation_model is None or not chunks:
|
87 |
st.error("Failed to load necessary components. Please try again later.")
|
|
|
3 |
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
|
4 |
import faiss
|
5 |
import numpy as np
|
6 |
+
import os
|
7 |
+
import pickle
|
8 |
|
9 |
@st.cache_resource
|
10 |
def load_models():
|
|
|
58 |
prompt = f"As the Muse of A.R. Ammons, respond to this query: {query}\nContext: {context}\nMuse:"
|
59 |
|
60 |
input_ids = tokenizer.encode(prompt, return_tensors="pt")
|
61 |
+
output = generation_model.generate(input_ids, max_new_tokens=100, num_return_sequences=1, temperature=0.7)
|
62 |
response = tokenizer.decode(output[0], skip_special_tokens=True)
|
63 |
|
64 |
muse_response = response.split("Muse:")[-1].strip()
|
65 |
return muse_response
|
66 |
|
67 |
+
def save_data(chunks, embeddings, index):
|
68 |
+
with open('chunks.pkl', 'wb') as f:
|
69 |
+
pickle.dump(chunks, f)
|
70 |
+
np.save('embeddings.npy', embeddings)
|
71 |
+
faiss.write_index(index, 'faiss_index.bin')
|
72 |
+
|
73 |
+
def load_data():
|
74 |
+
if os.path.exists('chunks.pkl') and os.path.exists('embeddings.npy') and os.path.exists('faiss_index.bin'):
|
75 |
+
with open('chunks.pkl', 'rb') as f:
|
76 |
+
chunks = pickle.load(f)
|
77 |
+
embeddings = np.load('embeddings.npy')
|
78 |
+
index = faiss.read_index('faiss_index.bin')
|
79 |
+
return chunks, embeddings, index
|
80 |
+
return None, None, None
|
81 |
+
|
82 |
# Streamlit UI
|
83 |
st.set_page_config(page_title="A.R. Ammons' Muse Chatbot", page_icon="🎭")
|
84 |
|
|
|
96 |
# Load models and data
|
97 |
with st.spinner("Loading models and data..."):
|
98 |
tokenizer, embedding_model, generation_model = load_models()
|
99 |
+
chunks, embeddings, index = load_data()
|
100 |
+
if chunks is None or embeddings is None or index is None:
|
101 |
+
chunks = load_and_process_text('ammons_muse.txt')
|
102 |
+
embeddings = create_embeddings(chunks, embedding_model)
|
103 |
+
index = create_faiss_index(embeddings)
|
104 |
+
save_data(chunks, embeddings, index)
|
105 |
|
106 |
if tokenizer is None or embedding_model is None or generation_model is None or not chunks:
|
107 |
st.error("Failed to load necessary components. Please try again later.")
|