user commited on
Commit
576b273
·
1 Parent(s): b300879
Files changed (2) hide show
  1. app.py +107 -67
  2. requirements.txt +8 -0
app.py CHANGED
@@ -1,81 +1,121 @@
1
- import gradio as gr
2
- from sentence_transformers import SentenceTransformer
 
3
  import faiss
4
- from transformers import pipeline
5
  import numpy as np
6
- import os
7
 
8
- # File paths
9
- INDEX_FILE = 'ammons_muse_index.faiss'
10
- EMBEDDINGS_FILE = 'ammons_muse_embeddings.npy'
11
- CHUNKS_FILE = 'ammons_muse_chunks.npy'
12
- TEXT_FILE = 'ammons_muse.txt'
 
 
 
 
 
13
 
14
- # Load and prepare the text
15
- def prepare_text():
16
- with open(TEXT_FILE, 'r', encoding='utf-8') as file:
17
- text = file.read()
18
- chunk_size = 1000
19
- return [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]
 
 
 
 
20
 
21
- # Create or load embeddings and index
22
- def get_embeddings_and_index(chunks):
23
- if os.path.exists(INDEX_FILE) and os.path.exists(EMBEDDINGS_FILE):
24
- print("Loading existing index and embeddings...")
25
- index = faiss.read_index(INDEX_FILE)
26
- embeddings = np.load(EMBEDDINGS_FILE)
27
- else:
28
- print("Creating new index and embeddings...")
29
- model = SentenceTransformer('all-MiniLM-L6-v2')
30
- embeddings = model.encode(chunks)
31
- dimension = embeddings.shape[1]
32
- index = faiss.IndexFlatL2(dimension)
33
- index.add(embeddings.astype('float32'))
34
-
35
- # Save index and embeddings
36
- faiss.write_index(index, INDEX_FILE)
37
- np.save(EMBEDDINGS_FILE, embeddings)
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- return embeddings, index
 
 
 
 
 
 
 
 
40
 
41
- # Load or create chunks
42
- if os.path.exists(CHUNKS_FILE):
43
- chunks = np.load(CHUNKS_FILE, allow_pickle=True).tolist()
44
- else:
45
- chunks = prepare_text()
46
- np.save(CHUNKS_FILE, np.array(chunks, dtype=object))
 
 
 
 
47
 
48
- # Get embeddings and index
49
- embeddings, index = get_embeddings_and_index(chunks)
 
 
 
 
50
 
51
- # Set up text generation pipeline
52
- generator = pipeline('text-generation', model='gpt2')
 
53
 
54
- # Retrieval function
55
- def retrieve_relevant_chunks(query, top_k=3):
56
- model = SentenceTransformer('all-MiniLM-L6-v2')
57
- query_vector = model.encode([query])
58
- _, indices = index.search(query_vector.astype('float32'), top_k)
59
- return [chunks[i] for i in indices[0]]
60
 
61
- # Character response generation
62
- def generate_character_response(query):
63
- relevant_chunks = retrieve_relevant_chunks(query)
64
- prompt = f"""As the Muse from A.R. Ammons' poetry, respond to this query:
65
- Context: {' '.join(relevant_chunks)}
66
- User: {query}
67
- Muse:"""
68
 
69
- response = generator(prompt, max_length=150, num_return_sequences=1)[0]['generated_text']
70
- return response.split('Muse:')[-1].strip()
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
- # Gradio interface
73
- iface = gr.Interface(
74
- fn=generate_character_response,
75
- inputs=gr.Textbox(lines=2, placeholder="Enter your question here..."),
76
- outputs="text",
77
- title="A.R. Ammons' Muse Chatbot",
78
- description="Ask a question and get a response from the Muse of A.R. Ammons' poetry."
79
- )
80
 
81
- iface.launch()
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
4
  import faiss
 
5
  import numpy as np
 
6
 
7
+ @st.cache_resource
8
+ def load_models():
9
+ try:
10
+ tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
11
+ embedding_model = AutoModel.from_pretrained("distilbert-base-uncased")
12
+ generation_model = AutoModelForCausalLM.from_pretrained("gpt2")
13
+ return tokenizer, embedding_model, generation_model
14
+ except Exception as e:
15
+ st.error(f"Error loading models: {str(e)}")
16
+ return None, None, None
17
 
18
+ @st.cache_data
19
+ def load_and_process_text(file_path):
20
+ try:
21
+ with open(file_path, 'r', encoding='utf-8') as file:
22
+ text = file.read()
23
+ chunks = [text[i:i+512] for i in range(0, len(text), 512)]
24
+ return chunks
25
+ except Exception as e:
26
+ st.error(f"Error loading text file: {str(e)}")
27
+ return []
28
 
