Spaces:
Runtime error
Runtime error
# Import libraries | |
import pandas as pd | |
from fastapi import FastAPI | |
from fastapi.responses import HTMLResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
from datasets import load_dataset | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
import gradio as gr | |
# Load the Dataset from Hugging Face and FAQ CSV | |
support_data = load_dataset("rjac/e-commerce-customer-support-qa") | |
# Load FAQ data from a local CSV file directly | |
faq_data = pd.read_csv("Ecommerce_FAQs.csv") | |
# Preprocess and Clean Data | |
faq_data.rename(columns={'prompt': 'Question', 'response': 'Answer'}, inplace=True) | |
faq_data = faq_data[['Question', 'Answer']] | |
support_data_df = pd.DataFrame(support_data['train']) | |
# Extract question-answer pairs from the conversation field | |
def extract_conversation(data): | |
try: | |
parts = data.split("\n\n") | |
question = parts[1].split(": ", 1)[1] if len(parts) > 1 else "" | |
answer = parts[2].split(": ", 1)[1] if len(parts) > 2 else "" | |
return pd.Series({"Question": question, "Answer": answer}) | |
except IndexError: | |
return pd.Series({"Question": "", "Answer": ""}) | |
# Apply extraction function | |
support_data_df[['Question', 'Answer']] = support_data_df['conversation'].apply(extract_conversation) | |
# Combine FAQ data with support data | |
combined_data = pd.concat([faq_data, support_data_df[['Question', 'Answer']]], ignore_index=True) | |
# Initialize SBERT Model | |
model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2') | |
# Generate and Index Embeddings for Combined Data | |
questions = combined_data['Question'].tolist() | |
embeddings = model.encode(questions, convert_to_tensor=True) | |
# Create FAISS index | |
index = faiss.IndexFlatL2(embeddings.shape[1]) | |
index.add(embeddings.cpu().numpy()) | |
# Load your fine-tuned DialoGPT model and tokenizer | |
tokenizer_gpt = AutoTokenizer.from_pretrained("Mishal23/fine_tuned_dialoGPT_model") # Update with your fine-tuned model path | |
model_gpt = AutoModelForCausalLM.from_pretrained("Mishal23/fine_tuned_dialoGPT_model") # Update with your fine-tuned model path | |
# Define Retrieval Function | |
def retrieve_answer(question): | |
question_embedding = model.encode([question], convert_to_tensor=True) | |
question_embedding_np = question_embedding.cpu().numpy() | |
_, closest_index = index.search(question_embedding_np, k=1) | |
best_match_idx = closest_index[0][0] | |
answer = combined_data.iloc[best_match_idx]['Answer'] | |
# If the answer is empty, generate a fallback response | |
if answer.strip() == "": | |
return generate_response(question) # Generate a response from DialoGPT | |
return answer | |
# Generate response using your fine-tuned DialoGPT model | |
def generate_response(user_input): | |
input_ids = tokenizer_gpt.encode(user_input, return_tensors='pt') | |
chat_history_ids = model_gpt.generate(input_ids, max_length=100, pad_token_id=tokenizer_gpt.eos_token_id) | |
response = tokenizer_gpt.decode(chat_history_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True) | |
return response if response.strip() else "Oops, I don't know the answer to that." | |
# Initialize FastAPI app | |
app = FastAPI() | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # Allows all origins | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Define FastAPI route for Gradio interface | |
async def read_root(): | |
return HTMLResponse("""<html> | |
<head> | |
<title>E-commerce Support Chatbot</title> | |
</head> | |
<body> | |
<h1>Welcome to the E-commerce Support Chatbot</h1> | |
<p>Use the Gradio interface to chat with the bot!</p> | |
</body> | |
</html>""") | |
# Gradio Chat Interface for E-commerce Support Chatbot | |
def chatbot_interface(user_input, chat_history=[]): | |
# Retrieve response from the knowledge base or generate it | |
response = retrieve_answer(user_input) | |
chat_history.append(("User", user_input)) | |
chat_history.append(("Bot", response)) | |
# Format chat history for display | |
chat_display = [] | |
for sender, message in chat_history: | |
if sender == "User": | |
chat_display.append(f"**You**: {message}") | |
else: | |
chat_display.append(f"**Bot**: {message}") | |
return "\n\n".join(chat_display), chat_history | |
# Set up Gradio Chat Interface with conversational format | |
iface = gr.Interface( | |
fn=chatbot_interface, | |
inputs=[ | |
gr.Textbox(lines=2, placeholder="Type your question here..."), | |
gr.State([]) # State variable to maintain chat history | |
], | |
outputs=[ | |
gr.Markdown(), # Display formatted chat history | |
gr.State() # Update state | |
], | |
title="E-commerce Support Chatbot", | |
description="Ask questions about order tracking, returns, account help, and more!", | |
) | |
# Launch Gradio interface directly | |
iface.launch() | |