Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM | |
import faiss | |
import numpy as np | |
import os | |
import pickle | |
import warnings | |
warnings.filterwarnings("ignore", category=FutureWarning, module="transformers") | |
def load_models(): | |
try: | |
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") | |
embedding_model = AutoModel.from_pretrained("distilbert-base-uncased") | |
generation_model = AutoModelForCausalLM.from_pretrained("gpt2") | |
return tokenizer, embedding_model, generation_model | |
except Exception as e: | |
st.error(f"Error loading models: {str(e)}") | |
return None, None, None | |
def load_and_process_text(file_path): | |
try: | |
with open(file_path, 'r', encoding='utf-8') as file: | |
text = file.read() | |
chunks = [text[i:i+512] for i in range(0, len(text), 512)] | |
return chunks | |
except Exception as e: | |
st.error(f"Error loading text file: {str(e)}") | |
return [] | |
def create_embeddings(chunks, _embedding_model): | |
embeddings = [] | |
for chunk in chunks: | |
inputs = tokenizer(chunk, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
with torch.no_grad(): | |
outputs = _embedding_model(**inputs) | |
embeddings.append(outputs.last_hidden_state.mean(dim=1).squeeze().numpy()) | |
return np.array(embeddings) | |
def create_faiss_index(embeddings): | |
index = faiss.IndexFlatL2(embeddings.shape[1]) | |
index.add(embeddings) | |
return index | |
def generate_response(query, tokenizer, generation_model, embedding_model, index, chunks): | |
inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
with torch.no_grad(): | |
outputs = embedding_model(**inputs) | |
query_embedding = outputs.last_hidden_state.mean(dim=1).squeeze().numpy() | |
k = 3 | |
_, I = index.search(query_embedding.reshape(1, -1), k) | |
context = " ".join([chunks[i] for i in I[0]]) | |
prompt = f"As the Muse of A.R. Ammons, respond to this query: {query}\nContext: {context}\nMuse:" | |
input_ids = tokenizer.encode(prompt, return_tensors="pt") | |
output = generation_model.generate(input_ids, max_new_tokens=100, num_return_sequences=1, temperature=0.7) | |
response = tokenizer.decode(output[0], skip_special_tokens=True) | |
muse_response = response.split("Muse:")[-1].strip() | |
return muse_response | |
def save_data(chunks, embeddings, index): | |
with open('chunks.pkl', 'wb') as f: | |
pickle.dump(chunks, f) | |
np.save('embeddings.npy', embeddings) | |
faiss.write_index(index, 'faiss_index.bin') | |
def load_data(): | |
if os.path.exists('chunks.pkl') and os.path.exists('embeddings.npy') and os.path.exists('faiss_index.bin'): | |
with open('chunks.pkl', 'rb') as f: | |
chunks = pickle.load(f) | |
embeddings = np.load('embeddings.npy') | |
index = faiss.read_index('faiss_index.bin') | |
return chunks, embeddings, index | |
return None, None, None | |
# Streamlit UI | |
st.set_page_config(page_title="A.R. Ammons' Muse Chatbot", page_icon="π") | |
st.title("A.R. Ammons' Muse Chatbot π") | |
st.markdown(""" | |
<style> | |
.big-font { | |
font-size:20px !important; | |
font-weight: bold; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
st.markdown('<p class="big-font">Chat with the Muse of A.R. Ammons. Ask questions or discuss poetry!</p>', unsafe_allow_html=True) | |
# Load models and data | |
with st.spinner("Loading models and data..."): | |
tokenizer, embedding_model, generation_model = load_models() | |
chunks, embeddings, index = load_data() | |
if chunks is None or embeddings is None or index is None: | |
chunks = load_and_process_text('ammons_muse.txt') | |
embeddings = create_embeddings(chunks, embedding_model) | |
index = create_faiss_index(embeddings) | |
save_data(chunks, embeddings, index) | |
if tokenizer is None or embedding_model is None or generation_model is None or not chunks: | |
st.error("Failed to load necessary components. Please try again later.") | |
st.stop() | |
# Initialize chat history | |
if 'messages' not in st.session_state: | |
st.session_state.messages = [] | |
# Display chat messages from history on app rerun | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
# React to user input | |
if prompt := st.chat_input("What would you like to ask the Muse?"): | |
st.chat_message("user").markdown(prompt) | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
with st.spinner("The Muse is contemplating..."): | |
try: | |
response = generate_response(prompt, tokenizer, generation_model, embedding_model, index, chunks) | |
except Exception as e: | |
response = f"I apologize, but I encountered an error: {str(e)}" | |
with st.chat_message("assistant"): | |
st.markdown(response) | |
st.session_state.messages.append({"role": "assistant", "content": response}) | |
# Add a button to clear chat history | |
if st.button("Clear Chat History"): | |
st.session_state.messages = [] | |
st.experimental_rerun() | |
# Add a footer | |
st.markdown("---") | |
st.markdown("*Powered by the spirit of A.R. Ammons and the magic of AI*") | |