Spaces:
Runtime error
Runtime error
File size: 4,926 Bytes
ec99f56 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
# Import libraries
import pandas as pd
from fastapi import FastAPI
from fastapi.responses import HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
from sentence_transformers import SentenceTransformer
import faiss
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import gradio as gr
# Load the Dataset from Hugging Face and FAQ CSV
support_data = load_dataset("rjac/e-commerce-customer-support-qa")
# Load FAQ data from a local CSV file directly
faq_data = pd.read_csv("Ecommerce_FAQs.csv")
# Preprocess and Clean Data
faq_data.rename(columns={'prompt': 'Question', 'response': 'Answer'}, inplace=True)
faq_data = faq_data[['Question', 'Answer']]
support_data_df = pd.DataFrame(support_data['train'])
# Extract question-answer pairs from the conversation field
def extract_conversation(data):
try:
parts = data.split("\n\n")
question = parts[1].split(": ", 1)[1] if len(parts) > 1 else ""
answer = parts[2].split(": ", 1)[1] if len(parts) > 2 else ""
return pd.Series({"Question": question, "Answer": answer})
except IndexError:
return pd.Series({"Question": "", "Answer": ""})
# Apply extraction function
support_data_df[['Question', 'Answer']] = support_data_df['conversation'].apply(extract_conversation)
# Combine FAQ data with support data
combined_data = pd.concat([faq_data, support_data_df[['Question', 'Answer']]], ignore_index=True)
# Initialize SBERT Model
model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
# Generate and Index Embeddings for Combined Data
questions = combined_data['Question'].tolist()
embeddings = model.encode(questions, convert_to_tensor=True)
# Create FAISS index
index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(embeddings.cpu().numpy())
# Load your fine-tuned DialoGPT model and tokenizer
tokenizer_gpt = AutoTokenizer.from_pretrained("Mishal23/fine_tuned_dialoGPT_model") # Update with your fine-tuned model path
model_gpt = AutoModelForCausalLM.from_pretrained("Mishal23/fine_tuned_dialoGPT_model") # Update with your fine-tuned model path
# Define Retrieval Function
def retrieve_answer(question):
question_embedding = model.encode([question], convert_to_tensor=True)
question_embedding_np = question_embedding.cpu().numpy()
_, closest_index = index.search(question_embedding_np, k=1)
best_match_idx = closest_index[0][0]
answer = combined_data.iloc[best_match_idx]['Answer']
# If the answer is empty, generate a fallback response
if answer.strip() == "":
return generate_response(question) # Generate a response from DialoGPT
return answer
# Generate response using your fine-tuned DialoGPT model
def generate_response(user_input):
input_ids = tokenizer_gpt.encode(user_input, return_tensors='pt')
chat_history_ids = model_gpt.generate(input_ids, max_length=100, pad_token_id=tokenizer_gpt.eos_token_id)
response = tokenizer_gpt.decode(chat_history_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
return response if response.strip() else "Oops, I don't know the answer to that."
# Initialize FastAPI app
app = FastAPI()
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Define FastAPI route for Gradio interface
@app.get("/")
async def read_root():
return HTMLResponse("""<html>
<head>
<title>E-commerce Support Chatbot</title>
</head>
<body>
<h1>Welcome to the E-commerce Support Chatbot</h1>
<p>Use the Gradio interface to chat with the bot!</p>
</body>
</html>""")
# Gradio Chat Interface for E-commerce Support Chatbot
def chatbot_interface(user_input, chat_history=[]):
# Retrieve response from the knowledge base or generate it
response = retrieve_answer(user_input)
chat_history.append(("User", user_input))
chat_history.append(("Bot", response))
# Format chat history for display
chat_display = []
for sender, message in chat_history:
if sender == "User":
chat_display.append(f"**You**: {message}")
else:
chat_display.append(f"**Bot**: {message}")
return "\n\n".join(chat_display), chat_history
# Set up Gradio Chat Interface with conversational format
iface = gr.Interface(
fn=chatbot_interface,
inputs=[
gr.Textbox(lines=2, placeholder="Type your question here..."),
gr.State([]) # State variable to maintain chat history
],
outputs=[
gr.Markdown(), # Display formatted chat history
gr.State() # Update state
],
title="E-commerce Support Chatbot",
description="Ask questions about order tracking, returns, account help, and more!",
)
# Launch Gradio interface directly
iface.launch()
|