Spaces:
Sleeping
Sleeping
import gradio as gr | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
from transformers import pipeline | |
import numpy as np | |
import os | |
# File paths | |
INDEX_FILE = 'ammons_muse_index.faiss' | |
EMBEDDINGS_FILE = 'ammons_muse_embeddings.npy' | |
CHUNKS_FILE = 'ammons_muse_chunks.npy' | |
TEXT_FILE = 'ammons_muse.txt' | |
# Load and prepare the text | |
def prepare_text(): | |
with open(TEXT_FILE, 'r', encoding='utf-8') as file: | |
text = file.read() | |
chunk_size = 1000 | |
return [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)] | |
# Create or load embeddings and index | |
def get_embeddings_and_index(chunks): | |
if os.path.exists(INDEX_FILE) and os.path.exists(EMBEDDINGS_FILE): | |
print("Loading existing index and embeddings...") | |
index = faiss.read_index(INDEX_FILE) | |
embeddings = np.load(EMBEDDINGS_FILE) | |
else: | |
print("Creating new index and embeddings...") | |
model = SentenceTransformer('all-MiniLM-L6-v2') | |
embeddings = model.encode(chunks) | |
dimension = embeddings.shape[1] | |
index = faiss.IndexFlatL2(dimension) | |
index.add(embeddings.astype('float32')) | |
# Save index and embeddings | |
faiss.write_index(index, INDEX_FILE) | |
np.save(EMBEDDINGS_FILE, embeddings) | |
return embeddings, index | |
# Load or create chunks | |
if os.path.exists(CHUNKS_FILE): | |
chunks = np.load(CHUNKS_FILE, allow_pickle=True).tolist() | |
else: | |
chunks = prepare_text() | |
np.save(CHUNKS_FILE, np.array(chunks, dtype=object)) | |
# Get embeddings and index | |
embeddings, index = get_embeddings_and_index(chunks) | |
# Set up text generation pipeline | |
generator = pipeline('text-generation', model='gpt2') | |
# Retrieval function | |
def retrieve_relevant_chunks(query, top_k=3): | |
model = SentenceTransformer('all-MiniLM-L6-v2') | |
query_vector = model.encode([query]) | |
_, indices = index.search(query_vector.astype('float32'), top_k) | |
return [chunks[i] for i in indices[0]] | |
# Character response generation | |
def generate_character_response(query): | |
relevant_chunks = retrieve_relevant_chunks(query) | |
prompt = f"""As the Muse from A.R. Ammons' poetry, respond to this query: | |
Context: {' '.join(relevant_chunks)} | |
User: {query} | |
Muse:""" | |
response = generator(prompt, max_length=150, num_return_sequences=1)[0]['generated_text'] | |
return response.split('Muse:')[-1].strip() | |
# Gradio interface | |
iface = gr.Interface( | |
fn=generate_character_response, | |
inputs=gr.Textbox(lines=2, placeholder="Enter your question here..."), | |
outputs="text", | |
title="A.R. Ammons' Muse Chatbot", | |
description="Ask a question and get a response from the Muse of A.R. Ammons' poetry." | |
) | |
iface.launch() |