rag2 / app.py
bardicreels's picture
Upload 2 files
b300879 verified
raw
history blame
2.69 kB
import gradio as gr
from sentence_transformers import SentenceTransformer
import faiss
from transformers import pipeline
import numpy as np
import os
# File paths
INDEX_FILE = 'ammons_muse_index.faiss'
EMBEDDINGS_FILE = 'ammons_muse_embeddings.npy'
CHUNKS_FILE = 'ammons_muse_chunks.npy'
TEXT_FILE = 'ammons_muse.txt'
# Load and prepare the text
def prepare_text():
with open(TEXT_FILE, 'r', encoding='utf-8') as file:
text = file.read()
chunk_size = 1000
return [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]
# Create or load embeddings and index
def get_embeddings_and_index(chunks):
if os.path.exists(INDEX_FILE) and os.path.exists(EMBEDDINGS_FILE):
print("Loading existing index and embeddings...")
index = faiss.read_index(INDEX_FILE)
embeddings = np.load(EMBEDDINGS_FILE)
else:
print("Creating new index and embeddings...")
model = SentenceTransformer('all-MiniLM-L6-v2')
embeddings = model.encode(chunks)
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings.astype('float32'))
# Save index and embeddings
faiss.write_index(index, INDEX_FILE)
np.save(EMBEDDINGS_FILE, embeddings)
return embeddings, index
# Load or create chunks
if os.path.exists(CHUNKS_FILE):
chunks = np.load(CHUNKS_FILE, allow_pickle=True).tolist()
else:
chunks = prepare_text()
np.save(CHUNKS_FILE, np.array(chunks, dtype=object))
# Get embeddings and index
embeddings, index = get_embeddings_and_index(chunks)
# Set up text generation pipeline
generator = pipeline('text-generation', model='gpt2')
# Retrieval function
def retrieve_relevant_chunks(query, top_k=3):
model = SentenceTransformer('all-MiniLM-L6-v2')
query_vector = model.encode([query])
_, indices = index.search(query_vector.astype('float32'), top_k)
return [chunks[i] for i in indices[0]]
# Character response generation
def generate_character_response(query):
relevant_chunks = retrieve_relevant_chunks(query)
prompt = f"""As the Muse from A.R. Ammons' poetry, respond to this query:
Context: {' '.join(relevant_chunks)}
User: {query}
Muse:"""
response = generator(prompt, max_length=150, num_return_sequences=1)[0]['generated_text']
return response.split('Muse:')[-1].strip()
# Gradio interface
iface = gr.Interface(
fn=generate_character_response,
inputs=gr.Textbox(lines=2, placeholder="Enter your question here..."),
outputs="text",
title="A.R. Ammons' Muse Chatbot",
description="Ask a question and get a response from the Muse of A.R. Ammons' poetry."
)
iface.launch()