29
+ @st.cache_data
30
+ def create_embeddings(chunks, tokenizer, embedding_model):
31
+ embeddings = []
32
+ for chunk in chunks:
33
+ inputs = tokenizer(chunk, return_tensors="pt", padding=True, truncation=True, max_length=512)
34
+ with torch.no_grad():
35
+ outputs = embedding_model(**inputs)
36
+ embeddings.append(outputs.last_hidden_state.mean(dim=1).squeeze().numpy())
37
+ return np.array(embeddings)
38
+
39
+ @st.cache_resource
40
+ def create_faiss_index(embeddings):
41
+ index = faiss.IndexFlatL2(embeddings.shape[1])
42
+ index.add(embeddings)
43
+ return index
44
+
45
+ def generate_response(query, tokenizer, generation_model, embedding_model, index, chunks):
46
+ inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True, max_length=512)
47
+ with torch.no_grad():
48
+ outputs = embedding_model(**inputs)
49
+ query_embedding = outputs.last_hidden_state.mean(dim=1).squeeze().numpy()
50
+
51
+ k = 3
52
+ _, I = index.search(query_embedding.reshape(1, -1), k)
53
+
54
+ context = " ".join([chunks[i] for i in I[0]])
55
+
56
+ prompt = f"As the Muse of A.R. Ammons, respond to this query: {query}\nContext: {context}\nMuse:"
57
 
58
+ input_ids = tokenizer.encode(prompt, return_tensors="pt")
59
+ output = generation_model.generate(input_ids, max_length=200, num_return_sequences=1, temperature=0.7)
60
+ response = tokenizer.decode(output[0], skip_special_tokens=True)
61
+
62
+ muse_response = response.split("Muse:")[-1].strip()
63
+ return muse_response
64
+
65
+ # Streamlit UI
66
+ st.set_page_config(page_title="A.R. Ammons' Muse Chatbot", page_icon="🎭")
67
 
68
+ st.title("A.R. Ammons' Muse Chatbot 🎭")
69
+ st.markdown("""
70
+ <style>
71
+ .big-font {
72
+ font-size:20px !important;
73
+ font-weight: bold;
74
+ }
75
+ </style>
76
+ """, unsafe_allow_html=True)
77
+ st.markdown('<p class="big-font">Chat with the Muse of A.R. Ammons. Ask questions or discuss poetry!</p>', unsafe_allow_html=True)
78
 
79
+ # Load models and data
80
+ with st.spinner("Loading models and data..."):
81
+ tokenizer, embedding_model, generation_model = load_models()
82
+ chunks = load_and_process_text('ammons_muse.txt')
83
+ embeddings = create_embeddings(chunks, tokenizer, embedding_model)
84
+ index = create_faiss_index(embeddings)
85
 
86
+ if tokenizer is None or embedding_model is None or generation_model is None or not chunks:
87
+ st.error("Failed to load necessary components. Please try again later.")
88
+ st.stop()
89
 
90
+ # Initialize chat history
91
+ if 'messages' not in st.session_state:
92
+ st.session_state.messages = []
 
 
 
93
 
94
+ # Display chat messages from history on app rerun
95
+ for message in st.session_state.messages:
96
+ with st.chat_message(message["role"]):
97
+ st.markdown(message["content"])
 
 
 
98
 
99
+ # React to user input
100
+ if prompt := st.chat_input("What would you like to ask the Muse?"):
101
+ st.chat_message("user").markdown(prompt)
102
+ st.session_state.messages.append({"role": "user", "content": prompt})
103
+
104
+ with st.spinner("The Muse is contemplating..."):
105
+ try:
106
+ response = generate_response(prompt, tokenizer, generation_model, embedding_model, index, chunks)
107
+ except Exception as e:
108
+ response = f"I apologize, but I encountered an error: {str(e)}"
109
+
110
+ with st.chat_message("assistant"):
111
+ st.markdown(response)
112
+ st.session_state.messages.append({"role": "assistant", "content": response})
113
 
114
+ # Add a button to clear chat history
115
+ if st.button("Clear Chat History"):
116
+ st.session_state.messages = []
117
+ st.experimental_rerun()
 
 
 
 
118
 
119
+ # Add a footer
120
+ st.markdown("---")
121
+ st.markdown("*Powered by the spirit of A.R. Ammons and the magic of AI*")
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+
2
+ torch
3
+ transformers
4
+ sentence-transformers
5
+ faiss-cpu
6
+ numpy
7
+ datasets
8
+ streamlit