testchatbot / app.py
Yoxas's picture
Update app.py
86d4425 verified
raw
history blame
2.28 kB
import gradio as gr
import pandas as pd
import numpy as np
from transformers import pipeline, BertTokenizer, BertModel
import faiss
import torch
import spaces
# Load CSV data
data = pd.read_csv('RB10kstats.csv')
# Convert embedding column from string to numpy array
data['embeddings'] = data['embeddings'].apply(lambda x: np.fromstring(x[1:-1], sep=', '))
# Initialize FAISS index
dimension = len(data['embeddings'][0])
res = faiss.StandardGpuResources() # use a single GPU
index = faiss.IndexFlatL2(dimension)
gpu_index = faiss.index_cpu_to_gpu(res, 0, index) # move to GPU
gpu_index.add(np.stack(data['embeddings'].values))
# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load QA model
qa_model = pipeline("question-answering", model="distilbert-base-uncased-distilled-squad", device=0 if torch.cuda.is_available() else -1)
# Load BERT model and tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased').to(device)
# Function to embed the question using BERT
@spaces.GPU(duration=120)
def embed_question(question, model, tokenizer):
inputs = tokenizer(question, return_tensors='pt').to(device)
with torch.no_grad():
outputs = model(**inputs)
return outputs.last_hidden_state.mean(dim=1).cpu().numpy()
# Function to retrieve the relevant document and generate a response
@spaces.GPU(duration=120)
def retrieve_and_generate(question):
# Embed the question
question_embedding = embed_question(question, model, tokenizer)
# Search in FAISS index
_, indices = gpu_index.search(question_embedding, k=1)
# Retrieve the most relevant document
relevant_doc = data.iloc[indices[0][0]]
# Use the QA model to generate the answer
context = relevant_doc['Abstract']
response = qa_model(question=question, context=context)
return response['answer']
# Create a Gradio interface
interface = gr.Interface(
fn=retrieve_and_generate,
inputs=gr.inputs.Textbox(lines=2, placeholder="Ask a question about the documents..."),
outputs="text",
title="RAG Chatbot",
description="Ask questions about the documents in the CSV file."
)
# Launch the Gradio app
interface.launch()